test_postgres_vector_deferred.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876
  1. """Unit tests for PGVectorStorage deferred-embedding contract.
  2. PGVectorStorage now buffers upserts/deletes in process-local pending buffers
  3. and embeds + persists only during ``index_done_callback()`` / ``finalize()``.
  4. This mirrors OpenSearchVectorDBStorage and NanoVectorDBStorage.
  5. These tests use the same ``MagicMock``-based DB stub as
  6. ``test_postgres_upsert.py``, plus a counting embedding function adapted from
  7. ``tests/kg/opensearch_impl/test_opensearch_storage.py``.
  8. """
  9. import asyncio
  10. import datetime
  11. import numpy as np
  12. import pytest
  13. from unittest.mock import AsyncMock, MagicMock
  14. from lightrag.kg.postgres_impl import (
  15. PGVectorStorage,
  16. _PendingPGVectorDoc,
  17. )
  18. from lightrag.namespace import NameSpace
  19. from lightrag.utils import EmbeddingFunc, compute_mdhash_id
  20. # ---------------------------------------------------------------------------
  21. # Helpers
  22. # ---------------------------------------------------------------------------
  23. class CountingEmbed:
  24. """Embedding test double that records calls and can fail N times first."""
  25. def __init__(self, dim: int = 3, fail_times: int = 0):
  26. self.embedding_dim = dim
  27. self.max_token_size = 512
  28. self.model_name = "test_model"
  29. self.fail_times = fail_times
  30. self.call_count = 0
  31. self.batches: list[list[str]] = []
  32. async def __call__(self, texts, **kwargs):
  33. self.call_count += 1
  34. batch = list(texts)
  35. self.batches.append(batch)
  36. if self.fail_times > 0:
  37. self.fail_times -= 1
  38. raise RuntimeError("embedding failed")
  39. return np.array(
  40. [[float(self.call_count), 0.0, 0.0] for _ in batch], dtype=np.float32
  41. )
  42. def _make_storage(
  43. namespace: str = NameSpace.VECTOR_STORE_CHUNKS,
  44. embed: CountingEmbed | None = None,
  45. embedding_batch_num: int = 10,
  46. fail_run_with_retry: bool = False,
  47. ) -> PGVectorStorage:
  48. """Construct a PGVectorStorage with a stubbed DB and embedding func."""
  49. db = MagicMock()
  50. captured_executemany: list[tuple] = []
  51. captured_execute: list[tuple] = []
  52. retry_kwargs: list[dict] = []
  53. retry_call_count = {"n": 0}
  54. async def fake_run_with_retry(operation, **kwargs):
  55. retry_kwargs.append(kwargs)
  56. retry_call_count["n"] += 1
  57. if fail_run_with_retry:
  58. raise RuntimeError("simulated PG failure")
  59. mock_conn = AsyncMock()
  60. tx_cm = AsyncMock()
  61. tx_cm.__aenter__.return_value = None
  62. tx_cm.__aexit__.return_value = None
  63. mock_conn.transaction = MagicMock(return_value=tx_cm)
  64. await operation(mock_conn)
  65. for call in mock_conn.executemany.call_args_list:
  66. captured_executemany.append((call.args[0], call.args[1]))
  67. for call in mock_conn.execute.call_args_list:
  68. captured_execute.append((call.args[0], call.args[1:]))
  69. db._run_with_retry = AsyncMock(side_effect=fake_run_with_retry)
  70. # db.execute is used by delete_entity, delete_entity_relation, drop.
  71. db.execute = AsyncMock(return_value=None)
  72. db.query = AsyncMock(return_value=[])
  73. db.workspace = "test_ws"
  74. embedding = embed or CountingEmbed()
  75. embedding_func = EmbeddingFunc(
  76. embedding_dim=embedding.embedding_dim,
  77. func=embedding,
  78. model_name=embedding.model_name,
  79. )
  80. storage = PGVectorStorage(
  81. namespace=namespace,
  82. workspace="test_ws",
  83. global_config={
  84. "embedding_batch_num": embedding_batch_num,
  85. "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.5},
  86. },
  87. embedding_func=embedding_func,
  88. )
  89. storage.db = db
  90. storage._flush_lock = asyncio.Lock()
  91. storage._counting_embed = embedding
  92. storage._captured_executemany = captured_executemany
  93. storage._captured_execute = captured_execute
  94. storage._retry_kwargs = retry_kwargs
  95. storage._retry_call_count = retry_call_count
  96. return storage
  97. def _chunk_data(**overrides):
  98. base = {
  99. "tokens": 1,
  100. "chunk_order_index": 0,
  101. "full_doc_id": "doc-1",
  102. "content": "alpha",
  103. "file_path": "/a.txt",
  104. }
  105. base.update(overrides)
  106. return base
  107. def _entity_data(name: str = "Alice", **overrides):
  108. base = {
  109. "entity_name": name,
  110. "content": f"{name} content",
  111. "source_id": "chunk-1",
  112. "file_path": "/e.txt",
  113. }
  114. base.update(overrides)
  115. return base
  116. def _relation_data(src: str = "Alice", tgt: str = "Bob", **overrides):
  117. base = {
  118. "src_id": src,
  119. "tgt_id": tgt,
  120. "content": f"{src}->{tgt}",
  121. "source_id": "chunk-1",
  122. "file_path": "/r.txt",
  123. }
  124. base.update(overrides)
  125. return base
  126. # ---------------------------------------------------------------------------
  127. # 1. upsert() buffers only
  128. # ---------------------------------------------------------------------------
  129. @pytest.mark.asyncio
  130. async def test_upsert_buffers_without_embedding_or_db_call():
  131. storage = _make_storage()
  132. await storage.upsert({"c1": _chunk_data(content="alpha")})
  133. assert storage._counting_embed.call_count == 0
  134. assert storage._retry_call_count["n"] == 0
  135. assert "c1" in storage._pending_vector_docs
  136. pending = storage._pending_vector_docs["c1"]
  137. assert isinstance(pending, _PendingPGVectorDoc)
  138. assert pending.vector is None
  139. assert pending.item["__id__"] == "c1"
  140. assert pending.item["content"] == "alpha"
  141. # ---------------------------------------------------------------------------
  142. # 2. Deferred batching across many upsert() calls
  143. # ---------------------------------------------------------------------------
  144. @pytest.mark.asyncio
  145. async def test_many_upserts_flush_in_one_executemany():
  146. storage = _make_storage(embedding_batch_num=3)
  147. for i in range(5):
  148. await storage.upsert({f"c{i}": _chunk_data(content=f"doc {i}")})
  149. assert storage._counting_embed.call_count == 0
  150. await storage.index_done_callback()
  151. # Embedding split only by embedding_batch_num (3 + 2).
  152. assert [len(b) for b in storage._counting_embed.batches] == [3, 2]
  153. # One executemany for 5 records (not one per upsert call).
  154. assert len(storage._captured_executemany) == 1
  155. sql, rows = storage._captured_executemany[0]
  156. assert len(rows) == 5
  157. assert "LIGHTRAG_VDB_CHUNKS" in sql
  158. # ---------------------------------------------------------------------------
  159. # 3. Same-id upsert overwrite
  160. # ---------------------------------------------------------------------------
  161. @pytest.mark.asyncio
  162. async def test_same_id_upsert_overwrites():
  163. storage = _make_storage()
  164. await storage.upsert({"x": _chunk_data(content="a")})
  165. await storage.upsert({"x": _chunk_data(content="b")})
  166. await storage.index_done_callback()
  167. rows = storage._captured_executemany[0][1]
  168. assert len(rows) == 1
  169. # The chunk tuple position 6 ($6) is content.
  170. assert rows[0][5] == "b"
  171. # ---------------------------------------------------------------------------
  172. # 4. Lazy vector cache: get_vectors_by_ids embeds, flush reuses
  173. # ---------------------------------------------------------------------------
  174. @pytest.mark.asyncio
  175. async def test_lazy_vector_cache_reused_by_flush():
  176. storage = _make_storage()
  177. await storage.upsert({"c1": _chunk_data(content="alpha")})
  178. vecs = await storage.get_vectors_by_ids(["c1"])
  179. assert "c1" in vecs
  180. assert storage._counting_embed.call_count == 1
  181. await storage.index_done_callback()
  182. # Flush must not re-embed; total call count stays 1.
  183. assert storage._counting_embed.call_count == 1
  184. # The vector that landed in the executemany row equals what get_vectors_by_ids returned.
  185. rows = storage._captured_executemany[0][1]
  186. persisted_vec = rows[0][6] # chunks tuple: $7 is content_vector
  187. assert list(np.asarray(persisted_vec, dtype=np.float32)) == list(
  188. np.asarray(vecs["c1"], dtype=np.float32)
  189. )
  190. # ---------------------------------------------------------------------------
  191. # 5. Upsert after lazy cache discards the cached vector
  192. # ---------------------------------------------------------------------------
  193. @pytest.mark.asyncio
  194. async def test_upsert_after_lazy_cache_discards_cached_vector():
  195. storage = _make_storage()
  196. await storage.upsert({"c1": _chunk_data(content="a")})
  197. await storage.get_vectors_by_ids(["c1"]) # embed call #1
  198. assert storage._pending_vector_docs["c1"].vector is not None
  199. await storage.upsert({"c1": _chunk_data(content="b")}) # discards cache
  200. assert storage._pending_vector_docs["c1"].vector is None
  201. await storage.index_done_callback()
  202. # Two embed calls total: one lazy, one for the new content during flush.
  203. assert storage._counting_embed.call_count == 2
  204. # And the persisted content is "b".
  205. rows = storage._captured_executemany[0][1]
  206. assert rows[0][5] == "b"
  207. # ---------------------------------------------------------------------------
  208. # 6. Embedding failure leaves buffers intact
  209. # ---------------------------------------------------------------------------
  210. @pytest.mark.asyncio
  211. async def test_embedding_failure_leaves_pending_for_retry():
  212. embed = CountingEmbed(fail_times=1)
  213. storage = _make_storage(embed=embed)
  214. await storage.upsert({"c1": _chunk_data(content="retry me")})
  215. with pytest.raises(RuntimeError, match="embedding failed"):
  216. await storage.index_done_callback()
  217. assert storage._retry_call_count["n"] == 0
  218. assert "c1" in storage._pending_vector_docs
  219. assert storage._pending_vector_docs["c1"].vector is None
  220. # Next flush succeeds; embed called twice total (one failure + one success).
  221. await storage.index_done_callback()
  222. assert embed.call_count == 2
  223. assert storage._pending_vector_docs == {}
  224. assert len(storage._captured_executemany) == 1
  225. # ---------------------------------------------------------------------------
  226. # 7. _run_with_retry failure leaves buffers + cached vectors intact
  227. # ---------------------------------------------------------------------------
  228. @pytest.mark.asyncio
  229. async def test_persistence_failure_keeps_buffers_and_cached_vectors():
  230. storage = _make_storage(fail_run_with_retry=True)
  231. await storage.upsert({"c1": _chunk_data(content="alpha")})
  232. with pytest.raises(RuntimeError, match="simulated PG failure"):
  233. await storage.index_done_callback()
  234. # Buffer intact, vector cached (so next flush won't re-embed).
  235. assert "c1" in storage._pending_vector_docs
  236. assert storage._pending_vector_docs["c1"].vector is not None
  237. embed_calls_before = storage._counting_embed.call_count
  238. # Repair the DB and flush again.
  239. storage.db._run_with_retry.side_effect = None
  240. storage.db._run_with_retry.return_value = None
  241. # We need to actually persist this time; re-attach a working side_effect.
  242. captured_em = storage._captured_executemany
  243. captured_ex = storage._captured_execute
  244. async def working_retry(operation, **kwargs):
  245. mock_conn = AsyncMock()
  246. tx_cm = AsyncMock()
  247. tx_cm.__aenter__.return_value = None
  248. tx_cm.__aexit__.return_value = None
  249. mock_conn.transaction = MagicMock(return_value=tx_cm)
  250. await operation(mock_conn)
  251. for call in mock_conn.executemany.call_args_list:
  252. captured_em.append((call.args[0], call.args[1]))
  253. for call in mock_conn.execute.call_args_list:
  254. captured_ex.append((call.args[0], call.args[1:]))
  255. storage.db._run_with_retry.side_effect = working_retry
  256. await storage.index_done_callback()
  257. # No re-embed thanks to the cached vector.
  258. assert storage._counting_embed.call_count == embed_calls_before
  259. assert storage._pending_vector_docs == {}
  260. assert len(captured_em) == 1
  261. # ---------------------------------------------------------------------------
  262. # 8. Delete cancels pending upsert
  263. # ---------------------------------------------------------------------------
  264. @pytest.mark.asyncio
  265. async def test_delete_cancels_pending_upsert():
  266. storage = _make_storage()
  267. await storage.upsert({"c1": _chunk_data()})
  268. await storage.delete(["c1"])
  269. assert "c1" not in storage._pending_vector_docs
  270. assert "c1" in storage._pending_vector_deletes
  271. await storage.index_done_callback()
  272. # Only a delete went out, no upsert executemany.
  273. assert storage._captured_executemany == []
  274. assert len(storage._captured_execute) == 1
  275. sql, args = storage._captured_execute[0]
  276. assert "DELETE FROM" in sql
  277. assert args[0] == "test_ws"
  278. assert args[1] == ["c1"]
  279. # ---------------------------------------------------------------------------
  280. # 9. Upsert cancels pending delete
  281. # ---------------------------------------------------------------------------
  282. @pytest.mark.asyncio
  283. async def test_upsert_cancels_pending_delete():
  284. storage = _make_storage()
  285. await storage.delete(["c1"])
  286. await storage.upsert({"c1": _chunk_data(content="new")})
  287. assert "c1" in storage._pending_vector_docs
  288. assert "c1" not in storage._pending_vector_deletes
  289. await storage.index_done_callback()
  290. assert len(storage._captured_executemany) == 1
  291. # And no DELETE in the same flush.
  292. assert storage._captured_execute == []
  293. # ---------------------------------------------------------------------------
  294. # 10. delete_entity prunes pending docs and runs SQL predicate under lock
  295. # ---------------------------------------------------------------------------
  296. @pytest.mark.asyncio
  297. async def test_delete_entity_prunes_pending_and_runs_sql():
  298. storage = _make_storage(namespace=NameSpace.VECTOR_STORE_ENTITIES)
  299. entity_id = compute_mdhash_id("Alice", prefix="ent-")
  300. # Pending entity keyed by the hash id.
  301. await storage.upsert({entity_id: _entity_data(name="Alice")})
  302. await storage.delete_entity("Alice")
  303. # Pending pruned.
  304. assert entity_id not in storage._pending_vector_docs
  305. # SQL predicate fired against db.execute (the immediate path).
  306. storage.db.execute.assert_awaited_once()
  307. sql_arg = storage.db.execute.await_args.args[0]
  308. params_arg = storage.db.execute.await_args.args[1]
  309. assert "entity_name=$2" in sql_arg
  310. assert params_arg == {"workspace": "test_ws", "entity_name": "Alice"}
  311. # ---------------------------------------------------------------------------
  312. # 11. delete_entity_relation prunes pending relation docs + runs SQL predicate
  313. # ---------------------------------------------------------------------------
  314. @pytest.mark.asyncio
  315. async def test_delete_entity_relation_prunes_pending_and_runs_sql():
  316. storage = _make_storage(namespace=NameSpace.VECTOR_STORE_RELATIONSHIPS)
  317. await storage.upsert(
  318. {
  319. "r1": _relation_data(src="Alice", tgt="Bob"),
  320. "r2": _relation_data(src="Carol", tgt="Alice"),
  321. "r3": _relation_data(src="Eve", tgt="Mallory"),
  322. }
  323. )
  324. await storage.delete_entity_relation("Alice")
  325. assert "r1" not in storage._pending_vector_docs
  326. assert "r2" not in storage._pending_vector_docs
  327. assert "r3" in storage._pending_vector_docs
  328. storage.db.execute.assert_awaited_once()
  329. sql_arg = storage.db.execute.await_args.args[0]
  330. assert "source_id=$2 OR target_id=$2" in sql_arg
  331. # ---------------------------------------------------------------------------
  332. # 12. drop() clears buffers and runs workspace delete
  333. # ---------------------------------------------------------------------------
  334. @pytest.mark.asyncio
  335. async def test_drop_clears_buffers_and_runs_delete():
  336. storage = _make_storage()
  337. await storage.upsert({"c1": _chunk_data()})
  338. await storage.delete(["c2"])
  339. assert storage._pending_vector_docs and storage._pending_vector_deletes
  340. result = await storage.drop()
  341. assert result["status"] == "success"
  342. assert storage._pending_vector_docs == {}
  343. assert storage._pending_vector_deletes == set()
  344. storage.db.execute.assert_awaited_once()
  345. # ---------------------------------------------------------------------------
  346. # 13. Read-your-writes: get_by_id, get_by_ids, get_vectors_by_ids
  347. # ---------------------------------------------------------------------------
  348. @pytest.mark.asyncio
  349. async def test_get_by_id_returns_pending_and_hides_deletes():
  350. storage = _make_storage()
  351. await storage.upsert({"c1": _chunk_data(content="hello")})
  352. doc = await storage.get_by_id("c1")
  353. assert doc is not None
  354. assert doc["id"] == "c1"
  355. assert doc["content"] == "hello"
  356. assert "__vector__" not in doc
  357. assert "__id__" not in doc
  358. assert "created_at" in doc
  359. # SQL not touched for buffered hits.
  360. storage.db.query.assert_not_called()
  361. # Now delete and ensure the buffered tombstone wins over SQL.
  362. await storage.delete(["c1"])
  363. assert (await storage.get_by_id("c1")) is None
  364. @pytest.mark.asyncio
  365. async def test_get_by_ids_preserves_order_and_uses_any_sql():
  366. storage = _make_storage()
  367. await storage.upsert({"c1": _chunk_data(content="a")})
  368. await storage.delete(["c2"])
  369. # c3 will fall through to SQL.
  370. storage.db.query = AsyncMock(
  371. return_value=[{"id": "c3", "content": "from-pg", "created_at": 0}]
  372. )
  373. docs = await storage.get_by_ids(["c1", "c2", "c3"])
  374. assert docs[0] is not None and docs[0]["id"] == "c1" and docs[0]["content"] == "a"
  375. assert docs[1] is None # pending delete
  376. assert docs[2] is not None and docs[2]["id"] == "c3"
  377. # SQL fallback used `id = ANY($2)` (not string-built IN).
  378. sql_used = storage.db.query.await_args.args[0]
  379. assert "id = ANY($2)" in sql_used
  380. assert storage.db.query.await_args.args[1] == ["test_ws", ["c3"]]
  381. @pytest.mark.asyncio
  382. async def test_get_vectors_by_ids_returns_cached_and_skips_deletes():
  383. storage = _make_storage()
  384. await storage.upsert({"c1": _chunk_data(content="a")})
  385. await storage.upsert({"c2": _chunk_data(content="b")})
  386. await storage.delete(["c2"])
  387. # c3 falls through to SQL.
  388. storage.db.query = AsyncMock(
  389. return_value=[{"id": "c3", "content_vector": [0.5, 0.6, 0.7]}]
  390. )
  391. vecs = await storage.get_vectors_by_ids(["c1", "c2", "c3"])
  392. # c1 lazily embedded; c2 skipped; c3 from SQL.
  393. assert "c1" in vecs and len(vecs["c1"]) == 3
  394. assert "c2" not in vecs
  395. assert vecs["c3"] == [0.5, 0.6, 0.7]
  396. sql_used = storage.db.query.await_args.args[0]
  397. assert "id = ANY($2)" in sql_used
  398. # ---------------------------------------------------------------------------
  399. # 14. finalize() raises with pending counts if flush failed
  400. # ---------------------------------------------------------------------------
  401. @pytest.mark.asyncio
  402. async def test_finalize_raises_when_flush_fails_and_releases_client():
  403. storage = _make_storage(fail_run_with_retry=True)
  404. await storage.upsert({"c1": _chunk_data()})
  405. await storage.delete(["c2"])
  406. # Patch ClientManager.release_client to a no-op so we don't touch real state.
  407. from lightrag.kg import postgres_impl
  408. release_mock = AsyncMock()
  409. original = postgres_impl.ClientManager.release_client
  410. postgres_impl.ClientManager.release_client = release_mock
  411. try:
  412. with pytest.raises(RuntimeError, match="pending upserts"):
  413. await storage.finalize()
  414. release_mock.assert_awaited_once()
  415. assert storage.db is None
  416. finally:
  417. postgres_impl.ClientManager.release_client = original
  418. @pytest.mark.asyncio
  419. async def test_finalize_clean_path_flushes_then_releases_client():
  420. storage = _make_storage()
  421. await storage.upsert({"c1": _chunk_data()})
  422. from lightrag.kg import postgres_impl
  423. release_mock = AsyncMock()
  424. original = postgres_impl.ClientManager.release_client
  425. postgres_impl.ClientManager.release_client = release_mock
  426. try:
  427. await storage.finalize()
  428. finally:
  429. postgres_impl.ClientManager.release_client = original
  430. release_mock.assert_awaited_once()
  431. assert storage.db is None
  432. assert storage._pending_vector_docs == {}
  433. # ---------------------------------------------------------------------------
  434. # 15. Empty input no-ops
  435. # ---------------------------------------------------------------------------
  436. @pytest.mark.asyncio
  437. async def test_empty_inputs_are_noops():
  438. storage = _make_storage()
  439. await storage.upsert({})
  440. await storage.delete([])
  441. await storage.index_done_callback()
  442. assert storage._retry_call_count["n"] == 0
  443. assert storage._counting_embed.call_count == 0
  444. # ---------------------------------------------------------------------------
  445. # 16. delete_entity serializes against an in-flight flush via _flush_lock
  446. # ---------------------------------------------------------------------------
  447. class _GatedEmbed:
  448. """Embedding func that blocks on a gate so a flush can be paused mid-call."""
  449. def __init__(self, dim: int = 3):
  450. self.embedding_dim = dim
  451. self.max_token_size = 512
  452. self.model_name = "test_model"
  453. self.gate = asyncio.Event()
  454. self.entered = asyncio.Event()
  455. self.call_count = 0
  456. async def __call__(self, texts, **kwargs):
  457. self.call_count += 1
  458. self.entered.set()
  459. await self.gate.wait()
  460. return np.array([[1.0, 0.0, 0.0] for _ in texts], dtype=np.float32)
  461. @pytest.mark.asyncio
  462. async def test_delete_entity_serializes_against_in_flight_flush():
  463. """A `delete_entity` issued while a flush is mid-embedding must wait for
  464. the flush's lock to release before its SQL predicate runs — otherwise the
  465. flush could persist the entity row a microsecond after the predicate
  466. deleted it. This pins the lock-then-SQL contract in the source.
  467. """
  468. embed = _GatedEmbed()
  469. storage = _make_storage(namespace=NameSpace.VECTOR_STORE_ENTITIES, embed=embed)
  470. entity_id = compute_mdhash_id("Alice", prefix="ent-")
  471. await storage.upsert({entity_id: _entity_data(name="Alice")})
  472. flush_task = asyncio.create_task(storage.index_done_callback())
  473. # Wait until the flush is blocked inside the embedding call (it now holds
  474. # _flush_lock).
  475. await asyncio.wait_for(embed.entered.wait(), timeout=1.0)
  476. # Kick off delete_entity; it must block on the same lock.
  477. delete_task = asyncio.create_task(storage.delete_entity("Alice"))
  478. # Give the event loop a few turns to confirm delete_entity is blocked.
  479. for _ in range(5):
  480. await asyncio.sleep(0)
  481. assert (
  482. not delete_task.done()
  483. ), "delete_entity should be waiting on _flush_lock while flush holds it"
  484. # Unblock the flush; both should complete.
  485. embed.gate.set()
  486. await asyncio.wait_for(flush_task, timeout=1.0)
  487. await asyncio.wait_for(delete_task, timeout=1.0)
  488. # Flush ran its executemany, then delete_entity ran its predicate SQL.
  489. assert len(storage._captured_executemany) == 1
  490. storage.db.execute.assert_awaited_once()
  491. # ---------------------------------------------------------------------------
  492. # 17. Deletes-only flush: no executemany, single ANY($2) DELETE
  493. # ---------------------------------------------------------------------------
  494. @pytest.mark.asyncio
  495. async def test_deletes_only_flush_skips_executemany():
  496. """A flush that has only buffered deletes (no upserts) must still issue
  497. the parameterized DELETE under the transaction, and must NOT call
  498. executemany with an empty batch.
  499. """
  500. storage = _make_storage()
  501. await storage.delete(["c1", "c2", "c3"])
  502. assert storage._pending_vector_docs == {}
  503. assert len(storage._pending_vector_deletes) == 3
  504. await storage.index_done_callback()
  505. # No embedding was needed.
  506. assert storage._counting_embed.call_count == 0
  507. # No upsert executemany ran.
  508. assert storage._captured_executemany == []
  509. # Exactly one parameterized DELETE under the transaction.
  510. assert len(storage._captured_execute) == 1
  511. sql, args = storage._captured_execute[0]
  512. assert "DELETE FROM" in sql
  513. assert "id = ANY($2)" in sql
  514. assert args[0] == "test_ws"
  515. assert sorted(args[1]) == ["c1", "c2", "c3"]
  516. # Buffers cleared on success.
  517. assert storage._pending_vector_deletes == set()
  518. # ---------------------------------------------------------------------------
  519. # 18. Embedding count mismatch raises and preserves the buffer
  520. # ---------------------------------------------------------------------------
  521. @pytest.mark.asyncio
  522. async def test_embedding_count_mismatch_raises_and_preserves_buffer():
  523. """The in-flush ``len(embeddings) != len(docs_to_embed)`` check is
  524. defense-in-depth against an embedding provider that bypasses the
  525. ``EmbeddingFunc`` wrapper validation. Bypass the wrapper by replacing
  526. ``storage.embedding_func`` with a bare async callable that returns
  527. fewer rows than requested.
  528. """
  529. storage = _make_storage(embedding_batch_num=10)
  530. async def short_embed(texts, **kwargs):
  531. rows = max(0, len(list(texts)) - 1)
  532. return np.array([[1.0, 0.0, 0.0] for _ in range(rows)], dtype=np.float32)
  533. storage.embedding_func = short_embed
  534. await storage.upsert({"c1": _chunk_data(content="a")})
  535. await storage.upsert({"c2": _chunk_data(content="b")})
  536. with pytest.raises(RuntimeError, match="Embedding count mismatch"):
  537. await storage.index_done_callback()
  538. # Buffer survives the mismatch; nothing was persisted.
  539. assert {"c1", "c2"} == set(storage._pending_vector_docs.keys())
  540. assert storage._retry_call_count["n"] == 0
  541. # ---------------------------------------------------------------------------
  542. # 19. delete_entity discards a matching pending delete for the same hash id
  543. # ---------------------------------------------------------------------------
  544. @pytest.mark.asyncio
  545. async def test_delete_entity_discards_matching_pending_delete():
  546. """If both `delete()` (which buffers a tombstone) and `delete_entity()`
  547. fire for the same entity, the pending tombstone for the entity's hash id
  548. must be discarded — the predicate SQL covers it and we don't want a
  549. redundant ANY-DELETE for the same id in the next flush.
  550. """
  551. storage = _make_storage(namespace=NameSpace.VECTOR_STORE_ENTITIES)
  552. entity_id = compute_mdhash_id("Alice", prefix="ent-")
  553. await storage.delete([entity_id])
  554. assert entity_id in storage._pending_vector_deletes
  555. await storage.delete_entity("Alice")
  556. # The pending tombstone for the hash id was discarded.
  557. assert entity_id not in storage._pending_vector_deletes
  558. # And the predicate SQL ran.
  559. storage.db.execute.assert_awaited_once()
  560. # A subsequent flush has nothing to do.
  561. await storage.index_done_callback()
  562. assert storage._retry_call_count["n"] == 0
  563. # ---------------------------------------------------------------------------
  564. # 20. delete_entity / delete_entity_relation raise pre-initialize()
  565. # ---------------------------------------------------------------------------
  566. @pytest.mark.asyncio
  567. async def test_delete_entity_pre_initialize_raises():
  568. """Calling delete_entity / delete_entity_relation before initialize()
  569. must raise RuntimeError, not silently drop the destructive intent.
  570. Silent no-op would defeat the data-loss visibility that finalize() and
  571. _flush_pending_vector_ops enforce on the symmetric paths.
  572. """
  573. db = MagicMock()
  574. db.execute = AsyncMock(return_value=None)
  575. embed = CountingEmbed()
  576. embedding_func = EmbeddingFunc(
  577. embedding_dim=embed.embedding_dim,
  578. func=embed,
  579. model_name=embed.model_name,
  580. )
  581. storage = PGVectorStorage(
  582. namespace=NameSpace.VECTOR_STORE_ENTITIES,
  583. workspace="test_ws",
  584. global_config={
  585. "embedding_batch_num": 10,
  586. "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.5},
  587. },
  588. embedding_func=embedding_func,
  589. )
  590. storage.db = db
  591. # Intentionally do NOT set _flush_lock (mimics pre-initialize state).
  592. assert storage._flush_lock is None
  593. with pytest.raises(RuntimeError, match="called before initialize"):
  594. await storage.delete_entity("Alice")
  595. with pytest.raises(RuntimeError, match="called before initialize"):
  596. await storage.delete_entity_relation("Alice")
  597. # No SQL fired (the methods short-circuited before touching db.execute).
  598. db.execute.assert_not_called()
  599. # ---------------------------------------------------------------------------
  600. # 21. _flush_pending_vector_ops raises on lifecycle violations
  601. # ---------------------------------------------------------------------------
  602. @pytest.mark.asyncio
  603. async def test_flush_after_client_release_raises_with_counts():
  604. """Direct call to _flush_pending_vector_ops after db release with a
  605. non-empty buffer must raise — silent return would lose data without any
  606. operator-visible signal (the symmetric path in finalize() already raises).
  607. """
  608. storage = _make_storage()
  609. await storage.upsert({"c1": _chunk_data()})
  610. await storage.delete(["c2"])
  611. # Mimic post-finalize state: client released, buffers preserved.
  612. storage.db = None
  613. with pytest.raises(RuntimeError, match="after client release"):
  614. await storage._flush_pending_vector_ops()
  615. # Buffers untouched — the call must not have eaten the data on its way out.
  616. assert "c1" in storage._pending_vector_docs
  617. assert "c2" in storage._pending_vector_deletes
  618. @pytest.mark.asyncio
  619. async def test_flush_pre_initialize_with_pending_raises():
  620. """Pre-initialize call with a non-empty buffer (programmer error path:
  621. direct buffer manipulation before initialize) also raises rather than
  622. silently returning."""
  623. storage = _make_storage()
  624. # Reset to pre-initialize state but seed the buffer to simulate the
  625. # programmer-error scenario.
  626. storage._flush_lock = None
  627. storage._pending_vector_docs["c1"] = _PendingPGVectorDoc(
  628. item={"__id__": "c1", **_chunk_data()},
  629. created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
  630. )
  631. with pytest.raises(RuntimeError, match="called before initialize"):
  632. await storage._flush_pending_vector_ops()
  633. # ---------------------------------------------------------------------------
  634. # 22. get_vectors_by_ids drops embeddings whose pending record changed
  635. # ---------------------------------------------------------------------------
  636. @pytest.mark.asyncio
  637. async def test_get_vectors_by_ids_drops_when_pending_record_swapped():
  638. """If a concurrent upsert replaces the pending record while embedding I/O
  639. is in flight (outside the lock), the freshly-computed vector no longer
  640. matches the current buffer state and must be dropped from the response
  641. rather than returned stale."""
  642. embed = _GatedEmbed()
  643. storage = _make_storage(embed=embed)
  644. # Seed the original pending doc.
  645. await storage.upsert({"c1": _chunk_data(content="original")})
  646. original_pending = storage._pending_vector_docs["c1"]
  647. # Kick off get_vectors_by_ids — it will block inside the embedding gate
  648. # *outside* _flush_lock.
  649. task = asyncio.create_task(storage.get_vectors_by_ids(["c1"]))
  650. await asyncio.wait_for(embed.entered.wait(), timeout=1.0)
  651. # While embedding is in flight, replace the pending record via upsert.
  652. # The new doc has a different content and a vector=None.
  653. await storage.upsert({"c1": _chunk_data(content="replaced")})
  654. assert storage._pending_vector_docs["c1"] is not original_pending
  655. # Release the gate; the embedding completes and the identity check fires.
  656. embed.gate.set()
  657. result = await asyncio.wait_for(task, timeout=1.0)
  658. # The stale embedding is NOT returned, and is NOT cached on the new
  659. # pending record (which keeps vector=None for the next flush to embed).
  660. assert result == {}
  661. assert storage._pending_vector_docs["c1"].vector is None
  662. @pytest.mark.asyncio
  663. async def test_get_vectors_by_ids_drops_when_pending_record_removed():
  664. """Same identity-check guard but for the delete-while-embedding race."""
  665. embed = _GatedEmbed()
  666. storage = _make_storage(embed=embed)
  667. await storage.upsert({"c1": _chunk_data(content="original")})
  668. task = asyncio.create_task(storage.get_vectors_by_ids(["c1"]))
  669. await asyncio.wait_for(embed.entered.wait(), timeout=1.0)
  670. # Delete the pending record mid-embedding.
  671. await storage.delete(["c1"])
  672. assert "c1" not in storage._pending_vector_docs
  673. embed.gate.set()
  674. result = await asyncio.wait_for(task, timeout=1.0)
  675. # The vector for the removed id is dropped from the response.
  676. assert result == {}