test_mongo_deferred_embedding.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. """Unit tests for MongoVectorDBStorage's deferred-embedding flush pipeline.
  2. All tests use mocks — no running MongoDB instance required.
  3. Mirrors the structure of tests/kg/opensearch_impl/test_opensearch_storage.py's
  4. TestVectorStorageBatching to keep behaviour aligned across backends.
  5. """
  6. import asyncio
  7. import os
  8. import numpy as np
  9. import pytest
  10. from unittest.mock import AsyncMock, MagicMock, patch
  11. pytest.importorskip(
  12. "pymongo",
  13. reason="pymongo is required for Mongo storage tests",
  14. )
  15. from pymongo import UpdateOne, DeleteOne # type: ignore
  16. from lightrag.kg.mongo_impl import MongoVectorDBStorage
  17. pytestmark = pytest.mark.offline
  18. # ---------------------------------------------------------------------------
  19. # Fixtures and helpers
  20. # ---------------------------------------------------------------------------
  21. class MockEmbeddingFunc:
  22. def __init__(self, dim=8):
  23. self.embedding_dim = dim
  24. self.max_token_size = 512
  25. self.model_name = "mock-embed"
  26. async def __call__(self, texts, **kwargs):
  27. return np.random.rand(len(texts), self.embedding_dim).astype(np.float32)
  28. class CountingEmbeddingFunc(MockEmbeddingFunc):
  29. def __init__(self, dim=8, fail_times=0):
  30. super().__init__(dim=dim)
  31. self.fail_times = fail_times
  32. self.call_count = 0
  33. self.batches: list[list[str]] = []
  34. self.texts: list[str] = []
  35. async def __call__(self, texts, **kwargs):
  36. self.call_count += 1
  37. batch = list(texts)
  38. self.batches.append(batch)
  39. self.texts.extend(batch)
  40. if self.fail_times > 0:
  41. self.fail_times -= 1
  42. raise RuntimeError("embedding failed")
  43. return await super().__call__(texts, **kwargs)
  44. class _AsyncCursor:
  45. def __init__(self, docs):
  46. self._docs = list(docs)
  47. async def to_list(self, length=None):
  48. return list(self._docs)
  49. @pytest.fixture(autouse=True)
  50. def patch_namespace_lock(monkeypatch):
  51. """Cache real asyncio.Locks per (namespace, workspace) for shared semantics.
  52. Also unconditionally clears ``MONGODB_WORKSPACE`` so tests are insulated
  53. from shell-level env leakage: ``final_namespace`` depends on this var,
  54. and a leaked value (e.g. ``space2``) silently collapses distinct
  55. workspaces into one namespace. Tests that need an override set it
  56. explicitly via ``patch.dict(os.environ, ...)``.
  57. """
  58. monkeypatch.delenv("MONGODB_WORKSPACE", raising=False)
  59. cache: dict[tuple[str, str | None], asyncio.Lock] = {}
  60. def factory(namespace, workspace=None, enable_logging=False):
  61. key = (namespace, workspace or "")
  62. lock = cache.get(key)
  63. if lock is None:
  64. lock = asyncio.Lock()
  65. cache[key] = lock
  66. return lock
  67. with patch("lightrag.kg.mongo_impl.get_namespace_lock", side_effect=factory):
  68. yield cache
  69. def _make_storage(
  70. embed_func,
  71. *,
  72. namespace="entities",
  73. workspace="test",
  74. meta_fields=None,
  75. ):
  76. if meta_fields is None:
  77. meta_fields = {"content", "entity_name", "src_id", "tgt_id"}
  78. storage = MongoVectorDBStorage(
  79. namespace=namespace,
  80. workspace=workspace,
  81. global_config={
  82. "embedding_batch_num": 10,
  83. "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2},
  84. },
  85. embedding_func=embed_func,
  86. meta_fields=meta_fields,
  87. )
  88. # Wire a fake AsyncCollection (the only Mongo surface our code touches).
  89. storage._data = MagicMock()
  90. storage._data.bulk_write = AsyncMock()
  91. storage._data.delete_many = AsyncMock(return_value=MagicMock(deleted_count=0))
  92. storage._data.find_one = AsyncMock(return_value=None)
  93. storage._data.find = MagicMock(return_value=_AsyncCursor([]))
  94. storage.db = MagicMock() # non-None so finalize releases it
  95. from lightrag.kg.mongo_impl import get_namespace_lock
  96. storage._flush_lock = get_namespace_lock(
  97. namespace=storage.final_namespace, workspace=""
  98. )
  99. return storage
  100. # ---------------------------------------------------------------------------
  101. # Tests
  102. # ---------------------------------------------------------------------------
  103. @pytest.mark.asyncio
  104. async def test_upsert_buffers_without_embedding():
  105. embed = CountingEmbeddingFunc()
  106. s = _make_storage(embed)
  107. await s.upsert({"v1": {"content": "hello"}, "v2": {"content": "world"}})
  108. assert embed.call_count == 0
  109. assert set(s._pending_vector_docs.keys()) == {"v1", "v2"}
  110. assert s._pending_vector_docs["v1"].vector is None
  111. s._data.bulk_write.assert_not_called()
  112. @pytest.mark.asyncio
  113. async def test_index_done_callback_triggers_flush():
  114. embed = CountingEmbeddingFunc()
  115. s = _make_storage(embed)
  116. await s.upsert({"v1": {"content": "hello"}, "v2": {"content": "world"}})
  117. await s.index_done_callback()
  118. assert embed.call_count == 1
  119. s._data.bulk_write.assert_called_once()
  120. ops, kwargs = (
  121. s._data.bulk_write.call_args.args[0],
  122. s._data.bulk_write.call_args.kwargs,
  123. )
  124. assert kwargs.get("ordered") is False
  125. assert all(isinstance(op, UpdateOne) for op in ops)
  126. assert len(ops) == 2
  127. assert s._pending_vector_docs == {}
  128. @pytest.mark.asyncio
  129. async def test_repeated_upsert_same_id_embeds_once():
  130. embed = CountingEmbeddingFunc()
  131. s = _make_storage(embed)
  132. await s.upsert({"v1": {"content": "first"}})
  133. await s.upsert({"v1": {"content": "second"}})
  134. await s.upsert({"v1": {"content": "third"}})
  135. await s.index_done_callback()
  136. assert embed.call_count == 1
  137. assert embed.texts == ["third"]
  138. s._data.bulk_write.assert_called_once()
  139. @pytest.mark.asyncio
  140. async def test_deferred_embeddings_respect_batch_size():
  141. embed = CountingEmbeddingFunc()
  142. s = _make_storage(embed)
  143. s._max_batch_size = 2
  144. await s.upsert({f"v{i}": {"content": f"doc {i}"} for i in range(5)})
  145. await s.index_done_callback()
  146. assert embed.call_count == 3
  147. assert [len(b) for b in embed.batches] == [2, 2, 1]
  148. @pytest.mark.asyncio
  149. async def test_get_vectors_by_ids_lazy_embed_then_reuse_in_flush():
  150. embed = CountingEmbeddingFunc()
  151. s = _make_storage(embed)
  152. await s.upsert({"v1": {"content": "hello"}})
  153. vectors = await s.get_vectors_by_ids(["v1"])
  154. assert "v1" in vectors
  155. assert embed.call_count == 1
  156. assert s._pending_vector_docs["v1"].vector is not None
  157. await s.index_done_callback()
  158. assert embed.call_count == 1
  159. s._data.bulk_write.assert_called_once()
  160. @pytest.mark.asyncio
  161. async def test_flush_failure_keeps_buffer_no_double_embed_on_retry():
  162. embed = CountingEmbeddingFunc(fail_times=1)
  163. s = _make_storage(embed)
  164. await s.upsert({"v1": {"content": "hello"}})
  165. with pytest.raises(RuntimeError, match="embedding failed"):
  166. await s.index_done_callback()
  167. assert "v1" in s._pending_vector_docs
  168. assert s._pending_vector_docs["v1"].vector is None
  169. s._data.bulk_write.assert_not_called()
  170. await s.index_done_callback()
  171. assert embed.call_count == 2
  172. s._data.bulk_write.assert_called_once()
  173. assert s._pending_vector_docs == {}
  174. @pytest.mark.asyncio
  175. async def test_server_write_failure_keeps_buffer():
  176. embed = CountingEmbeddingFunc()
  177. s = _make_storage(embed)
  178. s._data.bulk_write.side_effect = RuntimeError("mongo down")
  179. await s.upsert({"v1": {"content": "hello"}})
  180. with pytest.raises(RuntimeError, match="mongo down"):
  181. await s.index_done_callback()
  182. assert "v1" in s._pending_vector_docs
  183. assert s._pending_vector_docs["v1"].vector is not None
  184. s._data.bulk_write.side_effect = None
  185. await s.index_done_callback()
  186. assert embed.call_count == 1
  187. assert s._pending_vector_docs == {}
  188. @pytest.mark.asyncio
  189. async def test_finalize_raises_when_buffer_unflushed_and_still_releases_client():
  190. """finalize() must release the Mongo client even when the flush fails."""
  191. from lightrag.kg.mongo_impl import ClientManager
  192. embed = CountingEmbeddingFunc()
  193. s = _make_storage(embed)
  194. s._data.bulk_write.side_effect = RuntimeError("mongo down")
  195. await s.upsert({"v1": {"content": "hello"}})
  196. with patch.object(ClientManager, "release_client", new=AsyncMock()) as rel:
  197. with pytest.raises(RuntimeError, match="finalize.*flush raised") as exc_info:
  198. await s.finalize()
  199. rel.assert_awaited_once()
  200. # Operator-diagnostic counts must appear in the message so the buffered
  201. # data loss is auditable from the log alone (1 upsert pre-loaded, 0 deletes).
  202. msg = str(exc_info.value)
  203. assert "1 pending upserts" in msg
  204. assert "0 pending deletes" in msg
  205. # Client references cleared so a second finalize doesn't release twice.
  206. assert s.db is None
  207. @pytest.mark.asyncio
  208. async def test_delete_then_upsert_same_id_keeps_upsert():
  209. embed = CountingEmbeddingFunc()
  210. s = _make_storage(embed)
  211. await s.delete(["v1"])
  212. assert "v1" in s._pending_vector_deletes
  213. await s.upsert({"v1": {"content": "hello"}})
  214. assert "v1" in s._pending_vector_docs
  215. assert "v1" not in s._pending_vector_deletes
  216. await s.index_done_callback()
  217. ops = s._data.bulk_write.call_args.args[0]
  218. assert all(isinstance(op, UpdateOne) for op in ops)
  219. @pytest.mark.asyncio
  220. async def test_upsert_then_delete_same_id_keeps_delete():
  221. embed = CountingEmbeddingFunc()
  222. s = _make_storage(embed)
  223. await s.upsert({"v1": {"content": "hello"}})
  224. await s.delete(["v1"])
  225. assert "v1" not in s._pending_vector_docs
  226. assert "v1" in s._pending_vector_deletes
  227. await s.index_done_callback()
  228. ops = s._data.bulk_write.call_args.args[0]
  229. assert len(ops) == 1
  230. assert isinstance(ops[0], DeleteOne)
  231. @pytest.mark.asyncio
  232. async def test_bulk_write_uses_update_one_and_delete_one_mix():
  233. embed = CountingEmbeddingFunc()
  234. s = _make_storage(embed)
  235. await s.upsert({"u1": {"content": "u1"}, "u2": {"content": "u2"}})
  236. await s.delete(["d1", "d2"])
  237. await s.index_done_callback()
  238. ops = s._data.bulk_write.call_args.args[0]
  239. op_types = {type(op).__name__ for op in ops}
  240. assert op_types == {"UpdateOne", "DeleteOne"}
  241. assert sum(isinstance(op, UpdateOne) for op in ops) == 2
  242. assert sum(isinstance(op, DeleteOne) for op in ops) == 2
  243. @pytest.mark.asyncio
  244. async def test_delete_entity_relation_raises_on_server_failure():
  245. embed = CountingEmbeddingFunc()
  246. s = _make_storage(embed)
  247. s._data.find = MagicMock(
  248. return_value=_AsyncCursor([{"_id": "rel1"}, {"_id": "rel2"}])
  249. )
  250. s._data.delete_many = AsyncMock(side_effect=RuntimeError("mongo delete failed"))
  251. with pytest.raises(RuntimeError, match="mongo delete failed"):
  252. await s.delete_entity_relation("X")
  253. @pytest.mark.asyncio
  254. async def test_delete_entity_relation_prunes_pending_buffer():
  255. embed = CountingEmbeddingFunc()
  256. s = _make_storage(embed)
  257. await s.upsert(
  258. {
  259. "rel-A-B": {"content": "A → B", "src_id": "A", "tgt_id": "B"},
  260. "rel-C-D": {"content": "C → D", "src_id": "C", "tgt_id": "D"},
  261. }
  262. )
  263. s._data.find = MagicMock(return_value=_AsyncCursor([]))
  264. await s.delete_entity_relation("A")
  265. assert "rel-A-B" not in s._pending_vector_docs
  266. assert "rel-C-D" in s._pending_vector_docs
  267. @pytest.mark.asyncio
  268. async def test_get_by_id_buffer_excludes_vector():
  269. embed = CountingEmbeddingFunc()
  270. s = _make_storage(embed)
  271. await s.upsert({"v1": {"content": "hello", "entity_name": "E1"}})
  272. doc = await s.get_by_id("v1")
  273. assert doc is not None
  274. assert doc["id"] == "v1"
  275. assert doc.get("entity_name") == "E1"
  276. assert "vector" not in doc
  277. s._data.find_one.assert_not_called()
  278. @pytest.mark.asyncio
  279. async def test_get_by_id_fallback_projects_out_vector():
  280. """Server-side find_one must request projection={'vector': 0}."""
  281. embed = CountingEmbeddingFunc()
  282. s = _make_storage(embed)
  283. s._data.find_one = AsyncMock(
  284. return_value={"_id": "v9", "entity_name": "X", "created_at": 0}
  285. )
  286. doc = await s.get_by_id("v9")
  287. assert doc is not None
  288. assert "vector" not in doc
  289. args, kwargs = s._data.find_one.call_args.args, s._data.find_one.call_args.kwargs
  290. # projection is positional arg #2 in Mongo's API.
  291. projection = args[1] if len(args) > 1 else kwargs.get("projection")
  292. assert projection == {"vector": 0}
  293. @pytest.mark.asyncio
  294. async def test_get_by_ids_fallback_projects_out_vector():
  295. embed = CountingEmbeddingFunc()
  296. s = _make_storage(embed)
  297. s._data.find = MagicMock(
  298. return_value=_AsyncCursor(
  299. [{"_id": "a", "entity_name": "A"}, {"_id": "b", "entity_name": "B"}]
  300. )
  301. )
  302. docs = await s.get_by_ids(["a", "b"])
  303. assert len(docs) == 2
  304. assert all("vector" not in d for d in docs if d)
  305. args, kwargs = s._data.find.call_args.args, s._data.find.call_args.kwargs
  306. projection = args[1] if len(args) > 1 else kwargs.get("projection")
  307. assert projection == {"vector": 0}
  308. @pytest.mark.asyncio
  309. async def test_get_by_id_returns_none_for_pending_delete():
  310. embed = CountingEmbeddingFunc()
  311. s = _make_storage(embed)
  312. await s.delete(["v1"])
  313. assert await s.get_by_id("v1") is None
  314. s._data.find_one.assert_not_called()
  315. @pytest.mark.asyncio
  316. async def test_env_workspace_override_shares_flush_lock(patch_namespace_lock):
  317. cache = patch_namespace_lock
  318. embed = CountingEmbeddingFunc()
  319. with patch.dict(os.environ, {"MONGODB_WORKSPACE": "shared_ws"}, clear=False):
  320. a = _make_storage(embed, workspace="caller_a")
  321. b = _make_storage(embed, workspace="caller_b")
  322. assert a.final_namespace == b.final_namespace == "shared_ws_entities"
  323. assert a._flush_lock is b._flush_lock
  324. assert len([k for k in cache if k[0] == "shared_ws_entities"]) == 1
  325. @pytest.mark.asyncio
  326. async def test_same_workspace_param_shares_flush_lock(patch_namespace_lock):
  327. """Plain ctor path (no MONGODB_WORKSPACE env): same workspace → shared lock.
  328. Companion to ``test_env_workspace_override_shares_flush_lock``; together
  329. they cover both ways two instances can land on the same final_namespace.
  330. The autouse fixture clears MONGODB_WORKSPACE so this exercises the plain
  331. constructor path, not the env-override path.
  332. """
  333. cache = patch_namespace_lock
  334. embed = CountingEmbeddingFunc()
  335. a = _make_storage(embed, workspace="caller")
  336. b = _make_storage(embed, workspace="caller")
  337. assert a.final_namespace == b.final_namespace == "caller_entities"
  338. assert a._flush_lock is b._flush_lock
  339. assert len([k for k in cache if k[0] == "caller_entities"]) == 1
  340. @pytest.mark.asyncio
  341. async def test_distinct_namespaces_get_independent_locks():
  342. embed = CountingEmbeddingFunc()
  343. a = _make_storage(embed, workspace="a")
  344. b = _make_storage(embed, workspace="b")
  345. assert a.final_namespace != b.final_namespace
  346. assert a._flush_lock is not b._flush_lock
  347. @pytest.mark.asyncio
  348. async def test_concurrent_upsert_and_flush_serialize_on_lock():
  349. """upsert() and index_done_callback() racing on the same namespace must
  350. not corrupt the buffer or split a single doc across two embed calls.
  351. Drives a slow embed (asyncio.Event-gated) so the flush genuinely holds
  352. the lock while a second coroutine attempts upsert mid-flight. Asserts:
  353. - the late upsert lands in the buffer (not silently dropped)
  354. - it is *not* embedded by the in-flight flush (still pending after)
  355. - a follow-up flush picks it up cleanly with exactly one extra embed
  356. Note: we replace ``s.embedding_func`` directly (not ``patch.object`` on
  357. ``embed.__call__``) because Python dispatches ``embed(...)`` through
  358. ``type(embed).__call__``, bypassing any instance-level patch.
  359. """
  360. embed = CountingEmbeddingFunc()
  361. s = _make_storage(embed)
  362. embed_gate = asyncio.Event()
  363. flush_entered = asyncio.Event()
  364. original_embed = s.embedding_func
  365. async def gated_embed(texts, **kwargs):
  366. flush_entered.set()
  367. await embed_gate.wait()
  368. return await original_embed(texts, **kwargs)
  369. await s.upsert({"v1": {"content": "first"}})
  370. s.embedding_func = gated_embed
  371. try:
  372. flush_task = asyncio.create_task(s.index_done_callback())
  373. await flush_entered.wait() # flush is now holding _flush_lock
  374. # This upsert must wait on _flush_lock; schedule it concurrently.
  375. late_upsert = asyncio.create_task(s.upsert({"v2": {"content": "late"}}))
  376. # Give the event loop a chance to actually start late_upsert and
  377. # confirm it is blocked on the lock (still no v2 in buffer).
  378. for _ in range(5):
  379. await asyncio.sleep(0)
  380. assert "v2" not in s._pending_vector_docs
  381. assert not late_upsert.done()
  382. embed_gate.set()
  383. await flush_task
  384. await late_upsert
  385. finally:
  386. s.embedding_func = original_embed
  387. # Flush embedded v1 only; v2 arrived after the docs_to_embed snapshot.
  388. assert embed.call_count == 1
  389. assert embed.batches == [["first"]]
  390. assert "v1" not in s._pending_vector_docs # flushed
  391. assert "v2" in s._pending_vector_docs # still buffered
  392. s._data.bulk_write.assert_called_once()
  393. # Next flush picks up the late upsert without re-embedding v1.
  394. await s.index_done_callback()
  395. assert embed.call_count == 2
  396. assert embed.batches[-1] == ["late"]
  397. assert s._pending_vector_docs == {}
  398. assert s._data.bulk_write.call_count == 2
  399. @pytest.mark.asyncio
  400. async def test_drop_clears_pending_buffers():
  401. embed = CountingEmbeddingFunc()
  402. s = _make_storage(embed)
  403. with patch.object(s, "create_vector_index_if_not_exists", new=AsyncMock()):
  404. await s.upsert({"v1": {"content": "hello"}})
  405. await s.delete(["v2"])
  406. assert s._pending_vector_docs and s._pending_vector_deletes
  407. result = await s.drop()
  408. assert result["status"] == "success"
  409. assert s._pending_vector_docs == {}
  410. assert s._pending_vector_deletes == set()