test_llm_role_runtime.py 36 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117
  1. """Offline tests for role-specific LLM runtime configuration."""
  2. import asyncio
  3. import logging
  4. from argparse import Namespace
  5. import numpy as np
  6. import pytest
  7. from lightrag import LightRAG, ROLES, RoleLLMConfig
  8. from lightrag.llm.binding_options import OpenAILLMOptions
  9. from lightrag.utils import EmbeddingFunc, Tokenizer, priority_limit_async_func_call
  10. pytestmark = pytest.mark.offline
  11. @pytest.fixture
  12. def lightrag_logger_propagating(monkeypatch):
  13. """Force the lightrag logger to propagate so caplog can capture records."""
  14. monkeypatch.setattr(logging.getLogger("lightrag"), "propagate", True)
  15. class _SimpleTokenizerImpl:
  16. def encode(self, content: str) -> list[int]:
  17. return [ord(ch) for ch in content]
  18. def decode(self, tokens: list[int]) -> str:
  19. return "".join(chr(t) for t in tokens)
  20. async def _mock_embedding(texts: list[str]) -> np.ndarray:
  21. return np.random.rand(len(texts), 16)
  22. async def _base_llm(*args, **kwargs) -> str:
  23. return "base"
  24. _ROLE_FIELD_SUFFIXES = (
  25. ("_llm_model_func", "func"),
  26. ("_llm_model_kwargs", "kwargs"),
  27. ("_llm_model_max_async", "max_async"),
  28. ("_llm_timeout", "timeout"),
  29. )
  30. def _make_rag(tmp_path, **kwargs) -> LightRAG:
  31. """Create a LightRAG for role tests.
  32. Accepts both the canonical ``role_llm_configs={...}`` style and shorthand
  33. ``{role}_llm_model_func`` / ``{role}_llm_model_kwargs`` etc. keyword
  34. arguments. Shorthand kwargs are folded into ``role_llm_configs`` so the
  35. body of each test reads clearly.
  36. """
  37. role_configs: dict[str, RoleLLMConfig] = {}
  38. explicit = kwargs.pop("role_llm_configs", None)
  39. if explicit is not None:
  40. for name, cfg in explicit.items():
  41. role_configs[name] = (
  42. cfg if isinstance(cfg, RoleLLMConfig) else RoleLLMConfig(**dict(cfg))
  43. )
  44. for spec in ROLES:
  45. bucket = {}
  46. for suffix, target in _ROLE_FIELD_SUFFIXES:
  47. key = f"{spec.name}{suffix}"
  48. if key in kwargs:
  49. bucket[target] = kwargs.pop(key)
  50. if bucket:
  51. existing = role_configs.get(spec.name)
  52. if existing is not None:
  53. for target, value in bucket.items():
  54. if getattr(existing, target) is None:
  55. setattr(existing, target, value)
  56. else:
  57. role_configs[spec.name] = RoleLLMConfig(**bucket)
  58. if role_configs:
  59. kwargs["role_llm_configs"] = role_configs
  60. return LightRAG(
  61. working_dir=str(tmp_path / "role-runtime"),
  62. workspace="role-runtime",
  63. llm_model_func=_base_llm,
  64. embedding_func=EmbeddingFunc(
  65. embedding_dim=16,
  66. max_token_size=4096,
  67. func=_mock_embedding,
  68. ),
  69. tokenizer=Tokenizer("mock-tokenizer", _SimpleTokenizerImpl()),
  70. **kwargs,
  71. )
  72. def _captured_messages(caplog) -> list[str]:
  73. return [record.getMessage() for record in caplog.records]
  74. def _role_config_headers(caplog) -> list[str]:
  75. return [
  76. message
  77. for message in _captured_messages(caplog)
  78. if "Role LLM Configuration" in message
  79. ]
  80. def _clear_role_provider_env(monkeypatch, role: str, options_cls) -> None:
  81. for arg_item in options_cls.args_env_name_type_value():
  82. monkeypatch.delenv(f"{role.upper()}_{arg_item['env_name']}", raising=False)
  83. ROLE_MAX_ASYNC_ENV_KEYS = (
  84. "EXTRACT_MAX_ASYNC_LLM",
  85. "KEYWORD_MAX_ASYNC_LLM",
  86. "QUERY_MAX_ASYNC_LLM",
  87. "VLM_MAX_ASYNC_LLM",
  88. )
  89. @pytest.mark.asyncio
  90. async def test_priority_queue_stats_track_running_and_queued():
  91. started = asyncio.Event()
  92. release = asyncio.Event()
  93. async def slow_func(value: str, **_kwargs):
  94. started.set()
  95. await release.wait()
  96. return value
  97. wrapped = priority_limit_async_func_call(1, queue_name="test LLM func")(slow_func)
  98. first = asyncio.create_task(wrapped("first"))
  99. await started.wait()
  100. second = asyncio.create_task(wrapped("second"))
  101. await asyncio.sleep(0.05)
  102. stats = await wrapped.get_queue_stats()
  103. assert stats["max_async"] == 1
  104. assert stats["running"] == 1
  105. assert stats["queued"] == 1
  106. assert stats["in_flight"] == 2
  107. assert stats["submitted_total"] == 2
  108. release.set()
  109. assert await asyncio.gather(first, second) == ["first", "second"]
  110. await asyncio.sleep(0)
  111. stats = await wrapped.get_queue_stats()
  112. assert stats["running"] == 0
  113. assert stats["queued"] == 0
  114. assert stats["completed_total"] == 2
  115. assert stats["rejected_total"] == 0
  116. await wrapped.shutdown()
  117. @pytest.mark.asyncio
  118. async def test_priority_queue_graceful_shutdown_timeout_falls_back_to_force(
  119. caplog, lightrag_logger_propagating
  120. ):
  121. started = asyncio.Event()
  122. async def stuck_func(value: str, **_kwargs):
  123. started.set()
  124. await asyncio.sleep(60)
  125. return value
  126. wrapped = priority_limit_async_func_call(1, queue_name="stuck LLM func")(stuck_func)
  127. in_flight = asyncio.create_task(wrapped("hold"))
  128. await started.wait()
  129. with caplog.at_level("WARNING", logger="lightrag"):
  130. await wrapped.shutdown(graceful=True, timeout=0.1)
  131. assert any(
  132. "Graceful drain timed out" in record.getMessage() for record in caplog.records
  133. )
  134. with pytest.raises(asyncio.CancelledError):
  135. await in_flight
  136. stats = await wrapped.get_queue_stats()
  137. assert stats["cancelled_total"] >= 1
  138. @pytest.mark.asyncio
  139. async def test_priority_queue_rejects_submissions_after_shutdown():
  140. async def fast_func(value: str, **_kwargs):
  141. return value
  142. wrapped = priority_limit_async_func_call(1, queue_name="reject LLM func")(fast_func)
  143. assert await wrapped("warmup") == "warmup"
  144. await wrapped.shutdown()
  145. with pytest.raises(RuntimeError, match="Queue is shutting down"):
  146. await wrapped("rejected")
  147. stats = await wrapped.get_queue_stats()
  148. assert stats["rejected_total"] == 1
  149. def test_role_max_async_defaults_inherit_base(tmp_path, monkeypatch):
  150. # Use the literal "None" string rather than delenv: storage modules
  151. # (e.g. lightrag.kg.networkx_impl) are imported lazily during
  152. # LightRAG() and re-run load_dotenv(override=False), which would
  153. # restore deleted vars from .env. Setting "None" keeps the variable
  154. # present so load_dotenv leaves it alone, and _optional_env_int
  155. # interprets the string as Python None via special_none=True.
  156. for env_key in ROLE_MAX_ASYNC_ENV_KEYS:
  157. monkeypatch.setenv(env_key, "None")
  158. rag = _make_rag(tmp_path, llm_model_max_async=10)
  159. assert rag._role_llm_states["extract"].max_async is None
  160. assert rag._role_llm_states["keyword"].max_async is None
  161. assert rag._role_llm_states["query"].max_async is None
  162. assert rag._role_llm_states["vlm"].max_async is None
  163. assert rag._get_effective_role_llm_max_async("extract") == 10
  164. assert rag._get_effective_role_llm_max_async("keyword") == 10
  165. assert rag._get_effective_role_llm_max_async("query") == 10
  166. assert rag._get_effective_role_llm_max_async("vlm") == 10
  167. def test_role_max_async_env_override_keeps_other_roles_inherited(tmp_path, monkeypatch):
  168. # See note in test_role_max_async_defaults_inherit_base: lazy
  169. # storage imports re-run load_dotenv, so we mark unwanted keys with
  170. # "None" instead of deleting them.
  171. for env_key in ROLE_MAX_ASYNC_ENV_KEYS:
  172. monkeypatch.setenv(env_key, "None")
  173. monkeypatch.setenv("EXTRACT_MAX_ASYNC_LLM", "7")
  174. rag = _make_rag(tmp_path, llm_model_max_async=10)
  175. assert rag._role_llm_states["extract"].max_async == 7
  176. assert rag._role_llm_states["keyword"].max_async is None
  177. assert rag._role_llm_states["query"].max_async is None
  178. assert rag._role_llm_states["vlm"].max_async is None
  179. assert rag._get_effective_role_llm_max_async("extract") == 7
  180. assert rag._get_effective_role_llm_max_async("keyword") == 10
  181. assert rag._get_effective_role_llm_max_async("query") == 10
  182. assert rag._get_effective_role_llm_max_async("vlm") == 10
  183. @pytest.mark.asyncio
  184. async def test_role_functions_are_isolated_and_vlm_present(tmp_path):
  185. rag = _make_rag(tmp_path)
  186. funcs = [
  187. rag.llm_model_func,
  188. rag.role_llm_funcs["extract"],
  189. rag.role_llm_funcs["keyword"],
  190. rag.role_llm_funcs["query"],
  191. rag.role_llm_funcs["vlm"],
  192. ]
  193. assert all(callable(func) for func in funcs)
  194. assert len({id(func) for func in funcs}) == len(funcs)
  195. @pytest.mark.asyncio
  196. async def test_no_role_configs_keeps_base_raw_and_wraps_each_role(tmp_path):
  197. """Regression: base llm_model_func must stay raw; each role still gets
  198. its own queue wrapper around the base func when no override is given."""
  199. rag = _make_rag(tmp_path)
  200. # Base is the user-provided callable, untouched by any wrapper.
  201. assert rag.llm_model_func is _base_llm
  202. # Every role has a wrapped (queue-managed) func that's distinct from base.
  203. for spec in ROLES:
  204. wrapped = rag.role_llm_funcs[spec.name]
  205. assert callable(wrapped)
  206. assert wrapped is not _base_llm
  207. # All four role wrappers are independent (separate queues).
  208. wrappers = [rag.role_llm_funcs[spec.name] for spec in ROLES]
  209. assert len({id(w) for w in wrappers}) == len(wrappers)
  210. # Calling any role wrapper hits the base function.
  211. assert await rag.role_llm_funcs["extract"]("p") == "base"
  212. assert await rag.role_llm_funcs["vlm"]("p") == "base"
  213. # get_llm_queue_status no longer reports a 'base' entry.
  214. status = await rag.get_llm_queue_status()
  215. assert "base" not in status
  216. assert set(status) == {spec.name for spec in ROLES}
  217. @pytest.mark.asyncio
  218. async def test_role_llm_configs_accepts_dict_form(tmp_path):
  219. """Init accepts plain dicts in role_llm_configs (auto-normalized to RoleLLMConfig)."""
  220. async def query_fn(*args, **kwargs):
  221. return "query-via-dict"
  222. rag = LightRAG(
  223. working_dir=str(tmp_path / "dict-form"),
  224. workspace="dict-form",
  225. llm_model_func=_base_llm,
  226. embedding_func=EmbeddingFunc(
  227. embedding_dim=16, max_token_size=4096, func=_mock_embedding
  228. ),
  229. tokenizer=Tokenizer("mock-tokenizer", _SimpleTokenizerImpl()),
  230. role_llm_configs={"query": {"func": query_fn, "max_async": 5}},
  231. )
  232. assert rag._role_llm_states["query"].raw_func is query_fn
  233. assert rag._role_llm_states["query"].max_async == 5
  234. # Roles not present in the dict still wrap the base function.
  235. assert rag._role_llm_states["extract"].raw_func is _base_llm
  236. assert await rag.role_llm_funcs["query"]("ping") == "query-via-dict"
  237. def test_role_llm_configs_rejects_unknown_role_keys(tmp_path):
  238. with pytest.raises(ValueError, match="qurey"):
  239. _make_rag(tmp_path, role_llm_configs={"qurey": {}})
  240. def test_role_llm_config_logs_once_on_init_with_metadata(
  241. tmp_path, caplog, lightrag_logger_propagating
  242. ):
  243. with caplog.at_level("INFO", logger="lightrag"):
  244. rag = _make_rag(
  245. tmp_path,
  246. role_llm_configs={
  247. "query": RoleLLMConfig(
  248. max_async=7,
  249. timeout=42,
  250. metadata={
  251. "binding": "openai",
  252. "model": "gpt-test",
  253. "host": "https://api.example.com/v1",
  254. "api_key": "secret-key",
  255. "provider_options": {
  256. "temperature": 0.1,
  257. "token": "nested-token",
  258. },
  259. "bedrock_aws_options": {
  260. "region_name": "us-east-1",
  261. "aws_secret_access_key": "aws-secret",
  262. },
  263. },
  264. )
  265. },
  266. )
  267. snapshot = rag.get_llm_role_config("query")
  268. assert snapshot["binding"] == "openai"
  269. assert snapshot["model"] == "gpt-test"
  270. assert snapshot["host"] == "https://api.example.com/v1"
  271. assert snapshot["max_async"] == 7
  272. assert snapshot["timeout"] == 42
  273. headers = _role_config_headers(caplog)
  274. assert len(headers) == 1
  275. assert "initialized" in headers[0]
  276. messages = "\n".join(_captured_messages(caplog))
  277. assert " - query: openai/gpt-test" in messages
  278. assert "max_async=7" in messages
  279. assert "timeout=42" in messages
  280. assert "secret-key" not in messages
  281. assert "nested-token" not in messages
  282. assert "aws-secret" not in messages
  283. @pytest.mark.asyncio
  284. async def test_role_specific_kwargs_and_fallback(tmp_path):
  285. extract_calls = []
  286. vlm_calls = []
  287. async def extract_func(*args, **kwargs):
  288. extract_calls.append(kwargs)
  289. return "extract"
  290. async def vlm_func(*args, **kwargs):
  291. vlm_calls.append(kwargs)
  292. return "vlm"
  293. rag = _make_rag(
  294. tmp_path,
  295. llm_model_kwargs={"shared": "base"},
  296. extract_llm_model_func=extract_func,
  297. extract_llm_model_kwargs={"shared": "extract", "tag": "extract"},
  298. vlm_llm_model_func=vlm_func,
  299. vlm_llm_model_kwargs={"shared": "vlm", "tag": "vlm"},
  300. )
  301. await rag.role_llm_funcs["extract"]("extract prompt")
  302. await rag.role_llm_funcs["keyword"]("keyword prompt")
  303. await rag.role_llm_funcs["vlm"]("vlm prompt")
  304. assert extract_calls[-1]["tag"] == "extract"
  305. assert extract_calls[-1]["shared"] == "extract"
  306. assert "hashing_kv" in extract_calls[-1]
  307. # Keyword role falls back to base kwargs when no role kwargs are configured.
  308. # We do not inspect base function internals, but the call must succeed.
  309. assert vlm_calls[-1]["tag"] == "vlm"
  310. assert vlm_calls[-1]["shared"] == "vlm"
  311. @pytest.mark.asyncio
  312. async def test_update_llm_role_config_rewraps_without_double_call(tmp_path):
  313. call_count = 0
  314. seen_tags = []
  315. async def query_func(*args, **kwargs):
  316. nonlocal call_count
  317. call_count += 1
  318. seen_tags.append(kwargs.get("tag"))
  319. return "query"
  320. rag = _make_rag(
  321. tmp_path,
  322. query_llm_model_func=query_func,
  323. query_llm_model_kwargs={"tag": "v1"},
  324. )
  325. await rag.role_llm_funcs["query"]("first")
  326. assert call_count == 1
  327. assert seen_tags[-1] == "v1"
  328. for value in (3, 5, 7):
  329. rag.update_llm_role_config("query", max_async=value)
  330. await rag.role_llm_funcs["query"]("next")
  331. rag.update_llm_role_config("query", model_kwargs={"tag": "v2"})
  332. await rag.role_llm_funcs["query"]("final")
  333. assert call_count == 5
  334. assert seen_tags[-1] == "v2"
  335. assert rag._role_llm_states["query"].max_async == 7
  336. await rag.wait_for_retired_llm_queues()
  337. @pytest.mark.asyncio
  338. async def test_aupdate_llm_role_config_drains_old_queue(tmp_path):
  339. started = asyncio.Event()
  340. release = asyncio.Event()
  341. async def old_query_func(*args, **kwargs):
  342. started.set()
  343. await release.wait()
  344. return "old"
  345. async def new_query_func(*args, **kwargs):
  346. return "new"
  347. rag = _make_rag(tmp_path, query_llm_model_func=old_query_func)
  348. old_call = asyncio.create_task(rag.role_llm_funcs["query"]("old"))
  349. await started.wait()
  350. update_call = asyncio.create_task(
  351. rag.aupdate_llm_role_config("query", model_func=new_query_func)
  352. )
  353. await asyncio.sleep(0.05)
  354. assert not update_call.done()
  355. assert await rag.role_llm_funcs["query"]("new") == "new"
  356. release.set()
  357. assert await old_call == "old"
  358. await update_call
  359. @pytest.mark.asyncio
  360. async def test_sync_update_tracks_retired_queue_cleanup(tmp_path):
  361. async def query_func(*args, **kwargs):
  362. return "old"
  363. async def new_query_func(*args, **kwargs):
  364. return "new"
  365. rag = _make_rag(tmp_path, query_llm_model_func=query_func)
  366. assert await rag.role_llm_funcs["query"]("before") == "old"
  367. rag.update_llm_role_config("query", model_func=new_query_func)
  368. assert await rag.role_llm_funcs["query"]("after") == "new"
  369. await rag.wait_for_retired_llm_queues()
  370. assert not rag._retired_llm_queue_cleanup_tasks
  371. def test_sync_update_without_event_loop_skips_cleanup(
  372. tmp_path, caplog, lightrag_logger_propagating
  373. ):
  374. async def query_func(*args, **kwargs):
  375. return "old"
  376. async def new_query_func(*args, **kwargs):
  377. return "new"
  378. rag = _make_rag(tmp_path, query_llm_model_func=query_func)
  379. with caplog.at_level("WARNING", logger="lightrag"):
  380. rag.update_llm_role_config("query", model_func=new_query_func)
  381. assert not rag._retired_llm_queue_cleanup_tasks
  382. assert any(
  383. "no event loop is running" in record.getMessage() for record in caplog.records
  384. )
  385. async def call_new() -> str:
  386. return await rag.role_llm_funcs["query"]("after")
  387. assert asyncio.run(call_new()) == "new"
  388. @pytest.mark.asyncio
  389. async def test_aupdate_llm_role_config_with_builder_drains_old_queue(tmp_path):
  390. started = asyncio.Event()
  391. release = asyncio.Event()
  392. def builder(role, meta):
  393. model_name = meta["model"]
  394. if model_name == "old-model":
  395. async def built_func(*args, **kwargs):
  396. started.set()
  397. await release.wait()
  398. return model_name
  399. else:
  400. async def built_func(*args, **kwargs):
  401. return model_name
  402. return built_func, None
  403. rag = _make_rag(tmp_path)
  404. rag.register_role_llm_builder(builder)
  405. rag.set_role_llm_metadata(
  406. "query",
  407. binding="openai",
  408. model="seed",
  409. host="https://seed",
  410. api_key="seed-key",
  411. )
  412. rag.update_llm_role_config("query", binding="openai", model="old-model")
  413. await rag.wait_for_retired_llm_queues()
  414. in_flight = asyncio.create_task(rag.role_llm_funcs["query"]("hold"))
  415. await started.wait()
  416. update_call = asyncio.create_task(
  417. rag.aupdate_llm_role_config("query", binding="openai", model="new-model")
  418. )
  419. await asyncio.sleep(0.05)
  420. assert not update_call.done()
  421. assert await rag.role_llm_funcs["query"]("hello") == "new-model"
  422. release.set()
  423. assert await in_flight == "old-model"
  424. await update_call
  425. assert not rag._retired_llm_queue_cleanup_tasks
  426. @pytest.mark.asyncio
  427. async def test_aupdate_llm_role_config_updates_cache_identity(tmp_path):
  428. async def query_func(*_args, **_kwargs):
  429. return "query"
  430. rag = _make_rag(tmp_path)
  431. rag.register_role_llm_builder(lambda _role, _meta: (query_func, {}))
  432. await rag.aupdate_llm_role_config(
  433. "query",
  434. binding="openai",
  435. model="gpt-cache-test",
  436. host="https://api.example.com/v1",
  437. )
  438. identity = rag._build_global_config()["llm_cache_identities"]["query"]
  439. assert identity == {
  440. "role": "query",
  441. "binding": "openai",
  442. "model": "gpt-cache-test",
  443. "host": "https://api.example.com/v1",
  444. }
  445. await rag.wait_for_retired_llm_queues()
  446. @pytest.mark.asyncio
  447. async def test_update_llm_role_config_with_builder_metadata(tmp_path):
  448. built_calls = []
  449. def builder(role: str, meta: dict):
  450. async def built_func(*args, **kwargs):
  451. built_calls.append(
  452. {"role": role, "meta": dict(meta), "kwargs": dict(kwargs)}
  453. )
  454. return f"{meta['model']}"
  455. return built_func, {
  456. "runtime_host": meta["host"],
  457. "provider_options": meta["provider_options"],
  458. }
  459. rag = _make_rag(tmp_path)
  460. rag.register_role_llm_builder(builder)
  461. rag.set_role_llm_metadata(
  462. "query",
  463. binding="openai",
  464. model="old-model",
  465. host="https://old-host",
  466. api_key="old-key",
  467. provider_options={"temperature": 0.1},
  468. )
  469. rag.update_llm_role_config(
  470. "query",
  471. binding="gemini",
  472. model="gemini-2.0-flash",
  473. host="https://new-host",
  474. api_key="new-key",
  475. provider_options={"temperature": 0.3, "top_k": 8},
  476. )
  477. result = await rag.role_llm_funcs["query"]("hello")
  478. assert result == "gemini-2.0-flash"
  479. assert built_calls[-1]["role"] == "query"
  480. assert built_calls[-1]["meta"]["binding"] == "gemini"
  481. assert built_calls[-1]["meta"]["model"] == "gemini-2.0-flash"
  482. assert built_calls[-1]["kwargs"]["runtime_host"] == "https://new-host"
  483. assert built_calls[-1]["kwargs"]["provider_options"]["top_k"] == 8
  484. def test_update_llm_role_config_logs_after_success(
  485. tmp_path, caplog, lightrag_logger_propagating
  486. ):
  487. async def built_func(*args, **kwargs):
  488. return "ok"
  489. def builder(role: str, meta: dict):
  490. return built_func, None
  491. rag = _make_rag(
  492. tmp_path,
  493. role_llm_configs={
  494. "query": RoleLLMConfig(
  495. metadata={
  496. "base_binding": "openai",
  497. "binding": "openai",
  498. "model": "old-model",
  499. "host": "https://old.example/v1",
  500. },
  501. )
  502. },
  503. )
  504. rag.register_role_llm_builder(builder)
  505. caplog.clear()
  506. with caplog.at_level("INFO", logger="lightrag"):
  507. rag.update_llm_role_config(
  508. "query",
  509. binding="gemini",
  510. model="gemini-2.0-flash",
  511. host="https://gemini.example/v1",
  512. api_key="new-secret",
  513. provider_options={"token": "nested-token"},
  514. )
  515. headers = _role_config_headers(caplog)
  516. assert len(headers) == 1
  517. assert "updated: query" in headers[0]
  518. messages = "\n".join(_captured_messages(caplog))
  519. assert " - query: gemini/gemini-2.0-flash" in messages
  520. assert "host=https://gemini.example/v1" in messages
  521. assert "is_cross_provider" not in messages
  522. assert "new-secret" not in messages
  523. assert "nested-token" not in messages
  524. @pytest.mark.asyncio
  525. async def test_aupdate_llm_role_config_logs_after_success(
  526. tmp_path, caplog, lightrag_logger_propagating
  527. ):
  528. async def new_query_func(*args, **kwargs):
  529. return "new-query"
  530. rag = _make_rag(
  531. tmp_path,
  532. role_llm_configs={
  533. "query": RoleLLMConfig(
  534. metadata={
  535. "binding": "openai",
  536. "model": "old-model",
  537. "host": "https://old.example/v1",
  538. },
  539. )
  540. },
  541. )
  542. caplog.clear()
  543. with caplog.at_level("INFO", logger="lightrag"):
  544. await rag.aupdate_llm_role_config(
  545. "query",
  546. model_func=new_query_func,
  547. max_async=2,
  548. timeout=180,
  549. )
  550. headers = _role_config_headers(caplog)
  551. assert len(headers) == 1
  552. assert "updated: query" in headers[0]
  553. messages = "\n".join(_captured_messages(caplog))
  554. assert " - query: openai/old-model" in messages
  555. assert "max_async=2" in messages
  556. assert "timeout=180" in messages
  557. @pytest.mark.asyncio
  558. async def test_aupdate_llm_role_config_metadata_without_builder_raises(tmp_path):
  559. """Pin down the public-API contract: updating any metadata field
  560. (binding/model/host/api_key/provider_options) without a registered
  561. builder and without an explicit model_func must fail loudly with a
  562. ValueError. State must be intact so the caller can recover."""
  563. rag = _make_rag(tmp_path)
  564. original_wrapped = rag.role_llm_funcs["query"]
  565. original_metadata = dict(rag._role_llm_states["query"].metadata)
  566. with pytest.raises(ValueError, match="Runtime role builder is not configured"):
  567. await rag.aupdate_llm_role_config("query", binding="openai")
  568. assert rag.role_llm_funcs["query"] is original_wrapped
  569. assert rag._role_llm_states["query"].metadata == original_metadata
  570. assert await rag.role_llm_funcs["query"]("ping") == "base"
  571. @pytest.mark.asyncio
  572. async def test_aupdate_llm_role_config_rejects_non_callable_model_func(tmp_path):
  573. """model_func type check must reject non-callables before any state
  574. mutation happens."""
  575. rag = _make_rag(tmp_path)
  576. original_wrapped = rag.role_llm_funcs["query"]
  577. with pytest.raises(TypeError, match="model_func must be callable"):
  578. await rag.aupdate_llm_role_config("query", model_func="not-a-func")
  579. assert rag.role_llm_funcs["query"] is original_wrapped
  580. assert await rag.role_llm_funcs["query"]("ping") == "base"
  581. @pytest.mark.asyncio
  582. async def test_aupdate_llm_role_config_rejects_unknown_role(tmp_path):
  583. """Typos in the role name must surface as ValueError, not KeyError,
  584. via the shared _normalize_llm_role guard."""
  585. rag = _make_rag(tmp_path)
  586. with pytest.raises(ValueError, match="Invalid LLM role"):
  587. await rag.aupdate_llm_role_config("qurey", max_async=2)
  588. @pytest.mark.asyncio
  589. async def test_aupdate_llm_role_config_rolls_back_and_keeps_old_wrapped(tmp_path):
  590. """When the builder raises, the async path must roll state back AND
  591. skip the retired-wrapper shutdown — the swap effectively never
  592. happened, so the old queue must remain live and accept new work."""
  593. async def query_func(*args, **kwargs):
  594. return "old"
  595. rag = _make_rag(tmp_path, query_llm_model_func=query_func)
  596. rag.set_role_llm_metadata(
  597. "query",
  598. binding="openai",
  599. model="base-model",
  600. host="https://base",
  601. )
  602. original_wrapped = rag.role_llm_funcs["query"]
  603. original_raw = rag._role_llm_states["query"].raw_func
  604. original_metadata = dict(rag._role_llm_states["query"].metadata)
  605. def failing_builder(_role, _meta):
  606. raise RuntimeError("builder boom")
  607. rag.register_role_llm_builder(failing_builder)
  608. with pytest.raises(RuntimeError, match="builder boom"):
  609. await rag.aupdate_llm_role_config(
  610. "query",
  611. binding="gemini",
  612. model="new-model",
  613. )
  614. assert rag.role_llm_funcs["query"] is original_wrapped
  615. assert rag._role_llm_states["query"].raw_func is original_raw
  616. assert rag._role_llm_states["query"].metadata == original_metadata
  617. # Critical: old wrapper was NOT shut down — it still serves calls.
  618. assert await rag.role_llm_funcs["query"]("ping") == "old"
  619. @pytest.mark.asyncio
  620. async def test_aupdate_llm_role_config_drain_timeout_does_not_propagate(
  621. tmp_path, monkeypatch, caplog, lightrag_logger_propagating
  622. ):
  623. """If the retired queue drain hits its timeout, the underlying
  624. shutdown falls through to forced cancellation. aupdate must absorb
  625. that — no TimeoutError leaking to the caller — so config swaps stay
  626. bounded even with a deep backlog of slow LLM calls."""
  627. started = asyncio.Event()
  628. async def stuck_func(*args, **kwargs):
  629. started.set()
  630. await asyncio.sleep(60)
  631. return "never"
  632. async def new_func(*args, **kwargs):
  633. return "new"
  634. rag = _make_rag(tmp_path, query_llm_model_func=stuck_func)
  635. async def fast_shutdown(_self, wrapped_func):
  636. shutdown = getattr(wrapped_func, "shutdown", None)
  637. if callable(shutdown):
  638. await shutdown(graceful=True, timeout=0.05)
  639. monkeypatch.setattr(LightRAG, "_shutdown_llm_wrapper", fast_shutdown)
  640. in_flight = asyncio.create_task(rag.role_llm_funcs["query"]("hold"))
  641. await started.wait()
  642. with caplog.at_level("WARNING", logger="lightrag"):
  643. await rag.aupdate_llm_role_config("query", model_func=new_func)
  644. with pytest.raises(asyncio.CancelledError):
  645. await in_flight
  646. assert await rag.role_llm_funcs["query"]("now") == "new"
  647. assert any(
  648. "Graceful drain timed out" in record.getMessage() for record in caplog.records
  649. )
  650. @pytest.mark.asyncio
  651. async def test_llm_role_config_and_queue_status_are_observable(tmp_path):
  652. rag = _make_rag(tmp_path, query_llm_model_kwargs={"tag": "query"})
  653. rag.set_role_llm_metadata(
  654. "query",
  655. binding="openai",
  656. model="gpt-test",
  657. host="https://api.example.com/v1",
  658. api_key="secret-key",
  659. provider_options={"temperature": 0.1},
  660. )
  661. all_configs = rag.get_llm_role_config()
  662. assert set(all_configs) == {"extract", "keyword", "query", "vlm"}
  663. assert all_configs["query"]["binding"] == "openai"
  664. assert all_configs["query"]["model"] == "gpt-test"
  665. # Auth-bearing fields are dropped from the observability snapshot,
  666. # not masked — there is no "***" placeholder to confuse consumers.
  667. assert "api_key" not in all_configs["query"]["metadata"]
  668. assert all_configs["query"]["has_model_kwargs"] is True
  669. # Raw secrets remain accessible to in-process components that legitimately
  670. # need them (role builder, provider clients), but are not exposed via the
  671. # public observability method.
  672. assert rag._role_llm_states["query"].metadata["api_key"] == "secret-key"
  673. queue_status = await rag.get_llm_queue_status()
  674. assert set(queue_status) == {"extract", "keyword", "query", "vlm"}
  675. assert queue_status["query"]["available"] is True
  676. assert queue_status["query"]["queue_name"] == "query LLM func"
  677. @pytest.mark.asyncio
  678. async def test_embedding_and_rerank_queue_status_are_observable(tmp_path):
  679. async def rerank_func(*args, **kwargs):
  680. return []
  681. rag = _make_rag(tmp_path, rerank_model_func=rerank_func)
  682. embedding_status = await rag.get_embedding_queue_status()
  683. rerank_status = await rag.get_rerank_queue_status()
  684. assert embedding_status["available"] is True
  685. assert embedding_status["queue_name"] == "Embedding func"
  686. assert embedding_status["max_async"] == rag.embedding_func_max_async
  687. assert rerank_status["available"] is True
  688. assert rerank_status["queue_name"] == "Rerank func"
  689. assert rerank_status["max_async"] == rag.rerank_model_max_async
  690. def test_get_llm_role_config_strips_bedrock_and_password_fields(tmp_path):
  691. rag = _make_rag(tmp_path)
  692. rag.set_role_llm_metadata(
  693. "query",
  694. binding="bedrock",
  695. model="claude-3",
  696. password="proxy-password",
  697. provider_options={
  698. "temperature": 0.1,
  699. "extra_body": {
  700. "safe_option": True,
  701. "api_key": "nested-api-key",
  702. "headers": {
  703. "Authorization": "Bearer nested-token",
  704. "X-API-Key": "nested-api-key",
  705. "Accept": "application/json",
  706. },
  707. "tools": [
  708. {"name": "safe-tool", "token": "nested-token"},
  709. ],
  710. },
  711. },
  712. bedrock_aws_options={
  713. "region_name": "us-east-1",
  714. "aws_access_key_id": "AKIA-secret",
  715. "aws_secret_access_key": "TOPSECRET",
  716. "aws_session_token": "SESSION",
  717. },
  718. )
  719. snapshot = rag.get_llm_role_config("query")
  720. assert "password" not in snapshot["metadata"]
  721. provider_options = snapshot["metadata"]["provider_options"]
  722. assert provider_options["temperature"] == 0.1
  723. extra_body = provider_options["extra_body"]
  724. assert extra_body["safe_option"] is True
  725. assert "api_key" not in extra_body
  726. assert extra_body["headers"] == {"Accept": "application/json"}
  727. assert extra_body["tools"] == [{"name": "safe-tool"}]
  728. bedrock = snapshot["metadata"]["bedrock_aws_options"]
  729. # Non-secret fields stay; auth-bearing fields are removed entirely.
  730. assert bedrock["region_name"] == "us-east-1"
  731. assert "aws_access_key_id" not in bedrock
  732. assert "aws_secret_access_key" not in bedrock
  733. assert "aws_session_token" not in bedrock
  734. # Mutating the returned snapshot must not affect the live state.
  735. snapshot["metadata"]["bedrock_aws_options"]["region_name"] = "tampered"
  736. assert (
  737. rag._role_llm_states["query"].metadata["bedrock_aws_options"]["region_name"]
  738. == "us-east-1"
  739. )
  740. def test_get_llm_role_config_has_no_secret_escape_hatch(tmp_path):
  741. """Security guarantee: no parameter on get_llm_role_config can flip
  742. secret stripping off. This pins down the public-API contract so a future
  743. change can't accidentally re-introduce an ``include_secrets`` knob."""
  744. rag = _make_rag(tmp_path)
  745. rag.set_role_llm_metadata("query", api_key="super-secret")
  746. with pytest.raises(TypeError):
  747. rag.get_llm_role_config("query", include_secrets=True) # type: ignore[call-arg]
  748. assert "api_key" not in rag.get_llm_role_config("query")["metadata"]
  749. @pytest.mark.asyncio
  750. async def test_cross_provider_update_does_not_inherit_base_kwargs(tmp_path):
  751. built_calls = []
  752. def builder(role: str, meta: dict):
  753. async def built_func(*args, **kwargs):
  754. built_calls.append(
  755. {"role": role, "meta": dict(meta), "kwargs": dict(kwargs)}
  756. )
  757. return "ok"
  758. return built_func, None
  759. rag = _make_rag(
  760. tmp_path,
  761. llm_model_kwargs={
  762. "host": "http://base-host:11434",
  763. "options": {"temperature": 0.1},
  764. "api_key": "base-key",
  765. },
  766. )
  767. rag.register_role_llm_builder(builder)
  768. rag.set_role_llm_metadata(
  769. "query",
  770. base_binding="ollama",
  771. binding="ollama",
  772. model="base-ollama",
  773. host="http://base-host:11434",
  774. api_key="base-key",
  775. provider_options={"temperature": 0.1},
  776. is_cross_provider=False,
  777. )
  778. rag.update_llm_role_config(
  779. "query",
  780. binding="openai",
  781. model="gpt-4o-mini",
  782. host="https://api.example.com/v1",
  783. api_key="role-key",
  784. provider_options={"temperature": 0.4},
  785. )
  786. await rag.role_llm_funcs["query"]("hello")
  787. call_kwargs = built_calls[-1]["kwargs"]
  788. assert call_kwargs["hashing_kv"] is not None
  789. assert "host" not in call_kwargs
  790. assert "options" not in call_kwargs
  791. assert "api_key" not in call_kwargs
  792. @pytest.mark.asyncio
  793. async def test_update_llm_role_config_rolls_back_on_failure(
  794. tmp_path, caplog, lightrag_logger_propagating
  795. ):
  796. rag = _make_rag(tmp_path, extract_llm_model_kwargs={"tag": "before"})
  797. original_raw = rag._role_llm_states["extract"].raw_func
  798. original_wrapped = rag.role_llm_funcs["extract"]
  799. original_kwargs = dict(rag.role_llm_kwargs["extract"])
  800. def failing_builder(role: str, meta: dict):
  801. raise RuntimeError("boom")
  802. rag.register_role_llm_builder(failing_builder)
  803. rag.set_role_llm_metadata(
  804. "extract",
  805. binding="openai",
  806. model="base-model",
  807. host="https://base",
  808. api_key="key",
  809. provider_options={"temperature": 0.1},
  810. )
  811. caplog.clear()
  812. with caplog.at_level("INFO", logger="lightrag"):
  813. with pytest.raises(RuntimeError, match="boom"):
  814. rag.update_llm_role_config(
  815. "extract",
  816. binding="gemini",
  817. provider_options={"temperature": 0.9},
  818. )
  819. assert rag._role_llm_states["extract"].raw_func is original_raw
  820. assert rag.role_llm_funcs["extract"] is original_wrapped
  821. assert rag.role_llm_kwargs["extract"] == original_kwargs
  822. assert not _role_config_headers(caplog)
  823. def test_options_dict_for_role_inherits_same_provider(monkeypatch):
  824. args = Namespace(
  825. openai_llm_temperature=0.2,
  826. openai_llm_top_p=0.8,
  827. openai_llm_extra_body={"base": True},
  828. )
  829. _clear_role_provider_env(monkeypatch, "extract", OpenAILLMOptions)
  830. monkeypatch.setenv("EXTRACT_OPENAI_LLM_TEMPERATURE", "0.7")
  831. options = OpenAILLMOptions.options_dict_for_role(args, "extract")
  832. assert options["temperature"] == 0.7
  833. assert options["top_p"] == 0.8
  834. assert options["extra_body"] == {"base": True}
  835. def test_options_dict_for_role_resets_cross_provider(monkeypatch):
  836. args = Namespace(
  837. openai_llm_temperature=0.2,
  838. openai_llm_top_p=0.8,
  839. openai_llm_extra_body={"base": True},
  840. )
  841. _clear_role_provider_env(monkeypatch, "query", OpenAILLMOptions)
  842. monkeypatch.setenv("QUERY_OPENAI_LLM_TOP_P", "0.6")
  843. options = OpenAILLMOptions.options_dict_for_role(
  844. args, "query", is_cross_provider=True
  845. )
  846. assert options == {"top_p": 0.6}
  847. def test_options_dict_for_role_parses_nested_extra_body_cross_provider(monkeypatch):
  848. args = Namespace(openai_llm_extra_body={"base": True})
  849. _clear_role_provider_env(monkeypatch, "keyword", OpenAILLMOptions)
  850. monkeypatch.setenv(
  851. "KEYWORD_OPENAI_LLM_EXTRA_BODY",
  852. '{"chat_template_kwargs": {"enable_thinking": false}}',
  853. )
  854. options = OpenAILLMOptions.options_dict_for_role(
  855. args, "keyword", is_cross_provider=True
  856. )
  857. assert options["extra_body"] == {"chat_template_kwargs": {"enable_thinking": False}}
  858. @pytest.mark.asyncio
  859. async def test_vlm_role_supports_runtime_update(tmp_path):
  860. vlm_calls = []
  861. async def vlm_func(*args, **kwargs):
  862. vlm_calls.append(kwargs)
  863. return "vlm"
  864. rag = _make_rag(
  865. tmp_path,
  866. vlm_llm_model_func=vlm_func,
  867. vlm_llm_model_kwargs={"tag": "initial"},
  868. )
  869. await rag.role_llm_funcs["vlm"]("before")
  870. rag.update_llm_role_config(
  871. "vlm",
  872. model_kwargs={"tag": "updated"},
  873. max_async=2,
  874. timeout=240,
  875. )
  876. await rag.role_llm_funcs["vlm"]("after")
  877. assert vlm_calls[0]["tag"] == "initial"
  878. assert vlm_calls[1]["tag"] == "updated"
  879. assert rag._role_llm_states["vlm"].max_async == 2
  880. assert rag._role_llm_states["vlm"].timeout == 240