utils_api.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. """
  2. Utility functions for the LightRAG API.
  3. """
  4. import os
  5. import argparse
  6. from typing import Optional, List, Tuple
  7. import sys
  8. import time
  9. import logging
  10. from ascii_colors import ASCIIColors
  11. from .._version import __api_version__ as api_version
  12. from .._version import __version__ as core_version
  13. from lightrag.constants import (
  14. DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
  15. )
  16. from lightrag.api.runtime_validation import validate_runtime_target_from_env_file
  17. from fastapi import HTTPException, Security, Request, Response, status
  18. from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
  19. from starlette.status import HTTP_403_FORBIDDEN
  20. from .auth import auth_handler
  21. from .config import ollama_server_infos, global_args, get_env_value
  22. logger = logging.getLogger("lightrag")
  23. # ========== Token Renewal Rate Limiting ==========
  24. # Cache to track last renewal time per user (username as key)
  25. # Format: {username: last_renewal_timestamp}
  26. _token_renewal_cache: dict[str, float] = {}
  27. _RENEWAL_MIN_INTERVAL = 60 # Minimum 60 seconds between renewals for same user
  28. # ========== Token Renewal Path Exclusions ==========
  29. # Paths that should NOT trigger token auto-renewal
  30. # - /health: Health check endpoint, no login required
  31. # - /documents/paginated: Client polls this frequently (5-30s), renewal not needed
  32. # - /documents/pipeline_status: Client polls this very frequently (2s), renewal not needed
  33. _TOKEN_RENEWAL_SKIP_PATHS = [
  34. "/health",
  35. "/documents/paginated",
  36. "/documents/pipeline_status",
  37. ]
  38. def check_env_file():
  39. """
  40. Check if .env file exists and handle user confirmation if needed.
  41. Returns True if should continue, False if should exit.
  42. """
  43. env_path = ".env"
  44. if not os.path.exists(env_path):
  45. warning_msg = "Warning: Startup directory must contain .env file for multi-instance support."
  46. ASCIIColors.yellow(warning_msg)
  47. # Check if running in interactive terminal
  48. if sys.stdin.isatty():
  49. response = input("Do you want to continue? (yes/NO): ")
  50. if response.lower() != "yes":
  51. ASCIIColors.red("Server startup cancelled")
  52. return False
  53. return True
  54. is_valid, error_message = validate_runtime_target_from_env_file(env_path)
  55. if not is_valid:
  56. for line in error_message.splitlines():
  57. ASCIIColors.red(line)
  58. return False
  59. return True
  60. # Get whitelist paths from global_args, only once during initialization
  61. whitelist_paths = global_args.whitelist_paths.split(",")
  62. # Pre-compile path matching patterns
  63. whitelist_patterns: List[Tuple[str, bool]] = []
  64. for path in whitelist_paths:
  65. path = path.strip()
  66. if path:
  67. # If path ends with /*, match all paths with that prefix
  68. if path.endswith("/*"):
  69. prefix = path[:-2]
  70. whitelist_patterns.append((prefix, True)) # (prefix, is_prefix_match)
  71. else:
  72. whitelist_patterns.append((path, False)) # (exact_path, is_prefix_match)
  73. # Global authentication configuration
  74. auth_configured = bool(auth_handler.accounts)
  75. def get_combined_auth_dependency(api_key: Optional[str] = None):
  76. """
  77. Create a combined authentication dependency that implements authentication logic
  78. based on API key, OAuth2 token, and whitelist paths.
  79. Args:
  80. api_key (Optional[str]): API key for validation
  81. Returns:
  82. Callable: A dependency function that implements the authentication logic
  83. """
  84. # Use global whitelist_patterns and auth_configured variables
  85. # whitelist_patterns and auth_configured are already initialized at module level
  86. # Only calculate api_key_configured as it depends on the function parameter
  87. api_key_configured = bool(api_key)
  88. # Create security dependencies with proper descriptions for Swagger UI
  89. oauth2_scheme = OAuth2PasswordBearer(
  90. tokenUrl="login", auto_error=False, description="OAuth2 Password Authentication"
  91. )
  92. # If API key is configured, create an API key header security
  93. api_key_header = None
  94. if api_key_configured:
  95. api_key_header = APIKeyHeader(
  96. name="X-API-Key", auto_error=False, description="API Key Authentication"
  97. )
  98. async def combined_dependency(
  99. request: Request,
  100. response: Response, # Added: needed to return new token via response header
  101. token: str = Security(oauth2_scheme),
  102. api_key_header_value: Optional[str] = None
  103. if api_key_header is None
  104. else Security(api_key_header),
  105. ):
  106. # 1. Check if path is in whitelist
  107. path = request.url.path
  108. for pattern, is_prefix in whitelist_patterns:
  109. if (is_prefix and path.startswith(pattern)) or (
  110. not is_prefix and path == pattern
  111. ):
  112. return # Whitelist path, allow access
  113. # 2. Validate token first if provided in the request (Ensure 401 error if token is invalid)
  114. if token:
  115. try:
  116. token_info = auth_handler.validate_token(token)
  117. # ========== Token Auto-Renewal Logic ==========
  118. from lightrag.api.config import global_args
  119. from datetime import datetime, timezone
  120. if global_args.token_auto_renew:
  121. # Check if current path should skip token renewal
  122. skip_renewal = any(
  123. path == skip_path or path.startswith(skip_path + "/")
  124. for skip_path in _TOKEN_RENEWAL_SKIP_PATHS
  125. )
  126. if skip_renewal:
  127. logger.debug(f"Token auto-renewal skipped for path: {path}")
  128. else:
  129. try:
  130. expire_time = token_info.get("exp")
  131. if expire_time:
  132. # Calculate remaining time ratio
  133. now = datetime.now(timezone.utc)
  134. remaining_seconds = (expire_time - now).total_seconds()
  135. # Get original token expiration duration
  136. role = token_info.get("role", "user")
  137. total_hours = (
  138. auth_handler.guest_expire_hours
  139. if role == "guest"
  140. else auth_handler.expire_hours
  141. )
  142. total_seconds = total_hours * 3600
  143. # Issue new token if remaining time < threshold
  144. if (
  145. remaining_seconds
  146. < total_seconds * global_args.token_renew_threshold
  147. ):
  148. # ========== Rate Limiting Check ==========
  149. username = token_info["username"]
  150. current_time = time.time()
  151. last_renewal = _token_renewal_cache.get(username, 0)
  152. time_since_last_renewal = (
  153. current_time - last_renewal
  154. )
  155. # Only renew if enough time has passed since last renewal
  156. if time_since_last_renewal >= _RENEWAL_MIN_INTERVAL:
  157. new_token = auth_handler.create_token(
  158. username=username,
  159. role=role,
  160. metadata=token_info.get("metadata", {}),
  161. )
  162. # Return new token via response header
  163. response.headers["X-New-Token"] = new_token
  164. # Update renewal cache
  165. _token_renewal_cache[username] = current_time
  166. # Optional: log renewal
  167. logger.info(
  168. f"Token auto-renewed for user {username} "
  169. f"(role: {role}, remaining: {remaining_seconds:.0f}s)"
  170. )
  171. else:
  172. # Log skip due to rate limit
  173. logger.debug(
  174. f"Token renewal skipped for {username} "
  175. f"(rate limit: last renewal {time_since_last_renewal:.0f}s ago)"
  176. )
  177. # ========== End of Rate Limiting Check ==========
  178. except Exception as e:
  179. # Renewal failure should not affect normal request, just log
  180. logger.warning(f"Token auto-renew failed: {e}")
  181. # ========== End of Token Auto-Renewal Logic ==========
  182. # Accept guest token if no auth is configured
  183. if not auth_configured and token_info.get("role") == "guest":
  184. return
  185. # Accept non-guest token if auth is configured
  186. if auth_configured and token_info.get("role") != "guest":
  187. return
  188. # Token validation failed, immediately return 401 error
  189. raise HTTPException(
  190. status_code=status.HTTP_401_UNAUTHORIZED,
  191. detail="Invalid token. Please login again.",
  192. )
  193. except HTTPException as e:
  194. # If already a 401 error, re-raise it
  195. if e.status_code == status.HTTP_401_UNAUTHORIZED:
  196. raise
  197. # For other exceptions, continue processing
  198. # 3. Acept all request if no API protection needed
  199. if not auth_configured and not api_key_configured:
  200. return
  201. # 4. Validate API key if provided and API-Key authentication is configured
  202. if (
  203. api_key_configured
  204. and api_key_header_value
  205. and api_key_header_value == api_key
  206. ):
  207. return # API key validation successful
  208. ### Authentication failed ####
  209. # if password authentication is configured but not provided, ensure 401 error if auth_configured
  210. if auth_configured and not token:
  211. raise HTTPException(
  212. status_code=status.HTTP_401_UNAUTHORIZED,
  213. detail="No credentials provided. Please login.",
  214. )
  215. # if api key is provided but validation failed
  216. if api_key_header_value:
  217. raise HTTPException(
  218. status_code=HTTP_403_FORBIDDEN,
  219. detail="Invalid API Key",
  220. )
  221. # if api_key_configured but not provided
  222. if api_key_configured and not api_key_header_value:
  223. raise HTTPException(
  224. status_code=HTTP_403_FORBIDDEN,
  225. detail="API Key required",
  226. )
  227. # Otherwise: refuse access and return 403 error
  228. raise HTTPException(
  229. status_code=HTTP_403_FORBIDDEN,
  230. detail="API Key required or login authentication required.",
  231. )
  232. return combined_dependency
  233. def display_splash_screen(args: argparse.Namespace) -> None:
  234. """
  235. Display a colorful splash screen showing LightRAG server configuration
  236. Args:
  237. args: Parsed command line arguments
  238. """
  239. # Banner
  240. # Banner
  241. top_border = "╔══════════════════════════════════════════════════════════════╗"
  242. bottom_border = "╚══════════════════════════════════════════════════════════════╝"
  243. width = len(top_border) - 4 # width inside the borders
  244. line1_text = f"LightRAG Server v{core_version}/{api_version}"
  245. line2_text = "Fast, Lightweight RAG Server Implementation"
  246. line1 = f"║ {line1_text.center(width)} ║"
  247. line2 = f"║ {line2_text.center(width)} ║"
  248. banner = f"""
  249. {top_border}
  250. {line1}
  251. {line2}
  252. {bottom_border}
  253. """
  254. ASCIIColors.cyan(banner)
  255. # Server Configuration
  256. ASCIIColors.magenta("\n📡 Server Configuration:")
  257. ASCIIColors.white(" ├─ Host: ", end="")
  258. ASCIIColors.yellow(f"{args.host}")
  259. ASCIIColors.white(" ├─ Port: ", end="")
  260. ASCIIColors.yellow(f"{args.port}")
  261. ASCIIColors.white(" ├─ Workers: ", end="")
  262. ASCIIColors.yellow(f"{args.workers}")
  263. ASCIIColors.white(" ├─ Timeout: ", end="")
  264. ASCIIColors.yellow(f"{args.timeout}")
  265. ASCIIColors.white(" ├─ CORS Origins: ", end="")
  266. ASCIIColors.yellow(f"{args.cors_origins}")
  267. ASCIIColors.white(" ├─ SSL Enabled: ", end="")
  268. ASCIIColors.yellow(f"{args.ssl}")
  269. if args.ssl:
  270. ASCIIColors.white(" ├─ SSL Cert: ", end="")
  271. ASCIIColors.yellow(f"{args.ssl_certfile}")
  272. ASCIIColors.white(" ├─ SSL Key: ", end="")
  273. ASCIIColors.yellow(f"{args.ssl_keyfile}")
  274. ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
  275. ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
  276. ASCIIColors.white(" ├─ Log Level: ", end="")
  277. ASCIIColors.yellow(f"{args.log_level}")
  278. ASCIIColors.white(" ├─ Verbose Debug: ", end="")
  279. ASCIIColors.yellow(f"{args.verbose}")
  280. ASCIIColors.white(" ├─ API Key: ", end="")
  281. ASCIIColors.yellow("Set" if args.key else "Not Set")
  282. ASCIIColors.white(" └─ JWT Auth: ", end="")
  283. ASCIIColors.yellow("Enabled" if args.auth_accounts else "Disabled")
  284. # Directory Configuration
  285. ASCIIColors.magenta("\n📂 Directory Configuration:")
  286. ASCIIColors.white(" ├─ Working Directory: ", end="")
  287. ASCIIColors.yellow(f"{args.working_dir}")
  288. ASCIIColors.white(" └─ Input Directory: ", end="")
  289. ASCIIColors.yellow(f"{args.input_dir}")
  290. # Embedding Configuration
  291. ASCIIColors.magenta("\n📊 Embedding Configuration:")
  292. ASCIIColors.white(" ├─ Binding: ", end="")
  293. ASCIIColors.yellow(f"{args.embedding_binding}")
  294. ASCIIColors.white(" ├─ Host: ", end="")
  295. ASCIIColors.yellow(f"{args.embedding_binding_host}")
  296. ASCIIColors.white(" ├─ Model: ", end="")
  297. ASCIIColors.yellow(f"{args.embedding_model}")
  298. ASCIIColors.white(" ├─ Dimensions: ", end="")
  299. ASCIIColors.yellow(f"{args.embedding_dim}")
  300. ASCIIColors.white(" └─ Asymmetric: ", end="")
  301. ASCIIColors.yellow(f"{args.embedding_asymmetric}")
  302. # RAG Configuration
  303. ASCIIColors.magenta("\n⚙️ RAG Configuration:")
  304. ASCIIColors.white(" ├─ Summary Language: ", end="")
  305. ASCIIColors.yellow(f"{args.summary_language}")
  306. ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
  307. ASCIIColors.yellow(f"{args.max_parallel_insert}")
  308. ASCIIColors.white(" ├─ Chunk Size: ", end="")
  309. ASCIIColors.yellow(f"{args.chunk_size}")
  310. ASCIIColors.white(" ├─ Chunk Overlap Size: ", end="")
  311. ASCIIColors.yellow(f"{args.chunk_overlap_size}")
  312. ASCIIColors.white(" ├─ Cosine Threshold: ", end="")
  313. ASCIIColors.yellow(f"{args.cosine_threshold}")
  314. ASCIIColors.white(" ├─ Top-K: ", end="")
  315. ASCIIColors.yellow(f"{args.top_k}")
  316. ASCIIColors.white(" └─ Force LLM Summary on Merge: ", end="")
  317. ASCIIColors.yellow(
  318. f"{get_env_value('FORCE_LLM_SUMMARY_ON_MERGE', DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, int)}"
  319. )
  320. # System Configuration
  321. ASCIIColors.magenta("\n💾 Storage Configuration:")
  322. ASCIIColors.white(" ├─ KV Storage: ", end="")
  323. ASCIIColors.yellow(f"{args.kv_storage}")
  324. ASCIIColors.white(" ├─ Vector Storage: ", end="")
  325. ASCIIColors.yellow(f"{args.vector_storage}")
  326. ASCIIColors.white(" ├─ Graph Storage: ", end="")
  327. ASCIIColors.yellow(f"{args.graph_storage}")
  328. ASCIIColors.white(" ├─ Document Status Storage: ", end="")
  329. ASCIIColors.yellow(f"{args.doc_status_storage}")
  330. ASCIIColors.white(" └─ Workspace: ", end="")
  331. ASCIIColors.yellow(f"{args.workspace if args.workspace else '-'}")
  332. # Server Status
  333. ASCIIColors.green("\n✨ Server starting up...\n")
  334. # Server Access Information
  335. protocol = "https" if args.ssl else "http"
  336. if args.host == "0.0.0.0":
  337. ASCIIColors.magenta("\n🌐 Server Access Information:")
  338. ASCIIColors.white(" ├─ WebUI (local): ", end="")
  339. ASCIIColors.yellow(f"{protocol}://localhost:{args.port}")
  340. ASCIIColors.white(" ├─ Remote Access: ", end="")
  341. ASCIIColors.yellow(f"{protocol}://<your-ip-address>:{args.port}")
  342. ASCIIColors.white(" ├─ API Documentation (local): ", end="")
  343. ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs")
  344. ASCIIColors.white(" └─ Alternative Documentation (local): ", end="")
  345. ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc")
  346. ASCIIColors.magenta("\n📝 Note:")
  347. ASCIIColors.cyan(""" Since the server is running on 0.0.0.0:
  348. - Use 'localhost' or '127.0.0.1' for local access
  349. - Use your machine's IP address for remote access
  350. - To find your IP address:
  351. • Windows: Run 'ipconfig' in terminal
  352. • Linux/Mac: Run 'ifconfig' or 'ip addr' in terminal
  353. """)
  354. else:
  355. base_url = f"{protocol}://{args.host}:{args.port}"
  356. ASCIIColors.magenta("\n🌐 Server Access Information:")
  357. ASCIIColors.white(" ├─ WebUI (local): ", end="")
  358. ASCIIColors.yellow(f"{base_url}")
  359. ASCIIColors.white(" ├─ API Documentation: ", end="")
  360. ASCIIColors.yellow(f"{base_url}/docs")
  361. ASCIIColors.white(" └─ Alternative Documentation: ", end="")
  362. ASCIIColors.yellow(f"{base_url}/redoc")
  363. # Security Notice
  364. if args.key:
  365. ASCIIColors.yellow("\n⚠️ Security Notice:")
  366. ASCIIColors.white(""" API Key authentication is enabled.
  367. Make sure to include the X-API-Key header in all your requests.
  368. """)
  369. if args.auth_accounts:
  370. ASCIIColors.yellow("\n⚠️ Security Notice:")
  371. ASCIIColors.white(""" JWT authentication is enabled.
  372. Make sure to login before making the request, and include the 'Authorization' in the header.
  373. """)
  374. # Ensure splash output flush to system log
  375. sys.stdout.flush()