| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- from datetime import datetime, timedelta, timezone
- import jwt
- from dotenv import load_dotenv
- from fastapi import HTTPException, status
- from pydantic import BaseModel
- from ..utils import logger
- from .config import DEFAULT_TOKEN_SECRET, global_args
- from .passwords import verify_password
- # use the .env that is inside the current folder
- # allows to use different .env file for each lightrag instance
- # the OS environment variables take precedence over the .env file
- load_dotenv(dotenv_path=".env", override=False)
- class TokenPayload(BaseModel):
- sub: str # Username
- exp: datetime # Expiration time
- role: str = "user" # User role, default is regular user
- metadata: dict = {} # Additional metadata
- class AuthHandler:
- def __init__(self):
- auth_accounts = global_args.auth_accounts
- self.secret = global_args.token_secret
- if not self.secret:
- if auth_accounts:
- raise ValueError(
- "TOKEN_SECRET must be explicitly set to a non-default value when AUTH_ACCOUNTS is configured."
- )
- self.secret = DEFAULT_TOKEN_SECRET
- logger.warning(
- "TOKEN_SECRET not set and AUTH_ACCOUNTS is not configured. "
- "Falling back to the default guest-mode JWT secret. "
- )
- algorithm = global_args.jwt_algorithm
- if not algorithm or algorithm.lower() == "none":
- raise ValueError(
- "JWT_ALGORITHM must be set to a secure algorithm (e.g. HS256). "
- "The 'none' algorithm is not permitted."
- )
- self.algorithm = algorithm
- self.expire_hours = global_args.token_expire_hours
- self.guest_expire_hours = global_args.guest_token_expire_hours
- self.accounts = {}
- invalid_accounts = []
- if auth_accounts:
- for account in auth_accounts.split(","):
- try:
- username, password = account.split(":", 1)
- if not username or not password:
- raise ValueError
- self.accounts[username] = password
- except ValueError:
- invalid_accounts.append(account)
- if invalid_accounts:
- invalid_entries = ", ".join(invalid_accounts)
- logger.error(f"Invalid account format in AUTH_ACCOUNTS: {invalid_entries}")
- raise ValueError(
- "AUTH_ACCOUNTS must use comma-separated user:password pairs."
- )
- def verify_password(self, username: str, plain_password: str) -> bool:
- """
- Verify password for a user. Supports explicit bcrypt values and plaintext.
- Args:
- username: Username to verify
- plain_password: Plaintext password to check
- Returns:
- bool: True if password is correct, False otherwise
- """
- if username not in self.accounts:
- return False
- stored_password = self.accounts[username]
- return verify_password(plain_password, stored_password)
- def create_token(
- self,
- username: str,
- role: str = "user",
- custom_expire_hours: int = None,
- metadata: dict = None,
- ) -> str:
- """
- Create JWT token
- Args:
- username: Username
- role: User role, default is "user", guest is "guest"
- custom_expire_hours: Custom expiration time (hours), if None use default value
- metadata: Additional metadata
- Returns:
- str: Encoded JWT token
- """
- # Choose default expiration time based on role
- if custom_expire_hours is None:
- if role == "guest":
- expire_hours = self.guest_expire_hours
- else:
- expire_hours = self.expire_hours
- else:
- expire_hours = custom_expire_hours
- expire = datetime.now(timezone.utc) + timedelta(hours=expire_hours)
- # Create payload
- payload = TokenPayload(
- sub=username, exp=expire, role=role, metadata=metadata or {}
- )
- return jwt.encode(payload.model_dump(), self.secret, algorithm=self.algorithm)
- def validate_token(self, token: str) -> dict:
- """
- Validate JWT token
- Args:
- token: JWT token
- Returns:
- dict: Dictionary containing user information
- Raises:
- HTTPException: If token is invalid or expired
- """
- try:
- # Explicitly exclude 'none' to prevent algorithm confusion attacks
- allowed_algorithms = [self.algorithm]
- if "none" in (a.lower() for a in allowed_algorithms):
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Insecure JWT algorithm configuration",
- )
- payload = jwt.decode(token, self.secret, algorithms=allowed_algorithms)
- expire_timestamp = payload["exp"]
- expire_time = datetime.fromtimestamp(expire_timestamp, timezone.utc)
- if datetime.now(timezone.utc) > expire_time:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
- )
- # Return complete payload instead of just username
- return {
- "username": payload["sub"],
- "role": payload.get("role", "user"),
- "metadata": payload.get("metadata", {}),
- "exp": expire_time,
- }
- except jwt.PyJWTError:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
- )
- auth_handler = AuthHandler()
|