| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403 |
- """
- Pytest unit tests for token auto-renewal functionality
- Tests:
- 1. Backend token renewal logic
- 2. Rate limiting for token renewals
- 3. Token renewal state tracking
- """
- import pytest
- from datetime import datetime, timedelta, timezone
- from unittest.mock import Mock
- from fastapi import Response
- import time
- # Create a simple token renewal cache for testing
- _token_renewal_cache = {}
- _RENEWAL_MIN_INTERVAL = 60
- @pytest.mark.offline
- class TestTokenRenewal:
- """Tests for token auto-renewal logic"""
- @pytest.fixture
- def mock_auth_handler(self):
- """Mock authentication handler"""
- handler = Mock()
- handler.guest_expire_hours = 24
- handler.expire_hours = 24
- handler.create_token = Mock(return_value="new-token-12345")
- return handler
- @pytest.fixture
- def mock_global_args(self):
- """Mock global configuration"""
- args = Mock()
- args.token_auto_renew = True
- args.token_renew_threshold = 0.5
- return args
- @pytest.fixture
- def mock_token_info_guest(self):
- """Mock token info for guest user"""
- # Token with 10 hours remaining (below 50% of 24 hours)
- exp_time = datetime.now(timezone.utc) + timedelta(hours=10)
- return {
- "username": "guest",
- "role": "guest",
- "exp": exp_time,
- "metadata": {"auth_mode": "disabled"},
- }
- @pytest.fixture
- def mock_token_info_user(self):
- """Mock token info for regular user"""
- # Token with 10 hours remaining (below 50% of 24 hours)
- exp_time = datetime.now(timezone.utc) + timedelta(hours=10)
- return {
- "username": "testuser",
- "role": "user",
- "exp": exp_time,
- "metadata": {"auth_mode": "enabled"},
- }
- @pytest.fixture
- def mock_token_info_above_threshold(self):
- """Mock token info with time above renewal threshold"""
- # Token with 20 hours remaining (above 50% of 24 hours)
- exp_time = datetime.now(timezone.utc) + timedelta(hours=20)
- return {
- "username": "testuser",
- "role": "user",
- "exp": exp_time,
- "metadata": {"auth_mode": "enabled"},
- }
- def test_token_renewal_when_below_threshold(
- self, mock_auth_handler, mock_global_args, mock_token_info_user
- ):
- """Test that token is renewed when remaining time < threshold"""
- # Use global cache
- global _token_renewal_cache
- # Clear cache
- _token_renewal_cache.clear()
- response = Mock(spec=Response)
- response.headers = {}
- # Simulate the renewal logic
- expire_time = mock_token_info_user["exp"]
- now = datetime.now(timezone.utc)
- remaining_seconds = (expire_time - now).total_seconds()
- role = mock_token_info_user["role"]
- total_hours = (
- mock_auth_handler.expire_hours
- if role == "user"
- else mock_auth_handler.guest_expire_hours
- )
- total_seconds = total_hours * 3600
- # Should renew because remaining_seconds < total_seconds * 0.5
- should_renew = (
- remaining_seconds < total_seconds * mock_global_args.token_renew_threshold
- )
- assert should_renew is True
- # Simulate renewal
- username = mock_token_info_user["username"]
- current_time = time.time()
- last_renewal = _token_renewal_cache.get(username, 0)
- time_since_last_renewal = current_time - last_renewal
- # Should pass rate limit (first renewal)
- assert time_since_last_renewal >= 60 or last_renewal == 0
- # Perform renewal
- new_token = mock_auth_handler.create_token(
- username=username, role=role, metadata=mock_token_info_user["metadata"]
- )
- response.headers["X-New-Token"] = new_token
- _token_renewal_cache[username] = current_time
- # Verify
- assert "X-New-Token" in response.headers
- assert response.headers["X-New-Token"] == "new-token-12345"
- assert username in _token_renewal_cache
- def test_token_no_renewal_when_above_threshold(
- self, mock_auth_handler, mock_global_args, mock_token_info_above_threshold
- ):
- """Test that token is NOT renewed when remaining time > threshold"""
- response = Mock(spec=Response)
- response.headers = {}
- expire_time = mock_token_info_above_threshold["exp"]
- now = datetime.now(timezone.utc)
- remaining_seconds = (expire_time - now).total_seconds()
- mock_token_info_above_threshold["role"]
- total_hours = mock_auth_handler.expire_hours
- total_seconds = total_hours * 3600
- # Should NOT renew because remaining_seconds > total_seconds * 0.5
- should_renew = (
- remaining_seconds < total_seconds * mock_global_args.token_renew_threshold
- )
- assert should_renew is False
- # No renewal should happen
- assert "X-New-Token" not in response.headers
- def test_token_renewal_disabled(
- self, mock_auth_handler, mock_global_args, mock_token_info_user
- ):
- """Test that no renewal happens when TOKEN_AUTO_RENEW=false"""
- mock_global_args.token_auto_renew = False
- response = Mock(spec=Response)
- response.headers = {}
- # Auto-renewal is disabled, so even if below threshold, no renewal
- if not mock_global_args.token_auto_renew:
- # Skip renewal logic
- pass
- assert "X-New-Token" not in response.headers
- def test_token_renewal_for_guest_mode(
- self, mock_auth_handler, mock_global_args, mock_token_info_guest
- ):
- """Test that guest tokens are renewed correctly"""
- # Use global cache
- global _token_renewal_cache
- _token_renewal_cache.clear()
- response = Mock(spec=Response)
- response.headers = {}
- expire_time = mock_token_info_guest["exp"]
- now = datetime.now(timezone.utc)
- remaining_seconds = (expire_time - now).total_seconds()
- role = mock_token_info_guest["role"]
- total_hours = mock_auth_handler.guest_expire_hours
- total_seconds = total_hours * 3600
- should_renew = (
- remaining_seconds < total_seconds * mock_global_args.token_renew_threshold
- )
- assert should_renew is True
- # Renewal for guest
- username = mock_token_info_guest["username"]
- new_token = mock_auth_handler.create_token(
- username=username, role=role, metadata=mock_token_info_guest["metadata"]
- )
- response.headers["X-New-Token"] = new_token
- _token_renewal_cache[username] = time.time()
- assert "X-New-Token" in response.headers
- assert username in _token_renewal_cache
- @pytest.mark.offline
- class TestRateLimiting:
- """Tests for token renewal rate limiting"""
- @pytest.fixture
- def mock_auth_handler(self):
- """Mock authentication handler"""
- handler = Mock()
- handler.expire_hours = 24
- handler.create_token = Mock(return_value="new-token-12345")
- return handler
- def test_rate_limit_prevents_rapid_renewals(self, mock_auth_handler):
- """Test that second renewal within 60s is blocked"""
- # Use global cache and constant
- global _token_renewal_cache, _RENEWAL_MIN_INTERVAL
- username = "testuser"
- _token_renewal_cache.clear()
- # First renewal
- current_time_1 = time.time()
- _token_renewal_cache[username] = current_time_1
- response_1 = Mock(spec=Response)
- response_1.headers = {}
- response_1.headers["X-New-Token"] = "new-token-12345"
- # Immediate second renewal attempt (within 60s)
- current_time_2 = time.time() # Almost same time
- last_renewal = _token_renewal_cache.get(username, 0)
- time_since_last_renewal = current_time_2 - last_renewal
- # Should be blocked by rate limit
- assert time_since_last_renewal < _RENEWAL_MIN_INTERVAL
- response_2 = Mock(spec=Response)
- response_2.headers = {}
- # No new token should be issued
- if time_since_last_renewal < _RENEWAL_MIN_INTERVAL:
- # Rate limited, skip renewal
- pass
- assert "X-New-Token" not in response_2.headers
- def test_rate_limit_allows_renewal_after_interval(self, mock_auth_handler):
- """Test that renewal succeeds after 60s interval"""
- # Use global cache and constant
- global _token_renewal_cache, _RENEWAL_MIN_INTERVAL
- username = "testuser"
- _token_renewal_cache.clear()
- # First renewal at time T
- first_renewal_time = time.time() - 61 # 61 seconds ago
- _token_renewal_cache[username] = first_renewal_time
- # Second renewal attempt now
- current_time = time.time()
- last_renewal = _token_renewal_cache.get(username, 0)
- time_since_last_renewal = current_time - last_renewal
- # Should pass rate limit (>60s elapsed)
- assert time_since_last_renewal >= _RENEWAL_MIN_INTERVAL
- response = Mock(spec=Response)
- response.headers = {}
- if time_since_last_renewal >= _RENEWAL_MIN_INTERVAL:
- new_token = mock_auth_handler.create_token(
- username=username, role="user", metadata={}
- )
- response.headers["X-New-Token"] = new_token
- _token_renewal_cache[username] = current_time
- assert "X-New-Token" in response.headers
- assert response.headers["X-New-Token"] == "new-token-12345"
- def test_rate_limit_per_user(self, mock_auth_handler):
- """Test that different users have independent rate limits"""
- # Use global cache
- global _token_renewal_cache
- _token_renewal_cache.clear()
- user1 = "user1"
- user2 = "user2"
- current_time = time.time()
- # User1 gets renewal
- _token_renewal_cache[user1] = current_time
- # User2 should still be able to get renewal (independent cache)
- last_renewal_user2 = _token_renewal_cache.get(user2, 0)
- assert last_renewal_user2 == 0 # No previous renewal
- # User2 can renew
- _token_renewal_cache[user2] = current_time
- # Both users should have entries
- assert user1 in _token_renewal_cache
- assert user2 in _token_renewal_cache
- assert _token_renewal_cache[user1] == _token_renewal_cache[user2]
- @pytest.mark.offline
- class TestTokenExpirationCalculation:
- """Tests for token expiration time calculation"""
- def test_expiration_extraction_from_jwt(self):
- """Test extracting expiration time from JWT token"""
- import base64
- import json
- # Create a mock JWT payload
- exp_timestamp = int(
- (datetime.now(timezone.utc) + timedelta(hours=24)).timestamp()
- )
- payload = {"sub": "testuser", "role": "user", "exp": exp_timestamp}
- # Encode as base64 (simulating JWT structure: header.payload.signature)
- payload_b64 = base64.b64encode(json.dumps(payload).encode()).decode()
- mock_token = f"header.{payload_b64}.signature"
- # Simulate extraction
- parts = mock_token.split(".")
- assert len(parts) == 3
- decoded_payload = json.loads(base64.b64decode(parts[1]))
- assert decoded_payload["exp"] == exp_timestamp
- assert decoded_payload["sub"] == "testuser"
- def test_remaining_time_calculation(self):
- """Test calculation of remaining token time"""
- # Token expires in 10 hours
- exp_time = datetime.now(timezone.utc) + timedelta(hours=10)
- now = datetime.now(timezone.utc)
- remaining_seconds = (exp_time - now).total_seconds()
- # Should be approximately 10 hours (36000 seconds)
- assert 35990 < remaining_seconds < 36010
- # Calculate percentage remaining (for 24-hour token)
- total_seconds = 24 * 3600
- percentage_remaining = remaining_seconds / total_seconds
- # Should be approximately 41.67% remaining
- assert 0.41 < percentage_remaining < 0.42
- def test_threshold_comparison(self):
- """Test threshold-based renewal decision"""
- threshold = 0.5
- total_hours = 24
- total_seconds = total_hours * 3600
- # Scenario 1: 10 hours remaining -> should renew
- remaining_seconds_1 = 10 * 3600
- should_renew_1 = remaining_seconds_1 < total_seconds * threshold
- assert should_renew_1 is True
- # Scenario 2: 20 hours remaining -> should NOT renew
- remaining_seconds_2 = 20 * 3600
- should_renew_2 = remaining_seconds_2 < total_seconds * threshold
- assert should_renew_2 is False
- # Scenario 3: Exactly 12 hours remaining (at threshold) -> should NOT renew
- remaining_seconds_3 = 12 * 3600
- should_renew_3 = remaining_seconds_3 < total_seconds * threshold
- assert should_renew_3 is False
- @pytest.mark.offline
- def test_renewal_cache_cleanup():
- """Test that renewal cache can be cleared"""
- # Use global cache
- global _token_renewal_cache
- # Clear first
- _token_renewal_cache.clear()
- # Add some entries
- _token_renewal_cache["user1"] = time.time()
- _token_renewal_cache["user2"] = time.time()
- assert len(_token_renewal_cache) == 2
- # Clear cache
- _token_renewal_cache.clear()
- assert len(_token_renewal_cache) == 0
- if __name__ == "__main__":
- pytest.main([__file__, "-v", "--tb=short"])
|