test_nano_deferred_embedding.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. """Deferred-embedding coverage for ``NanoVectorDBStorage``.
  2. The storage no longer embeds eagerly in ``upsert``: it buffers a pending doc
  3. and embeds once per id at flush time (``index_done_callback`` / ``finalize``).
  4. These tests pin that contract using a counting mock embedding function — no
  5. live model or network. They mirror the protocol proven for
  6. ``OpenSearchVectorDBStorage`` (issue #2785).
  7. """
  8. import numpy as np
  9. import pytest
  10. nano_vectordb = pytest.importorskip("nano_vectordb") # noqa: F841
  11. from lightrag.kg.nano_vector_db_impl import NanoVectorDBStorage # noqa: E402
  12. from lightrag.kg.shared_storage import ( # noqa: E402
  13. initialize_share_data,
  14. finalize_share_data,
  15. )
  16. from lightrag.utils import EmbeddingFunc # noqa: E402
  17. DIM = 8
  18. @pytest.fixture(autouse=True)
  19. def _shared_data():
  20. finalize_share_data()
  21. initialize_share_data()
  22. yield
  23. finalize_share_data()
  24. class _CountingEmbed:
  25. """Async embedding callable that records how many texts it embedded and how
  26. many times it was invoked (one invocation == one batch)."""
  27. def __init__(self, dim: int = DIM):
  28. self.dim = dim
  29. self.call_count = 0
  30. self.embedded_texts: list[str] = []
  31. async def __call__(self, texts, **kwargs):
  32. self.call_count += 1
  33. self.embedded_texts.extend(texts)
  34. # Deterministic per-text vector so duplicates are still 1-1.
  35. return np.array(
  36. [
  37. np.full(self.dim, (abs(hash(t)) % 97) + 1, dtype=np.float32)
  38. for t in texts
  39. ]
  40. )
  41. def _make_storage(tmp_path, embed: _CountingEmbed) -> NanoVectorDBStorage:
  42. return NanoVectorDBStorage(
  43. namespace="test_vectors",
  44. workspace="ws",
  45. global_config={
  46. "working_dir": str(tmp_path),
  47. "embedding_batch_num": 32,
  48. "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2},
  49. },
  50. embedding_func=EmbeddingFunc(embedding_dim=DIM, max_token_size=512, func=embed),
  51. meta_fields={"content"},
  52. )
  53. @pytest.mark.offline
  54. @pytest.mark.asyncio
  55. async def test_upsert_defers_embedding_to_index_done_callback(tmp_path):
  56. embed = _CountingEmbed()
  57. storage = _make_storage(tmp_path, embed)
  58. await storage.initialize()
  59. await storage.upsert(
  60. {
  61. "id1": {"content": "alpha"},
  62. "id2": {"content": "beta"},
  63. }
  64. )
  65. assert embed.call_count == 0, "upsert must not embed"
  66. assert len(storage._client) == 0, "nothing should be materialized yet"
  67. await storage.index_done_callback()
  68. assert embed.call_count == 1, "flush should embed in a single batch"
  69. assert sorted(embed.embedded_texts) == ["alpha", "beta"]
  70. assert len(storage._client) == 2
  71. @pytest.mark.offline
  72. @pytest.mark.asyncio
  73. async def test_repeated_upserts_same_id_embed_once_per_flush(tmp_path):
  74. embed = _CountingEmbed()
  75. storage = _make_storage(tmp_path, embed)
  76. await storage.initialize()
  77. await storage.upsert({"id1": {"content": "v1"}})
  78. await storage.upsert({"id1": {"content": "v2"}})
  79. await storage.upsert({"id1": {"content": "v3"}})
  80. await storage.index_done_callback()
  81. assert embed.call_count == 1
  82. assert embed.embedded_texts == ["v3"], "only the latest content is embedded"
  83. assert len(storage._client) == 1
  84. @pytest.mark.offline
  85. @pytest.mark.asyncio
  86. async def test_get_vectors_caches_and_flush_reuses(tmp_path):
  87. embed = _CountingEmbed()
  88. storage = _make_storage(tmp_path, embed)
  89. await storage.initialize()
  90. await storage.upsert({"id1": {"content": "alpha"}})
  91. vecs = await storage.get_vectors_by_ids(["id1"])
  92. assert "id1" in vecs and len(vecs["id1"]) == DIM
  93. assert embed.call_count == 1, "get_vectors_by_ids embeds pending lazily"
  94. # Flush must reuse the cached vector, not re-embed.
  95. await storage.index_done_callback()
  96. assert embed.call_count == 1, "flush should reuse the cached temp vector"
  97. assert len(storage._client) == 1
  98. @pytest.mark.offline
  99. @pytest.mark.asyncio
  100. async def test_reupsert_after_get_vectors_clears_cached_vector(tmp_path):
  101. embed = _CountingEmbed()
  102. storage = _make_storage(tmp_path, embed)
  103. await storage.initialize()
  104. await storage.upsert({"id1": {"content": "old"}})
  105. await storage.get_vectors_by_ids(["id1"]) # caches a temp vector for "old"
  106. assert embed.call_count == 1
  107. # New content version must clear the cached vector and re-embed at flush.
  108. await storage.upsert({"id1": {"content": "new"}})
  109. await storage.index_done_callback()
  110. assert embed.call_count == 2
  111. assert embed.embedded_texts == ["old", "new"]
  112. @pytest.mark.offline
  113. @pytest.mark.asyncio
  114. async def test_delete_cancels_pending_and_removes_materialized(tmp_path):
  115. embed = _CountingEmbed()
  116. storage = _make_storage(tmp_path, embed)
  117. await storage.initialize()
  118. # Materialize id1; leave id2 only as a pending (unflushed) upsert.
  119. await storage.upsert({"id1": {"content": "alpha"}})
  120. await storage.index_done_callback()
  121. await storage.upsert({"id2": {"content": "beta"}})
  122. await storage.delete(["id1", "id2"])
  123. assert "id2" not in storage._pending_upserts, "delete cancels pending upsert"
  124. assert len(storage._client) == 0, "delete removes the materialized row immediately"
  125. assert await storage.get_by_id("id1") is None
  126. assert await storage.get_by_id("id2") is None
  127. @pytest.mark.offline
  128. @pytest.mark.asyncio
  129. async def test_stale_client_reload_still_flushes_pending_upsert(tmp_path):
  130. embed = _CountingEmbed()
  131. writer = _make_storage(tmp_path, embed)
  132. stale_writer = _make_storage(tmp_path, embed)
  133. await writer.initialize()
  134. await stale_writer.initialize()
  135. await writer.upsert({"id1": {"content": "alpha"}})
  136. assert await writer.index_done_callback() is True
  137. assert stale_writer.storage_updated.value is True
  138. await stale_writer.upsert({"id2": {"content": "beta"}})
  139. assert await stale_writer.index_done_callback() is True
  140. reader = _make_storage(tmp_path, embed)
  141. await reader.initialize()
  142. rows = await reader.get_by_ids(["id1", "id2"])
  143. assert [row["id"] for row in rows] == ["id1", "id2"]
  144. assert stale_writer._pending_upserts == {}
  145. @pytest.mark.offline
  146. @pytest.mark.asyncio
  147. async def test_delete_reloads_stale_client_before_mutating(tmp_path):
  148. embed = _CountingEmbed()
  149. writer = _make_storage(tmp_path, embed)
  150. stale_deleter = _make_storage(tmp_path, embed)
  151. await writer.initialize()
  152. await stale_deleter.initialize()
  153. await writer.upsert({"id1": {"content": "alpha"}})
  154. assert await writer.index_done_callback() is True
  155. assert stale_deleter.storage_updated.value is True
  156. await stale_deleter.delete(["id1"])
  157. assert stale_deleter.storage_updated.value is False
  158. assert await stale_deleter.index_done_callback() is True
  159. reader = _make_storage(tmp_path, embed)
  160. await reader.initialize()
  161. assert await reader.get_by_id("id1") is None
  162. @pytest.mark.offline
  163. @pytest.mark.asyncio
  164. async def test_finalize_reloads_stale_client_before_flushing(tmp_path):
  165. embed = _CountingEmbed()
  166. writer = _make_storage(tmp_path, embed)
  167. stale_finalizer = _make_storage(tmp_path, embed)
  168. await writer.initialize()
  169. await stale_finalizer.initialize()
  170. await writer.upsert({"id1": {"content": "alpha"}})
  171. assert await writer.index_done_callback() is True
  172. assert stale_finalizer.storage_updated.value is True
  173. await stale_finalizer.upsert({"id2": {"content": "beta"}})
  174. await stale_finalizer.finalize()
  175. reader = _make_storage(tmp_path, embed)
  176. await reader.initialize()
  177. rows = await reader.get_by_ids(["id1", "id2"])
  178. assert [row["id"] for row in rows] == ["id1", "id2"]
  179. assert stale_finalizer._pending_upserts == {}
  180. @pytest.mark.offline
  181. @pytest.mark.asyncio
  182. async def test_read_your_writes_and_query_after_flush(tmp_path):
  183. embed = _CountingEmbed()
  184. storage = _make_storage(tmp_path, embed)
  185. await storage.initialize()
  186. await storage.upsert({"id1": {"content": "alpha"}})
  187. # Before flush: read paths see the pending row, query does not.
  188. hit = await storage.get_by_id("id1")
  189. assert hit is not None and hit["id"] == "id1" and hit["content"] == "alpha"
  190. by_ids = await storage.get_by_ids(["id1", "missing"])
  191. assert by_ids[0]["id"] == "id1" and by_ids[1] is None
  192. assert await storage.query("alpha", top_k=5) == [], "query ignores unflushed data"
  193. # After flush: query returns the row.
  194. await storage.index_done_callback()
  195. results = await storage.query("alpha", top_k=5)
  196. assert any(r["id"] == "id1" for r in results)
  197. @pytest.mark.offline
  198. @pytest.mark.asyncio
  199. async def test_finalize_flushes_pending(tmp_path):
  200. embed = _CountingEmbed()
  201. storage = _make_storage(tmp_path, embed)
  202. await storage.initialize()
  203. await storage.upsert({"id1": {"content": "alpha"}})
  204. await storage.finalize()
  205. assert embed.call_count == 1
  206. assert storage._pending_upserts == {}
  207. assert len(storage._client) == 1
  208. @pytest.mark.offline
  209. @pytest.mark.asyncio
  210. async def test_delete_entity_relation_cancels_pending(tmp_path):
  211. embed = _CountingEmbed()
  212. storage = NanoVectorDBStorage(
  213. namespace="test_relations",
  214. workspace="ws",
  215. global_config={
  216. "working_dir": str(tmp_path),
  217. "embedding_batch_num": 32,
  218. "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2},
  219. },
  220. embedding_func=EmbeddingFunc(embedding_dim=DIM, max_token_size=512, func=embed),
  221. meta_fields={"content", "src_id", "tgt_id"},
  222. )
  223. await storage.initialize()
  224. # Materialize r1 (A->B), leave r2 (A->C) and r3 (X->Y) as pending.
  225. await storage.upsert({"r1": {"content": "rel1", "src_id": "A", "tgt_id": "B"}})
  226. await storage.index_done_callback()
  227. await storage.upsert(
  228. {
  229. "r2": {"content": "rel2", "src_id": "A", "tgt_id": "C"},
  230. "r3": {"content": "rel3", "src_id": "X", "tgt_id": "Y"},
  231. }
  232. )
  233. await storage.delete_entity_relation("A")
  234. assert "r2" not in storage._pending_upserts, "incident pending entry cancelled"
  235. assert "r3" in storage._pending_upserts, "unrelated pending entry preserved"
  236. assert len(storage._client) == 0, "materialized A->B removed"
  237. @pytest.mark.offline
  238. @pytest.mark.asyncio
  239. async def test_flush_embedding_failure_raises_and_keeps_pending(tmp_path):
  240. class _FailingEmbed:
  241. def __init__(self):
  242. self.call_count = 0
  243. async def __call__(self, texts, **kwargs):
  244. self.call_count += 1
  245. raise RuntimeError("embed boom")
  246. embed = _FailingEmbed()
  247. storage = NanoVectorDBStorage(
  248. namespace="test_vectors",
  249. workspace="ws",
  250. global_config={
  251. "working_dir": str(tmp_path),
  252. "embedding_batch_num": 32,
  253. "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2},
  254. },
  255. embedding_func=EmbeddingFunc(embedding_dim=DIM, max_token_size=512, func=embed),
  256. meta_fields={"content"},
  257. )
  258. await storage.initialize()
  259. await storage.upsert({"id1": {"content": "alpha"}})
  260. with pytest.raises(RuntimeError, match="embed boom"):
  261. await storage.index_done_callback()
  262. assert "id1" in storage._pending_upserts, "pending preserved for retry"
  263. assert len(storage._client) == 0, "nothing materialized on embed failure"
  264. # Embed failure happens before self._client.upsert in _flush_pending_locked,
  265. # so _client_dirty must NOT be set. (A save-stage failure would leave it True
  266. # — see test_finalize_retries_save_after_flush_failure.)
  267. assert storage._client_dirty is False
  268. @pytest.mark.offline
  269. @pytest.mark.asyncio
  270. async def test_drop_discards_pending_without_embedding(tmp_path):
  271. embed = _CountingEmbed()
  272. storage = _make_storage(tmp_path, embed)
  273. await storage.initialize()
  274. await storage.upsert({"id1": {"content": "alpha"}})
  275. assert "id1" in storage._pending_upserts
  276. result = await storage.drop()
  277. assert result["status"] == "success"
  278. assert storage._pending_upserts == {}, "drop discards buffered upserts"
  279. assert embed.call_count == 0, "drop must not embed"
  280. assert storage._client_dirty is False
  281. @pytest.mark.offline
  282. @pytest.mark.asyncio
  283. async def test_finalize_retries_save_after_flush_failure(tmp_path):
  284. embed = _CountingEmbed()
  285. storage = _make_storage(tmp_path, embed)
  286. await storage.initialize()
  287. await storage.upsert({"id1": {"content": "alpha"}})
  288. original_save = storage._save_to_disk_locked
  289. save_calls = 0
  290. def fail_once():
  291. nonlocal save_calls
  292. save_calls += 1
  293. if save_calls == 1:
  294. raise OSError("boom")
  295. original_save()
  296. storage._save_to_disk_locked = fail_once
  297. with pytest.raises(OSError, match="boom"):
  298. await storage.finalize()
  299. assert storage._pending_upserts == {}
  300. assert storage._client_dirty is True
  301. await storage.finalize()
  302. assert save_calls == 2
  303. assert storage._client_dirty is False
  304. reader = _make_storage(tmp_path, embed)
  305. await reader.initialize()
  306. hit = await reader.get_by_id("id1")
  307. assert hit is not None and hit["id"] == "id1"