test_milvus_deferred_embedding.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. """Unit tests for MilvusVectorDBStorage's deferred-embedding flush pipeline.
  2. All tests use mocks — no running Milvus instance required.
  3. Mirrors the structure of tests/kg/opensearch_impl/test_opensearch_storage.py's
  4. TestVectorStorageBatching to keep behaviour aligned across backends.
  5. """
  6. import asyncio
  7. import os
  8. import numpy as np
  9. import pytest
  10. from unittest.mock import MagicMock, patch
  11. from lightrag.kg.milvus_impl import MilvusVectorDBStorage
  12. pytestmark = pytest.mark.offline
  13. # ---------------------------------------------------------------------------
  14. # Fixtures
  15. # ---------------------------------------------------------------------------
  16. class MockEmbeddingFunc:
  17. """Mock embedding function that returns random vectors."""
  18. def __init__(self, dim=8):
  19. self.embedding_dim = dim
  20. self.max_token_size = 512
  21. self.model_name = "mock-embed"
  22. async def __call__(self, texts, **kwargs):
  23. return np.random.rand(len(texts), self.embedding_dim).astype(np.float32)
  24. class CountingEmbeddingFunc(MockEmbeddingFunc):
  25. """Embedding test double that records calls and can fail a fixed number of times."""
  26. def __init__(self, dim=8, fail_times=0):
  27. super().__init__(dim=dim)
  28. self.fail_times = fail_times
  29. self.call_count = 0
  30. self.batches: list[list[str]] = []
  31. self.texts: list[str] = []
  32. async def __call__(self, texts, **kwargs):
  33. self.call_count += 1
  34. batch = list(texts)
  35. self.batches.append(batch)
  36. self.texts.extend(batch)
  37. if self.fail_times > 0:
  38. self.fail_times -= 1
  39. raise RuntimeError("embedding failed")
  40. return await super().__call__(texts, **kwargs)
  41. @pytest.fixture(autouse=True)
  42. def patch_namespace_lock():
  43. """Cache real asyncio.Locks per (namespace, workspace) for shared semantics.
  44. Two storage instances whose ``final_namespace`` matches must observe the
  45. same Lock instance — this fixture lets us assert that and also exercises
  46. real serialization between concurrent flush/upsert coroutines.
  47. """
  48. cache: dict[tuple[str, str | None], asyncio.Lock] = {}
  49. def factory(namespace, workspace=None, enable_logging=False):
  50. key = (namespace, workspace or "")
  51. lock = cache.get(key)
  52. if lock is None:
  53. lock = asyncio.Lock()
  54. cache[key] = lock
  55. return lock
  56. with patch("lightrag.kg.milvus_impl.get_namespace_lock", side_effect=factory):
  57. yield cache
  58. def _make_storage(
  59. embed_func,
  60. *,
  61. namespace="entities",
  62. workspace="test",
  63. meta_fields=None,
  64. ):
  65. """Build a MilvusVectorDBStorage skipping `initialize()` (no real client)."""
  66. if meta_fields is None:
  67. meta_fields = {"content", "entity_name", "src_id", "tgt_id"}
  68. storage = MilvusVectorDBStorage(
  69. namespace=namespace,
  70. workspace=workspace,
  71. global_config={
  72. "embedding_batch_num": 10,
  73. "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2},
  74. },
  75. embedding_func=embed_func,
  76. meta_fields=meta_fields,
  77. )
  78. # Bypass real Milvus client; manually wire the bits initialize() would set.
  79. # The flush lock is already constructed in __post_init__ via the patched
  80. # get_namespace_lock factory, so no manual lock wiring is needed here.
  81. storage._client = MagicMock()
  82. storage._client.has_collection.return_value = True
  83. storage._client.upsert = MagicMock(return_value={"upsert_count": 0})
  84. storage._client.delete = MagicMock(return_value={"delete_count": 0})
  85. storage._client.query = MagicMock(return_value=[])
  86. storage._client.load_collection = MagicMock()
  87. storage._initialized = True
  88. return storage
  89. # ---------------------------------------------------------------------------
  90. # Tests: deferred embedding + batched flush
  91. # ---------------------------------------------------------------------------
  92. @pytest.mark.asyncio
  93. async def test_upsert_buffers_without_embedding():
  94. embed = CountingEmbeddingFunc()
  95. s = _make_storage(embed)
  96. await s.upsert({"v1": {"content": "hello"}, "v2": {"content": "world"}})
  97. assert embed.call_count == 0
  98. assert set(s._pending_vector_docs.keys()) == {"v1", "v2"}
  99. assert s._pending_vector_docs["v1"].vector is None
  100. assert s._pending_vector_docs["v2"].vector is None
  101. s._client.upsert.assert_not_called()
  102. @pytest.mark.asyncio
  103. async def test_index_done_callback_triggers_flush():
  104. embed = CountingEmbeddingFunc()
  105. s = _make_storage(embed)
  106. await s.upsert({"v1": {"content": "hello"}, "v2": {"content": "world"}})
  107. await s.index_done_callback()
  108. assert embed.call_count == 1
  109. s._client.upsert.assert_called_once()
  110. call_kwargs = s._client.upsert.call_args.kwargs
  111. assert call_kwargs["collection_name"] == s.final_namespace
  112. upserted = call_kwargs["data"]
  113. assert {row["id"] for row in upserted} == {"v1", "v2"}
  114. assert all("vector" in row for row in upserted)
  115. # Buffers cleared after a successful flush.
  116. assert s._pending_vector_docs == {}
  117. assert s._pending_vector_deletes == set()
  118. @pytest.mark.asyncio
  119. async def test_repeated_upsert_same_id_embeds_once():
  120. embed = CountingEmbeddingFunc()
  121. s = _make_storage(embed)
  122. await s.upsert({"v1": {"content": "first"}})
  123. await s.upsert({"v1": {"content": "second"}})
  124. await s.upsert({"v1": {"content": "third"}})
  125. await s.index_done_callback()
  126. assert embed.call_count == 1
  127. # Only the latest content survives and was embedded.
  128. assert embed.texts == ["third"]
  129. s._client.upsert.assert_called_once()
  130. @pytest.mark.asyncio
  131. async def test_deferred_embeddings_respect_batch_size():
  132. embed = CountingEmbeddingFunc()
  133. s = _make_storage(embed)
  134. s._max_batch_size = 2
  135. await s.upsert({f"v{i}": {"content": f"doc {i}"} for i in range(5)})
  136. await s.index_done_callback()
  137. # 5 docs / batch 2 → 3 batches → 3 embedding calls
  138. assert embed.call_count == 3
  139. assert [len(b) for b in embed.batches] == [2, 2, 1]
  140. @pytest.mark.asyncio
  141. async def test_get_vectors_by_ids_lazy_embed_then_reuse_in_flush():
  142. embed = CountingEmbeddingFunc()
  143. s = _make_storage(embed)
  144. await s.upsert({"v1": {"content": "hello"}})
  145. vectors = await s.get_vectors_by_ids(["v1"])
  146. assert "v1" in vectors
  147. assert embed.call_count == 1 # lazy embed inside get_vectors_by_ids
  148. # The lazy-embedded vector is cached on the pending doc.
  149. assert s._pending_vector_docs["v1"].vector is not None
  150. await s.index_done_callback()
  151. # Flush reused the cached vector — no extra embedding call.
  152. assert embed.call_count == 1
  153. s._client.upsert.assert_called_once()
  154. @pytest.mark.asyncio
  155. async def test_flush_failure_keeps_buffer_and_no_double_embed_on_retry():
  156. embed = CountingEmbeddingFunc(fail_times=1) # first flush raises
  157. s = _make_storage(embed)
  158. await s.upsert({"v1": {"content": "hello"}})
  159. with pytest.raises(RuntimeError, match="embedding failed"):
  160. await s.index_done_callback()
  161. # Buffer must remain so the next flush can retry.
  162. assert "v1" in s._pending_vector_docs
  163. assert s._pending_vector_docs["v1"].vector is None
  164. s._client.upsert.assert_not_called()
  165. # Second attempt succeeds; total embed calls is 2 (one failed + one ok),
  166. # not 3 — the same content was retried exactly once.
  167. await s.index_done_callback()
  168. assert embed.call_count == 2
  169. s._client.upsert.assert_called_once()
  170. assert s._pending_vector_docs == {}
  171. @pytest.mark.asyncio
  172. async def test_server_upsert_failure_keeps_buffer():
  173. embed = CountingEmbeddingFunc()
  174. s = _make_storage(embed)
  175. s._client.upsert.side_effect = RuntimeError("milvus down")
  176. await s.upsert({"v1": {"content": "hello"}})
  177. with pytest.raises(RuntimeError, match="milvus down"):
  178. await s.index_done_callback()
  179. # Embedding ran but server write failed; buffer must remain populated.
  180. assert "v1" in s._pending_vector_docs
  181. # Vector should be cached so retry doesn't re-embed.
  182. assert s._pending_vector_docs["v1"].vector is not None
  183. # On retry, no further embedding call; only the server write is reattempted.
  184. s._client.upsert.side_effect = None
  185. s._client.upsert.return_value = {"upsert_count": 1}
  186. await s.index_done_callback()
  187. assert embed.call_count == 1
  188. assert s._pending_vector_docs == {}
  189. @pytest.mark.asyncio
  190. async def test_finalize_raises_when_buffer_unflushed():
  191. embed = CountingEmbeddingFunc()
  192. s = _make_storage(embed)
  193. s._client.upsert.side_effect = RuntimeError("transient milvus error")
  194. await s.upsert({"v1": {"content": "hello"}})
  195. with pytest.raises(RuntimeError, match="finalize.*flush raised"):
  196. await s.finalize()
  197. # Buffer still populated — caller knows data was lost.
  198. assert "v1" in s._pending_vector_docs
  199. @pytest.mark.asyncio
  200. async def test_delete_then_upsert_same_id_keeps_upsert():
  201. embed = CountingEmbeddingFunc()
  202. s = _make_storage(embed)
  203. await s.delete(["v1"])
  204. assert "v1" in s._pending_vector_deletes
  205. await s.upsert({"v1": {"content": "hello"}})
  206. assert "v1" in s._pending_vector_docs
  207. assert "v1" not in s._pending_vector_deletes
  208. await s.index_done_callback()
  209. s._client.upsert.assert_called_once()
  210. s._client.delete.assert_not_called()
  211. @pytest.mark.asyncio
  212. async def test_upsert_then_delete_same_id_keeps_delete():
  213. embed = CountingEmbeddingFunc()
  214. s = _make_storage(embed)
  215. await s.upsert({"v1": {"content": "hello"}})
  216. await s.delete(["v1"])
  217. assert "v1" not in s._pending_vector_docs
  218. assert "v1" in s._pending_vector_deletes
  219. await s.index_done_callback()
  220. # No upsert payload, only the delete batch.
  221. s._client.upsert.assert_not_called()
  222. s._client.delete.assert_called_once()
  223. assert s._client.delete.call_args.kwargs["pks"] == ["v1"]
  224. @pytest.mark.asyncio
  225. async def test_delete_entity_relation_raises_on_server_failure():
  226. """Server-side failure must bubble up — no log-and-swallow."""
  227. embed = CountingEmbeddingFunc()
  228. s = _make_storage(embed)
  229. s._client.query.return_value = [{"id": "rel1"}, {"id": "rel2"}]
  230. s._client.delete.side_effect = RuntimeError("milvus delete failed")
  231. with pytest.raises(RuntimeError, match="milvus delete failed"):
  232. await s.delete_entity_relation("X")
  233. @pytest.mark.asyncio
  234. async def test_delete_entity_relation_prunes_pending_buffer():
  235. embed = CountingEmbeddingFunc()
  236. s = _make_storage(embed)
  237. await s.upsert(
  238. {
  239. "rel-A-B": {"content": "A → B", "src_id": "A", "tgt_id": "B"},
  240. "rel-C-D": {"content": "C → D", "src_id": "C", "tgt_id": "D"},
  241. }
  242. )
  243. s._client.query.return_value = [] # no server-side hits
  244. await s.delete_entity_relation("A")
  245. # Pending doc whose src_id == A is pruned, the other survives.
  246. assert "rel-A-B" not in s._pending_vector_docs
  247. assert "rel-C-D" in s._pending_vector_docs
  248. @pytest.mark.asyncio
  249. async def test_delete_entity_relation_diverges_when_buffer_overwrites_persisted():
  250. """Pins the deferred ↔ eager semantic divergence documented on
  251. ``delete_entity_relation``.
  252. Scenario: a persisted row ``rel-X-Y`` has ``src_id="X" / tgt_id="Y"``,
  253. and a pending upsert is about to rewrite that same id so it would
  254. instead carry ``src_id="A" / tgt_id="B"``. A call to
  255. ``delete_entity_relation("A")`` arrives before the buffer is flushed.
  256. Expected (deferred mode, current implementation):
  257. * server-side filter ``src_id == "A" or tgt_id == "A"`` does NOT
  258. match the persisted row (its src/tgt are still X/Y), so the
  259. server-side delete is a no-op;
  260. * the buffered upsert IS pruned (its buffered src/tgt match);
  261. * net effect: persisted ``rel-X-Y`` (old values) survives and the
  262. pending overwrite is lost.
  263. Under eager ordering (upsert → flush → delete) the persisted row
  264. would have been rewritten first and then matched by the filter, so
  265. the final state would have been a deleted ``rel-X-Y``. This test
  266. locks in the divergence so a future refactor can't silently change
  267. it without touching the docstring.
  268. """
  269. embed = CountingEmbeddingFunc()
  270. s = _make_storage(embed)
  271. # Buffered upsert rewriting an (assumed) already-persisted rel-X-Y
  272. # so that its new src/tgt would match entity "A".
  273. await s.upsert({"rel-X-Y": {"content": "A → B", "src_id": "A", "tgt_id": "B"}})
  274. assert "rel-X-Y" in s._pending_vector_docs
  275. # Server still sees the OLD persisted row (src_id="X" / tgt_id="Y"),
  276. # so a filter on entity "A" returns nothing.
  277. s._client.query.return_value = []
  278. await s.delete_entity_relation("A")
  279. # Buffered overwrite is pruned (matches buffered src/tgt view) …
  280. assert "rel-X-Y" not in s._pending_vector_docs
  281. # … but the server-side delete is not issued, because the filter
  282. # didn't match the persisted row's actual src/tgt.
  283. s._client.delete.assert_not_called()
  284. @pytest.mark.asyncio
  285. async def test_delete_entity_relation_eager_ordering_matches_persisted():
  286. """Counterpart to the divergence test: if the caller flushes before
  287. invoking ``delete_entity_relation``, the persisted row reflects the
  288. buffered overwrite and the server-side filter catches it.
  289. This documents the recommended workaround called out in the
  290. ``delete_entity_relation`` docstring: ``index_done_callback()`` first
  291. when eager-equivalent semantics are required.
  292. """
  293. embed = CountingEmbeddingFunc()
  294. s = _make_storage(embed)
  295. await s.upsert({"rel-X-Y": {"content": "A → B", "src_id": "A", "tgt_id": "B"}})
  296. await s.index_done_callback() # buffered upsert is now persisted
  297. assert s._pending_vector_docs == {}
  298. s._client.upsert.assert_called_once()
  299. # With the row persisted, the server filter on entity "A" now hits.
  300. s._client.query.return_value = [{"id": "rel-X-Y"}]
  301. await s.delete_entity_relation("A")
  302. s._client.delete.assert_called_once()
  303. assert s._client.delete.call_args.kwargs["pks"] == ["rel-X-Y"]
  304. @pytest.mark.asyncio
  305. async def test_get_by_id_reads_pending_buffer_without_vector():
  306. embed = CountingEmbeddingFunc()
  307. s = _make_storage(embed)
  308. await s.upsert({"v1": {"content": "hello", "entity_name": "E1"}})
  309. doc = await s.get_by_id("v1")
  310. assert doc is not None
  311. assert doc["id"] == "v1"
  312. assert doc.get("entity_name") == "E1"
  313. assert "vector" not in doc
  314. # Server was not queried because the buffer answered the read.
  315. s._client.query.assert_not_called()
  316. @pytest.mark.asyncio
  317. async def test_get_by_id_returns_none_for_pending_delete():
  318. embed = CountingEmbeddingFunc()
  319. s = _make_storage(embed)
  320. await s.delete(["v1"])
  321. assert await s.get_by_id("v1") is None
  322. s._client.query.assert_not_called()
  323. @pytest.mark.asyncio
  324. async def test_env_workspace_override_shares_flush_lock(patch_namespace_lock):
  325. """Two instances whose final_namespace collides must share the flush lock."""
  326. cache = patch_namespace_lock
  327. embed = CountingEmbeddingFunc()
  328. with patch.dict(os.environ, {"MILVUS_WORKSPACE": "shared_ws"}, clear=False):
  329. a = _make_storage(embed, workspace="caller_a")
  330. b = _make_storage(embed, workspace="caller_b")
  331. assert a.final_namespace == b.final_namespace == "shared_ws_entities"
  332. assert a._flush_lock is b._flush_lock
  333. # Sanity: only one lock object was cached for that final_namespace.
  334. assert len([k for k in cache if k[0] == "shared_ws_entities"]) == 1
  335. @pytest.mark.asyncio
  336. async def test_distinct_namespaces_get_independent_locks(patch_namespace_lock):
  337. """Different final_namespaces must NOT share a lock."""
  338. embed = CountingEmbeddingFunc()
  339. # Two instances with no env override and different workspaces produce
  340. # different final_namespaces ("a_entities" vs "b_entities").
  341. a = _make_storage(embed, workspace="a")
  342. b = _make_storage(embed, workspace="b")
  343. assert a.final_namespace != b.final_namespace
  344. assert a._flush_lock is not b._flush_lock
  345. @pytest.mark.asyncio
  346. async def test_mixed_upsert_and_delete_in_single_flush():
  347. """A flush carrying both pending upserts and pending deletes (on disjoint
  348. ids) must dispatch one server upsert and one server delete in a single
  349. pass, then clear both buffers."""
  350. embed = CountingEmbeddingFunc()
  351. s = _make_storage(embed)
  352. await s.upsert({"a": {"content": "alpha"}})
  353. await s.delete(["b"])
  354. assert set(s._pending_vector_docs.keys()) == {"a"}
  355. assert s._pending_vector_deletes == {"b"}
  356. await s.index_done_callback()
  357. s._client.upsert.assert_called_once()
  358. upsert_kwargs = s._client.upsert.call_args.kwargs
  359. assert {row["id"] for row in upsert_kwargs["data"]} == {"a"}
  360. s._client.delete.assert_called_once()
  361. assert s._client.delete.call_args.kwargs["pks"] == ["b"]
  362. # Both buffers cleared after a successful flush.
  363. assert s._pending_vector_docs == {}
  364. assert s._pending_vector_deletes == set()
  365. @pytest.mark.asyncio
  366. async def test_finalize_clean_flush_no_raise():
  367. """Happy-path counterpart to test_finalize_raises_when_buffer_unflushed:
  368. a successful flush during finalize() must leave both buffers empty and
  369. must not raise."""
  370. embed = CountingEmbeddingFunc()
  371. s = _make_storage(embed)
  372. await s.upsert({"v1": {"content": "hello"}})
  373. await s.delete(["v2"])
  374. await s.finalize() # must not raise
  375. s._client.upsert.assert_called_once()
  376. s._client.delete.assert_called_once()
  377. assert s._pending_vector_docs == {}
  378. assert s._pending_vector_deletes == set()
  379. @pytest.mark.asyncio
  380. async def test_drop_clears_pending_buffers():
  381. embed = CountingEmbeddingFunc()
  382. s = _make_storage(embed)
  383. s._client.has_collection.return_value = False # skip drop_collection call
  384. # Stub out _create_collection_if_not_exist to avoid hitting MilvusIndexConfig logic.
  385. with patch.object(s, "_create_collection_if_not_exist"):
  386. await s.upsert({"v1": {"content": "hello"}})
  387. await s.delete(["v2"])
  388. assert s._pending_vector_docs and s._pending_vector_deletes
  389. result = await s.drop()
  390. assert result["status"] == "success"
  391. assert s._pending_vector_docs == {}
  392. assert s._pending_vector_deletes == set()