| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572 |
- """LLM role registry, configuration types, and runtime mixin.
- LightRAG can route different stages of work (entity extraction, keyword
- extraction, query, vlm) to distinct LLM bindings. This module owns the
- static role registry (:data:`ROLES`), the per-role configuration
- (:class:`RoleLLMConfig`), and the :class:`_RoleLLMMixin` that drives the
- runtime: builder registration, wrapper rebuilding, hot config updates,
- queue cleanup, and queue-status reporting.
- """
- from __future__ import annotations
- import asyncio
- import inspect
- from copy import deepcopy
- from dataclasses import dataclass, field
- from functools import partial
- from typing import Any, Callable, Mapping
- from lightrag.utils import (
- get_env_value,
- logger,
- priority_limit_async_func_call,
- )
- def _optional_env_int(env_key: str) -> int | None:
- return get_env_value(env_key, None, int, special_none=True)
- @dataclass(frozen=True)
- class RoleSpec:
- """Static descriptor for a known LLM role.
- Adding a new role anywhere in LightRAG is a single-line edit: append a
- ``RoleSpec`` to :data:`ROLES`. Every other component (env var loop in
- ``api/config.py``, queue observability, role config update flow) iterates
- this registry rather than hard-coding role names.
- """
- name: str
- """Canonical lowercase role key (used in ``role_llm_configs`` dict and CLI/log output)."""
- env_prefix: str
- """Uppercase prefix used by the API env-var layer, e.g. ``"EXTRACT"`` for
- ``EXTRACT_LLM_BINDING`` / ``EXTRACT_MAX_ASYNC_LLM`` / ``EXTRACT_LLM_TIMEOUT``."""
- queue_name: str
- """Display name passed to ``priority_limit_async_func_call`` for log lines."""
- ROLES: tuple[RoleSpec, ...] = (
- RoleSpec("extract", "EXTRACT", "extract LLM func"),
- RoleSpec("keyword", "KEYWORD", "keyword LLM func"),
- RoleSpec("query", "QUERY", "query LLM func"),
- RoleSpec("vlm", "VLM", "vlm LLM func"),
- )
- ROLE_NAMES: frozenset[str] = frozenset(spec.name for spec in ROLES)
- ROLES_BY_NAME: dict[str, RoleSpec] = {spec.name: spec for spec in ROLES}
- @dataclass
- class RoleLLMConfig:
- """Per-role LLM override accepted at :class:`LightRAG` init time.
- Any field left as ``None`` falls back to the corresponding base LLM
- setting (``llm_model_func`` / ``llm_model_kwargs`` / ``llm_model_max_async``
- / ``default_llm_timeout``). When ``max_async`` is None at init and the
- user did not pass a ``role_llm_configs`` entry for the role, the value is
- additionally seeded from ``{ROLE_PREFIX}_MAX_ASYNC_LLM``. ``metadata`` seeds
- runtime observability and role-builder context.
- """
- func: Callable[..., object] | None = None
- kwargs: dict[str, Any] | None = None
- max_async: int | None = None
- timeout: int | None = None
- metadata: dict[str, Any] | None = None
- @dataclass
- class _RoleLLMState:
- """Runtime state for one role. Internal — not part of the public API."""
- raw_func: Callable[..., object]
- kwargs: dict[str, Any] | None
- max_async: int | None
- timeout: int | None
- metadata: dict[str, Any] = field(default_factory=dict)
- wrapped: Callable[..., object] | None = None
- class _RoleLLMMixin:
- """Mixin that owns the role LLM runtime on :class:`LightRAG`.
- Mixed into LightRAG only. Relies on attributes that the main class
- initializes in ``__post_init__`` (``_role_llm_states``, ``_role_llm_builders``,
- ``llm_model_func``, ``llm_model_kwargs``, ``llm_model_max_async``,
- ``default_llm_timeout``, ``embedding_func``, ``rerank_model_func``).
- """
- _SECRET_MARKERS = (
- "api_key",
- "api-key",
- "apikey",
- "access_key",
- "access-key",
- "secret",
- "token",
- "credential",
- "password",
- "passphrase",
- "pwd",
- "auth",
- "session",
- )
- @staticmethod
- def _normalize_llm_role(role: str) -> str:
- normalized = role.strip().lower()
- if normalized not in ROLE_NAMES:
- raise ValueError(f"Invalid LLM role: {role}")
- return normalized
- def register_role_llm_builder(
- self,
- builder: Callable[
- [str, dict[str, Any]], tuple[Callable[..., object], dict[str, Any] | None]
- ],
- ) -> None:
- """Register a runtime builder used by update_llm_role_config for binding/model updates."""
- self._llm_role_builder = builder
- def set_role_llm_metadata(self, role: str, **metadata: Any) -> None:
- """Store role metadata used when rebuilding a role-specific LLM function."""
- role = self._normalize_llm_role(role)
- state = self._role_llm_states[role]
- for key, value in metadata.items():
- if value is None:
- continue
- state.metadata[key] = value
- @property
- def role_llm_funcs(self) -> Mapping[str, Callable[..., object]]:
- """Read-only mapping of role name → wrapped (queue-managed) LLM func."""
- return {
- name: state.wrapped
- for name, state in self._role_llm_states.items()
- if state.wrapped is not None
- }
- @property
- def role_llm_kwargs(self) -> Mapping[str, dict[str, Any] | None]:
- """Read-only mapping of role name → effective LLM kwargs (None means inherit base)."""
- return {name: state.kwargs for name, state in self._role_llm_states.items()}
- def _get_effective_role_llm_kwargs(self, role: str) -> dict[str, Any]:
- state = self._role_llm_states[self._normalize_llm_role(role)]
- if state.kwargs is not None:
- return state.kwargs
- if state.metadata.get("is_cross_provider"):
- return {}
- return self.llm_model_kwargs
- def _get_effective_role_llm_timeout(self, role: str) -> int:
- state = self._role_llm_states[self._normalize_llm_role(role)]
- return state.timeout if state.timeout is not None else self.default_llm_timeout
- def _get_effective_role_llm_max_async(self, role: str) -> int:
- state = self._role_llm_states[self._normalize_llm_role(role)]
- return (
- state.max_async if state.max_async is not None else self.llm_model_max_async
- )
- def _wrap_llm_role_func(
- self,
- role_name: str,
- raw_func: Callable[..., object],
- max_async: int,
- timeout: int,
- model_kwargs: dict[str, Any],
- ) -> Callable[..., object]:
- spec = ROLES_BY_NAME[role_name]
- return priority_limit_async_func_call(
- max_async,
- llm_timeout=timeout,
- queue_name=spec.queue_name,
- )(
- partial(
- raw_func,
- hashing_kv=self.llm_response_cache,
- **model_kwargs,
- )
- )
- def _rebuild_role_llm_funcs(self) -> None:
- """Wrap each role's raw_func with its own priority queue.
- Base ``llm_model_func`` is intentionally NOT wrapped — concurrency
- for the base function is enforced at the role layer (every code path
- that calls an LLM goes through a role wrapper).
- """
- for spec in ROLES:
- self._rebuild_single_role_llm_func(spec.name)
- def _rebuild_single_role_llm_func(self, role: str) -> None:
- role = self._normalize_llm_role(role)
- state = self._role_llm_states[role]
- state.wrapped = self._wrap_llm_role_func(
- role,
- state.raw_func,
- self._get_effective_role_llm_max_async(role),
- self._get_effective_role_llm_timeout(role),
- self._get_effective_role_llm_kwargs(role),
- )
- async def _shutdown_llm_wrapper(self, wrapped_func: Callable[..., object]) -> None:
- shutdown = getattr(wrapped_func, "shutdown", None)
- if callable(shutdown):
- await shutdown(graceful=True)
- def _schedule_retired_llm_queue_cleanup(
- self, wrapped_func: Callable[..., object] | None
- ) -> None:
- if wrapped_func is None or not callable(
- getattr(wrapped_func, "shutdown", None)
- ):
- return
- try:
- loop = asyncio.get_running_loop()
- except RuntimeError:
- # The retired wrapper's queue and worker tasks are tied to the
- # event loop that first used them. Spinning up a fresh loop via
- # asyncio.run would either hang on queue.join() or touch
- # primitives bound to a closed loop. Skip cleanup with a warning
- # — call aupdate_llm_role_config() from an async context for
- # deterministic shutdown.
- logger.warning(
- "update_llm_role_config: skipping retired LLM queue cleanup "
- "because no event loop is running; call aupdate_llm_role_config() "
- "from an async context for deterministic shutdown"
- )
- return
- task = loop.create_task(self._shutdown_llm_wrapper(wrapped_func))
- self._retired_llm_queue_cleanup_tasks.add(task)
- task.add_done_callback(self._finalize_retired_llm_queue_cleanup)
- def _finalize_retired_llm_queue_cleanup(self, task: asyncio.Task) -> None:
- self._retired_llm_queue_cleanup_tasks.discard(task)
- try:
- task.result()
- except asyncio.CancelledError:
- pass
- except Exception as e:
- logger.warning(f"Retired LLM queue cleanup failed: {e}")
- async def wait_for_retired_llm_queues(self) -> None:
- """Wait until all retired role LLM queues have drained and shut down.
- Cleanup failures are logged by ``_finalize_retired_llm_queue_cleanup``
- and intentionally swallowed here so callers can rely on this method
- always returning once every retired wrapper has finished.
- """
- while self._retired_llm_queue_cleanup_tasks:
- tasks = list(self._retired_llm_queue_cleanup_tasks)
- await asyncio.gather(*tasks, return_exceptions=True)
- def _apply_llm_role_config_update(
- self,
- role: str,
- *,
- model_func: Callable[..., object] | None = None,
- model_kwargs: dict[str, Any] | None = None,
- max_async: int | None = None,
- timeout: int | None = None,
- binding: str | None = None,
- model: str | None = None,
- host: str | None = None,
- api_key: str | None = None,
- provider_options: dict[str, Any] | None = None,
- ) -> Callable[..., object] | None:
- role = self._normalize_llm_role(role)
- state = self._role_llm_states[role]
- old_wrapped = state.wrapped
- snapshot = _RoleLLMState(
- raw_func=state.raw_func,
- kwargs=deepcopy(state.kwargs),
- max_async=state.max_async,
- timeout=state.timeout,
- metadata=deepcopy(state.metadata),
- wrapped=state.wrapped,
- )
- try:
- if model_func is not None and not callable(model_func):
- raise TypeError("model_func must be callable")
- if model_kwargs is not None:
- state.kwargs = model_kwargs
- if max_async is not None:
- state.max_async = max_async
- if timeout is not None:
- state.timeout = timeout
- if model_func is not None:
- state.raw_func = model_func
- metadata_updated = any(
- value is not None
- for value in (binding, model, host, api_key, provider_options)
- )
- if binding is not None:
- state.metadata["binding"] = binding
- if model is not None:
- state.metadata["model"] = model
- if host is not None:
- state.metadata["host"] = host
- if api_key is not None:
- state.metadata["api_key"] = api_key
- if provider_options is not None:
- state.metadata["provider_options"] = provider_options
- if "base_binding" in state.metadata and "binding" in state.metadata:
- state.metadata["is_cross_provider"] = (
- state.metadata["binding"] != state.metadata["base_binding"]
- )
- if metadata_updated:
- builder = getattr(self, "_llm_role_builder", None)
- if builder is None and model_func is None:
- raise ValueError(
- "Runtime role builder is not configured; provide model_func or register_role_llm_builder() first"
- )
- if builder is not None:
- built_func, built_kwargs = builder(role, state.metadata)
- state.raw_func = built_func
- if model_kwargs is None and built_kwargs is not None:
- state.kwargs = built_kwargs
- self._rebuild_single_role_llm_func(role)
- except Exception:
- state.raw_func = snapshot.raw_func
- state.kwargs = snapshot.kwargs
- state.max_async = snapshot.max_async
- state.timeout = snapshot.timeout
- state.metadata = snapshot.metadata
- state.wrapped = snapshot.wrapped
- raise
- self._log_llm_role_config("updated", role=role)
- return old_wrapped
- def update_llm_role_config(
- self,
- role: str,
- *,
- model_func: Callable[..., object] | None = None,
- model_kwargs: dict[str, Any] | None = None,
- max_async: int | None = None,
- timeout: int | None = None,
- binding: str | None = None,
- model: str | None = None,
- host: str | None = None,
- api_key: str | None = None,
- provider_options: dict[str, Any] | None = None,
- ) -> None:
- """
- Update a role-specific LLM configuration at runtime.
- Supports lightweight updates (kwargs/max_async/timeout/model_func) directly.
- For binding/model/host/api_key/provider_options updates, a role builder must
- be registered via register_role_llm_builder().
- """
- old_wrapped = self._apply_llm_role_config_update(
- role,
- model_func=model_func,
- model_kwargs=model_kwargs,
- max_async=max_async,
- timeout=timeout,
- binding=binding,
- model=model,
- host=host,
- api_key=api_key,
- provider_options=provider_options,
- )
- self._schedule_retired_llm_queue_cleanup(old_wrapped)
- async def aupdate_llm_role_config(
- self,
- role: str,
- *,
- model_func: Callable[..., object] | None = None,
- model_kwargs: dict[str, Any] | None = None,
- max_async: int | None = None,
- timeout: int | None = None,
- binding: str | None = None,
- model: str | None = None,
- host: str | None = None,
- api_key: str | None = None,
- provider_options: dict[str, Any] | None = None,
- ) -> None:
- """Async variant of update_llm_role_config that waits for queue cleanup.
- Blocking behavior:
- This coroutine awaits a graceful shutdown of the retired role
- wrapper's priority queue. The shutdown blocks on
- ``queue.join()`` until every already-queued LLM call has been
- executed (workers always call ``task_done()`` in ``finally``,
- so in-flight requests are not cut off).
- The wait is bounded by ``max_task_duration`` of the retired
- queue, which is computed as ``llm_timeout * 2 + 15`` seconds
- (default ``180 * 2 + 15 = 375`` seconds, ~6 min 15 s). When
- this bound is reached, the drain times out and the shutdown
- falls through to forced cancellation: pending futures are
- cancelled, the queue is cleared, workers are stopped. So this
- method **never blocks indefinitely**, but with a deep backlog
- of slow LLM calls it can take up to that bound to return, and
- in-flight calls past the bound will be cancelled.
- If you need a non-blocking switch, use the sync
- ``update_llm_role_config()`` (which schedules cleanup as a
- background task) and await ``wait_for_retired_llm_queues()``
- separately when you want to confirm the old queue is gone.
- """
- old_wrapped = self._apply_llm_role_config_update(
- role,
- model_func=model_func,
- model_kwargs=model_kwargs,
- max_async=max_async,
- timeout=timeout,
- binding=binding,
- model=model,
- host=host,
- api_key=api_key,
- provider_options=provider_options,
- )
- if old_wrapped is not None:
- await self._shutdown_llm_wrapper(old_wrapped)
- @classmethod
- def _is_secret_key(cls, key: str) -> bool:
- lowered = key.lower()
- return any(marker in lowered for marker in cls._SECRET_MARKERS)
- def _scrubbed_llm_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
- """Return a deep copy of ``metadata`` with auth-bearing fields removed.
- Auth-bearing fields are stripped entirely — not masked — because a
- masked ``"***"`` carries no information for an external consumer
- (operators already see ``binding`` / ``host`` to confirm a role is
- configured). Stripping makes the invariant simple: anything that
- appears in this output is safe to log, cache, ship over the wire.
- Components that legitimately need the raw secret (the role builder,
- provider clients) read it directly off the private
- ``_role_llm_states[role].metadata`` dict.
- """
- def scrub_value(value: Any) -> Any:
- if isinstance(value, Mapping):
- return {
- key: scrub_value(inner_value)
- for key, inner_value in value.items()
- if not self._is_secret_key(str(key))
- }
- if isinstance(value, list):
- return [scrub_value(item) for item in value]
- if isinstance(value, tuple):
- return tuple(scrub_value(item) for item in value)
- return deepcopy(value)
- return scrub_value(metadata)
- def get_llm_role_config(self, role: str | None = None) -> dict[str, Any]:
- """Return effective role LLM runtime configuration (observability snapshot).
- Each role entry exposes ``binding`` / ``model`` / ``host`` at the top
- level for convenience and again inside ``metadata`` as part of the
- full runtime snapshot (which may contain extra builder-specific
- keys). Auth-bearing fields (``api_key``, ``aws_secret_access_key``,
- ``password``, …) are **stripped entirely** from ``metadata`` — this
- method is intended for ``/health`` / WebUI / audit output and must
- never leak credentials. There is no escape hatch; runtime components
- that legitimately need the raw value read it from
- ``_role_llm_states[role].metadata`` directly.
- """
- def role_config(role_name: str) -> dict[str, Any]:
- state = self._role_llm_states[role_name]
- metadata = self._scrubbed_llm_metadata(state.metadata)
- return {
- "binding": metadata.get("binding"),
- "model": metadata.get("model"),
- "host": metadata.get("host"),
- "is_cross_provider": metadata.get("is_cross_provider", False),
- "max_async": self._get_effective_role_llm_max_async(role_name),
- "timeout": self._get_effective_role_llm_timeout(role_name),
- "has_model_kwargs": state.kwargs is not None,
- "metadata": metadata,
- }
- if role is not None:
- return role_config(self._normalize_llm_role(role))
- return {spec.name: role_config(spec.name) for spec in ROLES}
- def _log_llm_role_config(self, reason: str, role: str | None = None) -> None:
- """Log the sanitized role LLM runtime configuration."""
- if role is None:
- configs = self.get_llm_role_config()
- role_names = [spec.name for spec in ROLES]
- logger.info(f"Role LLM Configuration ({reason}):")
- else:
- normalized_role = self._normalize_llm_role(role)
- configs = {normalized_role: self.get_llm_role_config(normalized_role)}
- role_names = [normalized_role]
- logger.info(f"Role LLM Configuration ({reason}: {normalized_role}):")
- for role_name in role_names:
- cfg = configs[role_name]
- logger.info(
- " - %s: %s/%s, host=%s, max_async=%s, timeout=%s",
- role_name,
- cfg["binding"],
- cfg["model"],
- cfg["host"],
- cfg["max_async"],
- cfg["timeout"],
- )
- async def _queue_status_for_func(
- self, func: Callable[..., object] | None
- ) -> dict[str, Any]:
- if func is None:
- return {"available": False}
- get_stats = getattr(func, "get_queue_stats", None)
- if not callable(get_stats):
- return {"available": False}
- stats = get_stats()
- if inspect.isawaitable(stats):
- stats = await stats
- stats["available"] = True
- return stats
- async def get_llm_queue_status(self, include_base: bool = True) -> dict[str, Any]:
- """Return queue status for each role's wrapped LLM func.
- The base ``llm_model_func`` is no longer queue-wrapped, so it is not
- reported here. ``include_base`` is kept for signature compatibility
- but has no effect.
- """
- del include_base # base is unwrapped — see docstring
- result: dict[str, Any] = {}
- for spec in ROLES:
- state = self._role_llm_states.get(spec.name)
- result[spec.name] = await self._queue_status_for_func(
- state.wrapped if state else None
- )
- return result
- async def get_embedding_queue_status(self) -> dict[str, Any]:
- """Return queue status for the wrapped embedding function."""
- return await self._queue_status_for_func(
- self.embedding_func.func if self.embedding_func is not None else None
- )
- async def get_rerank_queue_status(self) -> dict[str, Any]:
- """Return queue status for the wrapped rerank function."""
- return await self._queue_status_for_func(self.rerank_model_func)
|