auth.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from datetime import datetime, timedelta, timezone
  2. import jwt
  3. from dotenv import load_dotenv
  4. from fastapi import HTTPException, status
  5. from pydantic import BaseModel
  6. from ..utils import logger
  7. from .config import DEFAULT_TOKEN_SECRET, global_args
  8. from .passwords import verify_password
  9. # use the .env that is inside the current folder
  10. # allows to use different .env file for each lightrag instance
  11. # the OS environment variables take precedence over the .env file
  12. load_dotenv(dotenv_path=".env", override=False)
  13. class TokenPayload(BaseModel):
  14. sub: str # Username
  15. exp: datetime # Expiration time
  16. role: str = "user" # User role, default is regular user
  17. metadata: dict = {} # Additional metadata
  18. class AuthHandler:
  19. def __init__(self):
  20. auth_accounts = global_args.auth_accounts
  21. self.secret = global_args.token_secret
  22. if not self.secret:
  23. if auth_accounts:
  24. raise ValueError(
  25. "TOKEN_SECRET must be explicitly set to a non-default value when AUTH_ACCOUNTS is configured."
  26. )
  27. self.secret = DEFAULT_TOKEN_SECRET
  28. logger.warning(
  29. "TOKEN_SECRET not set and AUTH_ACCOUNTS is not configured. "
  30. "Falling back to the default guest-mode JWT secret. "
  31. )
  32. algorithm = global_args.jwt_algorithm
  33. if not algorithm or algorithm.lower() == "none":
  34. raise ValueError(
  35. "JWT_ALGORITHM must be set to a secure algorithm (e.g. HS256). "
  36. "The 'none' algorithm is not permitted."
  37. )
  38. self.algorithm = algorithm
  39. self.expire_hours = global_args.token_expire_hours
  40. self.guest_expire_hours = global_args.guest_token_expire_hours
  41. self.accounts = {}
  42. invalid_accounts = []
  43. if auth_accounts:
  44. for account in auth_accounts.split(","):
  45. try:
  46. username, password = account.split(":", 1)
  47. if not username or not password:
  48. raise ValueError
  49. self.accounts[username] = password
  50. except ValueError:
  51. invalid_accounts.append(account)
  52. if invalid_accounts:
  53. invalid_entries = ", ".join(invalid_accounts)
  54. logger.error(f"Invalid account format in AUTH_ACCOUNTS: {invalid_entries}")
  55. raise ValueError(
  56. "AUTH_ACCOUNTS must use comma-separated user:password pairs."
  57. )
  58. def verify_password(self, username: str, plain_password: str) -> bool:
  59. """
  60. Verify password for a user. Supports explicit bcrypt values and plaintext.
  61. Args:
  62. username: Username to verify
  63. plain_password: Plaintext password to check
  64. Returns:
  65. bool: True if password is correct, False otherwise
  66. """
  67. if username not in self.accounts:
  68. return False
  69. stored_password = self.accounts[username]
  70. return verify_password(plain_password, stored_password)
  71. def create_token(
  72. self,
  73. username: str,
  74. role: str = "user",
  75. custom_expire_hours: int = None,
  76. metadata: dict = None,
  77. ) -> str:
  78. """
  79. Create JWT token
  80. Args:
  81. username: Username
  82. role: User role, default is "user", guest is "guest"
  83. custom_expire_hours: Custom expiration time (hours), if None use default value
  84. metadata: Additional metadata
  85. Returns:
  86. str: Encoded JWT token
  87. """
  88. # Choose default expiration time based on role
  89. if custom_expire_hours is None:
  90. if role == "guest":
  91. expire_hours = self.guest_expire_hours
  92. else:
  93. expire_hours = self.expire_hours
  94. else:
  95. expire_hours = custom_expire_hours
  96. expire = datetime.now(timezone.utc) + timedelta(hours=expire_hours)
  97. # Create payload
  98. payload = TokenPayload(
  99. sub=username, exp=expire, role=role, metadata=metadata or {}
  100. )
  101. return jwt.encode(payload.model_dump(), self.secret, algorithm=self.algorithm)
  102. def validate_token(self, token: str) -> dict:
  103. """
  104. Validate JWT token
  105. Args:
  106. token: JWT token
  107. Returns:
  108. dict: Dictionary containing user information
  109. Raises:
  110. HTTPException: If token is invalid or expired
  111. """
  112. try:
  113. # Explicitly exclude 'none' to prevent algorithm confusion attacks
  114. allowed_algorithms = [self.algorithm]
  115. if "none" in (a.lower() for a in allowed_algorithms):
  116. raise HTTPException(
  117. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  118. detail="Insecure JWT algorithm configuration",
  119. )
  120. payload = jwt.decode(token, self.secret, algorithms=allowed_algorithms)
  121. expire_timestamp = payload["exp"]
  122. expire_time = datetime.fromtimestamp(expire_timestamp, timezone.utc)
  123. if datetime.now(timezone.utc) > expire_time:
  124. raise HTTPException(
  125. status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
  126. )
  127. # Return complete payload instead of just username
  128. return {
  129. "username": payload["sub"],
  130. "role": payload.get("role", "user"),
  131. "metadata": payload.get("metadata", {}),
  132. "exp": expire_time,
  133. }
  134. except jwt.PyJWTError:
  135. raise HTTPException(
  136. status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
  137. )
  138. auth_handler = AuthHandler()