test_entity_extraction_stability.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170
  1. """Tests for entity extraction stability after refactoring.
  2. Covers:
  3. - entity_types_guidance injected into prompts (text mode and JSON mode)
  4. - custom entity_types_guidance via addon_params overrides default
  5. - ENTITY_TYPES env var raises SystemExit at LightRAG init
  6. - EntityExtractionResult Pydantic schema used in JSON mode (entity_extraction kwarg)
  7. - Default entity type guidance constant is present and non-empty
  8. """
  9. import json
  10. import os
  11. from pathlib import Path
  12. from unittest.mock import AsyncMock, patch
  13. import pytest
  14. from lightrag.utils import EmbeddingFunc, Tokenizer, TokenizerInterface
  15. class DummyTokenizer(TokenizerInterface):
  16. """Simple 1:1 character-to-token mapping for testing."""
  17. def encode(self, content: str):
  18. return [ord(ch) for ch in content]
  19. def decode(self, tokens):
  20. return "".join(chr(token) for token in tokens)
  21. def _make_global_config(
  22. addon_params: dict | None = None,
  23. use_json: bool = False,
  24. max_gleaning: int = 0,
  25. prompt_profile: dict | None = None,
  26. ) -> dict:
  27. tokenizer = Tokenizer("dummy", DummyTokenizer())
  28. extract_func = AsyncMock(return_value="")
  29. return {
  30. "llm_model_func": extract_func,
  31. "role_llm_funcs": {
  32. "extract": extract_func,
  33. "keyword": extract_func,
  34. "query": extract_func,
  35. "vlm": extract_func,
  36. },
  37. "entity_extract_max_gleaning": max_gleaning,
  38. "entity_extract_max_records": 100,
  39. "entity_extract_max_entities": 40,
  40. "addon_params": addon_params if addon_params is not None else {},
  41. "tokenizer": tokenizer,
  42. "max_extract_input_tokens": 20480,
  43. "llm_model_max_async": 1,
  44. "entity_extraction_use_json": use_json,
  45. "_entity_extraction_prompt_profile": prompt_profile,
  46. }
  47. def _make_chunks(content: str = "Alice founded Acme Corp in 1990.") -> dict[str, dict]:
  48. return {
  49. "chunk-001": {
  50. "tokens": len(content),
  51. "content": content,
  52. "full_doc_id": "doc-001",
  53. "chunk_order_index": 0,
  54. }
  55. }
  56. def _require_yaml() -> None:
  57. pytest.importorskip("yaml")
  58. def _write_prompt_profile(
  59. path: Path,
  60. *,
  61. guidance: str | None = None,
  62. text_examples: list[str] | None = None,
  63. json_examples: list[str] | None = None,
  64. ) -> None:
  65. lines: list[str] = []
  66. def _append_block(key: str, value: str) -> None:
  67. lines.append(f"{key}: |")
  68. for line in value.strip("\n").splitlines():
  69. lines.append(f" {line}")
  70. def _append_examples(key: str, values: list[str]) -> None:
  71. lines.append(f"{key}:")
  72. for value in values:
  73. lines.append(" - |")
  74. for line in value.strip("\n").splitlines():
  75. lines.append(f" {line}")
  76. if guidance is not None:
  77. _append_block("entity_types_guidance", guidance)
  78. if text_examples is not None:
  79. _append_examples("entity_extraction_examples", text_examples)
  80. if json_examples is not None:
  81. _append_examples("entity_extraction_json_examples", json_examples)
  82. path.write_text("\n".join(lines) + "\n", encoding="utf-8")
  83. def _dummy_embedding_func() -> EmbeddingFunc:
  84. async def _embed(texts):
  85. return [[0.0, 0.0, 0.0] for _ in texts]
  86. return EmbeddingFunc(embedding_dim=3, func=_embed)
  87. def _patch_prompt_dir(path: Path):
  88. return patch("lightrag.prompt.get_entity_type_prompt_dir", return_value=path)
  89. def _text_profile_example(label: str) -> str:
  90. return f"""---Entity Types---
  91. - ExampleType: Test type
  92. ---Input Text---
  93. ```
  94. {label}
  95. ```
  96. ---Output---
  97. entity{{tuple_delimiter}}{label}{{tuple_delimiter}}ExampleType{{tuple_delimiter}}{label} description.
  98. {{completion_delimiter}}"""
  99. def _json_profile_example(label: str) -> str:
  100. return f"""---Entity Types---
  101. - ExampleType: Test type
  102. ---Input Text---
  103. ```
  104. {label}
  105. ```
  106. ---Output---
  107. {{
  108. "entities": [
  109. {{"name": "{label}", "type": "ExampleType", "description": "{label} description."}}
  110. ],
  111. "relationships": []
  112. }}"""
  113. # --- Minimal valid LLM responses ---
  114. _TEXT_MODE_RESPONSE = (
  115. "entity<|#|>Alice<|#|>Person<|#|>Alice is the founder of Acme Corp."
  116. "\nentity<|#|>Acme Corp<|#|>Organization<|#|>Acme Corp is a company founded by Alice."
  117. "\nrelation<|#|>Alice<|#|>Acme Corp<|#|>founded<|#|>Alice founded Acme Corp."
  118. "\n<|COMPLETE|>"
  119. )
  120. _TEXT_MODE_MISPREFIXED_RELATION_RESPONSE = (
  121. "entity<|#|>Alice<|#|>Person<|#|>Alice is the founder of Acme Corp."
  122. "\nentity<|#|>Acme Corp<|#|>Organization<|#|>Acme Corp is a company founded by Alice."
  123. "\nentity<|#|>Alice<|#|>Acme Corp<|#|>founded<|#|>Alice founded Acme Corp."
  124. "\n<|COMPLETE|>"
  125. )
  126. _TEXT_MODE_GLEANED_RELATION_RESPONSES = [
  127. _TEXT_MODE_MISPREFIXED_RELATION_RESPONSE,
  128. "\nrelation<|#|>Alice<|#|>Acme Corp<|#|>founded<|#|>Alice founded Acme Corp.\n<|COMPLETE|>",
  129. ]
  130. _TEXT_MODE_CROSS_PASS_RELATION_RESPONSES = [
  131. "entity<|#|>Alice<|#|>Person<|#|>Alice founded a company.\n<|COMPLETE|>",
  132. "entity<|#|>Acme Corp<|#|>Organization<|#|>Acme Corp was founded by Alice."
  133. "\nrelation<|#|>Alice<|#|>Acme Corp<|#|>founded<|#|>Alice founded Acme Corp.\n<|COMPLETE|>",
  134. ]
  135. _JSON_MODE_RESPONSE = json.dumps(
  136. {
  137. "entities": [
  138. {
  139. "name": "Alice",
  140. "type": "Person",
  141. "description": "Alice is the founder of Acme Corp.",
  142. },
  143. {
  144. "name": "Acme Corp",
  145. "type": "Organization",
  146. "description": "Acme Corp is a company founded by Alice.",
  147. },
  148. ],
  149. "relationships": [
  150. {
  151. "source": "Alice",
  152. "target": "Acme Corp",
  153. "keywords": "founded",
  154. "description": "Alice founded Acme Corp.",
  155. },
  156. ],
  157. }
  158. )
  159. class _DummyTextChunksStorage:
  160. async def get_by_id(self, chunk_id: str):
  161. return {"file_path": "test.md"}
  162. # ---------------------------------------------------------------------------
  163. # 1. Default entity_types_guidance constant
  164. # ---------------------------------------------------------------------------
  165. @pytest.mark.offline
  166. def test_default_entity_types_guidance_exists():
  167. """PROMPTS['default_entity_types_guidance'] must be a non-empty string."""
  168. from lightrag.prompt import PROMPTS
  169. guidance = PROMPTS["default_entity_types_guidance"]
  170. assert isinstance(guidance, str)
  171. assert len(guidance.strip()) > 0
  172. @pytest.mark.offline
  173. def test_default_entity_types_guidance_covers_all_types():
  174. """Default guidance must mention all 11 canonical entity types."""
  175. from lightrag.prompt import PROMPTS
  176. guidance = PROMPTS["default_entity_types_guidance"]
  177. expected_types = [
  178. "Person",
  179. "Creature",
  180. "Organization",
  181. "Location",
  182. "Event",
  183. "Concept",
  184. "Method",
  185. "Content",
  186. "Data",
  187. "Artifact",
  188. "NaturalObject",
  189. ]
  190. for t in expected_types:
  191. assert (
  192. t in guidance
  193. ), f"Expected entity type '{t}' missing from default_entity_types_guidance"
  194. @pytest.mark.offline
  195. def test_json_examples_define_all_relationship_endpoints_as_entities():
  196. """JSON examples must define every relationship endpoint in the entities list."""
  197. from lightrag.prompt import PROMPTS
  198. for example in PROMPTS["entity_extraction_json_examples"]:
  199. if "<Output>" in example:
  200. output = example.split("<Output>", 1)[1].strip()
  201. else:
  202. output = example.split("---Output---", 1)[1].strip()
  203. parsed = json.loads(output)
  204. entity_names = {
  205. entity["name"] for entity in parsed.get("entities", []) if entity
  206. }
  207. for relationship in parsed.get("relationships", []):
  208. assert relationship["source"] in entity_names
  209. assert relationship["target"] in entity_names
  210. # ---------------------------------------------------------------------------
  211. # 2. DEFAULT_ENTITY_TYPES is gone from constants
  212. # ---------------------------------------------------------------------------
  213. @pytest.mark.offline
  214. def test_default_entity_types_removed_from_constants():
  215. """DEFAULT_ENTITY_TYPES must no longer exist in lightrag.constants."""
  216. import lightrag.constants as constants
  217. assert not hasattr(
  218. constants, "DEFAULT_ENTITY_TYPES"
  219. ), "DEFAULT_ENTITY_TYPES should have been removed from constants.py"
  220. # ---------------------------------------------------------------------------
  221. # 3. ENTITY_TYPES env var raises SystemExit
  222. # ---------------------------------------------------------------------------
  223. @pytest.mark.offline
  224. def test_entity_types_env_var_raises_system_exit(tmp_path):
  225. """LightRAG.__post_init__ must raise SystemExit when ENTITY_TYPES env var is set."""
  226. from lightrag import LightRAG
  227. with patch.dict(os.environ, {"ENTITY_TYPES": '["Person"]'}):
  228. with pytest.raises(SystemExit) as exc_info:
  229. LightRAG(
  230. working_dir=str(tmp_path),
  231. llm_model_func=AsyncMock(),
  232. embedding_func=None,
  233. )
  234. assert "ENTITY_TYPES" in str(exc_info.value)
  235. # ---------------------------------------------------------------------------
  236. # 4. Text mode: entity_types_guidance injected into prompt
  237. # ---------------------------------------------------------------------------
  238. @pytest.mark.offline
  239. @pytest.mark.asyncio
  240. async def test_text_mode_default_guidance_injected_into_prompt():
  241. """Default entity_types_guidance is passed to LLM system prompt in text mode."""
  242. from lightrag.operate import extract_entities
  243. from lightrag.prompt import PROMPTS
  244. global_config = _make_global_config(use_json=False)
  245. llm_func = global_config["llm_model_func"]
  246. llm_func.return_value = _TEXT_MODE_RESPONSE
  247. with patch("lightrag.operate.logger"):
  248. await extract_entities(
  249. chunks=_make_chunks(),
  250. global_config=global_config,
  251. )
  252. # The system prompt passed to the LLM must contain the default guidance
  253. assert llm_func.await_count >= 1
  254. call_kwargs = llm_func.call_args_list[0][1]
  255. system_prompt = call_kwargs.get("system_prompt", "")
  256. assert PROMPTS["default_entity_types_guidance"] in system_prompt
  257. assert "must start with `relation`, never `entity`" in system_prompt
  258. assert "After the last entity row, switch prefixes to `relation`" in system_prompt
  259. assert "Output at most 100 total rows" in system_prompt
  260. assert "Output at most 40 entity rows" in system_prompt
  261. @pytest.mark.offline
  262. @pytest.mark.asyncio
  263. async def test_text_mode_custom_guidance_overrides_default():
  264. """Custom entity_types_guidance in addon_params overrides default."""
  265. from lightrag.operate import extract_entities
  266. custom_guidance = "- Widget: A test widget type"
  267. global_config = _make_global_config(
  268. addon_params={"entity_types_guidance": custom_guidance},
  269. use_json=False,
  270. )
  271. llm_func = global_config["llm_model_func"]
  272. llm_func.return_value = _TEXT_MODE_RESPONSE
  273. with patch("lightrag.operate.logger"):
  274. await extract_entities(
  275. chunks=_make_chunks(),
  276. global_config=global_config,
  277. )
  278. call_kwargs = llm_func.call_args_list[0][1]
  279. system_prompt = call_kwargs.get("system_prompt", "")
  280. assert custom_guidance in system_prompt
  281. @pytest.mark.offline
  282. def test_text_continue_prompt_requires_relation_prefix_for_corrections():
  283. from lightrag.prompt import PROMPTS
  284. prompt = PROMPTS["entity_continue_extraction_user_prompt"]
  285. assert (
  286. "Any corrected relationship row must be emitted with the literal `relation` prefix"
  287. in prompt
  288. )
  289. assert (
  290. "output at most {max_total_records} total rows and at most {max_entity_records} entity rows"
  291. in prompt
  292. )
  293. assert (
  294. "may reference entities that were already extracted correctly in the previous response"
  295. in prompt
  296. )
  297. assert (
  298. "whose source and target entities are both included in this response"
  299. not in prompt
  300. )
  301. @pytest.mark.offline
  302. def test_text_user_prompt_includes_quantity_limits():
  303. from lightrag.prompt import PROMPTS
  304. prompt = PROMPTS["entity_extraction_user_prompt"]
  305. assert (
  306. "output at most {max_total_records} total rows and at most {max_entity_records} entity rows"
  307. in prompt
  308. )
  309. assert (
  310. "If the row limit is reached, output `{completion_delimiter}` immediately"
  311. in prompt
  312. )
  313. # ---------------------------------------------------------------------------
  314. # 5. JSON mode: entity_types_guidance injected + entity_extraction kwarg set
  315. # ---------------------------------------------------------------------------
  316. @pytest.mark.offline
  317. @pytest.mark.asyncio
  318. async def test_rebuild_from_cached_fenced_json_uses_json_parser():
  319. """Cached JSON wrapped in markdown fences must not fall back to text parsing."""
  320. from lightrag import operate
  321. fenced_json = f"```json\n{_JSON_MODE_RESPONSE}\n```"
  322. with patch(
  323. "lightrag.operate._process_extraction_result",
  324. side_effect=AssertionError("text parser should not be used"),
  325. ):
  326. nodes, edges = await operate._rebuild_from_extraction_result(
  327. text_chunks_storage=_DummyTextChunksStorage(),
  328. extraction_result=fenced_json,
  329. chunk_id="chunk-001",
  330. timestamp=123,
  331. )
  332. assert set(nodes) == {"Alice", "Acme Corp"}
  333. assert ("Alice", "Acme Corp") in edges
  334. assert nodes["Alice"][0]["file_path"] == "test.md"
  335. @pytest.mark.offline
  336. @pytest.mark.asyncio
  337. async def test_json_mode_default_guidance_injected_into_prompt():
  338. """Default entity_types_guidance is passed to LLM system prompt in JSON mode."""
  339. from lightrag.operate import extract_entities
  340. from lightrag.prompt import PROMPTS
  341. global_config = _make_global_config(use_json=True)
  342. llm_func = global_config["llm_model_func"]
  343. llm_func.return_value = _JSON_MODE_RESPONSE
  344. with patch("lightrag.operate.logger"):
  345. await extract_entities(
  346. chunks=_make_chunks(),
  347. global_config=global_config,
  348. )
  349. assert llm_func.await_count >= 1
  350. call_kwargs = llm_func.call_args_list[0][1]
  351. system_prompt = call_kwargs.get("system_prompt", "")
  352. assert PROMPTS["default_entity_types_guidance"] in system_prompt
  353. assert "Output at most 100 total records" in system_prompt
  354. assert "Output at most 40 entity objects" in system_prompt
  355. @pytest.mark.offline
  356. @pytest.mark.asyncio
  357. async def test_json_mode_entity_extraction_kwarg_passed():
  358. """JSON mode must pass response_format={'type':'json_object'} to the LLM function."""
  359. from lightrag.operate import extract_entities
  360. global_config = _make_global_config(use_json=True)
  361. llm_func = global_config["llm_model_func"]
  362. llm_func.return_value = _JSON_MODE_RESPONSE
  363. with patch("lightrag.operate.logger"):
  364. await extract_entities(
  365. chunks=_make_chunks(),
  366. global_config=global_config,
  367. )
  368. assert llm_func.await_count >= 1
  369. call_kwargs = llm_func.call_args_list[0][1]
  370. assert call_kwargs.get("response_format") == {"type": "json_object"}
  371. assert call_kwargs.get("entity_extraction") is not True
  372. @pytest.mark.offline
  373. @pytest.mark.asyncio
  374. async def test_json_mode_custom_guidance_overrides_default():
  375. """Custom entity_types_guidance in addon_params overrides default in JSON mode."""
  376. from lightrag.operate import extract_entities
  377. custom_guidance = "- Gadget: A test gadget type"
  378. global_config = _make_global_config(
  379. addon_params={"entity_types_guidance": custom_guidance},
  380. use_json=True,
  381. )
  382. llm_func = global_config["llm_model_func"]
  383. llm_func.return_value = _JSON_MODE_RESPONSE
  384. with patch("lightrag.operate.logger"):
  385. await extract_entities(
  386. chunks=_make_chunks(),
  387. global_config=global_config,
  388. )
  389. call_kwargs = llm_func.call_args_list[0][1]
  390. system_prompt = call_kwargs.get("system_prompt", "")
  391. assert custom_guidance in system_prompt
  392. @pytest.mark.offline
  393. def test_json_user_prompt_includes_quantity_limits():
  394. from lightrag.prompt import PROMPTS
  395. prompt = PROMPTS["entity_extraction_json_user_prompt"]
  396. assert (
  397. "output at most {max_total_records} total records and at most {max_entity_records} entity objects"
  398. in prompt
  399. )
  400. assert (
  401. "Only output relationship objects whose `source` and `target` are both included"
  402. in prompt
  403. )
  404. @pytest.mark.offline
  405. def test_json_continue_prompt_includes_quantity_limits():
  406. from lightrag.prompt import PROMPTS
  407. prompt = PROMPTS["entity_continue_extraction_json_user_prompt"]
  408. assert (
  409. "output at most {max_total_records} total records and at most {max_entity_records} entity objects"
  410. in prompt
  411. )
  412. assert (
  413. "may reference entities already extracted correctly in the previous response"
  414. in prompt
  415. )
  416. # ---------------------------------------------------------------------------
  417. # 6. Text mode: entity_extraction kwarg NOT passed (only JSON mode uses it)
  418. # ---------------------------------------------------------------------------
  419. @pytest.mark.offline
  420. @pytest.mark.asyncio
  421. async def test_text_mode_no_entity_extraction_kwarg():
  422. """Text mode must NOT pass entity_extraction=True to the LLM function."""
  423. from lightrag.operate import extract_entities
  424. global_config = _make_global_config(use_json=False)
  425. llm_func = global_config["llm_model_func"]
  426. llm_func.return_value = _TEXT_MODE_RESPONSE
  427. with patch("lightrag.operate.logger"):
  428. await extract_entities(
  429. chunks=_make_chunks(),
  430. global_config=global_config,
  431. )
  432. call_kwargs = llm_func.call_args_list[0][1]
  433. assert not call_kwargs.get("entity_extraction", False)
  434. @pytest.mark.offline
  435. @pytest.mark.asyncio
  436. async def test_text_mode_recovers_mis_prefixed_relationship_row():
  437. from lightrag.operate import extract_entities
  438. global_config = _make_global_config(use_json=False)
  439. llm_func = global_config["llm_model_func"]
  440. llm_func.return_value = _TEXT_MODE_MISPREFIXED_RELATION_RESPONSE
  441. with patch("lightrag.operate.logger"):
  442. chunk_results = await extract_entities(
  443. chunks=_make_chunks(),
  444. global_config=global_config,
  445. )
  446. entities, relationships = chunk_results[0]
  447. assert len(entities) == 2
  448. assert len(relationships) == 1
  449. assert next(iter(relationships.keys())) == ("Alice", "Acme Corp")
  450. @pytest.mark.offline
  451. @pytest.mark.asyncio
  452. async def test_text_mode_gleaned_relation_merges_cleanly_after_recovery():
  453. from lightrag.operate import extract_entities
  454. global_config = _make_global_config(use_json=False, max_gleaning=1)
  455. llm_func = global_config["llm_model_func"]
  456. llm_func.side_effect = _TEXT_MODE_GLEANED_RELATION_RESPONSES
  457. with patch("lightrag.operate.logger"):
  458. chunk_results = await extract_entities(
  459. chunks=_make_chunks(),
  460. global_config=global_config,
  461. )
  462. entities, relationships = chunk_results[0]
  463. assert len(entities) == 2
  464. assert len(relationships) == 1
  465. relation_data = next(iter(relationships.values()))[0]
  466. assert relation_data["src_id"] == "Alice"
  467. assert relation_data["tgt_id"] == "Acme Corp"
  468. @pytest.mark.offline
  469. @pytest.mark.asyncio
  470. async def test_text_mode_gleaned_relation_can_reference_prior_entity():
  471. from lightrag.operate import extract_entities
  472. global_config = _make_global_config(use_json=False, max_gleaning=1)
  473. llm_func = global_config["llm_model_func"]
  474. llm_func.side_effect = _TEXT_MODE_CROSS_PASS_RELATION_RESPONSES
  475. with patch("lightrag.operate.logger"):
  476. chunk_results = await extract_entities(
  477. chunks=_make_chunks(),
  478. global_config=global_config,
  479. )
  480. entities, relationships = chunk_results[0]
  481. assert set(entities.keys()) == {"Alice", "Acme Corp"}
  482. assert len(relationships) == 1
  483. relation_data = next(iter(relationships.values()))[0]
  484. assert relation_data["src_id"] == "Alice"
  485. assert relation_data["tgt_id"] == "Acme Corp"
  486. @pytest.mark.offline
  487. def test_addon_params_default_includes_entity_type_prompt_file_env(tmp_path):
  488. _require_yaml()
  489. from lightrag import LightRAG
  490. prompt_dir = tmp_path / "entity_type"
  491. prompt_dir.mkdir()
  492. _write_prompt_profile(
  493. prompt_dir / "entity_type_prompt.sample.yml",
  494. text_examples=[_text_profile_example("Env Default Example")],
  495. )
  496. with patch.dict(
  497. os.environ,
  498. {
  499. "SUMMARY_LANGUAGE": "English",
  500. "ENTITY_TYPE_PROMPT_FILE": "entity_type_prompt.sample.yml",
  501. },
  502. ):
  503. with _patch_prompt_dir(prompt_dir):
  504. rag = LightRAG(
  505. working_dir=str(tmp_path / "rag-default-env"),
  506. llm_model_func=AsyncMock(),
  507. embedding_func=_dummy_embedding_func(),
  508. entity_extraction_use_json=False,
  509. )
  510. assert (
  511. rag.addon_params["entity_type_prompt_file"] == "entity_type_prompt.sample.yml"
  512. )
  513. @pytest.mark.offline
  514. @pytest.mark.asyncio
  515. async def test_text_mode_prompt_file_injects_examples_and_guidance():
  516. _require_yaml()
  517. from lightrag.operate import extract_entities
  518. guidance = "- ExampleType: Injected guidance"
  519. example_label = "Custom Text Example"
  520. prompt_profile = {
  521. "entity_types_guidance": guidance,
  522. "entity_extraction_examples": [_text_profile_example(example_label)],
  523. "entity_extraction_json_examples": [],
  524. }
  525. global_config = _make_global_config(
  526. prompt_profile=prompt_profile,
  527. use_json=False,
  528. )
  529. llm_func = global_config["llm_model_func"]
  530. llm_func.return_value = _TEXT_MODE_RESPONSE
  531. with patch("lightrag.operate.logger"):
  532. await extract_entities(chunks=_make_chunks(), global_config=global_config)
  533. call_kwargs = llm_func.call_args_list[0][1]
  534. system_prompt = call_kwargs.get("system_prompt", "")
  535. assert guidance in system_prompt
  536. assert example_label in system_prompt
  537. @pytest.mark.offline
  538. @pytest.mark.asyncio
  539. async def test_json_mode_prompt_file_injects_examples_and_guidance():
  540. _require_yaml()
  541. from lightrag.operate import extract_entities
  542. guidance = "- ExampleType: Injected JSON guidance"
  543. example_label = "Custom Json Example"
  544. prompt_profile = {
  545. "entity_types_guidance": guidance,
  546. "entity_extraction_examples": [],
  547. "entity_extraction_json_examples": [_json_profile_example(example_label)],
  548. }
  549. global_config = _make_global_config(
  550. prompt_profile=prompt_profile,
  551. use_json=True,
  552. )
  553. llm_func = global_config["llm_model_func"]
  554. llm_func.return_value = _JSON_MODE_RESPONSE
  555. with patch("lightrag.operate.logger"):
  556. await extract_entities(chunks=_make_chunks(), global_config=global_config)
  557. call_kwargs = llm_func.call_args_list[0][1]
  558. system_prompt = call_kwargs.get("system_prompt", "")
  559. assert guidance in system_prompt
  560. assert example_label in system_prompt
  561. @pytest.mark.offline
  562. @pytest.mark.asyncio
  563. async def test_prompt_file_guidance_falls_back_to_default_when_missing():
  564. _require_yaml()
  565. from lightrag.operate import extract_entities
  566. from lightrag.prompt import PROMPTS
  567. global_config = _make_global_config(
  568. prompt_profile={
  569. "entity_types_guidance": PROMPTS["default_entity_types_guidance"].rstrip(),
  570. "entity_extraction_examples": [
  571. _text_profile_example("Fallback Guidance Example")
  572. ],
  573. "entity_extraction_json_examples": [],
  574. },
  575. use_json=False,
  576. )
  577. llm_func = global_config["llm_model_func"]
  578. llm_func.return_value = _TEXT_MODE_RESPONSE
  579. with patch("lightrag.operate.logger"):
  580. await extract_entities(chunks=_make_chunks(), global_config=global_config)
  581. call_kwargs = llm_func.call_args_list[0][1]
  582. system_prompt = call_kwargs.get("system_prompt", "")
  583. assert PROMPTS["default_entity_types_guidance"] in system_prompt
  584. @pytest.mark.offline
  585. @pytest.mark.asyncio
  586. async def test_cached_prompt_profile_supplies_merged_guidance():
  587. from lightrag.operate import extract_entities
  588. merged_guidance = "- ExampleType: Addon override"
  589. global_config = _make_global_config(
  590. prompt_profile={
  591. "entity_types_guidance": merged_guidance,
  592. "entity_extraction_examples": [_text_profile_example("Override Example")],
  593. "entity_extraction_json_examples": [],
  594. },
  595. use_json=False,
  596. )
  597. llm_func = global_config["llm_model_func"]
  598. llm_func.return_value = _TEXT_MODE_RESPONSE
  599. with patch("lightrag.operate.logger"):
  600. await extract_entities(chunks=_make_chunks(), global_config=global_config)
  601. call_kwargs = llm_func.call_args_list[0][1]
  602. system_prompt = call_kwargs.get("system_prompt", "")
  603. assert merged_guidance in system_prompt
  604. @pytest.mark.offline
  605. def test_text_mode_prompt_file_can_omit_json_examples(tmp_path):
  606. _require_yaml()
  607. from lightrag import LightRAG
  608. prompt_dir = tmp_path / "entity_type"
  609. prompt_dir.mkdir()
  610. _write_prompt_profile(
  611. prompt_dir / "text_only.yml",
  612. text_examples=[_text_profile_example("Text Only Example")],
  613. )
  614. with _patch_prompt_dir(prompt_dir):
  615. rag = LightRAG(
  616. working_dir=str(tmp_path / "rag-text"),
  617. llm_model_func=AsyncMock(),
  618. embedding_func=_dummy_embedding_func(),
  619. entity_extraction_use_json=False,
  620. addon_params={"entity_type_prompt_file": "text_only.yml"},
  621. )
  622. assert rag.addon_params["entity_type_prompt_file"] == "text_only.yml"
  623. @pytest.mark.offline
  624. def test_json_mode_prompt_file_can_omit_text_examples(tmp_path):
  625. _require_yaml()
  626. from lightrag import LightRAG
  627. prompt_dir = tmp_path / "entity_type"
  628. prompt_dir.mkdir()
  629. _write_prompt_profile(
  630. prompt_dir / "json_only.yml",
  631. json_examples=[_json_profile_example("Json Only Example")],
  632. )
  633. with _patch_prompt_dir(prompt_dir):
  634. rag = LightRAG(
  635. working_dir=str(tmp_path / "rag-json"),
  636. llm_model_func=AsyncMock(),
  637. embedding_func=_dummy_embedding_func(),
  638. entity_extraction_use_json=True,
  639. addon_params={"entity_type_prompt_file": "json_only.yml"},
  640. )
  641. assert rag.addon_params["entity_type_prompt_file"] == "json_only.yml"
  642. @pytest.mark.offline
  643. def test_text_mode_prompt_file_requires_text_examples(tmp_path):
  644. _require_yaml()
  645. from lightrag import LightRAG
  646. prompt_dir = tmp_path / "entity_type"
  647. prompt_dir.mkdir()
  648. _write_prompt_profile(
  649. prompt_dir / "missing_text_examples.yml",
  650. json_examples=[_json_profile_example("Wrong Mode Only")],
  651. )
  652. with _patch_prompt_dir(prompt_dir):
  653. with pytest.raises(ValueError) as exc_info:
  654. LightRAG(
  655. working_dir=str(tmp_path / "rag-missing-text"),
  656. llm_model_func=AsyncMock(),
  657. embedding_func=None,
  658. entity_extraction_use_json=False,
  659. addon_params={"entity_type_prompt_file": "missing_text_examples.yml"},
  660. )
  661. assert "entity_extraction_examples" in str(exc_info.value)
  662. @pytest.mark.offline
  663. def test_json_mode_prompt_file_requires_json_examples(tmp_path):
  664. _require_yaml()
  665. from lightrag import LightRAG
  666. prompt_dir = tmp_path / "entity_type"
  667. prompt_dir.mkdir()
  668. _write_prompt_profile(
  669. prompt_dir / "missing_json_examples.yml",
  670. text_examples=[_text_profile_example("Wrong Mode Only")],
  671. )
  672. with _patch_prompt_dir(prompt_dir):
  673. with pytest.raises(ValueError) as exc_info:
  674. LightRAG(
  675. working_dir=str(tmp_path / "rag-missing-json"),
  676. llm_model_func=AsyncMock(),
  677. embedding_func=None,
  678. entity_extraction_use_json=True,
  679. addon_params={"entity_type_prompt_file": "missing_json_examples.yml"},
  680. )
  681. assert "entity_extraction_json_examples" in str(exc_info.value)
  682. @pytest.mark.offline
  683. def test_prompt_file_rejects_directory_segments(tmp_path):
  684. _require_yaml()
  685. from lightrag import LightRAG
  686. with pytest.raises(ValueError) as exc_info:
  687. LightRAG(
  688. working_dir=str(tmp_path / "rag-bad-path"),
  689. llm_model_func=AsyncMock(),
  690. embedding_func=None,
  691. addon_params={"entity_type_prompt_file": "../outside.yml"},
  692. )
  693. assert "file name only" in str(exc_info.value)
  694. @pytest.mark.offline
  695. def test_prompt_file_rejects_absolute_paths(tmp_path):
  696. _require_yaml()
  697. from lightrag import LightRAG
  698. with pytest.raises(ValueError) as exc_info:
  699. LightRAG(
  700. working_dir=str(tmp_path / "rag-abs-path"),
  701. llm_model_func=AsyncMock(),
  702. embedding_func=None,
  703. addon_params={"entity_type_prompt_file": str(tmp_path / "abs.yml")},
  704. )
  705. assert "file name only" in str(exc_info.value)
  706. @pytest.mark.offline
  707. @pytest.mark.asyncio
  708. async def test_extract_entities_uses_cached_prompt_profile_without_reloading():
  709. from lightrag.operate import extract_entities
  710. cached_profile = {
  711. "entity_types_guidance": "- ExampleType: Cached guidance",
  712. "entity_extraction_examples": [_text_profile_example("Cached Text Example")],
  713. "entity_extraction_json_examples": [],
  714. }
  715. global_config = _make_global_config(use_json=False, prompt_profile=cached_profile)
  716. llm_func = global_config["llm_model_func"]
  717. llm_func.return_value = _TEXT_MODE_RESPONSE
  718. with patch(
  719. "lightrag.operate.resolve_entity_extraction_prompt_profile",
  720. side_effect=AssertionError("should not resolve profile when cache exists"),
  721. ):
  722. with patch("lightrag.operate.logger"):
  723. await extract_entities(chunks=_make_chunks(), global_config=global_config)
  724. await extract_entities(chunks=_make_chunks(), global_config=global_config)
  725. system_prompt = llm_func.call_args_list[0][1].get("system_prompt", "")
  726. assert "Cached Text Example" in system_prompt
  727. assert "Cached guidance" in system_prompt
  728. @pytest.mark.offline
  729. def test_sample_prompt_file_matches_builtin_prompt_data():
  730. _require_yaml()
  731. from lightrag.prompt import (
  732. get_default_entity_extraction_prompt_profile,
  733. load_entity_extraction_prompt_profile,
  734. )
  735. sample_file = (
  736. Path(__file__).resolve().parents[2]
  737. / "prompts"
  738. / "samples"
  739. / "entity_type_prompt.sample.yml"
  740. )
  741. loaded_profile = load_entity_extraction_prompt_profile(sample_file)
  742. assert loaded_profile == get_default_entity_extraction_prompt_profile()
  743. @pytest.mark.offline
  744. def test_prompt_dir_env_var_overrides_default(tmp_path, monkeypatch):
  745. _require_yaml()
  746. from lightrag.prompt import (
  747. get_entity_type_prompt_dir,
  748. resolve_entity_type_prompt_path,
  749. )
  750. monkeypatch.setenv("PROMPT_DIR", str(tmp_path))
  751. expected_dir = (tmp_path / "entity_type").resolve()
  752. assert get_entity_type_prompt_dir() == expected_dir
  753. resolved = resolve_entity_type_prompt_path("custom.yml")
  754. assert resolved == expected_dir / "custom.yml"
  755. @pytest.mark.offline
  756. def test_prompt_dir_defaults_to_cwd_relative(tmp_path, monkeypatch):
  757. _require_yaml()
  758. from lightrag.prompt import get_entity_type_prompt_dir
  759. monkeypatch.delenv("PROMPT_DIR", raising=False)
  760. monkeypatch.chdir(tmp_path)
  761. assert (
  762. get_entity_type_prompt_dir() == (tmp_path / "prompts" / "entity_type").resolve()
  763. )
  764. @pytest.mark.offline
  765. def test_prompt_file_rejects_unsupported_extension(tmp_path):
  766. _require_yaml()
  767. from lightrag import LightRAG
  768. with pytest.raises(ValueError, match="'.yml' or '.yaml'"):
  769. LightRAG(
  770. working_dir=str(tmp_path / "rag-bad-ext"),
  771. llm_model_func=AsyncMock(),
  772. embedding_func=None,
  773. addon_params={"entity_type_prompt_file": "profile.txt"},
  774. )
  775. @pytest.mark.offline
  776. def test_prompt_file_malformed_yaml_raises_valueerror(tmp_path):
  777. _require_yaml()
  778. from lightrag.prompt import load_entity_extraction_prompt_profile
  779. bad_file = tmp_path / "broken.yml"
  780. bad_file.write_text("entity_types_guidance: [unclosed", encoding="utf-8")
  781. with pytest.raises(ValueError, match="invalid YAML"):
  782. load_entity_extraction_prompt_profile(bad_file)
  783. @pytest.mark.offline
  784. def test_addon_guidance_overrides_file_profile(tmp_path):
  785. _require_yaml()
  786. from lightrag.prompt import resolve_entity_extraction_prompt_profile
  787. prompt_dir = tmp_path / "entity_type"
  788. prompt_dir.mkdir()
  789. _write_prompt_profile(
  790. prompt_dir / "profile.yml",
  791. guidance="- FileType: from file",
  792. text_examples=[_text_profile_example("Merged Example")],
  793. )
  794. with _patch_prompt_dir(prompt_dir):
  795. profile = resolve_entity_extraction_prompt_profile(
  796. addon_params={
  797. "entity_type_prompt_file": "profile.yml",
  798. "entity_types_guidance": "- AddonType: from addon_params",
  799. },
  800. use_json=False,
  801. )
  802. assert profile["entity_types_guidance"] == "- AddonType: from addon_params"
  803. # File-provided examples must still be honored.
  804. assert any(
  805. "Merged Example" in example for example in profile["entity_extraction_examples"]
  806. )
  807. @pytest.mark.offline
  808. def test_explicit_addon_params_still_picks_up_env_defaults(tmp_path, monkeypatch):
  809. """Passing addon_params explicitly must not drop env-based defaults."""
  810. _require_yaml()
  811. from lightrag import LightRAG
  812. prompt_dir = tmp_path / "entity_type"
  813. prompt_dir.mkdir()
  814. _write_prompt_profile(
  815. prompt_dir / "from_env.yml",
  816. text_examples=[_text_profile_example("Env Example")],
  817. )
  818. monkeypatch.setenv("ENTITY_TYPE_PROMPT_FILE", "from_env.yml")
  819. with _patch_prompt_dir(prompt_dir):
  820. rag = LightRAG(
  821. working_dir=str(tmp_path / "rag-env-default"),
  822. llm_model_func=AsyncMock(),
  823. embedding_func=_dummy_embedding_func(),
  824. entity_extraction_use_json=False,
  825. addon_params={"language": "English"},
  826. )
  827. assert rag.addon_params["entity_type_prompt_file"] == "from_env.yml"
  828. @pytest.mark.offline
  829. def test_runtime_addon_params_item_update_refreshes_cached_values(tmp_path):
  830. _require_yaml()
  831. from lightrag import LightRAG
  832. prompt_dir = tmp_path / "entity_type"
  833. prompt_dir.mkdir()
  834. _write_prompt_profile(
  835. prompt_dir / "initial.yml",
  836. text_examples=[_text_profile_example("Initial Example")],
  837. )
  838. _write_prompt_profile(
  839. prompt_dir / "updated.yml",
  840. guidance="- UpdatedType: runtime update",
  841. text_examples=[_text_profile_example("Updated Example")],
  842. )
  843. with _patch_prompt_dir(prompt_dir):
  844. rag = LightRAG(
  845. working_dir=str(tmp_path / "rag-runtime-update"),
  846. llm_model_func=AsyncMock(),
  847. embedding_func=_dummy_embedding_func(),
  848. entity_extraction_use_json=False,
  849. addon_params={
  850. "entity_type_prompt_file": "initial.yml",
  851. "language": "English",
  852. },
  853. )
  854. rag.addon_params["entity_type_prompt_file"] = "updated.yml"
  855. rag.addon_params["language"] = "French"
  856. global_config = rag._build_global_config()
  857. assert global_config["addon_params"]["language"] == "French"
  858. assert global_config["_resolved_summary_language"] == "French"
  859. assert (
  860. global_config["_entity_extraction_prompt_profile"]["entity_types_guidance"]
  861. == "- UpdatedType: runtime update"
  862. )
  863. assert any(
  864. "Updated Example" in example
  865. for example in global_config["_entity_extraction_prompt_profile"][
  866. "entity_extraction_examples"
  867. ]
  868. )
  869. @pytest.mark.offline
  870. def test_runtime_addon_params_replacement_refreshes_cached_values(tmp_path):
  871. _require_yaml()
  872. from lightrag import LightRAG
  873. rag = LightRAG(
  874. working_dir=str(tmp_path / "rag-runtime-replace"),
  875. llm_model_func=AsyncMock(),
  876. embedding_func=_dummy_embedding_func(),
  877. entity_extraction_use_json=False,
  878. addon_params={"language": "English"},
  879. )
  880. rag.addon_params = {
  881. "language": "German",
  882. "entity_types_guidance": "- ReplacementType: runtime replace",
  883. }
  884. global_config = rag._build_global_config()
  885. assert global_config["addon_params"]["language"] == "German"
  886. assert global_config["_resolved_summary_language"] == "German"
  887. assert (
  888. global_config["_entity_extraction_prompt_profile"]["entity_types_guidance"]
  889. == "- ReplacementType: runtime replace"
  890. )
  891. @pytest.mark.offline
  892. def test_runtime_mode_flip_invalidates_cached_prompt_profile(tmp_path):
  893. _require_yaml()
  894. from lightrag import LightRAG
  895. prompt_dir = tmp_path / "entity_type"
  896. prompt_dir.mkdir()
  897. _write_prompt_profile(
  898. prompt_dir / "text_only.yml",
  899. text_examples=[_text_profile_example("Text Only Example")],
  900. )
  901. with _patch_prompt_dir(prompt_dir):
  902. rag = LightRAG(
  903. working_dir=str(tmp_path / "rag-mode-flip"),
  904. llm_model_func=AsyncMock(),
  905. embedding_func=_dummy_embedding_func(),
  906. entity_extraction_use_json=False,
  907. addon_params={"entity_type_prompt_file": "text_only.yml"},
  908. )
  909. rag._build_global_config()
  910. rag.entity_extraction_use_json = True
  911. with pytest.raises(ValueError) as exc_info:
  912. rag._build_global_config()
  913. assert "entity_extraction_json_examples" in str(exc_info.value)