test_bedrock_llm.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902
  1. import importlib
  2. import logging
  3. import os
  4. import sys
  5. from types import SimpleNamespace
  6. from unittest.mock import AsyncMock, Mock, patch
  7. import pytest
  8. from fastapi import APIRouter
  9. from fastapi.testclient import TestClient
  10. from lightrag.llm.bedrock import (
  11. bedrock_complete,
  12. bedrock_complete_if_cache,
  13. bedrock_embed,
  14. )
  15. def _reload_api_modules_if_mocked() -> None:
  16. """Drop Mock-replaced lightrag.api entries so importlib reloads the real modules.
  17. Other test files (e.g. test_token_auto_renewal.py) replace
  18. ``sys.modules["lightrag.api.config"]`` with a Mock at import time. When
  19. pytest collects those files before ours, any subsequent
  20. ``from .config import global_args`` inside lightrag_server picks up the
  21. Mock, which breaks ``create_app`` in create_app_* tests below.
  22. """
  23. for modname in (
  24. "lightrag.api.lightrag_server",
  25. "lightrag.api.auth",
  26. "lightrag.api.config",
  27. ):
  28. if isinstance(sys.modules.get(modname), Mock):
  29. sys.modules.pop(modname, None)
  30. class _FakeBedrockClient:
  31. def __init__(self, captured_calls: list[dict]):
  32. self._captured_calls = captured_calls
  33. async def __aenter__(self):
  34. return self
  35. async def __aexit__(self, exc_type, exc, tb):
  36. return False
  37. async def converse(self, **kwargs):
  38. self._captured_calls.append(kwargs)
  39. return {
  40. "output": {
  41. "message": {
  42. "content": [
  43. {
  44. "text": '{"high_level_keywords":["AI"],"low_level_keywords":["RAG"]}'
  45. }
  46. ]
  47. }
  48. }
  49. }
  50. class _FakeSession:
  51. def __init__(self, captured_calls: list[dict], client_kwargs_calls: list[dict]):
  52. self._captured_calls = captured_calls
  53. self._client_kwargs_calls = client_kwargs_calls
  54. def client(self, *_args, **kwargs):
  55. self._client_kwargs_calls.append(dict(kwargs))
  56. return _FakeBedrockClient(self._captured_calls)
  57. class _FakeReasoningClient(_FakeBedrockClient):
  58. async def converse(self, **kwargs):
  59. self._captured_calls.append(kwargs)
  60. return {
  61. "output": {
  62. "message": {
  63. "content": [
  64. {
  65. "reasoningContent": {
  66. "reasoningText": {"text": "internal thought"}
  67. }
  68. },
  69. {"text": "final answer"},
  70. ]
  71. }
  72. }
  73. }
  74. class _FakeReasoningSession(_FakeSession):
  75. def client(self, *_args, **kwargs):
  76. self._client_kwargs_calls.append(dict(kwargs))
  77. return _FakeReasoningClient(self._captured_calls)
  78. @pytest.mark.offline
  79. @pytest.mark.asyncio
  80. async def test_bedrock_complete_skips_reasoning_content_block(monkeypatch):
  81. monkeypatch.delenv("AWS_REGION", raising=False)
  82. captured_calls: list[dict] = []
  83. with patch(
  84. "lightrag.llm.bedrock.aioboto3.Session",
  85. return_value=_FakeReasoningSession(captured_calls, []),
  86. ):
  87. result = await bedrock_complete_if_cache(
  88. model="bedrock-model",
  89. prompt="hello",
  90. extra_fields={"reasoning_config": {"type": "enabled"}},
  91. )
  92. assert result == "final answer"
  93. @pytest.mark.offline
  94. @pytest.mark.asyncio
  95. async def test_bedrock_complete_forwards_keyword_extraction_to_if_cache():
  96. hashing_kv = SimpleNamespace(global_config={"llm_model_name": "bedrock-model"})
  97. with patch(
  98. "lightrag.llm.bedrock.bedrock_complete_if_cache",
  99. AsyncMock(return_value="{}"),
  100. ) as mocked_complete:
  101. await bedrock_complete(
  102. prompt="hello",
  103. hashing_kv=hashing_kv,
  104. keyword_extraction=True,
  105. )
  106. assert mocked_complete.await_args.kwargs["keyword_extraction"] is True
  107. @pytest.mark.offline
  108. @pytest.mark.asyncio
  109. async def test_bedrock_keyword_extraction_does_not_inject_system_prompt(monkeypatch):
  110. captured_calls: list[dict] = []
  111. client_kwargs_calls: list[dict] = []
  112. monkeypatch.delenv("AWS_REGION", raising=False)
  113. with patch(
  114. "lightrag.llm.bedrock.aioboto3.Session",
  115. return_value=_FakeSession(captured_calls, client_kwargs_calls),
  116. ):
  117. result = await bedrock_complete_if_cache(
  118. model="bedrock-model",
  119. prompt="hello",
  120. response_format={"type": "json_object"},
  121. )
  122. assert result == '{"high_level_keywords":["AI"],"low_level_keywords":["RAG"]}'
  123. assert len(captured_calls) == 1
  124. assert "system" not in captured_calls[0]
  125. assert client_kwargs_calls[-1] == {"region_name": None}
  126. @pytest.mark.offline
  127. @pytest.mark.asyncio
  128. async def test_bedrock_default_endpoint_sentinel_uses_sdk_default(monkeypatch):
  129. captured_calls: list[dict] = []
  130. client_kwargs_calls: list[dict] = []
  131. monkeypatch.delenv("AWS_REGION", raising=False)
  132. with patch(
  133. "lightrag.llm.bedrock.aioboto3.Session",
  134. return_value=_FakeSession(captured_calls, client_kwargs_calls),
  135. ):
  136. await bedrock_complete_if_cache(
  137. model="bedrock-model",
  138. prompt="hello",
  139. endpoint_url="DEFAULT_BEDROCK_ENDPOINT",
  140. )
  141. assert client_kwargs_calls[-1] == {"region_name": None}
  142. @pytest.mark.offline
  143. @pytest.mark.asyncio
  144. async def test_bedrock_empty_endpoint_url_uses_sdk_default(monkeypatch):
  145. captured_calls: list[dict] = []
  146. client_kwargs_calls: list[dict] = []
  147. monkeypatch.delenv("AWS_REGION", raising=False)
  148. with patch(
  149. "lightrag.llm.bedrock.aioboto3.Session",
  150. return_value=_FakeSession(captured_calls, client_kwargs_calls),
  151. ):
  152. await bedrock_complete_if_cache(
  153. model="bedrock-model",
  154. prompt="hello",
  155. endpoint_url="",
  156. )
  157. assert client_kwargs_calls[-1] == {"region_name": None}
  158. @pytest.mark.offline
  159. @pytest.mark.asyncio
  160. async def test_bedrock_custom_endpoint_url_is_forwarded(monkeypatch):
  161. captured_calls: list[dict] = []
  162. client_kwargs_calls: list[dict] = []
  163. monkeypatch.delenv("AWS_REGION", raising=False)
  164. with patch(
  165. "lightrag.llm.bedrock.aioboto3.Session",
  166. return_value=_FakeSession(captured_calls, client_kwargs_calls),
  167. ):
  168. await bedrock_complete_if_cache(
  169. model="bedrock-model",
  170. prompt="hello",
  171. endpoint_url="https://proxy.example.com",
  172. )
  173. assert client_kwargs_calls[-1] == {
  174. "region_name": None,
  175. "endpoint_url": "https://proxy.example.com",
  176. }
  177. class _FakeEmbeddingBody:
  178. async def json(self):
  179. return {"embedding": [0.1] * 1024}
  180. class _FakeEmbeddingResponse:
  181. def get(self, key):
  182. assert key == "body"
  183. return _FakeEmbeddingBody()
  184. class _FakeEmbeddingClient(_FakeBedrockClient):
  185. async def invoke_model(self, **_kwargs):
  186. return _FakeEmbeddingResponse()
  187. class _FakeEmbeddingSession(_FakeSession):
  188. def client(self, *_args, **kwargs):
  189. self._client_kwargs_calls.append(dict(kwargs))
  190. return _FakeEmbeddingClient(self._captured_calls)
  191. @pytest.mark.offline
  192. @pytest.mark.asyncio
  193. async def test_bedrock_embed_custom_endpoint_url_is_forwarded(monkeypatch):
  194. captured_calls: list[dict] = []
  195. client_kwargs_calls: list[dict] = []
  196. monkeypatch.delenv("AWS_REGION", raising=False)
  197. with patch(
  198. "lightrag.llm.bedrock.aioboto3.Session",
  199. return_value=_FakeEmbeddingSession(captured_calls, client_kwargs_calls),
  200. ):
  201. await bedrock_embed(
  202. texts=["hello"],
  203. endpoint_url="https://proxy.example.com",
  204. )
  205. assert client_kwargs_calls[-1] == {
  206. "region_name": None,
  207. "endpoint_url": "https://proxy.example.com",
  208. }
  209. @pytest.mark.offline
  210. @pytest.mark.asyncio
  211. async def test_bedrock_embed_default_endpoint_sentinel_uses_sdk_default(monkeypatch):
  212. captured_calls: list[dict] = []
  213. client_kwargs_calls: list[dict] = []
  214. monkeypatch.delenv("AWS_REGION", raising=False)
  215. with patch(
  216. "lightrag.llm.bedrock.aioboto3.Session",
  217. return_value=_FakeEmbeddingSession(captured_calls, client_kwargs_calls),
  218. ):
  219. await bedrock_embed(
  220. texts=["hello"],
  221. endpoint_url="DEFAULT_BEDROCK_ENDPOINT",
  222. )
  223. assert client_kwargs_calls[-1] == {"region_name": None}
  224. @pytest.mark.offline
  225. @pytest.mark.asyncio
  226. async def test_bedrock_embed_empty_endpoint_url_uses_sdk_default(monkeypatch):
  227. captured_calls: list[dict] = []
  228. client_kwargs_calls: list[dict] = []
  229. monkeypatch.delenv("AWS_REGION", raising=False)
  230. with patch(
  231. "lightrag.llm.bedrock.aioboto3.Session",
  232. return_value=_FakeEmbeddingSession(captured_calls, client_kwargs_calls),
  233. ):
  234. await bedrock_embed(
  235. texts=["hello"],
  236. endpoint_url="",
  237. )
  238. assert client_kwargs_calls[-1] == {"region_name": None}
  239. @pytest.mark.offline
  240. @pytest.mark.asyncio
  241. async def test_bedrock_complete_forwards_explicit_sigv4_client_kwargs(monkeypatch):
  242. monkeypatch.delenv("AWS_REGION", raising=False)
  243. captured_calls: list[dict] = []
  244. client_kwargs_calls: list[dict] = []
  245. with patch(
  246. "lightrag.llm.bedrock.aioboto3.Session",
  247. return_value=_FakeSession(captured_calls, client_kwargs_calls),
  248. ):
  249. await bedrock_complete_if_cache(
  250. model="bedrock-model",
  251. prompt="hello",
  252. aws_region="us-west-2",
  253. aws_access_key_id="akid",
  254. aws_secret_access_key="secret",
  255. aws_session_token="session",
  256. endpoint_url="https://proxy.example.com",
  257. )
  258. assert client_kwargs_calls[-1] == {
  259. "region_name": "us-west-2",
  260. "endpoint_url": "https://proxy.example.com",
  261. "aws_access_key_id": "akid",
  262. "aws_secret_access_key": "secret",
  263. "aws_session_token": "session",
  264. }
  265. @pytest.mark.offline
  266. @pytest.mark.asyncio
  267. async def test_bedrock_extra_fields_maps_to_additional_model_request_fields(
  268. monkeypatch,
  269. ):
  270. monkeypatch.delenv("AWS_REGION", raising=False)
  271. captured_calls: list[dict] = []
  272. with patch(
  273. "lightrag.llm.bedrock.aioboto3.Session",
  274. return_value=_FakeSession(captured_calls, []),
  275. ):
  276. await bedrock_complete_if_cache(
  277. model="bedrock-model",
  278. prompt="hello",
  279. extra_fields={"reasoning_config": {"type": "enabled"}},
  280. )
  281. assert captured_calls[-1]["additionalModelRequestFields"] == {
  282. "reasoning_config": {"type": "enabled"}
  283. }
  284. @pytest.mark.offline
  285. @pytest.mark.asyncio
  286. async def test_bedrock_empty_extra_fields_is_dropped(monkeypatch):
  287. monkeypatch.delenv("AWS_REGION", raising=False)
  288. captured_calls: list[dict] = []
  289. with patch(
  290. "lightrag.llm.bedrock.aioboto3.Session",
  291. return_value=_FakeSession(captured_calls, []),
  292. ):
  293. await bedrock_complete_if_cache(
  294. model="bedrock-model",
  295. prompt="hello",
  296. extra_fields=None,
  297. )
  298. await bedrock_complete_if_cache(
  299. model="bedrock-model",
  300. prompt="hello",
  301. extra_fields={},
  302. )
  303. for call in captured_calls:
  304. assert "additionalModelRequestFields" not in call
  305. @pytest.mark.offline
  306. @pytest.mark.asyncio
  307. async def test_bedrock_api_key_is_ignored_and_does_not_mutate_env(monkeypatch):
  308. monkeypatch.delenv("AWS_REGION", raising=False)
  309. monkeypatch.setenv("AWS_BEARER_TOKEN_BEDROCK", "absk-from-env")
  310. monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False)
  311. monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False)
  312. monkeypatch.delenv("AWS_SESSION_TOKEN", raising=False)
  313. with patch(
  314. "lightrag.llm.bedrock.aioboto3.Session",
  315. return_value=_FakeSession([], []),
  316. ):
  317. with pytest.warns(DeprecationWarning, match="api_key=.*ignored"):
  318. await bedrock_complete_if_cache(
  319. model="bedrock-model",
  320. prompt="hello",
  321. api_key="absk-should-be-ignored",
  322. aws_access_key_id="akid",
  323. aws_secret_access_key="secret",
  324. aws_session_token="session",
  325. )
  326. assert os.environ.get("AWS_BEARER_TOKEN_BEDROCK") == "absk-from-env"
  327. assert os.environ.get("AWS_ACCESS_KEY_ID") is None
  328. assert os.environ.get("AWS_SECRET_ACCESS_KEY") is None
  329. assert os.environ.get("AWS_SESSION_TOKEN") is None
  330. @pytest.mark.offline
  331. @pytest.mark.asyncio
  332. async def test_bedrock_embed_forwards_sigv4_and_ignores_api_key(monkeypatch):
  333. monkeypatch.delenv("AWS_REGION", raising=False)
  334. monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False)
  335. client_kwargs_calls: list[dict] = []
  336. with patch(
  337. "lightrag.llm.bedrock.aioboto3.Session",
  338. return_value=_FakeEmbeddingSession([], client_kwargs_calls),
  339. ):
  340. with pytest.warns(DeprecationWarning, match="api_key=.*ignored"):
  341. await bedrock_embed(
  342. texts=["hello"],
  343. api_key="absk-embedding-key",
  344. aws_region="us-east-1",
  345. aws_access_key_id="akid",
  346. aws_secret_access_key="secret",
  347. aws_session_token="session",
  348. )
  349. assert client_kwargs_calls[-1] == {
  350. "region_name": "us-east-1",
  351. "aws_access_key_id": "akid",
  352. "aws_secret_access_key": "secret",
  353. "aws_session_token": "session",
  354. }
  355. assert os.environ.get("AWS_BEARER_TOKEN_BEDROCK") is None
  356. @pytest.mark.offline
  357. def test_bedrock_auth_docstrings_describe_generic_api_key_behavior():
  358. assert "AWS_BEARER_TOKEN_BEDROCK" in bedrock_complete_if_cache.__doc__
  359. assert "LLM_BINDING_API_KEY" in bedrock_complete_if_cache.__doc__
  360. assert "EMBEDDING_BINDING_API_KEY" in bedrock_embed.func.__doc__
  361. class _FakeLightRAG:
  362. last_init_kwargs = None
  363. last_instance = None
  364. def __init__(self, **kwargs):
  365. type(self).last_init_kwargs = dict(kwargs)
  366. type(self).last_instance = self
  367. self.role_config_snapshot = {}
  368. for role, cfg in (kwargs.get("role_llm_configs") or {}).items():
  369. metadata = dict(getattr(cfg, "metadata", None) or {})
  370. self.role_config_snapshot[role] = {
  371. "binding": metadata.get("binding"),
  372. "model": metadata.get("model"),
  373. "host": metadata.get("host"),
  374. "is_cross_provider": metadata.get("is_cross_provider", False),
  375. "max_async": getattr(cfg, "max_async", None),
  376. "timeout": getattr(cfg, "timeout", None),
  377. "has_model_kwargs": getattr(cfg, "kwargs", None) is not None,
  378. "metadata": metadata,
  379. }
  380. self.queue_status_snapshot = {}
  381. self.embedding_queue_status_snapshot = {}
  382. self.rerank_queue_status_snapshot = {}
  383. def register_role_llm_builder(self, _builder) -> None:
  384. return None
  385. def set_role_llm_metadata(self, _role: str, **_metadata) -> None:
  386. return None
  387. def get_llm_role_config(self):
  388. return self.role_config_snapshot
  389. async def get_llm_queue_status(self, include_base=True):
  390. return self.queue_status_snapshot
  391. async def get_embedding_queue_status(self):
  392. return self.embedding_queue_status_snapshot
  393. async def get_rerank_queue_status(self):
  394. return self.rerank_queue_status_snapshot
  395. class _FakeOllamaAPI:
  396. def __init__(self, *_args, **_kwargs):
  397. self.router = APIRouter()
  398. def _make_args(tmp_path) -> SimpleNamespace:
  399. return SimpleNamespace(
  400. host="127.0.0.1",
  401. port=9621,
  402. log_level="INFO",
  403. verbose=False,
  404. cors_origins="*",
  405. whitelist_paths="/health,/api/*",
  406. auth_accounts="",
  407. token_secret=None,
  408. token_expire_hours=48,
  409. guest_token_expire_hours=24,
  410. jwt_algorithm="HS256",
  411. token_auto_renew=True,
  412. token_renew_threshold=0.5,
  413. llm_binding="bedrock",
  414. embedding_binding="bedrock",
  415. llm_binding_host="DEFAULT_BEDROCK_ENDPOINT",
  416. embedding_binding_host="DEFAULT_BEDROCK_ENDPOINT",
  417. ssl=False,
  418. ssl_certfile=None,
  419. ssl_keyfile=None,
  420. key=None,
  421. input_dir=str(tmp_path / "inputs"),
  422. workspace="",
  423. working_dir=str(tmp_path / "rag_storage"),
  424. llm_binding_api_key=None,
  425. embedding_binding_api_key="",
  426. aws_region="us-east-1",
  427. aws_access_key_id="global-akid",
  428. aws_secret_access_key="global-secret",
  429. aws_session_token="global-session",
  430. query_aws_region=None,
  431. query_aws_access_key_id=None,
  432. query_aws_secret_access_key=None,
  433. query_aws_session_token=None,
  434. llm_model="us.amazon.nova-lite-v1:0",
  435. embedding_model=None,
  436. embedding_dim=None,
  437. embedding_send_dim=False,
  438. embedding_token_limit=None,
  439. embedding_document_prefix=None,
  440. embedding_document_prefix_configured=False,
  441. embedding_query_prefix=None,
  442. embedding_query_prefix_configured=False,
  443. embedding_prefix_no_prefix_sentinel="NO_PREFIX",
  444. embedding_prefixes_configured=False,
  445. embedding_asymmetric=False,
  446. embedding_asymmetric_configured=False,
  447. max_async=4,
  448. summary_max_tokens=512,
  449. summary_context_size=4096,
  450. force_llm_summary_on_merge=8,
  451. chunk_size=1200,
  452. chunk_overlap_size=100,
  453. kv_storage="JsonKVStorage",
  454. graph_storage="NetworkXStorage",
  455. vector_storage="NanoVectorDBStorage",
  456. doc_status_storage="JsonDocStatusStorage",
  457. cosine_threshold=0.2,
  458. enable_llm_cache_for_extract=True,
  459. enable_llm_cache=True,
  460. vlm_process_enable=False,
  461. max_parallel_insert=2,
  462. max_graph_nodes=1000,
  463. simulated_model_name="lightrag",
  464. simulated_model_tag="latest",
  465. summary_language="English",
  466. rerank_binding="null",
  467. rerank_model=None,
  468. rerank_binding_host=None,
  469. rerank_binding_api_key=None,
  470. embedding_func_max_async=8,
  471. embedding_batch_num=10,
  472. min_rerank_score=0.0,
  473. related_chunk_number=5,
  474. top_k=10,
  475. llm_timeout=180,
  476. embedding_timeout=30,
  477. rerank_max_async=4,
  478. rerank_timeout=30,
  479. )
  480. @pytest.mark.offline
  481. @pytest.mark.asyncio
  482. async def test_create_app_query_role_uses_bedrock_binding(tmp_path, monkeypatch):
  483. _reload_api_modules_if_mocked()
  484. monkeypatch.setattr(sys, "argv", ["pytest"])
  485. config = importlib.import_module("lightrag.api.config")
  486. config.initialize_config(_make_args(tmp_path), force=True)
  487. lightrag_server = importlib.import_module("lightrag.api.lightrag_server")
  488. monkeypatch.setattr(lightrag_server, "LightRAG", _FakeLightRAG)
  489. monkeypatch.setattr(lightrag_server, "check_frontend_build", lambda: (True, False))
  490. monkeypatch.setattr(
  491. lightrag_server, "create_document_routes", lambda *_args, **_kwargs: APIRouter()
  492. )
  493. monkeypatch.setattr(
  494. lightrag_server, "create_query_routes", lambda *_args, **_kwargs: APIRouter()
  495. )
  496. monkeypatch.setattr(
  497. lightrag_server, "create_graph_routes", lambda *_args, **_kwargs: APIRouter()
  498. )
  499. monkeypatch.setattr(lightrag_server, "OllamaAPI", _FakeOllamaAPI)
  500. args = _make_args(tmp_path)
  501. with (
  502. patch(
  503. "lightrag.llm.bedrock.bedrock_complete_if_cache",
  504. AsyncMock(return_value="bedrock-ok"),
  505. ) as mocked_bedrock,
  506. patch(
  507. "lightrag.llm.openai.openai_complete_if_cache",
  508. AsyncMock(side_effect=AssertionError("OpenAI fallback should not be used")),
  509. ) as mocked_openai,
  510. ):
  511. lightrag_server.create_app(args)
  512. query_cfg = _FakeLightRAG.last_init_kwargs["role_llm_configs"]["query"]
  513. query_func = query_cfg.func
  514. result = await query_func("hello")
  515. assert query_cfg.metadata["binding"] == "bedrock"
  516. assert query_cfg.metadata["model"] == "us.amazon.nova-lite-v1:0"
  517. assert query_cfg.metadata["host"] == "DEFAULT_BEDROCK_ENDPOINT"
  518. assert query_cfg.metadata["api_key"] is None
  519. assert query_cfg.metadata["bedrock_aws_options"]["aws_region"] == "us-east-1"
  520. assert result == "bedrock-ok"
  521. assert mocked_openai.await_count == 0
  522. assert mocked_bedrock.await_count == 1
  523. assert mocked_bedrock.await_args.args[:2] == ("us.amazon.nova-lite-v1:0", "hello")
  524. assert "api_key" not in mocked_bedrock.await_args.kwargs
  525. assert (
  526. mocked_bedrock.await_args.kwargs["endpoint_url"] == "DEFAULT_BEDROCK_ENDPOINT"
  527. )
  528. assert mocked_bedrock.await_args.kwargs["aws_region"] == "us-east-1"
  529. assert mocked_bedrock.await_args.kwargs["aws_access_key_id"] == "global-akid"
  530. @pytest.mark.offline
  531. @pytest.mark.asyncio
  532. async def test_create_app_bedrock_query_role_uses_role_sigv4_credentials(
  533. tmp_path, monkeypatch
  534. ):
  535. _reload_api_modules_if_mocked()
  536. monkeypatch.setattr(sys, "argv", ["pytest"])
  537. config = importlib.import_module("lightrag.api.config")
  538. config.initialize_config(_make_args(tmp_path), force=True)
  539. lightrag_server = importlib.import_module("lightrag.api.lightrag_server")
  540. monkeypatch.setattr(lightrag_server, "LightRAG", _FakeLightRAG)
  541. monkeypatch.setattr(lightrag_server, "check_frontend_build", lambda: (True, False))
  542. monkeypatch.setattr(
  543. lightrag_server, "create_document_routes", lambda *_args, **_kwargs: APIRouter()
  544. )
  545. monkeypatch.setattr(
  546. lightrag_server, "create_query_routes", lambda *_args, **_kwargs: APIRouter()
  547. )
  548. monkeypatch.setattr(
  549. lightrag_server, "create_graph_routes", lambda *_args, **_kwargs: APIRouter()
  550. )
  551. monkeypatch.setattr(lightrag_server, "OllamaAPI", _FakeOllamaAPI)
  552. args = _make_args(tmp_path)
  553. args.query_aws_region = "us-west-2"
  554. args.query_aws_access_key_id = "query-akid"
  555. args.query_aws_secret_access_key = "query-secret"
  556. args.query_aws_session_token = "query-session"
  557. with patch(
  558. "lightrag.llm.bedrock.bedrock_complete_if_cache",
  559. AsyncMock(return_value="bedrock-ok"),
  560. ) as mocked_bedrock:
  561. lightrag_server.create_app(args)
  562. query_func = _FakeLightRAG.last_init_kwargs["role_llm_configs"]["query"].func
  563. await query_func("hello")
  564. assert mocked_bedrock.await_args.kwargs["aws_region"] == "us-west-2"
  565. assert mocked_bedrock.await_args.kwargs["aws_access_key_id"] == "query-akid"
  566. assert mocked_bedrock.await_args.kwargs["aws_secret_access_key"] == "query-secret"
  567. assert mocked_bedrock.await_args.kwargs["aws_session_token"] == "query-session"
  568. @pytest.mark.offline
  569. @pytest.mark.asyncio
  570. async def test_create_app_keyword_openai_role_forwards_nested_extra_body(
  571. tmp_path, monkeypatch, caplog
  572. ):
  573. _reload_api_modules_if_mocked()
  574. monkeypatch.setattr(sys, "argv", ["pytest"])
  575. monkeypatch.setattr(logging.getLogger("lightrag"), "propagate", True)
  576. monkeypatch.setenv(
  577. "KEYWORD_OPENAI_LLM_EXTRA_BODY",
  578. '{"chat_template_kwargs": {"enable_thinking": false}}',
  579. )
  580. config = importlib.import_module("lightrag.api.config")
  581. config.initialize_config(_make_args(tmp_path), force=True)
  582. lightrag_server = importlib.import_module("lightrag.api.lightrag_server")
  583. monkeypatch.setattr(lightrag_server, "LightRAG", _FakeLightRAG)
  584. monkeypatch.setattr(lightrag_server, "check_frontend_build", lambda: (True, False))
  585. monkeypatch.setattr(
  586. lightrag_server, "create_document_routes", lambda *_args, **_kwargs: APIRouter()
  587. )
  588. monkeypatch.setattr(
  589. lightrag_server, "create_query_routes", lambda *_args, **_kwargs: APIRouter()
  590. )
  591. monkeypatch.setattr(
  592. lightrag_server, "create_graph_routes", lambda *_args, **_kwargs: APIRouter()
  593. )
  594. monkeypatch.setattr(lightrag_server, "OllamaAPI", _FakeOllamaAPI)
  595. args = _make_args(tmp_path)
  596. args.keyword_llm_binding = "openai"
  597. args.keyword_llm_model = "xhd/Qwen3.5-35B-A3B"
  598. args.keyword_llm_binding_host = "https://keyword.example/v1"
  599. args.keyword_llm_binding_api_key = "keyword-secret"
  600. with (
  601. caplog.at_level("INFO", logger="lightrag"),
  602. patch(
  603. "lightrag.llm.openai.openai_complete_if_cache",
  604. AsyncMock(
  605. return_value='{"high_level_keywords":[],"low_level_keywords":[]}'
  606. ),
  607. ) as mocked_openai,
  608. ):
  609. lightrag_server.create_app(args)
  610. keyword_cfg = _FakeLightRAG.last_init_kwargs["role_llm_configs"]["keyword"]
  611. result = await keyword_cfg.func(
  612. "keyword prompt", response_format={"type": "json_object"}
  613. )
  614. assert result == '{"high_level_keywords":[],"low_level_keywords":[]}'
  615. assert keyword_cfg.metadata["binding"] == "openai"
  616. assert keyword_cfg.metadata["provider_options"]["extra_body"] == {
  617. "chat_template_kwargs": {"enable_thinking": False}
  618. }
  619. assert mocked_openai.await_count == 1
  620. assert mocked_openai.await_args.args[:2] == (
  621. "xhd/Qwen3.5-35B-A3B",
  622. "keyword prompt",
  623. )
  624. kwargs = mocked_openai.await_args.kwargs
  625. assert kwargs["base_url"] == "https://keyword.example/v1"
  626. assert kwargs["api_key"] == "keyword-secret"
  627. assert kwargs["response_format"] == {"type": "json_object"}
  628. assert kwargs["extra_body"] == {"chat_template_kwargs": {"enable_thinking": False}}
  629. messages = "\n".join(record.getMessage() for record in caplog.records)
  630. assert "Role LLM Option:" in messages
  631. assert " - extract: Bedrock {}" in messages
  632. assert " - keyword: OpenAI {'extra_body':" in messages
  633. assert " - query: Bedrock {}" in messages
  634. assert " - vlm: Bedrock {}" in messages
  635. assert "chat_template_kwargs" in messages
  636. assert "reasoning_effort" not in messages
  637. assert "frequency_penalty" not in messages
  638. assert "keyword-secret" not in messages
  639. @pytest.mark.offline
  640. def test_create_app_rejects_bedrock_role_api_key(tmp_path, monkeypatch):
  641. _reload_api_modules_if_mocked()
  642. monkeypatch.setattr(sys, "argv", ["pytest"])
  643. config = importlib.import_module("lightrag.api.config")
  644. config.initialize_config(_make_args(tmp_path), force=True)
  645. lightrag_server = importlib.import_module("lightrag.api.lightrag_server")
  646. monkeypatch.setattr(lightrag_server, "check_frontend_build", lambda: (True, False))
  647. args = _make_args(tmp_path)
  648. args.query_llm_binding_api_key = "absk-role"
  649. with pytest.raises(ValueError, match="does not support role-specific"):
  650. lightrag_server.create_app(args)
  651. @pytest.mark.offline
  652. def test_health_role_llm_config_uses_runtime_snapshot(tmp_path, monkeypatch):
  653. _reload_api_modules_if_mocked()
  654. monkeypatch.setattr(sys, "argv", ["pytest"])
  655. config = importlib.import_module("lightrag.api.config")
  656. config.initialize_config(_make_args(tmp_path), force=True)
  657. lightrag_server = importlib.import_module("lightrag.api.lightrag_server")
  658. monkeypatch.setattr(lightrag_server, "LightRAG", _FakeLightRAG)
  659. monkeypatch.setattr(lightrag_server, "check_frontend_build", lambda: (True, False))
  660. monkeypatch.setattr(
  661. lightrag_server, "create_document_routes", lambda *_args, **_kwargs: APIRouter()
  662. )
  663. monkeypatch.setattr(
  664. lightrag_server, "create_query_routes", lambda *_args, **_kwargs: APIRouter()
  665. )
  666. monkeypatch.setattr(
  667. lightrag_server, "create_graph_routes", lambda *_args, **_kwargs: APIRouter()
  668. )
  669. monkeypatch.setattr(lightrag_server, "OllamaAPI", _FakeOllamaAPI)
  670. monkeypatch.setattr(
  671. lightrag_server,
  672. "get_namespace_data",
  673. AsyncMock(return_value={"busy": False}),
  674. )
  675. monkeypatch.setattr(lightrag_server, "get_default_workspace", lambda: "default")
  676. monkeypatch.setattr(
  677. lightrag_server,
  678. "cleanup_keyed_lock",
  679. lambda: {"cleanup_performed": {}, "current_status": {}},
  680. )
  681. app = lightrag_server.create_app(_make_args(tmp_path))
  682. _FakeLightRAG.last_instance.role_config_snapshot = {
  683. "query": {
  684. "binding": "runtime-binding",
  685. "model": "runtime-model",
  686. "host": "https://runtime.example/v1",
  687. "max_async": 9,
  688. "metadata": {"binding": "runtime-binding"},
  689. }
  690. }
  691. _FakeLightRAG.last_instance.queue_status_snapshot = {
  692. "query": {"available": True, "rejected_total": 2}
  693. }
  694. _FakeLightRAG.last_instance.embedding_queue_status_snapshot = {
  695. "available": True,
  696. "running": 1,
  697. }
  698. _FakeLightRAG.last_instance.rerank_queue_status_snapshot = {
  699. "available": False,
  700. }
  701. response = TestClient(app).get("/health")
  702. assert response.status_code == 200
  703. body = response.json()
  704. role_cfg = body["configuration"]["role_llm_config"]["query"]
  705. assert role_cfg["binding"] == "runtime-binding"
  706. assert role_cfg["model"] == "runtime-model"
  707. assert role_cfg["host"] == "https://runtime.example/v1"
  708. assert role_cfg["max_async"] == 9
  709. assert role_cfg["model"] != "us.amazon.nova-lite-v1:0"
  710. assert body["llm_queue_status"]["query"]["rejected_total"] == 2
  711. assert body["embedding_queue_status"]["running"] == 1
  712. assert body["rerank_queue_status"]["available"] is False
  713. @pytest.mark.offline
  714. @pytest.mark.parametrize(
  715. "pipeline_state, expected_active",
  716. [
  717. ({"busy": False}, False),
  718. ({"busy": True}, True),
  719. ({"busy": False, "scanning": True}, True),
  720. ({"busy": False, "destructive_busy": True}, True),
  721. ({"busy": False, "pending_enqueues": 2}, True),
  722. (
  723. {
  724. "busy": False,
  725. "scanning": False,
  726. "destructive_busy": False,
  727. "pending_enqueues": 0,
  728. },
  729. False,
  730. ),
  731. ],
  732. )
  733. def test_health_pipeline_active_derivation(
  734. tmp_path, monkeypatch, pipeline_state, expected_active
  735. ):
  736. _reload_api_modules_if_mocked()
  737. monkeypatch.setattr(sys, "argv", ["pytest"])
  738. config = importlib.import_module("lightrag.api.config")
  739. config.initialize_config(_make_args(tmp_path), force=True)
  740. lightrag_server = importlib.import_module("lightrag.api.lightrag_server")
  741. monkeypatch.setattr(lightrag_server, "LightRAG", _FakeLightRAG)
  742. monkeypatch.setattr(lightrag_server, "check_frontend_build", lambda: (True, False))
  743. monkeypatch.setattr(
  744. lightrag_server, "create_document_routes", lambda *_args, **_kwargs: APIRouter()
  745. )
  746. monkeypatch.setattr(
  747. lightrag_server, "create_query_routes", lambda *_args, **_kwargs: APIRouter()
  748. )
  749. monkeypatch.setattr(
  750. lightrag_server, "create_graph_routes", lambda *_args, **_kwargs: APIRouter()
  751. )
  752. monkeypatch.setattr(lightrag_server, "OllamaAPI", _FakeOllamaAPI)
  753. monkeypatch.setattr(
  754. lightrag_server,
  755. "get_namespace_data",
  756. AsyncMock(return_value=pipeline_state),
  757. )
  758. monkeypatch.setattr(lightrag_server, "get_default_workspace", lambda: "default")
  759. monkeypatch.setattr(
  760. lightrag_server,
  761. "cleanup_keyed_lock",
  762. lambda: {"cleanup_performed": {}, "current_status": {}},
  763. )
  764. app = lightrag_server.create_app(_make_args(tmp_path))
  765. response = TestClient(app).get("/health")
  766. assert response.status_code == 200
  767. body = response.json()
  768. assert body["pipeline_busy"] is bool(pipeline_state.get("busy", False))
  769. assert body["pipeline_scanning"] is bool(pipeline_state.get("scanning", False))
  770. assert body["pipeline_destructive_busy"] is bool(
  771. pipeline_state.get("destructive_busy", False)
  772. )
  773. assert body["pipeline_pending_enqueues"] == int(
  774. pipeline_state.get("pending_enqueues", 0)
  775. )
  776. assert body["pipeline_active"] is expected_active