test_token_auto_renewal.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. """
  2. Pytest unit tests for token auto-renewal functionality
  3. Tests:
  4. 1. Backend token renewal logic
  5. 2. Rate limiting for token renewals
  6. 3. Token renewal state tracking
  7. """
  8. import pytest
  9. from datetime import datetime, timedelta, timezone
  10. from unittest.mock import Mock
  11. from fastapi import Response
  12. import time
  13. # Create a simple token renewal cache for testing
  14. _token_renewal_cache = {}
  15. _RENEWAL_MIN_INTERVAL = 60
  16. @pytest.mark.offline
  17. class TestTokenRenewal:
  18. """Tests for token auto-renewal logic"""
  19. @pytest.fixture
  20. def mock_auth_handler(self):
  21. """Mock authentication handler"""
  22. handler = Mock()
  23. handler.guest_expire_hours = 24
  24. handler.expire_hours = 24
  25. handler.create_token = Mock(return_value="new-token-12345")
  26. return handler
  27. @pytest.fixture
  28. def mock_global_args(self):
  29. """Mock global configuration"""
  30. args = Mock()
  31. args.token_auto_renew = True
  32. args.token_renew_threshold = 0.5
  33. return args
  34. @pytest.fixture
  35. def mock_token_info_guest(self):
  36. """Mock token info for guest user"""
  37. # Token with 10 hours remaining (below 50% of 24 hours)
  38. exp_time = datetime.now(timezone.utc) + timedelta(hours=10)
  39. return {
  40. "username": "guest",
  41. "role": "guest",
  42. "exp": exp_time,
  43. "metadata": {"auth_mode": "disabled"},
  44. }
  45. @pytest.fixture
  46. def mock_token_info_user(self):
  47. """Mock token info for regular user"""
  48. # Token with 10 hours remaining (below 50% of 24 hours)
  49. exp_time = datetime.now(timezone.utc) + timedelta(hours=10)
  50. return {
  51. "username": "testuser",
  52. "role": "user",
  53. "exp": exp_time,
  54. "metadata": {"auth_mode": "enabled"},
  55. }
  56. @pytest.fixture
  57. def mock_token_info_above_threshold(self):
  58. """Mock token info with time above renewal threshold"""
  59. # Token with 20 hours remaining (above 50% of 24 hours)
  60. exp_time = datetime.now(timezone.utc) + timedelta(hours=20)
  61. return {
  62. "username": "testuser",
  63. "role": "user",
  64. "exp": exp_time,
  65. "metadata": {"auth_mode": "enabled"},
  66. }
  67. def test_token_renewal_when_below_threshold(
  68. self, mock_auth_handler, mock_global_args, mock_token_info_user
  69. ):
  70. """Test that token is renewed when remaining time < threshold"""
  71. # Use global cache
  72. global _token_renewal_cache
  73. # Clear cache
  74. _token_renewal_cache.clear()
  75. response = Mock(spec=Response)
  76. response.headers = {}
  77. # Simulate the renewal logic
  78. expire_time = mock_token_info_user["exp"]
  79. now = datetime.now(timezone.utc)
  80. remaining_seconds = (expire_time - now).total_seconds()
  81. role = mock_token_info_user["role"]
  82. total_hours = (
  83. mock_auth_handler.expire_hours
  84. if role == "user"
  85. else mock_auth_handler.guest_expire_hours
  86. )
  87. total_seconds = total_hours * 3600
  88. # Should renew because remaining_seconds < total_seconds * 0.5
  89. should_renew = (
  90. remaining_seconds < total_seconds * mock_global_args.token_renew_threshold
  91. )
  92. assert should_renew is True
  93. # Simulate renewal
  94. username = mock_token_info_user["username"]
  95. current_time = time.time()
  96. last_renewal = _token_renewal_cache.get(username, 0)
  97. time_since_last_renewal = current_time - last_renewal
  98. # Should pass rate limit (first renewal)
  99. assert time_since_last_renewal >= 60 or last_renewal == 0
  100. # Perform renewal
  101. new_token = mock_auth_handler.create_token(
  102. username=username, role=role, metadata=mock_token_info_user["metadata"]
  103. )
  104. response.headers["X-New-Token"] = new_token
  105. _token_renewal_cache[username] = current_time
  106. # Verify
  107. assert "X-New-Token" in response.headers
  108. assert response.headers["X-New-Token"] == "new-token-12345"
  109. assert username in _token_renewal_cache
  110. def test_token_no_renewal_when_above_threshold(
  111. self, mock_auth_handler, mock_global_args, mock_token_info_above_threshold
  112. ):
  113. """Test that token is NOT renewed when remaining time > threshold"""
  114. response = Mock(spec=Response)
  115. response.headers = {}
  116. expire_time = mock_token_info_above_threshold["exp"]
  117. now = datetime.now(timezone.utc)
  118. remaining_seconds = (expire_time - now).total_seconds()
  119. mock_token_info_above_threshold["role"]
  120. total_hours = mock_auth_handler.expire_hours
  121. total_seconds = total_hours * 3600
  122. # Should NOT renew because remaining_seconds > total_seconds * 0.5
  123. should_renew = (
  124. remaining_seconds < total_seconds * mock_global_args.token_renew_threshold
  125. )
  126. assert should_renew is False
  127. # No renewal should happen
  128. assert "X-New-Token" not in response.headers
  129. def test_token_renewal_disabled(
  130. self, mock_auth_handler, mock_global_args, mock_token_info_user
  131. ):
  132. """Test that no renewal happens when TOKEN_AUTO_RENEW=false"""
  133. mock_global_args.token_auto_renew = False
  134. response = Mock(spec=Response)
  135. response.headers = {}
  136. # Auto-renewal is disabled, so even if below threshold, no renewal
  137. if not mock_global_args.token_auto_renew:
  138. # Skip renewal logic
  139. pass
  140. assert "X-New-Token" not in response.headers
  141. def test_token_renewal_for_guest_mode(
  142. self, mock_auth_handler, mock_global_args, mock_token_info_guest
  143. ):
  144. """Test that guest tokens are renewed correctly"""
  145. # Use global cache
  146. global _token_renewal_cache
  147. _token_renewal_cache.clear()
  148. response = Mock(spec=Response)
  149. response.headers = {}
  150. expire_time = mock_token_info_guest["exp"]
  151. now = datetime.now(timezone.utc)
  152. remaining_seconds = (expire_time - now).total_seconds()
  153. role = mock_token_info_guest["role"]
  154. total_hours = mock_auth_handler.guest_expire_hours
  155. total_seconds = total_hours * 3600
  156. should_renew = (
  157. remaining_seconds < total_seconds * mock_global_args.token_renew_threshold
  158. )
  159. assert should_renew is True
  160. # Renewal for guest
  161. username = mock_token_info_guest["username"]
  162. new_token = mock_auth_handler.create_token(
  163. username=username, role=role, metadata=mock_token_info_guest["metadata"]
  164. )
  165. response.headers["X-New-Token"] = new_token
  166. _token_renewal_cache[username] = time.time()
  167. assert "X-New-Token" in response.headers
  168. assert username in _token_renewal_cache
  169. @pytest.mark.offline
  170. class TestRateLimiting:
  171. """Tests for token renewal rate limiting"""
  172. @pytest.fixture
  173. def mock_auth_handler(self):
  174. """Mock authentication handler"""
  175. handler = Mock()
  176. handler.expire_hours = 24
  177. handler.create_token = Mock(return_value="new-token-12345")
  178. return handler
  179. def test_rate_limit_prevents_rapid_renewals(self, mock_auth_handler):
  180. """Test that second renewal within 60s is blocked"""
  181. # Use global cache and constant
  182. global _token_renewal_cache, _RENEWAL_MIN_INTERVAL
  183. username = "testuser"
  184. _token_renewal_cache.clear()
  185. # First renewal
  186. current_time_1 = time.time()
  187. _token_renewal_cache[username] = current_time_1
  188. response_1 = Mock(spec=Response)
  189. response_1.headers = {}
  190. response_1.headers["X-New-Token"] = "new-token-12345"
  191. # Immediate second renewal attempt (within 60s)
  192. current_time_2 = time.time() # Almost same time
  193. last_renewal = _token_renewal_cache.get(username, 0)
  194. time_since_last_renewal = current_time_2 - last_renewal
  195. # Should be blocked by rate limit
  196. assert time_since_last_renewal < _RENEWAL_MIN_INTERVAL
  197. response_2 = Mock(spec=Response)
  198. response_2.headers = {}
  199. # No new token should be issued
  200. if time_since_last_renewal < _RENEWAL_MIN_INTERVAL:
  201. # Rate limited, skip renewal
  202. pass
  203. assert "X-New-Token" not in response_2.headers
  204. def test_rate_limit_allows_renewal_after_interval(self, mock_auth_handler):
  205. """Test that renewal succeeds after 60s interval"""
  206. # Use global cache and constant
  207. global _token_renewal_cache, _RENEWAL_MIN_INTERVAL
  208. username = "testuser"
  209. _token_renewal_cache.clear()
  210. # First renewal at time T
  211. first_renewal_time = time.time() - 61 # 61 seconds ago
  212. _token_renewal_cache[username] = first_renewal_time
  213. # Second renewal attempt now
  214. current_time = time.time()
  215. last_renewal = _token_renewal_cache.get(username, 0)
  216. time_since_last_renewal = current_time - last_renewal
  217. # Should pass rate limit (>60s elapsed)
  218. assert time_since_last_renewal >= _RENEWAL_MIN_INTERVAL
  219. response = Mock(spec=Response)
  220. response.headers = {}
  221. if time_since_last_renewal >= _RENEWAL_MIN_INTERVAL:
  222. new_token = mock_auth_handler.create_token(
  223. username=username, role="user", metadata={}
  224. )
  225. response.headers["X-New-Token"] = new_token
  226. _token_renewal_cache[username] = current_time
  227. assert "X-New-Token" in response.headers
  228. assert response.headers["X-New-Token"] == "new-token-12345"
  229. def test_rate_limit_per_user(self, mock_auth_handler):
  230. """Test that different users have independent rate limits"""
  231. # Use global cache
  232. global _token_renewal_cache
  233. _token_renewal_cache.clear()
  234. user1 = "user1"
  235. user2 = "user2"
  236. current_time = time.time()
  237. # User1 gets renewal
  238. _token_renewal_cache[user1] = current_time
  239. # User2 should still be able to get renewal (independent cache)
  240. last_renewal_user2 = _token_renewal_cache.get(user2, 0)
  241. assert last_renewal_user2 == 0 # No previous renewal
  242. # User2 can renew
  243. _token_renewal_cache[user2] = current_time
  244. # Both users should have entries
  245. assert user1 in _token_renewal_cache
  246. assert user2 in _token_renewal_cache
  247. assert _token_renewal_cache[user1] == _token_renewal_cache[user2]
  248. @pytest.mark.offline
  249. class TestTokenExpirationCalculation:
  250. """Tests for token expiration time calculation"""
  251. def test_expiration_extraction_from_jwt(self):
  252. """Test extracting expiration time from JWT token"""
  253. import base64
  254. import json
  255. # Create a mock JWT payload
  256. exp_timestamp = int(
  257. (datetime.now(timezone.utc) + timedelta(hours=24)).timestamp()
  258. )
  259. payload = {"sub": "testuser", "role": "user", "exp": exp_timestamp}
  260. # Encode as base64 (simulating JWT structure: header.payload.signature)
  261. payload_b64 = base64.b64encode(json.dumps(payload).encode()).decode()
  262. mock_token = f"header.{payload_b64}.signature"
  263. # Simulate extraction
  264. parts = mock_token.split(".")
  265. assert len(parts) == 3
  266. decoded_payload = json.loads(base64.b64decode(parts[1]))
  267. assert decoded_payload["exp"] == exp_timestamp
  268. assert decoded_payload["sub"] == "testuser"
  269. def test_remaining_time_calculation(self):
  270. """Test calculation of remaining token time"""
  271. # Token expires in 10 hours
  272. exp_time = datetime.now(timezone.utc) + timedelta(hours=10)
  273. now = datetime.now(timezone.utc)
  274. remaining_seconds = (exp_time - now).total_seconds()
  275. # Should be approximately 10 hours (36000 seconds)
  276. assert 35990 < remaining_seconds < 36010
  277. # Calculate percentage remaining (for 24-hour token)
  278. total_seconds = 24 * 3600
  279. percentage_remaining = remaining_seconds / total_seconds
  280. # Should be approximately 41.67% remaining
  281. assert 0.41 < percentage_remaining < 0.42
  282. def test_threshold_comparison(self):
  283. """Test threshold-based renewal decision"""
  284. threshold = 0.5
  285. total_hours = 24
  286. total_seconds = total_hours * 3600
  287. # Scenario 1: 10 hours remaining -> should renew
  288. remaining_seconds_1 = 10 * 3600
  289. should_renew_1 = remaining_seconds_1 < total_seconds * threshold
  290. assert should_renew_1 is True
  291. # Scenario 2: 20 hours remaining -> should NOT renew
  292. remaining_seconds_2 = 20 * 3600
  293. should_renew_2 = remaining_seconds_2 < total_seconds * threshold
  294. assert should_renew_2 is False
  295. # Scenario 3: Exactly 12 hours remaining (at threshold) -> should NOT renew
  296. remaining_seconds_3 = 12 * 3600
  297. should_renew_3 = remaining_seconds_3 < total_seconds * threshold
  298. assert should_renew_3 is False
  299. @pytest.mark.offline
  300. def test_renewal_cache_cleanup():
  301. """Test that renewal cache can be cleared"""
  302. # Use global cache
  303. global _token_renewal_cache
  304. # Clear first
  305. _token_renewal_cache.clear()
  306. # Add some entries
  307. _token_renewal_cache["user1"] = time.time()
  308. _token_renewal_cache["user2"] = time.time()
  309. assert len(_token_renewal_cache) == 2
  310. # Clear cache
  311. _token_renewal_cache.clear()
  312. assert len(_token_renewal_cache) == 0
  313. if __name__ == "__main__":
  314. pytest.main([__file__, "-v", "--tb=short"])