| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- from fastapi import Depends, HTTPException, status, Request
- from fastapi.security import OAuth2PasswordBearer
- from jose import JWTError, jwt
- from sqlalchemy.orm import Session
- from flowsint_core.core.auth import ALGORITHM, AUTH_SECRET
- from flowsint_core.core.postgre_db import get_db
- from flowsint_core.core.models import Profile
- from typing import Optional
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
- def get_current_user(
- token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)
- ) -> Profile:
- credentials_exception = HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Could not validate credentials",
- headers={"WWW-Authenticate": "Bearer"},
- )
- try:
- payload = jwt.decode(token, AUTH_SECRET, algorithms=[ALGORITHM])
- email: str = payload.get("sub")
- if email is None:
- raise credentials_exception
- except JWTError:
- raise credentials_exception
- user = db.query(Profile).filter(Profile.email == email).first()
- if user is None:
- raise credentials_exception
- return user
- def get_current_user_sse(
- request: Request, db: Session = Depends(get_db)
- ) -> Profile:
- """
- Alternative authentication for SSE endpoints that accepts token via query parameter.
- EventSource API doesn't support custom headers, so we need to pass the token in the URL.
- """
- credentials_exception = HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Could not validate credentials",
- )
- # Try to get token from query parameter
- token: Optional[str] = request.query_params.get("token")
- # Fallback to Authorization header if query param not present
- if not token:
- auth_header = request.headers.get("Authorization")
- if auth_header and auth_header.startswith("Bearer "):
- token = auth_header.replace("Bearer ", "")
- if not token:
- raise credentials_exception
- try:
- payload = jwt.decode(token, AUTH_SECRET, algorithms=[ALGORITHM])
- email: str = payload.get("sub")
- if email is None:
- raise credentials_exception
- except JWTError:
- raise credentials_exception
- user = db.query(Profile).filter(Profile.email == email).first()
- if user is None:
- raise credentials_exception
- return user
|