test_faiss_deferred_embedding.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768
  1. """Deferred-embedding coverage for ``FaissVectorDBStorage``.
  2. The storage no longer embeds eagerly in ``upsert``: it buffers a pending doc
  3. and embeds once per id at flush time (``index_done_callback`` / ``finalize``).
  4. These tests pin that contract using a counting mock embedding function — no
  5. live model or network. They mirror the protocol proven for
  6. ``NanoVectorDBStorage`` (issue #2785) plus three Faiss-specific cases:
  7. - ``test_reupsert_after_flush_replaces_single_fid`` — Faiss has no in-place
  8. upsert; verify the rebuild keeps a single fid per custom id.
  9. - ``test_index_done_callback_save_failure_raises`` — flush succeeds, save IO
  10. fails: pending is empty, ``_index_dirty`` stays True, the materialized index
  11. is preserved for a finalize retry.
  12. - ``test_reload_warns_on_index_meta_skew`` — ``index > meta`` on-disk skew
  13. (from a crash between the two atomic_writes) is logged on reload but **not**
  14. auto-repaired.
  15. """
  16. import json
  17. import os
  18. import numpy as np
  19. import pytest
  20. faiss = pytest.importorskip("faiss")
  21. from lightrag.kg.faiss_impl import FaissVectorDBStorage # noqa: E402
  22. from lightrag.kg.shared_storage import ( # noqa: E402
  23. initialize_share_data,
  24. finalize_share_data,
  25. )
  26. from lightrag.utils import EmbeddingFunc # noqa: E402
  27. DIM = 8
  28. @pytest.fixture(autouse=True)
  29. def _shared_data():
  30. finalize_share_data()
  31. initialize_share_data()
  32. yield
  33. finalize_share_data()
  34. class _CountingEmbed:
  35. """Async embedding callable that records how many texts it embedded and how
  36. many times it was invoked (one invocation == one batch)."""
  37. def __init__(self, dim: int = DIM):
  38. self.dim = dim
  39. self.call_count = 0
  40. self.embedded_texts: list[str] = []
  41. async def __call__(self, texts, **kwargs):
  42. self.call_count += 1
  43. self.embedded_texts.extend(texts)
  44. # Deterministic per-text vector so duplicates are still 1-1.
  45. return np.array(
  46. [
  47. np.full(self.dim, (abs(hash(t)) % 97) + 1, dtype=np.float32)
  48. for t in texts
  49. ]
  50. )
  51. def _make_storage(tmp_path, embed: _CountingEmbed) -> FaissVectorDBStorage:
  52. return FaissVectorDBStorage(
  53. namespace="test_vectors",
  54. workspace="ws",
  55. global_config={
  56. "working_dir": str(tmp_path),
  57. "embedding_batch_num": 32,
  58. "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2},
  59. },
  60. embedding_func=EmbeddingFunc(embedding_dim=DIM, max_token_size=512, func=embed),
  61. meta_fields={"content"},
  62. )
  63. def _assert_consistent(storage: FaissVectorDBStorage) -> None:
  64. """Faiss has two structures (index + meta dict); the root failure mode is
  65. them diverging. Every test that mutates state asserts they match."""
  66. assert storage._index.ntotal == len(storage._id_to_meta), (
  67. f"index ntotal ({storage._index.ntotal}) != meta length "
  68. f"({len(storage._id_to_meta)})"
  69. )
  70. # ---------------------------------------------------------------------------
  71. # (A) Nano-ported tests
  72. # ---------------------------------------------------------------------------
  73. @pytest.mark.offline
  74. @pytest.mark.asyncio
  75. async def test_upsert_defers_embedding_to_index_done_callback(tmp_path):
  76. embed = _CountingEmbed()
  77. storage = _make_storage(tmp_path, embed)
  78. await storage.initialize()
  79. await storage.upsert(
  80. {
  81. "id1": {"content": "alpha"},
  82. "id2": {"content": "beta"},
  83. }
  84. )
  85. assert embed.call_count == 0, "upsert must not embed"
  86. assert storage._index.ntotal == 0, "nothing should be materialized yet"
  87. _assert_consistent(storage)
  88. await storage.index_done_callback()
  89. assert embed.call_count == 1, "flush should embed in a single batch"
  90. assert sorted(embed.embedded_texts) == ["alpha", "beta"]
  91. assert storage._index.ntotal == 2
  92. _assert_consistent(storage)
  93. @pytest.mark.offline
  94. @pytest.mark.asyncio
  95. async def test_repeated_upserts_same_id_embed_once_per_flush(tmp_path):
  96. embed = _CountingEmbed()
  97. storage = _make_storage(tmp_path, embed)
  98. await storage.initialize()
  99. await storage.upsert({"id1": {"content": "v1"}})
  100. await storage.upsert({"id1": {"content": "v2"}})
  101. await storage.upsert({"id1": {"content": "v3"}})
  102. await storage.index_done_callback()
  103. assert embed.call_count == 1
  104. assert embed.embedded_texts == ["v3"], "only the latest content is embedded"
  105. assert storage._index.ntotal == 1
  106. _assert_consistent(storage)
  107. @pytest.mark.offline
  108. @pytest.mark.asyncio
  109. async def test_get_vectors_caches_and_flush_reuses(tmp_path):
  110. embed = _CountingEmbed()
  111. storage = _make_storage(tmp_path, embed)
  112. await storage.initialize()
  113. await storage.upsert({"id1": {"content": "alpha"}})
  114. vecs = await storage.get_vectors_by_ids(["id1"])
  115. assert "id1" in vecs and len(vecs["id1"]) == DIM
  116. assert embed.call_count == 1, "get_vectors_by_ids embeds pending lazily"
  117. # Flush must reuse the cached vector, not re-embed.
  118. await storage.index_done_callback()
  119. assert embed.call_count == 1, "flush should reuse the cached temp vector"
  120. assert storage._index.ntotal == 1
  121. _assert_consistent(storage)
  122. @pytest.mark.offline
  123. @pytest.mark.asyncio
  124. async def test_reupsert_after_get_vectors_clears_cached_vector(tmp_path):
  125. embed = _CountingEmbed()
  126. storage = _make_storage(tmp_path, embed)
  127. await storage.initialize()
  128. await storage.upsert({"id1": {"content": "old"}})
  129. await storage.get_vectors_by_ids(["id1"]) # caches a temp vector for "old"
  130. assert embed.call_count == 1
  131. # New content version must clear the cached vector and re-embed at flush.
  132. await storage.upsert({"id1": {"content": "new"}})
  133. await storage.index_done_callback()
  134. assert embed.call_count == 2
  135. assert embed.embedded_texts == ["old", "new"]
  136. _assert_consistent(storage)
  137. @pytest.mark.offline
  138. @pytest.mark.asyncio
  139. async def test_delete_cancels_pending_and_removes_materialized(tmp_path):
  140. embed = _CountingEmbed()
  141. storage = _make_storage(tmp_path, embed)
  142. await storage.initialize()
  143. # Materialize id1; leave id2 only as a pending (unflushed) upsert.
  144. await storage.upsert({"id1": {"content": "alpha"}})
  145. await storage.index_done_callback()
  146. await storage.upsert({"id2": {"content": "beta"}})
  147. await storage.delete(["id1", "id2"])
  148. assert "id2" not in storage._pending_upserts, "delete cancels pending upsert"
  149. assert storage._index.ntotal == 0, "delete removes the materialized row"
  150. assert await storage.get_by_id("id1") is None
  151. assert await storage.get_by_id("id2") is None
  152. _assert_consistent(storage)
  153. @pytest.mark.offline
  154. @pytest.mark.asyncio
  155. async def test_stale_client_reload_still_flushes_pending_upsert(tmp_path):
  156. embed = _CountingEmbed()
  157. writer = _make_storage(tmp_path, embed)
  158. stale_writer = _make_storage(tmp_path, embed)
  159. await writer.initialize()
  160. await stale_writer.initialize()
  161. await writer.upsert({"id1": {"content": "alpha"}})
  162. assert await writer.index_done_callback() is True
  163. assert stale_writer.storage_updated.value is True
  164. await stale_writer.upsert({"id2": {"content": "beta"}})
  165. assert await stale_writer.index_done_callback() is True
  166. reader = _make_storage(tmp_path, embed)
  167. await reader.initialize()
  168. rows = await reader.get_by_ids(["id1", "id2"])
  169. assert [row["id"] for row in rows] == ["id1", "id2"]
  170. assert stale_writer._pending_upserts == {}
  171. _assert_consistent(reader)
  172. @pytest.mark.offline
  173. @pytest.mark.asyncio
  174. async def test_delete_reloads_stale_client_before_mutating(tmp_path):
  175. embed = _CountingEmbed()
  176. writer = _make_storage(tmp_path, embed)
  177. stale_deleter = _make_storage(tmp_path, embed)
  178. await writer.initialize()
  179. await stale_deleter.initialize()
  180. await writer.upsert({"id1": {"content": "alpha"}})
  181. assert await writer.index_done_callback() is True
  182. assert stale_deleter.storage_updated.value is True
  183. await stale_deleter.delete(["id1"])
  184. assert stale_deleter.storage_updated.value is False
  185. assert await stale_deleter.index_done_callback() is True
  186. reader = _make_storage(tmp_path, embed)
  187. await reader.initialize()
  188. assert await reader.get_by_id("id1") is None
  189. _assert_consistent(reader)
  190. @pytest.mark.offline
  191. @pytest.mark.asyncio
  192. async def test_finalize_reloads_stale_client_before_flushing(tmp_path):
  193. embed = _CountingEmbed()
  194. writer = _make_storage(tmp_path, embed)
  195. stale_finalizer = _make_storage(tmp_path, embed)
  196. await writer.initialize()
  197. await stale_finalizer.initialize()
  198. await writer.upsert({"id1": {"content": "alpha"}})
  199. assert await writer.index_done_callback() is True
  200. assert stale_finalizer.storage_updated.value is True
  201. await stale_finalizer.upsert({"id2": {"content": "beta"}})
  202. await stale_finalizer.finalize()
  203. reader = _make_storage(tmp_path, embed)
  204. await reader.initialize()
  205. rows = await reader.get_by_ids(["id1", "id2"])
  206. assert [row["id"] for row in rows] == ["id1", "id2"]
  207. assert stale_finalizer._pending_upserts == {}
  208. _assert_consistent(reader)
  209. @pytest.mark.offline
  210. @pytest.mark.asyncio
  211. async def test_read_your_writes_and_query_after_flush(tmp_path):
  212. embed = _CountingEmbed()
  213. storage = _make_storage(tmp_path, embed)
  214. await storage.initialize()
  215. await storage.upsert({"id1": {"content": "alpha"}})
  216. # Before flush: read paths see the pending row, query does not.
  217. hit = await storage.get_by_id("id1")
  218. assert hit is not None and hit["id"] == "id1" and hit["content"] == "alpha"
  219. by_ids = await storage.get_by_ids(["id1", "missing"])
  220. assert by_ids[0]["id"] == "id1" and by_ids[1] is None
  221. assert await storage.query("alpha", top_k=5) == [], "query ignores unflushed data"
  222. # After flush: query returns the row.
  223. await storage.index_done_callback()
  224. results = await storage.query("alpha", top_k=5)
  225. assert any(r["id"] == "id1" for r in results)
  226. _assert_consistent(storage)
  227. @pytest.mark.offline
  228. @pytest.mark.asyncio
  229. async def test_finalize_flushes_pending(tmp_path):
  230. embed = _CountingEmbed()
  231. storage = _make_storage(tmp_path, embed)
  232. await storage.initialize()
  233. await storage.upsert({"id1": {"content": "alpha"}})
  234. await storage.finalize()
  235. assert embed.call_count == 1
  236. assert storage._pending_upserts == {}
  237. assert storage._index.ntotal == 1
  238. _assert_consistent(storage)
  239. @pytest.mark.offline
  240. @pytest.mark.asyncio
  241. async def test_delete_entity_relation_cancels_pending(tmp_path):
  242. embed = _CountingEmbed()
  243. storage = FaissVectorDBStorage(
  244. namespace="test_relations",
  245. workspace="ws",
  246. global_config={
  247. "working_dir": str(tmp_path),
  248. "embedding_batch_num": 32,
  249. "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2},
  250. },
  251. embedding_func=EmbeddingFunc(embedding_dim=DIM, max_token_size=512, func=embed),
  252. meta_fields={"content", "src_id", "tgt_id"},
  253. )
  254. await storage.initialize()
  255. # Materialize r1 (A->B), leave r2 (A->C) and r3 (X->Y) as pending.
  256. await storage.upsert({"r1": {"content": "rel1", "src_id": "A", "tgt_id": "B"}})
  257. await storage.index_done_callback()
  258. await storage.upsert(
  259. {
  260. "r2": {"content": "rel2", "src_id": "A", "tgt_id": "C"},
  261. "r3": {"content": "rel3", "src_id": "X", "tgt_id": "Y"},
  262. }
  263. )
  264. await storage.delete_entity_relation("A")
  265. assert "r2" not in storage._pending_upserts, "incident pending entry cancelled"
  266. assert "r3" in storage._pending_upserts, "unrelated pending entry preserved"
  267. assert storage._index.ntotal == 0, "materialized A->B removed"
  268. _assert_consistent(storage)
  269. @pytest.mark.offline
  270. @pytest.mark.asyncio
  271. async def test_flush_embedding_failure_raises_and_keeps_pending(tmp_path):
  272. class _FailingEmbed:
  273. def __init__(self):
  274. self.call_count = 0
  275. async def __call__(self, texts, **kwargs):
  276. self.call_count += 1
  277. raise RuntimeError("embed boom")
  278. embed = _FailingEmbed()
  279. storage = FaissVectorDBStorage(
  280. namespace="test_vectors",
  281. workspace="ws",
  282. global_config={
  283. "working_dir": str(tmp_path),
  284. "embedding_batch_num": 32,
  285. "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2},
  286. },
  287. embedding_func=EmbeddingFunc(embedding_dim=DIM, max_token_size=512, func=embed),
  288. meta_fields={"content"},
  289. )
  290. await storage.initialize()
  291. await storage.upsert({"id1": {"content": "alpha"}})
  292. with pytest.raises(RuntimeError, match="embed boom"):
  293. await storage.index_done_callback()
  294. assert "id1" in storage._pending_upserts, "pending preserved for retry"
  295. assert storage._index.ntotal == 0, "nothing materialized on embed failure"
  296. # Embed failure happens before self._index.add in _flush_pending_locked,
  297. # so _index_dirty must NOT be set. (A save-stage failure would leave it True
  298. # — see test_index_done_callback_save_failure_raises.)
  299. assert storage._index_dirty is False
  300. _assert_consistent(storage)
  301. @pytest.mark.offline
  302. @pytest.mark.asyncio
  303. async def test_drop_discards_pending_without_embedding(tmp_path):
  304. embed = _CountingEmbed()
  305. storage = _make_storage(tmp_path, embed)
  306. await storage.initialize()
  307. await storage.upsert({"id1": {"content": "alpha"}})
  308. assert "id1" in storage._pending_upserts
  309. result = await storage.drop()
  310. assert result["status"] == "success"
  311. assert storage._pending_upserts == {}, "drop discards buffered upserts"
  312. assert embed.call_count == 0, "drop must not embed"
  313. assert storage._index_dirty is False
  314. _assert_consistent(storage)
  315. @pytest.mark.offline
  316. @pytest.mark.asyncio
  317. async def test_finalize_retries_save_after_flush_failure(tmp_path):
  318. embed = _CountingEmbed()
  319. storage = _make_storage(tmp_path, embed)
  320. await storage.initialize()
  321. await storage.upsert({"id1": {"content": "alpha"}})
  322. original_save = storage._save_faiss_index
  323. save_calls = 0
  324. def fail_once():
  325. nonlocal save_calls
  326. save_calls += 1
  327. if save_calls == 1:
  328. raise OSError("boom")
  329. original_save()
  330. storage._save_faiss_index = fail_once
  331. with pytest.raises(OSError, match="boom"):
  332. await storage.finalize()
  333. assert storage._pending_upserts == {}
  334. assert storage._index_dirty is True
  335. await storage.finalize()
  336. assert save_calls == 2
  337. assert storage._index_dirty is False
  338. reader = _make_storage(tmp_path, embed)
  339. await reader.initialize()
  340. hit = await reader.get_by_id("id1")
  341. assert hit is not None and hit["id"] == "id1"
  342. _assert_consistent(reader)
  343. # ---------------------------------------------------------------------------
  344. # (B) Faiss-specific tests
  345. # ---------------------------------------------------------------------------
  346. @pytest.mark.offline
  347. @pytest.mark.asyncio
  348. async def test_reupsert_after_flush_replaces_single_fid(tmp_path):
  349. """Faiss has no in-place upsert: re-upserting an already-materialized id
  350. must rebuild the index without the old fid, so we still end up with
  351. exactly one row per custom id."""
  352. embed = _CountingEmbed()
  353. storage = _make_storage(tmp_path, embed)
  354. await storage.initialize()
  355. await storage.upsert({"id1": {"content": "old"}})
  356. await storage.index_done_callback()
  357. assert storage._index.ntotal == 1
  358. _assert_consistent(storage)
  359. await storage.upsert({"id1": {"content": "new"}})
  360. await storage.index_done_callback()
  361. assert storage._index.ntotal == 1, "rebuild must remove old fid before adding new"
  362. assert len(storage._id_to_meta) == 1
  363. _assert_consistent(storage)
  364. hit = await storage.get_by_id("id1")
  365. assert hit is not None and hit["content"] == "new"
  366. assert embed.call_count == 2, "each flush embeds the latest content once"
  367. @pytest.mark.offline
  368. @pytest.mark.asyncio
  369. async def test_index_done_callback_save_failure_raises(tmp_path):
  370. """Save failure in index_done_callback must propagate, leave pending empty
  371. (flush already succeeded), and keep _index_dirty=True so finalize retries."""
  372. embed = _CountingEmbed()
  373. storage = _make_storage(tmp_path, embed)
  374. await storage.initialize()
  375. await storage.upsert({"id1": {"content": "alpha"}})
  376. original_save = storage._save_faiss_index
  377. def fail_save():
  378. raise OSError("save boom")
  379. storage._save_faiss_index = fail_save
  380. with pytest.raises(OSError, match="save boom"):
  381. await storage.index_done_callback()
  382. assert storage._pending_upserts == {}, "flush succeeded so pending is empty"
  383. assert storage._index_dirty is True, "save failure preserves dirty for retry"
  384. assert storage._index.ntotal == 1, "materialized state is preserved"
  385. _assert_consistent(storage)
  386. # Restore real save; finalize must retry only the save (no re-embed).
  387. storage._save_faiss_index = original_save
  388. embed_before = embed.call_count
  389. await storage.finalize()
  390. assert embed.call_count == embed_before, "save retry must not re-embed"
  391. assert storage._index_dirty is False
  392. reader = _make_storage(tmp_path, embed)
  393. await reader.initialize()
  394. hit = await reader.get_by_id("id1")
  395. assert hit is not None and hit["id"] == "id1"
  396. _assert_consistent(reader)
  397. @pytest.mark.offline
  398. @pytest.mark.asyncio
  399. async def test_reload_warns_on_index_meta_skew(tmp_path, caplog):
  400. """A crash between the .index write and the .meta.json write leaves
  401. ``ntotal(.index) > rows(.meta)``. ``_load_faiss_index`` must log a warning
  402. on reload; auto-repair is intentionally not in scope here."""
  403. import logging
  404. from lightrag.utils import logger as lightrag_logger
  405. embed = _CountingEmbed()
  406. writer = _make_storage(tmp_path, embed)
  407. await writer.initialize()
  408. await writer.upsert({"id1": {"content": "alpha"}, "id2": {"content": "beta"}})
  409. await writer.index_done_callback()
  410. # Corrupt the meta file: drop one entry so disk has index > meta.
  411. with open(writer._meta_file, "r", encoding="utf-8") as f:
  412. meta = json.load(f)
  413. assert len(meta) == 2
  414. dropped_key = next(iter(meta))
  415. del meta[dropped_key]
  416. with open(writer._meta_file, "w", encoding="utf-8") as f:
  417. json.dump(meta, f)
  418. # The lightrag logger sets propagate=False (lightrag/utils.py), so caplog —
  419. # which attaches to root by default — never sees its records. Flip propagate
  420. # for the duration of the reload, then restore.
  421. caplog.clear()
  422. old_propagate = lightrag_logger.propagate
  423. lightrag_logger.propagate = True
  424. try:
  425. with caplog.at_level(logging.WARNING, logger="lightrag"):
  426. reader = _make_storage(tmp_path, embed)
  427. await reader.initialize()
  428. finally:
  429. lightrag_logger.propagate = old_propagate
  430. # The reader's index still has 2 vectors but only 1 reachable via meta —
  431. # this is the "known risk, not auto-repaired" state.
  432. assert reader._index.ntotal == 2
  433. assert len(reader._id_to_meta) == 1
  434. skew_messages = [
  435. rec.message
  436. for rec in caplog.records
  437. if "skew" in rec.message or "index > meta" in rec.message
  438. ]
  439. assert skew_messages, (
  440. f"expected an index>meta skew warning; got: "
  441. f"{[r.message for r in caplog.records]}"
  442. )
  443. # Sanity: state files exist where we left them.
  444. assert os.path.exists(writer._faiss_index_file)
  445. assert os.path.exists(writer._meta_file)
  446. @pytest.mark.offline
  447. @pytest.mark.asyncio
  448. async def test_query_skips_orphan_faiss_hits(tmp_path):
  449. """After an ``index > meta`` skew the orphan vector is still searchable by
  450. similarity, but ``query`` must skip it instead of leaking a ghost
  451. ``{"id": None, ...}`` row to the caller."""
  452. embed = _CountingEmbed()
  453. storage = _make_storage(tmp_path, embed)
  454. await storage.initialize()
  455. # Materialize two rows.
  456. await storage.upsert({"id1": {"content": "alpha"}, "id2": {"content": "beta"}})
  457. await storage.index_done_callback()
  458. assert storage._index.ntotal == 2
  459. # Synthesize the skew: drop one meta row in memory, keeping the faiss
  460. # index untouched. This mirrors what _load_faiss_index would surface on
  461. # reload after a crash between the two atomic_writes.
  462. orphan_fid = next(iter(storage._id_to_meta))
  463. del storage._id_to_meta[orphan_fid]
  464. assert storage._index.ntotal == 2
  465. assert len(storage._id_to_meta) == 1
  466. # The orphan vector still scores high in similarity search; query must
  467. # filter it out instead of returning {"id": None, ...}.
  468. results = await storage.query("anything", top_k=5)
  469. for row in results:
  470. assert row["id"] is not None, f"orphan hit leaked: {row}"
  471. # And the surviving row is still returned.
  472. surviving_id = next(iter(storage._id_to_meta.values()))["__id__"]
  473. assert any(r["id"] == surviving_id for r in results)
  474. @pytest.mark.offline
  475. @pytest.mark.asyncio
  476. async def test_reupsert_cleans_duplicate_custom_id_rows(tmp_path):
  477. """Defends against legacy / externally corrupted stores where multiple
  478. fids in ``_id_to_meta`` share the same ``__id__``. A re-upsert + flush
  479. must collapse them to a single row; a ``delete`` must remove all of them."""
  480. embed = _CountingEmbed()
  481. storage = _make_storage(tmp_path, embed)
  482. await storage.initialize()
  483. # Hand-craft a corrupt state: two fids carry the same custom id "dup".
  484. matrix = np.array([[1.0] * DIM, [2.0] * DIM], dtype=np.float32)
  485. faiss.normalize_L2(matrix)
  486. storage._index.add(matrix)
  487. storage._id_to_meta[0] = {
  488. "__id__": "dup",
  489. "__created_at__": 1,
  490. "content": "v1",
  491. "__vector__": matrix[0].tolist(),
  492. }
  493. storage._id_to_meta[1] = {
  494. "__id__": "dup",
  495. "__created_at__": 1,
  496. "content": "v2",
  497. "__vector__": matrix[1].tolist(),
  498. }
  499. _assert_consistent(storage)
  500. assert storage._find_faiss_ids_by_custom_id("dup") == [0, 1]
  501. # Re-upsert + flush: both duplicates must be removed in the rebuild
  502. # before the new vector is added; final state is a single row.
  503. await storage.upsert({"dup": {"content": "v3"}})
  504. await storage.index_done_callback()
  505. assert storage._index.ntotal == 1, "flush rebuild must drop both duplicates"
  506. assert len(storage._id_to_meta) == 1
  507. assert storage._find_faiss_ids_by_custom_id("dup") == list(
  508. storage._id_to_meta.keys()
  509. )
  510. hit = await storage.get_by_id("dup")
  511. assert hit is not None and hit["content"] == "v3"
  512. _assert_consistent(storage)
  513. # Re-seed two more duplicates and verify delete also removes them all.
  514. matrix2 = np.array([[3.0] * DIM, [4.0] * DIM], dtype=np.float32)
  515. faiss.normalize_L2(matrix2)
  516. storage._index.add(matrix2)
  517. next_fid = max(storage._id_to_meta) + 1
  518. storage._id_to_meta[next_fid] = {
  519. "__id__": "dup",
  520. "__created_at__": 2,
  521. "content": "dup-a",
  522. "__vector__": matrix2[0].tolist(),
  523. }
  524. storage._id_to_meta[next_fid + 1] = {
  525. "__id__": "dup",
  526. "__created_at__": 2,
  527. "content": "dup-b",
  528. "__vector__": matrix2[1].tolist(),
  529. }
  530. assert len(storage._find_faiss_ids_by_custom_id("dup")) == 3
  531. await storage.delete(["dup"])
  532. assert storage._find_faiss_ids_by_custom_id("dup") == []
  533. assert storage._index.ntotal == 0
  534. _assert_consistent(storage)
  535. @pytest.mark.offline
  536. @pytest.mark.asyncio
  537. async def test_delete_propagates_errors(tmp_path, monkeypatch):
  538. """Faiss ``delete`` must NOT swallow errors — the caller (document
  539. deletion / status update path) needs to abort if vectors weren't
  540. actually removed. This intentionally diverges from Nano."""
  541. embed = _CountingEmbed()
  542. storage = _make_storage(tmp_path, embed)
  543. await storage.initialize()
  544. await storage.upsert({"id1": {"content": "alpha"}})
  545. await storage.index_done_callback()
  546. def boom(_self, _fids):
  547. raise RuntimeError("rebuild boom")
  548. # _remove_faiss_ids_locked is what delete calls under the hood.
  549. monkeypatch.setattr(
  550. FaissVectorDBStorage, "_remove_faiss_ids_locked", boom, raising=True
  551. )
  552. with pytest.raises(RuntimeError, match="rebuild boom"):
  553. await storage.delete(["id1"])
  554. @pytest.mark.offline
  555. @pytest.mark.asyncio
  556. async def test_flush_recovers_from_index_add_failure_without_re_embedding(tmp_path):
  557. """Self-heal contract: if ``index.add`` raises mid-flush (after embedding
  558. already succeeded), the pending buffer keeps the cached vectors and a
  559. subsequent ``finalize`` retries the flush **without re-embedding**. Pins
  560. the "pending is the source of truth on mid-write failure" invariant
  561. documented on ``_flush_pending_locked``."""
  562. class _AddFailsOnce:
  563. """Wraps a real faiss index, raising on the first ``.add`` call. After
  564. the second add succeeds it swaps the storage's ``_index`` attribute
  565. back to the real instance, so ``faiss.write_index`` (which requires a
  566. real SWIG-wrapped object) can run during the retry's save step. This
  567. is a test-only shim — in production ``self._index`` is always a real
  568. faiss index throughout the retry.
  569. """
  570. def __init__(self, storage, real):
  571. self._storage = storage
  572. self._real = real
  573. self._calls = 0
  574. def __getattr__(self, name):
  575. return getattr(self._real, name)
  576. def add(self, arr):
  577. self._calls += 1
  578. if self._calls == 1:
  579. raise RuntimeError("add boom")
  580. result = self._real.add(arr)
  581. self._storage._index = self._real
  582. return result
  583. embed = _CountingEmbed()
  584. storage = _make_storage(tmp_path, embed)
  585. await storage.initialize()
  586. await storage.upsert({"id1": {"content": "alpha"}})
  587. real_index = storage._index
  588. storage._index = _AddFailsOnce(storage, real_index)
  589. with pytest.raises(RuntimeError, match="add boom"):
  590. await storage.index_done_callback()
  591. # Embedding completed once (failure happened after embed, in index.add).
  592. assert embed.call_count == 1
  593. # Pending preserved with cached vectors — that's the self-healing key.
  594. assert "id1" in storage._pending_upserts
  595. assert storage._pending_upserts["id1"].vector is not None
  596. # _index_dirty stays False: docstring says we deliberately don't flip it
  597. # on mid-write failure (pending is the source of truth).
  598. assert storage._index_dirty is False
  599. assert storage._index.ntotal == 0
  600. # Retry through the same public entry point. The wrapper's second add
  601. # succeeds, unwraps itself, and the rest of finalize (save + notify)
  602. # runs against the real index.
  603. await storage.finalize()
  604. assert embed.call_count == 1, "retry must reuse cached vectors, not re-embed"
  605. assert storage._index is real_index, "wrapper unwrapped itself on the second add"
  606. assert storage._index.ntotal == 1
  607. assert storage._pending_upserts == {}
  608. assert storage._index_dirty is False
  609. _assert_consistent(storage)
  610. # And the row was persisted to disk by the retry's save.
  611. reader = _make_storage(tmp_path, embed)
  612. await reader.initialize()
  613. hit = await reader.get_by_id("id1")
  614. assert hit is not None and hit["content"] == "alpha"