rate_limiter.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import time
  2. from collections import defaultdict
  3. from functools import wraps
  4. from flask import request, jsonify
  5. from app import Config
  6. RATE_LIMIT_DEFAULT_LIMIT = Config.RATE_LIMIT_DEFAULT_LIMIT
  7. RATE_LIMIT_DEFAULT_PERIOD = Config.RATE_LIMIT_DEFAULT_PERIOD
  8. class RateLimiter:
  9. """
  10. API 限流器类,用于限制 API 请求频率。
  11. 支持基于 IP、用户 ID 或 API 端点的限流策略。
  12. Attributes:
  13. storage (dict): 存储请求记录的字典
  14. default_limit (int): 默认的请求限制次数
  15. default_period (int): 默认的时间窗口(s)
  16. """
  17. def __init__(self):
  18. # 使用内存存储请求记录
  19. # 格式: {key: [(timestamp1, count1), (timestamp2, count2), ...]}
  20. self.storage = defaultdict(list)
  21. self.default_limit = RATE_LIMIT_DEFAULT_LIMIT # 默认每分钟 60 次请求
  22. self.default_period = RATE_LIMIT_DEFAULT_PERIOD # 默认时间窗口为 60 s(1 min)
  23. def _generate_key(self, key_func):
  24. """
  25. 生成限流的键
  26. Args:
  27. key_func: 生成键的函数或字符串
  28. Returns:
  29. str: 限流键
  30. """
  31. if callable(key_func):
  32. return key_func()
  33. elif key_func == 'ip':
  34. return request.remote_addr
  35. elif key_func == 'endpoint':
  36. return request.endpoint
  37. else:
  38. return f"{request.remote_addr}:{request.endpoint}"
  39. def _clean_old_requests(self, key, period):
  40. """
  41. 清理过期的请求记录
  42. Args:
  43. key (str): 限流键
  44. period (int): 时间窗口(s)
  45. """
  46. current_time = time.time()
  47. self.storage[key] = [(ts, count) for ts, count in self.storage[key]
  48. if current_time - ts < period]
  49. def is_allowed(self, key_func='ip', limit=None, period=None):
  50. """
  51. 检查请求是否被允许
  52. Args:
  53. key_func: 生成键的函数或字符串
  54. limit (int): 时间窗口内允许的最大请求次数
  55. period (int): 时间窗口(s)
  56. Returns:
  57. tuple: (是否允许, 剩余可用请求数, 重置时间)
  58. """
  59. limit = limit or self.default_limit
  60. period = period or self.default_period
  61. key = self._generate_key(key_func)
  62. self._clean_old_requests(key, period)
  63. current_time = time.time()
  64. request_count = sum(count for _, count in self.storage[key])
  65. # 如果没有记录或者请求数量未达到限制
  66. if not self.storage[key] or request_count < limit:
  67. self.storage[key].append((current_time, 1))
  68. return True, limit - request_count - 1, period
  69. # 计算重置时间
  70. oldest_timestamp = min(ts for ts, _ in self.storage[key])
  71. reset_time = oldest_timestamp + period - current_time
  72. return False, 0, max(0, reset_time)
  73. # 创建全局限流器实例
  74. rate_limiter = RateLimiter()
  75. def rate_limit(key_func='ip', limit=None, period=None):
  76. """
  77. API限流装饰器
  78. Args:
  79. key_func: 生成限流键的函数或预定义字符串('ip'或'endpoint')
  80. limit (int): 时间窗口内允许的最大请求次数
  81. period (int): 时间窗口(s)
  82. Returns:
  83. function: 装饰器函数
  84. """
  85. def decorator(f):
  86. @wraps(f)
  87. def decorated_function(*args, **kwargs):
  88. allowed, remaining, reset_time = rate_limiter.is_allowed(key_func, limit, period)
  89. # 设置响应头部,包含限流信息
  90. response_headers = {
  91. 'X-RateLimit-Limit': str(limit or rate_limiter.default_limit),
  92. 'X-RateLimit-Remaining': str(remaining),
  93. 'X-RateLimit-Reset': str(int(reset_time))
  94. }
  95. if not allowed:
  96. error_response = jsonify({
  97. 'error': 'Too Many Requests',
  98. 'failure_message': '请求频率超过限制,请稍后再试',
  99. 'retry_after': int(reset_time)
  100. })
  101. # 添加响应头部
  102. for header, value in response_headers.items():
  103. error_response.headers[header] = value
  104. error_response.status_code = 429
  105. return error_response
  106. # 执行原始函数
  107. response = f(*args, **kwargs)
  108. # 如果响应是元组 (response, status_code),则只修改 response 部分
  109. if isinstance(response, tuple) and len(response) >= 1:
  110. response_obj = response[0]
  111. if hasattr(response_obj, 'headers'):
  112. for header, value in response_headers.items():
  113. response_obj.headers[header] = value
  114. return response
  115. # 如果响应是直接的响应对象
  116. if hasattr(response, 'headers'):
  117. for header, value in response_headers.items():
  118. response.headers[header] = value
  119. return response
  120. return decorated_function
  121. return decorator
  122. def user_rate_limit(limit=None, period=None):
  123. """
  124. 基于用户ID的限流装饰器,需要在JWT认证之后使用
  125. Args:
  126. limit (int): 时间窗口内允许的最大请求次数
  127. period (int): 时间窗口(秒)
  128. Returns:
  129. function: 装饰器函数
  130. """
  131. from flask_jwt_extended import get_jwt_identity
  132. def get_user_key():
  133. try:
  134. user_id = get_jwt_identity()
  135. return f"user:{user_id}"
  136. except Exception:
  137. # 如果无法获取用户ID,则回退到IP限流
  138. return request.remote_addr
  139. # 将函数对象传递给rate_limit,而不是立即调用它
  140. return rate_limit(get_user_key, limit, period)
  141. def configure_rate_limiting(app):
  142. """
  143. 配置应用的全局限流设置
  144. Args:
  145. app: Flask 应用实例
  146. """
  147. # 从配置中读取限流设置
  148. rate_limiter.default_limit = app.config.get('RATE_LIMIT_DEFAULT_LIMIT', 60)
  149. rate_limiter.default_period = app.config.get('RATE_LIMIT_DEFAULT_PERIOD', 60)
  150. app.logger.info(f"已配置 API 限流: {rate_limiter.default_limit}次/{rate_limiter.default_period}秒")