| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- import time
- from collections import defaultdict
- from functools import wraps
- from flask import request, jsonify
- from app import Config
- RATE_LIMIT_DEFAULT_LIMIT = Config.RATE_LIMIT_DEFAULT_LIMIT
- RATE_LIMIT_DEFAULT_PERIOD = Config.RATE_LIMIT_DEFAULT_PERIOD
- class RateLimiter:
- """
- API 限流器类,用于限制 API 请求频率。
- 支持基于 IP、用户 ID 或 API 端点的限流策略。
-
- Attributes:
- storage (dict): 存储请求记录的字典
- default_limit (int): 默认的请求限制次数
- default_period (int): 默认的时间窗口(s)
- """
- def __init__(self):
- # 使用内存存储请求记录
- # 格式: {key: [(timestamp1, count1), (timestamp2, count2), ...]}
- self.storage = defaultdict(list)
- self.default_limit = RATE_LIMIT_DEFAULT_LIMIT # 默认每分钟 60 次请求
- self.default_period = RATE_LIMIT_DEFAULT_PERIOD # 默认时间窗口为 60 s(1 min)
- def _generate_key(self, key_func):
- """
- 生成限流的键
-
- Args:
- key_func: 生成键的函数或字符串
-
- Returns:
- str: 限流键
- """
- if callable(key_func):
- return key_func()
- elif key_func == 'ip':
- return request.remote_addr
- elif key_func == 'endpoint':
- return request.endpoint
- else:
- return f"{request.remote_addr}:{request.endpoint}"
- def _clean_old_requests(self, key, period):
- """
- 清理过期的请求记录
-
- Args:
- key (str): 限流键
- period (int): 时间窗口(s)
- """
- current_time = time.time()
- self.storage[key] = [(ts, count) for ts, count in self.storage[key]
- if current_time - ts < period]
- def is_allowed(self, key_func='ip', limit=None, period=None):
- """
- 检查请求是否被允许
-
- Args:
- key_func: 生成键的函数或字符串
- limit (int): 时间窗口内允许的最大请求次数
- period (int): 时间窗口(s)
-
- Returns:
- tuple: (是否允许, 剩余可用请求数, 重置时间)
- """
- limit = limit or self.default_limit
- period = period or self.default_period
- key = self._generate_key(key_func)
- self._clean_old_requests(key, period)
- current_time = time.time()
- request_count = sum(count for _, count in self.storage[key])
- # 如果没有记录或者请求数量未达到限制
- if not self.storage[key] or request_count < limit:
- self.storage[key].append((current_time, 1))
- return True, limit - request_count - 1, period
- # 计算重置时间
- oldest_timestamp = min(ts for ts, _ in self.storage[key])
- reset_time = oldest_timestamp + period - current_time
- return False, 0, max(0, reset_time)
- # 创建全局限流器实例
- rate_limiter = RateLimiter()
- def rate_limit(key_func='ip', limit=None, period=None):
- """
- API限流装饰器
-
- Args:
- key_func: 生成限流键的函数或预定义字符串('ip'或'endpoint')
- limit (int): 时间窗口内允许的最大请求次数
- period (int): 时间窗口(s)
-
- Returns:
- function: 装饰器函数
- """
- def decorator(f):
- @wraps(f)
- def decorated_function(*args, **kwargs):
- allowed, remaining, reset_time = rate_limiter.is_allowed(key_func, limit, period)
- # 设置响应头部,包含限流信息
- response_headers = {
- 'X-RateLimit-Limit': str(limit or rate_limiter.default_limit),
- 'X-RateLimit-Remaining': str(remaining),
- 'X-RateLimit-Reset': str(int(reset_time))
- }
- if not allowed:
- error_response = jsonify({
- 'error': 'Too Many Requests',
- 'failure_message': '请求频率超过限制,请稍后再试',
- 'retry_after': int(reset_time)
- })
- # 添加响应头部
- for header, value in response_headers.items():
- error_response.headers[header] = value
- error_response.status_code = 429
- return error_response
- # 执行原始函数
- response = f(*args, **kwargs)
- # 如果响应是元组 (response, status_code),则只修改 response 部分
- if isinstance(response, tuple) and len(response) >= 1:
- response_obj = response[0]
- if hasattr(response_obj, 'headers'):
- for header, value in response_headers.items():
- response_obj.headers[header] = value
- return response
- # 如果响应是直接的响应对象
- if hasattr(response, 'headers'):
- for header, value in response_headers.items():
- response.headers[header] = value
- return response
- return decorated_function
- return decorator
- def user_rate_limit(limit=None, period=None):
- """
- 基于用户ID的限流装饰器,需要在JWT认证之后使用
-
- Args:
- limit (int): 时间窗口内允许的最大请求次数
- period (int): 时间窗口(秒)
-
- Returns:
- function: 装饰器函数
- """
- from flask_jwt_extended import get_jwt_identity
- def get_user_key():
- try:
- user_id = get_jwt_identity()
- return f"user:{user_id}"
- except Exception:
- # 如果无法获取用户ID,则回退到IP限流
- return request.remote_addr
- # 将函数对象传递给rate_limit,而不是立即调用它
- return rate_limit(get_user_key, limit, period)
- def configure_rate_limiting(app):
- """
- 配置应用的全局限流设置
-
- Args:
- app: Flask 应用实例
- """
- # 从配置中读取限流设置
- rate_limiter.default_limit = app.config.get('RATE_LIMIT_DEFAULT_LIMIT', 60)
- rate_limiter.default_period = app.config.get('RATE_LIMIT_DEFAULT_PERIOD', 60)
- app.logger.info(f"已配置 API 限流: {rate_limiter.default_limit}次/{rate_limiter.default_period}秒")
|