test_batch_graph_operations.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872
  1. """
  2. Unit tests for batch graph operations (PR #2910 follow-up).
  3. Verifies:
  4. 1. BaseGraphStorage default batch methods fall back to serial single-item calls.
  5. 2. NetworkXStorage overrides batch methods with optimized in-memory operations.
  6. 3. ainsert_custom_kg uses the batch interface end-to-end (no hasattr guards).
  7. 4. has_nodes_batch returns only existing nodes, including newly inserted ones.
  8. 5. upsert_edges_batch and upsert_nodes_batch are idempotent (safe to call twice).
  9. """
  10. import time
  11. import tempfile
  12. import pytest
  13. import numpy as np
  14. from unittest.mock import AsyncMock
  15. from lightrag.kg.networkx_impl import NetworkXStorage
  16. from lightrag.kg.shared_storage import initialize_share_data
  17. from lightrag.utils import EmbeddingFunc, make_relation_vdb_ids
  18. # ---------------------------------------------------------------------------
  19. # Helpers
  20. # ---------------------------------------------------------------------------
  21. GLOBAL_CONFIG = {
  22. "embedding_batch_num": 10,
  23. "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.5},
  24. "working_dir": "/tmp/test_batch_graph",
  25. }
  26. async def _raw_embedding_func(texts):
  27. return np.random.rand(len(texts), 10)
  28. mock_embedding_func = EmbeddingFunc(
  29. embedding_dim=10,
  30. max_token_size=512,
  31. func=_raw_embedding_func,
  32. )
  33. def make_networkx_storage(tmp_dir: str) -> NetworkXStorage:
  34. config = dict(GLOBAL_CONFIG, working_dir=tmp_dir)
  35. initialize_share_data()
  36. storage = NetworkXStorage(
  37. namespace="test_graph",
  38. workspace="test_ws",
  39. global_config=config,
  40. embedding_func=_raw_embedding_func,
  41. )
  42. return storage
  43. def _make_node(entity_id: str, entity_type: str = "TEST") -> dict:
  44. return {
  45. "entity_id": entity_id,
  46. "entity_type": entity_type,
  47. "description": f"Description of {entity_id}",
  48. "source_id": "chunk-1",
  49. "file_path": "test.txt",
  50. "created_at": int(time.time()),
  51. }
  52. def _make_edge(weight: float = 1.0) -> dict:
  53. return {
  54. "weight": weight,
  55. "description": "test edge",
  56. "keywords": "test",
  57. "source_id": "chunk-1",
  58. "file_path": "test.txt",
  59. "created_at": int(time.time()),
  60. }
  61. # ---------------------------------------------------------------------------
  62. # 1. BaseGraphStorage default implementations delegate to single-item methods
  63. # ---------------------------------------------------------------------------
  64. class TestBaseGraphStorageDefaults:
  65. """
  66. Use NetworkXStorage as a concrete instance but spy on the single-item
  67. methods to verify the default batch implementations delegate correctly.
  68. """
  69. @pytest.mark.offline
  70. @pytest.mark.asyncio
  71. async def test_upsert_nodes_batch_calls_upsert_node(self):
  72. with tempfile.TemporaryDirectory() as tmp:
  73. storage = make_networkx_storage(tmp)
  74. await storage.initialize()
  75. nodes = [
  76. ("NodeA", _make_node("NodeA")),
  77. ("NodeB", _make_node("NodeB")),
  78. ]
  79. call_log: list[str] = []
  80. original = storage.upsert_node
  81. async def spy(node_id, *, node_data):
  82. call_log.append(node_id)
  83. return await original(node_id, node_data=node_data)
  84. # Temporarily replace the optimised override with the base default
  85. async def base_upsert_nodes_batch(self, nodes):
  86. for node_id, node_data in nodes:
  87. await self.upsert_node(node_id, node_data=node_data)
  88. storage.upsert_node = spy # type: ignore[assignment]
  89. await base_upsert_nodes_batch(storage, nodes)
  90. assert call_log == ["NodeA", "NodeB"]
  91. @pytest.mark.offline
  92. @pytest.mark.asyncio
  93. async def test_has_nodes_batch_calls_has_node(self):
  94. with tempfile.TemporaryDirectory() as tmp:
  95. storage = make_networkx_storage(tmp)
  96. await storage.initialize()
  97. await storage.upsert_node("NodeA", node_data=_make_node("NodeA"))
  98. call_log: list[str] = []
  99. original = storage.has_node
  100. async def spy(node_id):
  101. call_log.append(node_id)
  102. return await original(node_id)
  103. async def base_has_nodes_batch(self, node_ids):
  104. existing = set()
  105. for node_id in node_ids:
  106. if await self.has_node(node_id):
  107. existing.add(node_id)
  108. return existing
  109. storage.has_node = spy # type: ignore[assignment]
  110. result = await base_has_nodes_batch(storage, ["NodeA", "NodeB"])
  111. assert call_log == ["NodeA", "NodeB"]
  112. assert result == {"NodeA"}
  113. @pytest.mark.offline
  114. @pytest.mark.asyncio
  115. async def test_upsert_edges_batch_calls_upsert_edge(self):
  116. with tempfile.TemporaryDirectory() as tmp:
  117. storage = make_networkx_storage(tmp)
  118. await storage.initialize()
  119. await storage.upsert_node("NodeA", node_data=_make_node("NodeA"))
  120. await storage.upsert_node("NodeB", node_data=_make_node("NodeB"))
  121. await storage.upsert_node("NodeC", node_data=_make_node("NodeC"))
  122. call_log: list[tuple] = []
  123. original = storage.upsert_edge
  124. async def spy(src, tgt, *, edge_data):
  125. call_log.append((src, tgt))
  126. return await original(src, tgt, edge_data=edge_data)
  127. async def base_upsert_edges_batch(self, edges):
  128. for src, tgt, edge_data in edges:
  129. await self.upsert_edge(src, tgt, edge_data=edge_data)
  130. edges = [
  131. ("NodeA", "NodeB", _make_edge()),
  132. ("NodeB", "NodeC", _make_edge()),
  133. ]
  134. storage.upsert_edge = spy # type: ignore[assignment]
  135. await base_upsert_edges_batch(storage, edges)
  136. assert call_log == [("NodeA", "NodeB"), ("NodeB", "NodeC")]
  137. # ---------------------------------------------------------------------------
  138. # 2. NetworkXStorage optimised batch implementations
  139. # ---------------------------------------------------------------------------
  140. class TestNetworkXBatchOperations:
  141. @pytest.mark.offline
  142. @pytest.mark.asyncio
  143. async def test_upsert_nodes_batch_inserts_all_nodes(self):
  144. with tempfile.TemporaryDirectory() as tmp:
  145. storage = make_networkx_storage(tmp)
  146. await storage.initialize()
  147. nodes = [(f"Entity{i}", _make_node(f"Entity{i}")) for i in range(5)]
  148. await storage.upsert_nodes_batch(nodes)
  149. for entity_id, _ in nodes:
  150. assert await storage.has_node(entity_id), f"{entity_id} should exist"
  151. @pytest.mark.offline
  152. @pytest.mark.asyncio
  153. async def test_upsert_nodes_batch_is_idempotent(self):
  154. with tempfile.TemporaryDirectory() as tmp:
  155. storage = make_networkx_storage(tmp)
  156. await storage.initialize()
  157. node_data = _make_node("Alpha")
  158. await storage.upsert_nodes_batch([("Alpha", node_data)])
  159. await storage.upsert_nodes_batch([("Alpha", node_data)]) # second call
  160. assert await storage.has_node("Alpha")
  161. node = await storage.get_node("Alpha")
  162. assert node["entity_id"] == "Alpha"
  163. @pytest.mark.offline
  164. @pytest.mark.asyncio
  165. async def test_has_nodes_batch_returns_existing_subset(self):
  166. with tempfile.TemporaryDirectory() as tmp:
  167. storage = make_networkx_storage(tmp)
  168. await storage.initialize()
  169. await storage.upsert_nodes_batch(
  170. [
  171. ("Present1", _make_node("Present1")),
  172. ("Present2", _make_node("Present2")),
  173. ]
  174. )
  175. result = await storage.has_nodes_batch(["Present1", "Present2", "Missing"])
  176. assert result == {"Present1", "Present2"}
  177. @pytest.mark.offline
  178. @pytest.mark.asyncio
  179. async def test_has_nodes_batch_empty_input(self):
  180. with tempfile.TemporaryDirectory() as tmp:
  181. storage = make_networkx_storage(tmp)
  182. await storage.initialize()
  183. result = await storage.has_nodes_batch([])
  184. assert result == set()
  185. @pytest.mark.offline
  186. @pytest.mark.asyncio
  187. async def test_upsert_edges_batch_creates_edges(self):
  188. with tempfile.TemporaryDirectory() as tmp:
  189. storage = make_networkx_storage(tmp)
  190. await storage.initialize()
  191. await storage.upsert_nodes_batch(
  192. [
  193. ("A", _make_node("A")),
  194. ("B", _make_node("B")),
  195. ("C", _make_node("C")),
  196. ]
  197. )
  198. edges = [
  199. ("A", "B", _make_edge(1.5)),
  200. ("B", "C", _make_edge(2.0)),
  201. ]
  202. await storage.upsert_edges_batch(edges)
  203. edge_ab = await storage.get_edge("A", "B")
  204. assert edge_ab is not None
  205. assert float(edge_ab["weight"]) == pytest.approx(1.5)
  206. edge_bc = await storage.get_edge("B", "C")
  207. assert edge_bc is not None
  208. assert float(edge_bc["weight"]) == pytest.approx(2.0)
  209. @pytest.mark.offline
  210. @pytest.mark.asyncio
  211. async def test_upsert_edges_batch_is_idempotent(self):
  212. with tempfile.TemporaryDirectory() as tmp:
  213. storage = make_networkx_storage(tmp)
  214. await storage.initialize()
  215. await storage.upsert_nodes_batch(
  216. [
  217. ("X", _make_node("X")),
  218. ("Y", _make_node("Y")),
  219. ]
  220. )
  221. edge_data = _make_edge(3.0)
  222. await storage.upsert_edges_batch([("X", "Y", edge_data)])
  223. await storage.upsert_edges_batch([("X", "Y", edge_data)]) # second call
  224. edge = await storage.get_edge("X", "Y")
  225. assert edge is not None
  226. @pytest.mark.offline
  227. @pytest.mark.asyncio
  228. async def test_upsert_nodes_batch_updates_existing_node(self):
  229. with tempfile.TemporaryDirectory() as tmp:
  230. storage = make_networkx_storage(tmp)
  231. await storage.initialize()
  232. original = _make_node("Node1")
  233. await storage.upsert_nodes_batch([("Node1", original)])
  234. updated = dict(original, description="Updated description")
  235. await storage.upsert_nodes_batch([("Node1", updated)])
  236. node = await storage.get_node("Node1")
  237. assert node["description"] == "Updated description"
  238. # ---------------------------------------------------------------------------
  239. # 3. ainsert_custom_kg uses batch interface end-to-end
  240. # ---------------------------------------------------------------------------
  241. class TestAinsertCustomKgBatchPath:
  242. """
  243. Verify that ainsert_custom_kg calls the three batch methods rather than
  244. the single-item methods, using a mock graph storage backend.
  245. """
  246. def _make_custom_kg(self):
  247. return {
  248. "chunks": [
  249. {
  250. "content": "chunk content",
  251. "chunk_order_index": 0,
  252. "source_id": "src-1",
  253. }
  254. ],
  255. "entities": [
  256. {
  257. "entity_name": "EntityA",
  258. "entity_type": "CONCEPT",
  259. "description": "An entity",
  260. "source_id": "src-1",
  261. "file_path": "test.pdf",
  262. }
  263. ],
  264. "relationships": [
  265. {
  266. "src_id": "EntityA",
  267. "tgt_id": "EntityB",
  268. "description": "relates to",
  269. "keywords": "relation",
  270. "weight": 1.0,
  271. "source_id": "src-1",
  272. "file_path": "test.pdf",
  273. }
  274. ],
  275. }
  276. @pytest.mark.offline
  277. @pytest.mark.asyncio
  278. async def test_ainsert_custom_kg_calls_batch_methods(self):
  279. """upsert_nodes_batch, has_nodes_batch, upsert_edges_batch must all be called."""
  280. from lightrag import LightRAG
  281. with tempfile.TemporaryDirectory() as tmp:
  282. rag = LightRAG(
  283. working_dir=tmp,
  284. llm_model_func=AsyncMock(return_value=""),
  285. embedding_func=mock_embedding_func,
  286. )
  287. await rag.initialize_storages()
  288. graph = rag.chunk_entity_relation_graph
  289. upsert_nodes_batch = AsyncMock(wraps=graph.upsert_nodes_batch)
  290. has_nodes_batch = AsyncMock(wraps=graph.has_nodes_batch)
  291. upsert_edges_batch = AsyncMock(wraps=graph.upsert_edges_batch)
  292. graph.upsert_nodes_batch = upsert_nodes_batch
  293. graph.has_nodes_batch = has_nodes_batch
  294. graph.upsert_edges_batch = upsert_edges_batch
  295. # Mock VDB upserts to avoid needing real embeddings
  296. rag.entities_vdb.upsert = AsyncMock()
  297. rag.relationships_vdb.upsert = AsyncMock()
  298. rag.relationships_vdb.delete = AsyncMock()
  299. rag.text_chunks.upsert = AsyncMock()
  300. rag.doc_status.upsert = AsyncMock()
  301. await rag.ainsert_custom_kg(self._make_custom_kg())
  302. upsert_nodes_batch.assert_called()
  303. has_nodes_batch.assert_called()
  304. upsert_edges_batch.assert_called()
  305. await rag.finalize_storages()
  306. @pytest.mark.offline
  307. @pytest.mark.asyncio
  308. async def test_ainsert_custom_kg_canonicalizes_file_paths_before_upsert(self):
  309. """custom KG ingestion normalizes file names before touching storage."""
  310. from lightrag import LightRAG
  311. custom_kg = self._make_custom_kg()
  312. for section in ("chunks", "entities", "relationships"):
  313. for item in custom_kg[section]:
  314. item["file_path"] = "/tmp/uploads/test.[native-Fi].pdf"
  315. with tempfile.TemporaryDirectory() as tmp:
  316. rag = LightRAG(
  317. working_dir=tmp,
  318. llm_model_func=AsyncMock(return_value=""),
  319. embedding_func=mock_embedding_func,
  320. )
  321. await rag.initialize_storages()
  322. rag.entities_vdb.upsert = AsyncMock()
  323. rag.relationships_vdb.upsert = AsyncMock()
  324. rag.relationships_vdb.delete = AsyncMock()
  325. rag.text_chunks.upsert = AsyncMock()
  326. rag.doc_status.upsert = AsyncMock()
  327. await rag.ainsert_custom_kg(custom_kg)
  328. text_chunks = rag.text_chunks.upsert.call_args.args[0]
  329. assert next(iter(text_chunks.values()))["file_path"] == "test.pdf"
  330. entities = rag.entities_vdb.upsert.call_args.args[0]
  331. assert next(iter(entities.values()))["file_path"] == "test.pdf"
  332. relationships = rag.relationships_vdb.upsert.call_args.args[0]
  333. assert next(iter(relationships.values()))["file_path"] == "test.pdf"
  334. await rag.finalize_storages()
  335. @pytest.mark.offline
  336. @pytest.mark.asyncio
  337. async def test_ainsert_custom_kg_no_hasattr_needed(self):
  338. """
  339. The batch methods are always available on the base class, so no
  340. hasattr() guard should be needed. Verify that a storage backend
  341. implementing only the abstract methods (no batch overrides) still
  342. works via the default serial fallback.
  343. """
  344. from lightrag.base import BaseGraphStorage
  345. # All three batch methods should exist on the base class
  346. assert hasattr(BaseGraphStorage, "upsert_nodes_batch")
  347. assert hasattr(BaseGraphStorage, "has_nodes_batch")
  348. assert hasattr(BaseGraphStorage, "upsert_edges_batch")
  349. @pytest.mark.offline
  350. def test_neo4j_has_nodes_batch_uses_read_retry(self):
  351. pytest.importorskip("neo4j")
  352. from lightrag.kg.neo4j_impl import Neo4JStorage
  353. assert hasattr(Neo4JStorage.has_nodes_batch, "retry")
  354. assert hasattr(Neo4JStorage.upsert_nodes_batch, "retry")
  355. assert hasattr(Neo4JStorage.upsert_edges_batch, "retry")
  356. @pytest.mark.offline
  357. @pytest.mark.asyncio
  358. async def test_ainsert_custom_kg_missing_entity_nodes_created(self):
  359. """
  360. Nodes referenced in relationships but not in the entity list must
  361. be created as placeholder UNKNOWN nodes.
  362. """
  363. from lightrag import LightRAG
  364. with tempfile.TemporaryDirectory() as tmp:
  365. rag = LightRAG(
  366. working_dir=tmp,
  367. llm_model_func=AsyncMock(return_value=""),
  368. embedding_func=mock_embedding_func,
  369. )
  370. await rag.initialize_storages()
  371. rag.entities_vdb.upsert = AsyncMock()
  372. rag.relationships_vdb.upsert = AsyncMock()
  373. rag.relationships_vdb.delete = AsyncMock()
  374. rag.text_chunks.upsert = AsyncMock()
  375. rag.doc_status.upsert = AsyncMock()
  376. custom_kg = {
  377. "chunks": [
  378. {"content": "text", "chunk_order_index": 0, "source_id": "s1"}
  379. ],
  380. "entities": [], # No entities declared
  381. "relationships": [
  382. {
  383. "src_id": "ImplicitNode",
  384. "tgt_id": "AnotherImplicit",
  385. "description": "connects",
  386. "keywords": "link",
  387. "weight": 1.0,
  388. "source_id": "s1",
  389. "file_path": "test.pdf",
  390. }
  391. ],
  392. }
  393. await rag.ainsert_custom_kg(custom_kg)
  394. graph = rag.chunk_entity_relation_graph
  395. assert await graph.has_node(
  396. "ImplicitNode"
  397. ), "Implicit node should be created"
  398. assert await graph.has_node(
  399. "AnotherImplicit"
  400. ), "Implicit node should be created"
  401. await rag.finalize_storages()
  402. @pytest.mark.offline
  403. @pytest.mark.asyncio
  404. async def test_ainsert_custom_kg_deduplicates_entities_and_undirected_edges(self):
  405. from lightrag import LightRAG
  406. with tempfile.TemporaryDirectory() as tmp:
  407. rag = LightRAG(
  408. working_dir=tmp,
  409. llm_model_func=AsyncMock(return_value=""),
  410. embedding_func=mock_embedding_func,
  411. )
  412. await rag.initialize_storages()
  413. graph = rag.chunk_entity_relation_graph
  414. graph.upsert_nodes_batch = AsyncMock()
  415. graph.has_nodes_batch = AsyncMock(return_value={"EntityA"})
  416. graph.upsert_edges_batch = AsyncMock()
  417. rag.entities_vdb.upsert = AsyncMock()
  418. rag.relationships_vdb.upsert = AsyncMock()
  419. rag.relationships_vdb.delete = AsyncMock()
  420. rag.text_chunks.upsert = AsyncMock()
  421. rag.doc_status.upsert = AsyncMock()
  422. custom_kg = {
  423. "chunks": [
  424. {
  425. "content": "chunk content",
  426. "chunk_order_index": 0,
  427. "source_id": "src-1",
  428. }
  429. ],
  430. "entities": [
  431. {
  432. "entity_name": "EntityA",
  433. "entity_type": "CONCEPT",
  434. "description": "first version",
  435. "source_id": "src-1",
  436. "file_path": "test.pdf",
  437. },
  438. {
  439. "entity_name": "EntityA",
  440. "entity_type": "CONCEPT",
  441. "description": "latest version",
  442. "source_id": "src-1",
  443. "file_path": "test.pdf",
  444. },
  445. ],
  446. "relationships": [
  447. {
  448. "src_id": "EntityA",
  449. "tgt_id": "EntityB",
  450. "description": "old relation",
  451. "keywords": "first",
  452. "weight": 1.0,
  453. "source_id": "src-1",
  454. "file_path": "test.pdf",
  455. },
  456. {
  457. "src_id": "EntityB",
  458. "tgt_id": "EntityA",
  459. "description": "latest relation",
  460. "keywords": "second",
  461. "weight": 2.0,
  462. "source_id": "src-1",
  463. "file_path": "test.pdf",
  464. },
  465. ],
  466. }
  467. await rag.ainsert_custom_kg(custom_kg)
  468. entity_batch = graph.upsert_nodes_batch.await_args_list[0].args[0]
  469. assert len(entity_batch) == 1
  470. assert entity_batch[0][0] == "EntityA"
  471. assert entity_batch[0][1]["entity_type"] == "CONCEPT"
  472. assert entity_batch[0][1]["description"] == "latest version"
  473. assert entity_batch[0][1]["file_path"] == "test.pdf"
  474. assert entity_batch[0][1]["source_id"]
  475. placeholder_batch = graph.upsert_nodes_batch.await_args_list[1].args[0]
  476. assert len(placeholder_batch) == 1
  477. assert placeholder_batch[0][0] == "EntityB"
  478. edge_batch = graph.upsert_edges_batch.await_args.args[0]
  479. assert len(edge_batch) == 1
  480. assert edge_batch[0][0] == "EntityB"
  481. assert edge_batch[0][1] == "EntityA"
  482. assert edge_batch[0][2]["description"] == "latest relation"
  483. assert edge_batch[0][2]["weight"] == 2.0
  484. entity_vdb_payload = rag.entities_vdb.upsert.await_args.args[0]
  485. assert len(entity_vdb_payload) == 1
  486. only_entity = next(iter(entity_vdb_payload.values()))
  487. assert only_entity["description"] == "latest version"
  488. rel_vdb_payload = rag.relationships_vdb.upsert.await_args.args[0]
  489. assert len(rel_vdb_payload) == 1
  490. only_rel = next(iter(rel_vdb_payload.values()))
  491. assert only_rel["src_id"] == "EntityA"
  492. assert only_rel["tgt_id"] == "EntityB"
  493. assert only_rel["description"] == "latest relation"
  494. assert rag.relationships_vdb.delete.await_args.args[0] == [
  495. make_relation_vdb_ids("EntityA", "EntityB")[1]
  496. ]
  497. await rag.finalize_storages()
  498. @pytest.mark.offline
  499. @pytest.mark.asyncio
  500. async def test_ainsert_custom_kg_keeps_legacy_relation_rows_if_upsert_fails(self):
  501. from lightrag import LightRAG
  502. with tempfile.TemporaryDirectory() as tmp:
  503. rag = LightRAG(
  504. working_dir=tmp,
  505. llm_model_func=AsyncMock(return_value=""),
  506. embedding_func=mock_embedding_func,
  507. )
  508. await rag.initialize_storages()
  509. rag.entities_vdb.upsert = AsyncMock()
  510. rag.relationships_vdb.upsert = AsyncMock(side_effect=RuntimeError("boom"))
  511. rag.relationships_vdb.delete = AsyncMock()
  512. rag.text_chunks.upsert = AsyncMock()
  513. rag.doc_status.upsert = AsyncMock()
  514. custom_kg = {
  515. "chunks": [
  516. {
  517. "content": "chunk content",
  518. "chunk_order_index": 0,
  519. "source_id": "src-1",
  520. }
  521. ],
  522. "entities": [
  523. {
  524. "entity_name": "EntityA",
  525. "entity_type": "CONCEPT",
  526. "description": "Entity A",
  527. "source_id": "src-1",
  528. "file_path": "test.pdf",
  529. },
  530. {
  531. "entity_name": "EntityB",
  532. "entity_type": "CONCEPT",
  533. "description": "Entity B",
  534. "source_id": "src-1",
  535. "file_path": "test.pdf",
  536. },
  537. ],
  538. "relationships": [
  539. {
  540. "src_id": "EntityB",
  541. "tgt_id": "EntityA",
  542. "description": "latest relation",
  543. "keywords": "second",
  544. "weight": 2.0,
  545. "source_id": "src-1",
  546. "file_path": "test.pdf",
  547. },
  548. ],
  549. }
  550. with pytest.raises(RuntimeError, match="boom"):
  551. await rag.ainsert_custom_kg(custom_kg)
  552. rag.relationships_vdb.delete.assert_not_called()
  553. await rag.finalize_storages()
  554. @pytest.mark.offline
  555. @pytest.mark.asyncio
  556. async def test_get_relation_info_falls_back_to_legacy_relation_vdb_id(self):
  557. from lightrag import LightRAG
  558. with tempfile.TemporaryDirectory() as tmp:
  559. rag = LightRAG(
  560. working_dir=tmp,
  561. llm_model_func=AsyncMock(return_value=""),
  562. embedding_func=mock_embedding_func,
  563. )
  564. await rag.initialize_storages()
  565. rag.entities_vdb.upsert = AsyncMock()
  566. rag.relationships_vdb.upsert = AsyncMock()
  567. rag.relationships_vdb.delete = AsyncMock()
  568. rag.text_chunks.upsert = AsyncMock()
  569. rag.doc_status.upsert = AsyncMock()
  570. custom_kg = {
  571. "chunks": [
  572. {
  573. "content": "chunk content",
  574. "chunk_order_index": 0,
  575. "source_id": "src-1",
  576. }
  577. ],
  578. "entities": [
  579. {
  580. "entity_name": "EntityA",
  581. "entity_type": "CONCEPT",
  582. "description": "Entity A",
  583. "source_id": "src-1",
  584. "file_path": "test.pdf",
  585. },
  586. {
  587. "entity_name": "EntityB",
  588. "entity_type": "CONCEPT",
  589. "description": "Entity B",
  590. "source_id": "src-1",
  591. "file_path": "test.pdf",
  592. },
  593. ],
  594. "relationships": [
  595. {
  596. "src_id": "EntityB",
  597. "tgt_id": "EntityA",
  598. "description": "latest relation",
  599. "keywords": "second",
  600. "weight": 2.0,
  601. "source_id": "src-1",
  602. "file_path": "test.pdf",
  603. },
  604. ],
  605. }
  606. await rag.ainsert_custom_kg(custom_kg)
  607. normalized_rel_id, legacy_rel_id = make_relation_vdb_ids(
  608. "EntityA", "EntityB"
  609. )
  610. rag.relationships_vdb.get_by_id = AsyncMock(
  611. side_effect=lambda rid: {"ok": True} if rid == legacy_rel_id else None
  612. )
  613. result_ab = await rag.get_relation_info(
  614. "EntityA", "EntityB", include_vector_data=True
  615. )
  616. result_ba = await rag.get_relation_info(
  617. "EntityB", "EntityA", include_vector_data=True
  618. )
  619. assert result_ab["vector_data"] == {"ok": True}
  620. assert result_ba["vector_data"] == {"ok": True}
  621. assert [
  622. call.args[0] for call in rag.relationships_vdb.get_by_id.await_args_list
  623. ] == [
  624. normalized_rel_id,
  625. legacy_rel_id,
  626. normalized_rel_id,
  627. legacy_rel_id,
  628. ]
  629. await rag.finalize_storages()
  630. class TestPostgresBatchOrdering:
  631. @pytest.mark.offline
  632. @pytest.mark.asyncio
  633. async def test_upsert_nodes_batch_preserves_last_write_wins(self):
  634. from lightrag.kg.postgres_impl import PGGraphStorage
  635. storage = PGGraphStorage.__new__(PGGraphStorage)
  636. call_log: list[tuple[str, str]] = []
  637. async def spy(node_id, *, node_data):
  638. call_log.append((node_id, node_data["description"]))
  639. storage.upsert_node = spy # type: ignore[assignment]
  640. await PGGraphStorage.upsert_nodes_batch(
  641. storage,
  642. [
  643. ("EntityA", _make_node("EntityA")),
  644. ("EntityA", dict(_make_node("EntityA"), description="latest")),
  645. ("EntityB", _make_node("EntityB")),
  646. ],
  647. )
  648. assert call_log == [
  649. ("EntityA", "latest"),
  650. ("EntityB", "Description of EntityB"),
  651. ]
  652. @pytest.mark.offline
  653. @pytest.mark.asyncio
  654. async def test_upsert_edges_batch_preserves_last_write_wins(self):
  655. from lightrag.kg.postgres_impl import PGGraphStorage
  656. storage = PGGraphStorage.__new__(PGGraphStorage)
  657. call_log: list[tuple[str, str, float]] = []
  658. async def spy(src, tgt, *, edge_data):
  659. call_log.append((src, tgt, edge_data["weight"]))
  660. storage.upsert_edge = spy # type: ignore[assignment]
  661. await PGGraphStorage.upsert_edges_batch(
  662. storage,
  663. [
  664. ("EntityA", "EntityB", _make_edge(1.0)),
  665. ("EntityB", "EntityA", _make_edge(2.0)),
  666. ("EntityB", "EntityC", _make_edge(3.0)),
  667. ],
  668. )
  669. assert call_log == [("EntityB", "EntityA", 2.0), ("EntityB", "EntityC", 3.0)]
  670. class TestMongoBatchOrdering:
  671. @pytest.mark.offline
  672. @pytest.mark.asyncio
  673. async def test_upsert_nodes_batch_uses_ordered_bulk_write(self):
  674. pytest.importorskip("pymongo")
  675. from lightrag.kg.mongo_impl import MongoGraphStorage
  676. storage = MongoGraphStorage.__new__(MongoGraphStorage)
  677. storage.collection = AsyncMock()
  678. await MongoGraphStorage.upsert_nodes_batch(
  679. storage,
  680. [
  681. ("EntityA", _make_node("EntityA")),
  682. ("EntityA", dict(_make_node("EntityA"), description="latest")),
  683. ],
  684. )
  685. assert storage.collection.bulk_write.await_args.kwargs["ordered"] is True
  686. @pytest.mark.offline
  687. @pytest.mark.asyncio
  688. async def test_upsert_edges_batch_uses_ordered_bulk_write(self):
  689. pytest.importorskip("pymongo")
  690. from lightrag.kg.mongo_impl import MongoGraphStorage
  691. storage = MongoGraphStorage.__new__(MongoGraphStorage)
  692. storage.collection = AsyncMock()
  693. storage.edge_collection = AsyncMock()
  694. await MongoGraphStorage.upsert_edges_batch(
  695. storage,
  696. [
  697. ("EntityA", "EntityB", _make_edge(1.0)),
  698. ("EntityB", "EntityA", _make_edge(2.0)),
  699. ],
  700. )
  701. assert storage.edge_collection.bulk_write.await_args.kwargs["ordered"] is True
  702. @pytest.mark.offline
  703. @pytest.mark.asyncio
  704. async def test_upsert_edges_batch_deduplicates_source_node_upserts(self):
  705. pytest.importorskip("pymongo")
  706. from lightrag.kg.mongo_impl import MongoGraphStorage
  707. storage = MongoGraphStorage.__new__(MongoGraphStorage)
  708. storage.collection = AsyncMock()
  709. storage.edge_collection = AsyncMock()
  710. await MongoGraphStorage.upsert_edges_batch(
  711. storage,
  712. [
  713. ("EntityA", "EntityB", _make_edge(1.0)),
  714. ("EntityA", "EntityC", _make_edge(2.0)),
  715. ],
  716. )
  717. node_ops = storage.collection.bulk_write.await_args.args[0]
  718. assert len(node_ops) == 1
  719. assert node_ops[0]._filter == {"_id": "EntityA"}