llm_roles.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. """LLM role registry, configuration types, and runtime mixin.
  2. LightRAG can route different stages of work (entity extraction, keyword
  3. extraction, query, vlm) to distinct LLM bindings. This module owns the
  4. static role registry (:data:`ROLES`), the per-role configuration
  5. (:class:`RoleLLMConfig`), and the :class:`_RoleLLMMixin` that drives the
  6. runtime: builder registration, wrapper rebuilding, hot config updates,
  7. queue cleanup, and queue-status reporting.
  8. """
  9. from __future__ import annotations
  10. import asyncio
  11. import inspect
  12. from copy import deepcopy
  13. from dataclasses import dataclass, field
  14. from functools import partial
  15. from typing import Any, Callable, Mapping
  16. from lightrag.utils import (
  17. get_env_value,
  18. logger,
  19. priority_limit_async_func_call,
  20. )
  21. def _optional_env_int(env_key: str) -> int | None:
  22. return get_env_value(env_key, None, int, special_none=True)
  23. @dataclass(frozen=True)
  24. class RoleSpec:
  25. """Static descriptor for a known LLM role.
  26. Adding a new role anywhere in LightRAG is a single-line edit: append a
  27. ``RoleSpec`` to :data:`ROLES`. Every other component (env var loop in
  28. ``api/config.py``, queue observability, role config update flow) iterates
  29. this registry rather than hard-coding role names.
  30. """
  31. name: str
  32. """Canonical lowercase role key (used in ``role_llm_configs`` dict and CLI/log output)."""
  33. env_prefix: str
  34. """Uppercase prefix used by the API env-var layer, e.g. ``"EXTRACT"`` for
  35. ``EXTRACT_LLM_BINDING`` / ``EXTRACT_MAX_ASYNC_LLM`` / ``EXTRACT_LLM_TIMEOUT``."""
  36. queue_name: str
  37. """Display name passed to ``priority_limit_async_func_call`` for log lines."""
  38. ROLES: tuple[RoleSpec, ...] = (
  39. RoleSpec("extract", "EXTRACT", "extract LLM func"),
  40. RoleSpec("keyword", "KEYWORD", "keyword LLM func"),
  41. RoleSpec("query", "QUERY", "query LLM func"),
  42. RoleSpec("vlm", "VLM", "vlm LLM func"),
  43. )
  44. ROLE_NAMES: frozenset[str] = frozenset(spec.name for spec in ROLES)
  45. ROLES_BY_NAME: dict[str, RoleSpec] = {spec.name: spec for spec in ROLES}
  46. @dataclass
  47. class RoleLLMConfig:
  48. """Per-role LLM override accepted at :class:`LightRAG` init time.
  49. Any field left as ``None`` falls back to the corresponding base LLM
  50. setting (``llm_model_func`` / ``llm_model_kwargs`` / ``llm_model_max_async``
  51. / ``default_llm_timeout``). When ``max_async`` is None at init and the
  52. user did not pass a ``role_llm_configs`` entry for the role, the value is
  53. additionally seeded from ``{ROLE_PREFIX}_MAX_ASYNC_LLM``. ``metadata`` seeds
  54. runtime observability and role-builder context.
  55. """
  56. func: Callable[..., object] | None = None
  57. kwargs: dict[str, Any] | None = None
  58. max_async: int | None = None
  59. timeout: int | None = None
  60. metadata: dict[str, Any] | None = None
  61. @dataclass
  62. class _RoleLLMState:
  63. """Runtime state for one role. Internal — not part of the public API."""
  64. raw_func: Callable[..., object]
  65. kwargs: dict[str, Any] | None
  66. max_async: int | None
  67. timeout: int | None
  68. metadata: dict[str, Any] = field(default_factory=dict)
  69. wrapped: Callable[..., object] | None = None
  70. class _RoleLLMMixin:
  71. """Mixin that owns the role LLM runtime on :class:`LightRAG`.
  72. Mixed into LightRAG only. Relies on attributes that the main class
  73. initializes in ``__post_init__`` (``_role_llm_states``, ``_role_llm_builders``,
  74. ``llm_model_func``, ``llm_model_kwargs``, ``llm_model_max_async``,
  75. ``default_llm_timeout``, ``embedding_func``, ``rerank_model_func``).
  76. """
  77. _SECRET_MARKERS = (
  78. "api_key",
  79. "api-key",
  80. "apikey",
  81. "access_key",
  82. "access-key",
  83. "secret",
  84. "token",
  85. "credential",
  86. "password",
  87. "passphrase",
  88. "pwd",
  89. "auth",
  90. "session",
  91. )
  92. @staticmethod
  93. def _normalize_llm_role(role: str) -> str:
  94. normalized = role.strip().lower()
  95. if normalized not in ROLE_NAMES:
  96. raise ValueError(f"Invalid LLM role: {role}")
  97. return normalized
  98. def register_role_llm_builder(
  99. self,
  100. builder: Callable[
  101. [str, dict[str, Any]], tuple[Callable[..., object], dict[str, Any] | None]
  102. ],
  103. ) -> None:
  104. """Register a runtime builder used by update_llm_role_config for binding/model updates."""
  105. self._llm_role_builder = builder
  106. def set_role_llm_metadata(self, role: str, **metadata: Any) -> None:
  107. """Store role metadata used when rebuilding a role-specific LLM function."""
  108. role = self._normalize_llm_role(role)
  109. state = self._role_llm_states[role]
  110. for key, value in metadata.items():
  111. if value is None:
  112. continue
  113. state.metadata[key] = value
  114. @property
  115. def role_llm_funcs(self) -> Mapping[str, Callable[..., object]]:
  116. """Read-only mapping of role name → wrapped (queue-managed) LLM func."""
  117. return {
  118. name: state.wrapped
  119. for name, state in self._role_llm_states.items()
  120. if state.wrapped is not None
  121. }
  122. @property
  123. def role_llm_kwargs(self) -> Mapping[str, dict[str, Any] | None]:
  124. """Read-only mapping of role name → effective LLM kwargs (None means inherit base)."""
  125. return {name: state.kwargs for name, state in self._role_llm_states.items()}
  126. def _get_effective_role_llm_kwargs(self, role: str) -> dict[str, Any]:
  127. state = self._role_llm_states[self._normalize_llm_role(role)]
  128. if state.kwargs is not None:
  129. return state.kwargs
  130. if state.metadata.get("is_cross_provider"):
  131. return {}
  132. return self.llm_model_kwargs
  133. def _get_effective_role_llm_timeout(self, role: str) -> int:
  134. state = self._role_llm_states[self._normalize_llm_role(role)]
  135. return state.timeout if state.timeout is not None else self.default_llm_timeout
  136. def _get_effective_role_llm_max_async(self, role: str) -> int:
  137. state = self._role_llm_states[self._normalize_llm_role(role)]
  138. return (
  139. state.max_async if state.max_async is not None else self.llm_model_max_async
  140. )
  141. def _wrap_llm_role_func(
  142. self,
  143. role_name: str,
  144. raw_func: Callable[..., object],
  145. max_async: int,
  146. timeout: int,
  147. model_kwargs: dict[str, Any],
  148. ) -> Callable[..., object]:
  149. spec = ROLES_BY_NAME[role_name]
  150. return priority_limit_async_func_call(
  151. max_async,
  152. llm_timeout=timeout,
  153. queue_name=spec.queue_name,
  154. )(
  155. partial(
  156. raw_func,
  157. hashing_kv=self.llm_response_cache,
  158. **model_kwargs,
  159. )
  160. )
  161. def _rebuild_role_llm_funcs(self) -> None:
  162. """Wrap each role's raw_func with its own priority queue.
  163. Base ``llm_model_func`` is intentionally NOT wrapped — concurrency
  164. for the base function is enforced at the role layer (every code path
  165. that calls an LLM goes through a role wrapper).
  166. """
  167. for spec in ROLES:
  168. self._rebuild_single_role_llm_func(spec.name)
  169. def _rebuild_single_role_llm_func(self, role: str) -> None:
  170. role = self._normalize_llm_role(role)
  171. state = self._role_llm_states[role]
  172. state.wrapped = self._wrap_llm_role_func(
  173. role,
  174. state.raw_func,
  175. self._get_effective_role_llm_max_async(role),
  176. self._get_effective_role_llm_timeout(role),
  177. self._get_effective_role_llm_kwargs(role),
  178. )
  179. async def _shutdown_llm_wrapper(self, wrapped_func: Callable[..., object]) -> None:
  180. shutdown = getattr(wrapped_func, "shutdown", None)
  181. if callable(shutdown):
  182. await shutdown(graceful=True)
  183. def _schedule_retired_llm_queue_cleanup(
  184. self, wrapped_func: Callable[..., object] | None
  185. ) -> None:
  186. if wrapped_func is None or not callable(
  187. getattr(wrapped_func, "shutdown", None)
  188. ):
  189. return
  190. try:
  191. loop = asyncio.get_running_loop()
  192. except RuntimeError:
  193. # The retired wrapper's queue and worker tasks are tied to the
  194. # event loop that first used them. Spinning up a fresh loop via
  195. # asyncio.run would either hang on queue.join() or touch
  196. # primitives bound to a closed loop. Skip cleanup with a warning
  197. # — call aupdate_llm_role_config() from an async context for
  198. # deterministic shutdown.
  199. logger.warning(
  200. "update_llm_role_config: skipping retired LLM queue cleanup "
  201. "because no event loop is running; call aupdate_llm_role_config() "
  202. "from an async context for deterministic shutdown"
  203. )
  204. return
  205. task = loop.create_task(self._shutdown_llm_wrapper(wrapped_func))
  206. self._retired_llm_queue_cleanup_tasks.add(task)
  207. task.add_done_callback(self._finalize_retired_llm_queue_cleanup)
  208. def _finalize_retired_llm_queue_cleanup(self, task: asyncio.Task) -> None:
  209. self._retired_llm_queue_cleanup_tasks.discard(task)
  210. try:
  211. task.result()
  212. except asyncio.CancelledError:
  213. pass
  214. except Exception as e:
  215. logger.warning(f"Retired LLM queue cleanup failed: {e}")
  216. async def wait_for_retired_llm_queues(self) -> None:
  217. """Wait until all retired role LLM queues have drained and shut down.
  218. Cleanup failures are logged by ``_finalize_retired_llm_queue_cleanup``
  219. and intentionally swallowed here so callers can rely on this method
  220. always returning once every retired wrapper has finished.
  221. """
  222. while self._retired_llm_queue_cleanup_tasks:
  223. tasks = list(self._retired_llm_queue_cleanup_tasks)
  224. await asyncio.gather(*tasks, return_exceptions=True)
  225. def _apply_llm_role_config_update(
  226. self,
  227. role: str,
  228. *,
  229. model_func: Callable[..., object] | None = None,
  230. model_kwargs: dict[str, Any] | None = None,
  231. max_async: int | None = None,
  232. timeout: int | None = None,
  233. binding: str | None = None,
  234. model: str | None = None,
  235. host: str | None = None,
  236. api_key: str | None = None,
  237. provider_options: dict[str, Any] | None = None,
  238. ) -> Callable[..., object] | None:
  239. role = self._normalize_llm_role(role)
  240. state = self._role_llm_states[role]
  241. old_wrapped = state.wrapped
  242. snapshot = _RoleLLMState(
  243. raw_func=state.raw_func,
  244. kwargs=deepcopy(state.kwargs),
  245. max_async=state.max_async,
  246. timeout=state.timeout,
  247. metadata=deepcopy(state.metadata),
  248. wrapped=state.wrapped,
  249. )
  250. try:
  251. if model_func is not None and not callable(model_func):
  252. raise TypeError("model_func must be callable")
  253. if model_kwargs is not None:
  254. state.kwargs = model_kwargs
  255. if max_async is not None:
  256. state.max_async = max_async
  257. if timeout is not None:
  258. state.timeout = timeout
  259. if model_func is not None:
  260. state.raw_func = model_func
  261. metadata_updated = any(
  262. value is not None
  263. for value in (binding, model, host, api_key, provider_options)
  264. )
  265. if binding is not None:
  266. state.metadata["binding"] = binding
  267. if model is not None:
  268. state.metadata["model"] = model
  269. if host is not None:
  270. state.metadata["host"] = host
  271. if api_key is not None:
  272. state.metadata["api_key"] = api_key
  273. if provider_options is not None:
  274. state.metadata["provider_options"] = provider_options
  275. if "base_binding" in state.metadata and "binding" in state.metadata:
  276. state.metadata["is_cross_provider"] = (
  277. state.metadata["binding"] != state.metadata["base_binding"]
  278. )
  279. if metadata_updated:
  280. builder = getattr(self, "_llm_role_builder", None)
  281. if builder is None and model_func is None:
  282. raise ValueError(
  283. "Runtime role builder is not configured; provide model_func or register_role_llm_builder() first"
  284. )
  285. if builder is not None:
  286. built_func, built_kwargs = builder(role, state.metadata)
  287. state.raw_func = built_func
  288. if model_kwargs is None and built_kwargs is not None:
  289. state.kwargs = built_kwargs
  290. self._rebuild_single_role_llm_func(role)
  291. except Exception:
  292. state.raw_func = snapshot.raw_func
  293. state.kwargs = snapshot.kwargs
  294. state.max_async = snapshot.max_async
  295. state.timeout = snapshot.timeout
  296. state.metadata = snapshot.metadata
  297. state.wrapped = snapshot.wrapped
  298. raise
  299. self._log_llm_role_config("updated", role=role)
  300. return old_wrapped
  301. def update_llm_role_config(
  302. self,
  303. role: str,
  304. *,
  305. model_func: Callable[..., object] | None = None,
  306. model_kwargs: dict[str, Any] | None = None,
  307. max_async: int | None = None,
  308. timeout: int | None = None,
  309. binding: str | None = None,
  310. model: str | None = None,
  311. host: str | None = None,
  312. api_key: str | None = None,
  313. provider_options: dict[str, Any] | None = None,
  314. ) -> None:
  315. """
  316. Update a role-specific LLM configuration at runtime.
  317. Supports lightweight updates (kwargs/max_async/timeout/model_func) directly.
  318. For binding/model/host/api_key/provider_options updates, a role builder must
  319. be registered via register_role_llm_builder().
  320. """
  321. old_wrapped = self._apply_llm_role_config_update(
  322. role,
  323. model_func=model_func,
  324. model_kwargs=model_kwargs,
  325. max_async=max_async,
  326. timeout=timeout,
  327. binding=binding,
  328. model=model,
  329. host=host,
  330. api_key=api_key,
  331. provider_options=provider_options,
  332. )
  333. self._schedule_retired_llm_queue_cleanup(old_wrapped)
  334. async def aupdate_llm_role_config(
  335. self,
  336. role: str,
  337. *,
  338. model_func: Callable[..., object] | None = None,
  339. model_kwargs: dict[str, Any] | None = None,
  340. max_async: int | None = None,
  341. timeout: int | None = None,
  342. binding: str | None = None,
  343. model: str | None = None,
  344. host: str | None = None,
  345. api_key: str | None = None,
  346. provider_options: dict[str, Any] | None = None,
  347. ) -> None:
  348. """Async variant of update_llm_role_config that waits for queue cleanup.
  349. Blocking behavior:
  350. This coroutine awaits a graceful shutdown of the retired role
  351. wrapper's priority queue. The shutdown blocks on
  352. ``queue.join()`` until every already-queued LLM call has been
  353. executed (workers always call ``task_done()`` in ``finally``,
  354. so in-flight requests are not cut off).
  355. The wait is bounded by ``max_task_duration`` of the retired
  356. queue, which is computed as ``llm_timeout * 2 + 15`` seconds
  357. (default ``180 * 2 + 15 = 375`` seconds, ~6 min 15 s). When
  358. this bound is reached, the drain times out and the shutdown
  359. falls through to forced cancellation: pending futures are
  360. cancelled, the queue is cleared, workers are stopped. So this
  361. method **never blocks indefinitely**, but with a deep backlog
  362. of slow LLM calls it can take up to that bound to return, and
  363. in-flight calls past the bound will be cancelled.
  364. If you need a non-blocking switch, use the sync
  365. ``update_llm_role_config()`` (which schedules cleanup as a
  366. background task) and await ``wait_for_retired_llm_queues()``
  367. separately when you want to confirm the old queue is gone.
  368. """
  369. old_wrapped = self._apply_llm_role_config_update(
  370. role,
  371. model_func=model_func,
  372. model_kwargs=model_kwargs,
  373. max_async=max_async,
  374. timeout=timeout,
  375. binding=binding,
  376. model=model,
  377. host=host,
  378. api_key=api_key,
  379. provider_options=provider_options,
  380. )
  381. if old_wrapped is not None:
  382. await self._shutdown_llm_wrapper(old_wrapped)
  383. @classmethod
  384. def _is_secret_key(cls, key: str) -> bool:
  385. lowered = key.lower()
  386. return any(marker in lowered for marker in cls._SECRET_MARKERS)
  387. def _scrubbed_llm_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
  388. """Return a deep copy of ``metadata`` with auth-bearing fields removed.
  389. Auth-bearing fields are stripped entirely — not masked — because a
  390. masked ``"***"`` carries no information for an external consumer
  391. (operators already see ``binding`` / ``host`` to confirm a role is
  392. configured). Stripping makes the invariant simple: anything that
  393. appears in this output is safe to log, cache, ship over the wire.
  394. Components that legitimately need the raw secret (the role builder,
  395. provider clients) read it directly off the private
  396. ``_role_llm_states[role].metadata`` dict.
  397. """
  398. def scrub_value(value: Any) -> Any:
  399. if isinstance(value, Mapping):
  400. return {
  401. key: scrub_value(inner_value)
  402. for key, inner_value in value.items()
  403. if not self._is_secret_key(str(key))
  404. }
  405. if isinstance(value, list):
  406. return [scrub_value(item) for item in value]
  407. if isinstance(value, tuple):
  408. return tuple(scrub_value(item) for item in value)
  409. return deepcopy(value)
  410. return scrub_value(metadata)
  411. def get_llm_role_config(self, role: str | None = None) -> dict[str, Any]:
  412. """Return effective role LLM runtime configuration (observability snapshot).
  413. Each role entry exposes ``binding`` / ``model`` / ``host`` at the top
  414. level for convenience and again inside ``metadata`` as part of the
  415. full runtime snapshot (which may contain extra builder-specific
  416. keys). Auth-bearing fields (``api_key``, ``aws_secret_access_key``,
  417. ``password``, …) are **stripped entirely** from ``metadata`` — this
  418. method is intended for ``/health`` / WebUI / audit output and must
  419. never leak credentials. There is no escape hatch; runtime components
  420. that legitimately need the raw value read it from
  421. ``_role_llm_states[role].metadata`` directly.
  422. """
  423. def role_config(role_name: str) -> dict[str, Any]:
  424. state = self._role_llm_states[role_name]
  425. metadata = self._scrubbed_llm_metadata(state.metadata)
  426. return {
  427. "binding": metadata.get("binding"),
  428. "model": metadata.get("model"),
  429. "host": metadata.get("host"),
  430. "is_cross_provider": metadata.get("is_cross_provider", False),
  431. "max_async": self._get_effective_role_llm_max_async(role_name),
  432. "timeout": self._get_effective_role_llm_timeout(role_name),
  433. "has_model_kwargs": state.kwargs is not None,
  434. "metadata": metadata,
  435. }
  436. if role is not None:
  437. return role_config(self._normalize_llm_role(role))
  438. return {spec.name: role_config(spec.name) for spec in ROLES}
  439. def _log_llm_role_config(self, reason: str, role: str | None = None) -> None:
  440. """Log the sanitized role LLM runtime configuration."""
  441. if role is None:
  442. configs = self.get_llm_role_config()
  443. role_names = [spec.name for spec in ROLES]
  444. logger.info(f"Role LLM Configuration ({reason}):")
  445. else:
  446. normalized_role = self._normalize_llm_role(role)
  447. configs = {normalized_role: self.get_llm_role_config(normalized_role)}
  448. role_names = [normalized_role]
  449. logger.info(f"Role LLM Configuration ({reason}: {normalized_role}):")
  450. for role_name in role_names:
  451. cfg = configs[role_name]
  452. logger.info(
  453. " - %s: %s/%s, host=%s, max_async=%s, timeout=%s",
  454. role_name,
  455. cfg["binding"],
  456. cfg["model"],
  457. cfg["host"],
  458. cfg["max_async"],
  459. cfg["timeout"],
  460. )
  461. async def _queue_status_for_func(
  462. self, func: Callable[..., object] | None
  463. ) -> dict[str, Any]:
  464. if func is None:
  465. return {"available": False}
  466. get_stats = getattr(func, "get_queue_stats", None)
  467. if not callable(get_stats):
  468. return {"available": False}
  469. stats = get_stats()
  470. if inspect.isawaitable(stats):
  471. stats = await stats
  472. stats["available"] = True
  473. return stats
  474. async def get_llm_queue_status(self, include_base: bool = True) -> dict[str, Any]:
  475. """Return queue status for each role's wrapped LLM func.
  476. The base ``llm_model_func`` is no longer queue-wrapped, so it is not
  477. reported here. ``include_base`` is kept for signature compatibility
  478. but has no effect.
  479. """
  480. del include_base # base is unwrapped — see docstring
  481. result: dict[str, Any] = {}
  482. for spec in ROLES:
  483. state = self._role_llm_states.get(spec.name)
  484. result[spec.name] = await self._queue_status_for_func(
  485. state.wrapped if state else None
  486. )
  487. return result
  488. async def get_embedding_queue_status(self) -> dict[str, Any]:
  489. """Return queue status for the wrapped embedding function."""
  490. return await self._queue_status_for_func(
  491. self.embedding_func.func if self.embedding_func is not None else None
  492. )
  493. async def get_rerank_queue_status(self) -> dict[str, Any]:
  494. """Return queue status for the wrapped rerank function."""
  495. return await self._queue_status_for_func(self.rerank_model_func)