前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >springboot 实现限流控制

springboot 实现限流控制

作者头像
用户9131103
发布2023-09-16 08:10:48
2600
发布2023-09-16 08:10:48
举报
文章被收录于专栏:工作经验工作经验
代码语言:javascript
复制
package com.jinw.cms.config;

import com.jinw.cms.aspectj.RateLimiterAspect;
import lombok.RequiredArgsConstructor;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;

@RequiredArgsConstructor
@Configuration
public class RateLimiterConfig {

    private final RedisTemplate<String, Object> redisTempate;

    @Bean
    @ConditionalOnProperty(name = "jw.rate-limiter.enable", havingValue = "true")
    public RateLimiterAspect rateLimitAspect() {
        return new RateLimiterAspect(redisTempate, limitScript());
    }

    /**
     * Lua限流脚本
     */
    public DefaultRedisScript<Boolean> limitScript() {
        DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
        redisScript.setScriptText(" local key = KEYS[1] --限流KEY\n" +
                "                local limit = tonumber(ARGV[1]) --限流大小\n" +
                "                local expireTime = tonumber(ARGV[2]) --过期时间 单位/s\n" +
                "\n" +
                "                local current = tonumber(redis.call('get', key) or \"0\")\n" +
                "                if current + 1 > limit then\n" +
                "                    return false --当前值超过限流大小阈值\n" +
                "                end\n" +
                "                current = tonumber(redis.call('incr', key)) --请求数+1\n" +
                "                if current == 1 then\n" +
                "                    redis.call('expire', key, expireTime) --设置过期时间\n" +
                "                end\n" +
                "                return true;");
        redisScript.setResultType(Boolean.class);
        return redisScript;
    }
}

代码语言:javascript
复制
package com.jinw.cms.aspectj;

/**
 * 限流类型
 *
 * @author ruoyi
 */

public enum LimitType {
    /**
     * 默认策略全局限流
     */
    DEFAULT,

    /**
     * 根据请求者IP进行限流
     */
    IP
}
代码语言:javascript
复制
package com.jinw.cms.aspectj.annotation;

import com.jinw.cms.aspectj.LimitType;
import com.jinw.cms.constants.ExtendConstants;

import java.lang.annotation.*;

/**
 * 限流注解
 *
 * @author ruoyi
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {

    /**
     * 限流缓存key前缀
     */
    public String prefix() default ExtendConstants.RATE_LIMIT_KEY;

    /**
     * 限流时间,单位秒
     */
    public int expire() default 60;

    /**
     * 限流阈值,单位时间内的请求上限
     */
    public int limit() default 100;

    /**
     * 限流类型
     */
    public LimitType limitType() default LimitType.DEFAULT;
}
代码语言:javascript
复制
package com.jinw.cms.config;

import com.jinw.cms.aspectj.RateLimiterAspect;
import lombok.RequiredArgsConstructor;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;

@RequiredArgsConstructor
@Configuration
public class RateLimiterConfig {

    private final RedisTemplate<String, Object> redisTempate;

    @Bean
    @ConditionalOnProperty(name = "jw.rate-limiter.enable", havingValue = "true")
    public RateLimiterAspect rateLimitAspect() {
        return new RateLimiterAspect(redisTempate, limitScript());
    }

    /**
     * Lua限流脚本
     */
    public DefaultRedisScript<Boolean> limitScript() {
        DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
        redisScript.setScriptText(" local key = KEYS[1] --限流KEY\n" +
                "                local limit = tonumber(ARGV[1]) --限流大小\n" +
                "                local expireTime = tonumber(ARGV[2]) --过期时间 单位/s\n" +
                "\n" +
                "                local current = tonumber(redis.call('get', key) or \"0\")\n" +
                "                if current + 1 > limit then\n" +
                "                    return false --当前值超过限流大小阈值\n" +
                "                end\n" +
                "                current = tonumber(redis.call('incr', key)) --请求数+1\n" +
                "                if current == 1 then\n" +
                "                    redis.call('expire', key, expireTime) --设置过期时间\n" +
                "                end\n" +
                "                return true;");
        redisScript.setResultType(Boolean.class);
        return redisScript;
    }
}
代码语言:javascript
复制
package com.ruoyi.common.extend.aspectj;

import java.lang.reflect.Method;
import java.util.List;

import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;

import com.ruoyi.common.exception.GlobalException;
import com.ruoyi.common.extend.annotation.RateLimiter;
import com.ruoyi.common.extend.enums.LimitType;
import com.ruoyi.common.extend.exception.RateLimiterErrorCode;
import com.ruoyi.common.utils.ServletUtils;

import lombok.RequiredArgsConstructor;

/**
 * 限流处理
 */
@Aspect
@RequiredArgsConstructor
public class RateLimiterAspect {
    
    private static final Logger log = LoggerFactory.getLogger(RateLimiterAspect.class);

    private final RedisTemplate<String, Object> redisTemplate;

    private final RedisScript<Boolean> limitScript;

    @Before("@annotation(rateLimiter)")
    public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable {
        int limit = rateLimiter.limit();
        int expire = rateLimiter.expire();

        try {
            String combineKey = this.getCombineKey(rateLimiter, point);
            List<String> keys = List.of(combineKey);
            if (!redisTemplate.execute(this.limitScript, keys, limit, expire)) {
                log.warn("限制请求'{}',缓存key'{}'", limit, combineKey);
                throw RateLimiterErrorCode.RATE_LIMIT.exception();
            }
        } catch (GlobalException e) {
            throw e;
        } catch (Exception e) {
            throw RateLimiterErrorCode.RATE_LIMIT_ERR.exception();
        }
    }

    public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
        StringBuffer stringBuffer = new StringBuffer(rateLimiter.prefix());
        if (rateLimiter.limitType() == LimitType.IP) {
            stringBuffer.append(ServletUtils.getIpAddr(ServletUtils.getRequest())).append(".");
        }
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        Class<?> targetClass = method.getDeclaringClass();
        stringBuffer.append(targetClass.getName()).append(".").append(method.getName());
        return stringBuffer.toString();
    }
}
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档