test_qdrant_deferred_embedding.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. """Unit tests for QdrantVectorDBStorage's deferred-embedding flush pipeline.
  2. All tests use mocks — no running Qdrant instance required.
  3. Mirrors the structure of tests/kg/opensearch_impl/test_opensearch_storage.py's
  4. TestVectorStorageBatching to keep behaviour aligned across backends.
  5. """
  6. import asyncio
  7. import os
  8. import numpy as np
  9. import pytest
  10. from unittest.mock import MagicMock, patch
  11. pytest.importorskip(
  12. "qdrant_client",
  13. reason="qdrant-client is required for Qdrant storage tests",
  14. )
  15. from lightrag.kg.qdrant_impl import ( # noqa: E402
  16. QdrantVectorDBStorage,
  17. compute_mdhash_id_for_qdrant,
  18. )
  19. pytestmark = pytest.mark.offline
  20. # ---------------------------------------------------------------------------
  21. # Fixtures and helpers
  22. # ---------------------------------------------------------------------------
  23. class MockEmbeddingFunc:
  24. def __init__(self, dim=8):
  25. self.embedding_dim = dim
  26. self.max_token_size = 512
  27. self.model_name = "mock-embed"
  28. async def __call__(self, texts, **kwargs):
  29. return np.random.rand(len(texts), self.embedding_dim).astype(np.float32)
  30. class CountingEmbeddingFunc(MockEmbeddingFunc):
  31. def __init__(self, dim=8, fail_times=0):
  32. super().__init__(dim=dim)
  33. self.fail_times = fail_times
  34. self.call_count = 0
  35. self.batches: list[list[str]] = []
  36. self.texts: list[str] = []
  37. async def __call__(self, texts, **kwargs):
  38. self.call_count += 1
  39. batch = list(texts)
  40. self.batches.append(batch)
  41. self.texts.extend(batch)
  42. if self.fail_times > 0:
  43. self.fail_times -= 1
  44. raise RuntimeError("embedding failed")
  45. return await super().__call__(texts, **kwargs)
  46. @pytest.fixture(autouse=True)
  47. def patch_namespace_lock():
  48. """Cache real asyncio.Locks per (namespace, workspace) for shared semantics."""
  49. cache: dict[tuple[str, str | None], asyncio.Lock] = {}
  50. def factory(namespace, workspace=None, enable_logging=False):
  51. key = (namespace, workspace or "")
  52. lock = cache.get(key)
  53. if lock is None:
  54. lock = asyncio.Lock()
  55. cache[key] = lock
  56. return lock
  57. with patch("lightrag.kg.qdrant_impl.get_namespace_lock", side_effect=factory):
  58. yield cache
  59. def _make_storage(
  60. embed_func,
  61. *,
  62. namespace="entities",
  63. workspace="test_ws",
  64. meta_fields=None,
  65. ):
  66. if meta_fields is None:
  67. meta_fields = {"content", "entity_name", "src_id", "tgt_id"}
  68. # Bypass real initialization paths (e.g. model suffix generation),
  69. # mirroring the existing pattern in test_qdrant_upsert_batching.py.
  70. storage = QdrantVectorDBStorage.__new__(QdrantVectorDBStorage)
  71. storage.workspace = workspace
  72. storage.namespace = namespace
  73. storage.effective_workspace = workspace
  74. storage.final_namespace = f"lightrag_vdb_{namespace}_mock"
  75. storage.meta_fields = meta_fields
  76. storage.embedding_func = embed_func
  77. storage._max_batch_size = 10
  78. storage._max_upsert_payload_bytes = 16 * 1024 * 1024
  79. storage._max_upsert_points_per_batch = 128
  80. storage._pending_vector_docs = {}
  81. storage._pending_vector_deletes = set()
  82. storage._client = MagicMock()
  83. storage._client.upsert = MagicMock()
  84. storage._client.delete = MagicMock()
  85. storage._client.retrieve = MagicMock(return_value=[])
  86. storage._client.scroll = MagicMock(return_value=([], None))
  87. from lightrag.kg.qdrant_impl import get_namespace_lock
  88. storage._flush_lock = get_namespace_lock(
  89. namespace=storage.final_namespace, workspace=storage.effective_workspace
  90. )
  91. return storage
  92. # ---------------------------------------------------------------------------
  93. # Tests
  94. # ---------------------------------------------------------------------------
  95. @pytest.mark.asyncio
  96. async def test_upsert_buffers_without_embedding():
  97. embed = CountingEmbeddingFunc()
  98. s = _make_storage(embed)
  99. await s.upsert({"v1": {"content": "hello"}, "v2": {"content": "world"}})
  100. assert embed.call_count == 0
  101. assert set(s._pending_vector_docs.keys()) == {"v1", "v2"}
  102. assert s._pending_vector_docs["v1"].vector is None
  103. s._client.upsert.assert_not_called()
  104. @pytest.mark.asyncio
  105. async def test_index_done_callback_triggers_flush():
  106. embed = CountingEmbeddingFunc()
  107. s = _make_storage(embed)
  108. await s.upsert({"v1": {"content": "hello"}, "v2": {"content": "world"}})
  109. await s.index_done_callback()
  110. assert embed.call_count == 1
  111. s._client.upsert.assert_called_once()
  112. kwargs = s._client.upsert.call_args.kwargs
  113. assert kwargs["collection_name"] == s.final_namespace
  114. points = kwargs["points"]
  115. assert len(points) == 2
  116. expected_ids = {
  117. compute_mdhash_id_for_qdrant("v1", prefix=s.effective_workspace),
  118. compute_mdhash_id_for_qdrant("v2", prefix=s.effective_workspace),
  119. }
  120. assert {p.id for p in points} == expected_ids
  121. assert s._pending_vector_docs == {}
  122. @pytest.mark.asyncio
  123. async def test_repeated_upsert_same_id_embeds_once():
  124. embed = CountingEmbeddingFunc()
  125. s = _make_storage(embed)
  126. await s.upsert({"v1": {"content": "first"}})
  127. await s.upsert({"v1": {"content": "second"}})
  128. await s.upsert({"v1": {"content": "third"}})
  129. await s.index_done_callback()
  130. assert embed.call_count == 1
  131. assert embed.texts == ["third"]
  132. s._client.upsert.assert_called_once()
  133. @pytest.mark.asyncio
  134. async def test_deferred_embeddings_respect_batch_size():
  135. embed = CountingEmbeddingFunc()
  136. s = _make_storage(embed)
  137. s._max_batch_size = 2
  138. await s.upsert({f"v{i}": {"content": f"doc {i}"} for i in range(5)})
  139. await s.index_done_callback()
  140. assert embed.call_count == 3
  141. assert [len(b) for b in embed.batches] == [2, 2, 1]
  142. @pytest.mark.asyncio
  143. async def test_get_vectors_by_ids_lazy_embed_then_reuse_in_flush():
  144. embed = CountingEmbeddingFunc()
  145. s = _make_storage(embed)
  146. await s.upsert({"v1": {"content": "hello"}})
  147. vectors = await s.get_vectors_by_ids(["v1"])
  148. assert "v1" in vectors
  149. assert embed.call_count == 1
  150. assert s._pending_vector_docs["v1"].vector is not None
  151. await s.index_done_callback()
  152. assert embed.call_count == 1
  153. s._client.upsert.assert_called_once()
  154. @pytest.mark.asyncio
  155. async def test_flush_failure_keeps_buffer_no_double_embed_on_retry():
  156. embed = CountingEmbeddingFunc(fail_times=1)
  157. s = _make_storage(embed)
  158. await s.upsert({"v1": {"content": "hello"}})
  159. with pytest.raises(RuntimeError, match="embedding failed"):
  160. await s.index_done_callback()
  161. assert "v1" in s._pending_vector_docs
  162. assert s._pending_vector_docs["v1"].vector is None
  163. s._client.upsert.assert_not_called()
  164. await s.index_done_callback()
  165. assert embed.call_count == 2
  166. s._client.upsert.assert_called_once()
  167. assert s._pending_vector_docs == {}
  168. @pytest.mark.asyncio
  169. async def test_server_upsert_failure_keeps_buffer():
  170. embed = CountingEmbeddingFunc()
  171. s = _make_storage(embed)
  172. s._client.upsert.side_effect = RuntimeError("qdrant down")
  173. await s.upsert({"v1": {"content": "hello"}})
  174. with pytest.raises(RuntimeError, match="qdrant down"):
  175. await s.index_done_callback()
  176. assert "v1" in s._pending_vector_docs
  177. assert s._pending_vector_docs["v1"].vector is not None
  178. s._client.upsert.side_effect = None
  179. await s.index_done_callback()
  180. assert embed.call_count == 1
  181. @pytest.mark.asyncio
  182. async def test_finalize_raises_when_buffer_unflushed():
  183. embed = CountingEmbeddingFunc()
  184. s = _make_storage(embed)
  185. s._client.upsert.side_effect = RuntimeError("transient qdrant error")
  186. await s.upsert({"v1": {"content": "hello"}})
  187. with pytest.raises(RuntimeError, match="finalize.*flush raised"):
  188. await s.finalize()
  189. assert "v1" in s._pending_vector_docs
  190. @pytest.mark.asyncio
  191. async def test_delete_then_upsert_same_id_keeps_upsert():
  192. embed = CountingEmbeddingFunc()
  193. s = _make_storage(embed)
  194. await s.delete(["v1"])
  195. assert "v1" in s._pending_vector_deletes
  196. await s.upsert({"v1": {"content": "hello"}})
  197. assert "v1" in s._pending_vector_docs
  198. assert "v1" not in s._pending_vector_deletes
  199. await s.index_done_callback()
  200. s._client.upsert.assert_called_once()
  201. s._client.delete.assert_not_called()
  202. @pytest.mark.asyncio
  203. async def test_upsert_then_delete_same_id_keeps_delete():
  204. embed = CountingEmbeddingFunc()
  205. s = _make_storage(embed)
  206. await s.upsert({"v1": {"content": "hello"}})
  207. await s.delete(["v1"])
  208. assert "v1" not in s._pending_vector_docs
  209. assert "v1" in s._pending_vector_deletes
  210. await s.index_done_callback()
  211. s._client.upsert.assert_not_called()
  212. s._client.delete.assert_called_once()
  213. qdrant_delete_ids = s._client.delete.call_args.kwargs["points_selector"].points
  214. assert qdrant_delete_ids == [
  215. compute_mdhash_id_for_qdrant("v1", prefix=s.effective_workspace)
  216. ]
  217. @pytest.mark.asyncio
  218. async def test_delete_entity_relation_raises_on_server_failure():
  219. """scroll-then-delete pattern: server-side failure must bubble up."""
  220. embed = CountingEmbeddingFunc()
  221. s = _make_storage(embed)
  222. fake_point = MagicMock()
  223. fake_point.id = "qid1"
  224. s._client.scroll.return_value = ([fake_point], None)
  225. s._client.delete.side_effect = RuntimeError("qdrant delete failed")
  226. with pytest.raises(RuntimeError, match="qdrant delete failed"):
  227. await s.delete_entity_relation("X")
  228. @pytest.mark.asyncio
  229. async def test_delete_entity_relation_prunes_pending_buffer():
  230. embed = CountingEmbeddingFunc()
  231. s = _make_storage(embed)
  232. await s.upsert(
  233. {
  234. "rel-A-B": {"content": "A→B", "src_id": "A", "tgt_id": "B"},
  235. "rel-C-D": {"content": "C→D", "src_id": "C", "tgt_id": "D"},
  236. }
  237. )
  238. s._client.scroll.return_value = ([], None)
  239. await s.delete_entity_relation("A")
  240. assert "rel-A-B" not in s._pending_vector_docs
  241. assert "rel-C-D" in s._pending_vector_docs
  242. @pytest.mark.asyncio
  243. async def test_get_by_id_reads_pending_buffer_without_vector():
  244. embed = CountingEmbeddingFunc()
  245. s = _make_storage(embed)
  246. await s.upsert({"v1": {"content": "hello", "entity_name": "E1"}})
  247. doc = await s.get_by_id("v1")
  248. assert doc is not None
  249. assert doc.get("entity_name") == "E1"
  250. assert "vector" not in doc
  251. s._client.retrieve.assert_not_called()
  252. @pytest.mark.asyncio
  253. async def test_get_by_id_returns_none_for_pending_delete():
  254. embed = CountingEmbeddingFunc()
  255. s = _make_storage(embed)
  256. await s.delete(["v1"])
  257. assert await s.get_by_id("v1") is None
  258. s._client.retrieve.assert_not_called()
  259. @pytest.mark.asyncio
  260. async def test_flush_uses_build_upsert_batches_for_multiple_batches():
  261. """When the points exceed the per-batch point limit, flush calls
  262. `_client.upsert` multiple times — and a mid-batch failure keeps the
  263. entire buffer for retry.
  264. """
  265. embed = CountingEmbeddingFunc()
  266. s = _make_storage(embed)
  267. s._max_upsert_points_per_batch = 2 # force batching
  268. await s.upsert({f"v{i}": {"content": f"c{i}"} for i in range(5)})
  269. s._client.upsert.side_effect = [None, RuntimeError("batch 2 failed"), None]
  270. with pytest.raises(RuntimeError, match="batch 2 failed"):
  271. await s.index_done_callback()
  272. # Stopped at batch 2, total 2 calls so far.
  273. assert s._client.upsert.call_count == 2
  274. # Buffer preserved.
  275. assert len(s._pending_vector_docs) == 5
  276. @pytest.mark.asyncio
  277. async def test_env_workspace_override_shares_flush_lock(patch_namespace_lock):
  278. cache = patch_namespace_lock
  279. embed = CountingEmbeddingFunc()
  280. with patch.dict(os.environ, {"QDRANT_WORKSPACE": "shared_ws"}, clear=False):
  281. # Two callers passing different `workspace` would both be redirected
  282. # by the env override to "shared_ws". Since `_make_storage` skips
  283. # __post_init__, simulate the override directly:
  284. a = _make_storage(embed, workspace="shared_ws")
  285. b = _make_storage(embed, workspace="shared_ws")
  286. assert a.final_namespace == b.final_namespace
  287. assert a.effective_workspace == b.effective_workspace == "shared_ws"
  288. assert a._flush_lock is b._flush_lock
  289. assert len([k for k in cache if k == (a.final_namespace, "shared_ws")]) == 1
  290. @pytest.mark.asyncio
  291. async def test_distinct_workspaces_in_same_collection_get_independent_locks(
  292. patch_namespace_lock,
  293. ):
  294. """Same final_namespace but different workspaces → independent locks."""
  295. embed = CountingEmbeddingFunc()
  296. a = _make_storage(embed, workspace="ws_a")
  297. b = _make_storage(embed, workspace="ws_b")
  298. # final_namespace depends on namespace only (model suffix is mocked),
  299. # so the two share it, but workspaces differ → different locks.
  300. assert a.final_namespace == b.final_namespace
  301. assert a.effective_workspace != b.effective_workspace
  302. assert a._flush_lock is not b._flush_lock
  303. @pytest.mark.asyncio
  304. async def test_drop_clears_pending_buffers():
  305. embed = CountingEmbeddingFunc()
  306. s = _make_storage(embed)
  307. await s.upsert({"v1": {"content": "hello"}})
  308. await s.delete(["v2"])
  309. assert s._pending_vector_docs and s._pending_vector_deletes
  310. result = await s.drop()
  311. assert result["status"] == "success"
  312. assert s._pending_vector_docs == {}
  313. assert s._pending_vector_deletes == set()