deps.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from fastapi import Depends, HTTPException, status, Request
  2. from fastapi.security import OAuth2PasswordBearer
  3. from jose import JWTError, jwt
  4. from sqlalchemy.orm import Session
  5. from flowsint_core.core.auth import ALGORITHM, AUTH_SECRET
  6. from flowsint_core.core.postgre_db import get_db
  7. from flowsint_core.core.models import Profile
  8. from typing import Optional
  9. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
  10. def get_current_user(
  11. token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)
  12. ) -> Profile:
  13. credentials_exception = HTTPException(
  14. status_code=status.HTTP_401_UNAUTHORIZED,
  15. detail="Could not validate credentials",
  16. headers={"WWW-Authenticate": "Bearer"},
  17. )
  18. try:
  19. payload = jwt.decode(token, AUTH_SECRET, algorithms=[ALGORITHM])
  20. email: str = payload.get("sub")
  21. if email is None:
  22. raise credentials_exception
  23. except JWTError:
  24. raise credentials_exception
  25. user = db.query(Profile).filter(Profile.email == email).first()
  26. if user is None:
  27. raise credentials_exception
  28. return user
  29. def get_current_user_sse(
  30. request: Request, db: Session = Depends(get_db)
  31. ) -> Profile:
  32. """
  33. Alternative authentication for SSE endpoints that accepts token via query parameter.
  34. EventSource API doesn't support custom headers, so we need to pass the token in the URL.
  35. """
  36. credentials_exception = HTTPException(
  37. status_code=status.HTTP_401_UNAUTHORIZED,
  38. detail="Could not validate credentials",
  39. )
  40. # Try to get token from query parameter
  41. token: Optional[str] = request.query_params.get("token")
  42. # Fallback to Authorization header if query param not present
  43. if not token:
  44. auth_header = request.headers.get("Authorization")
  45. if auth_header and auth_header.startswith("Bearer "):
  46. token = auth_header.replace("Bearer ", "")
  47. if not token:
  48. raise credentials_exception
  49. try:
  50. payload = jwt.decode(token, AUTH_SECRET, algorithms=[ALGORITHM])
  51. email: str = payload.get("sub")
  52. if email is None:
  53. raise credentials_exception
  54. except JWTError:
  55. raise credentials_exception
  56. user = db.query(Profile).filter(Profile.email == email).first()
  57. if user is None:
  58. raise credentials_exception
  59. return user