| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396 |
- """Deferred-embedding coverage for ``NanoVectorDBStorage``.
- The storage no longer embeds eagerly in ``upsert``: it buffers a pending doc
- and embeds once per id at flush time (``index_done_callback`` / ``finalize``).
- These tests pin that contract using a counting mock embedding function — no
- live model or network. They mirror the protocol proven for
- ``OpenSearchVectorDBStorage`` (issue #2785).
- """
- import numpy as np
- import pytest
- nano_vectordb = pytest.importorskip("nano_vectordb") # noqa: F841
- from lightrag.kg.nano_vector_db_impl import NanoVectorDBStorage # noqa: E402
- from lightrag.kg.shared_storage import ( # noqa: E402
- initialize_share_data,
- finalize_share_data,
- )
- from lightrag.utils import EmbeddingFunc # noqa: E402
- DIM = 8
- @pytest.fixture(autouse=True)
- def _shared_data():
- finalize_share_data()
- initialize_share_data()
- yield
- finalize_share_data()
- class _CountingEmbed:
- """Async embedding callable that records how many texts it embedded and how
- many times it was invoked (one invocation == one batch)."""
- def __init__(self, dim: int = DIM):
- self.dim = dim
- self.call_count = 0
- self.embedded_texts: list[str] = []
- async def __call__(self, texts, **kwargs):
- self.call_count += 1
- self.embedded_texts.extend(texts)
- # Deterministic per-text vector so duplicates are still 1-1.
- return np.array(
- [
- np.full(self.dim, (abs(hash(t)) % 97) + 1, dtype=np.float32)
- for t in texts
- ]
- )
- def _make_storage(tmp_path, embed: _CountingEmbed) -> NanoVectorDBStorage:
- return NanoVectorDBStorage(
- namespace="test_vectors",
- workspace="ws",
- global_config={
- "working_dir": str(tmp_path),
- "embedding_batch_num": 32,
- "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2},
- },
- embedding_func=EmbeddingFunc(embedding_dim=DIM, max_token_size=512, func=embed),
- meta_fields={"content"},
- )
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_upsert_defers_embedding_to_index_done_callback(tmp_path):
- embed = _CountingEmbed()
- storage = _make_storage(tmp_path, embed)
- await storage.initialize()
- await storage.upsert(
- {
- "id1": {"content": "alpha"},
- "id2": {"content": "beta"},
- }
- )
- assert embed.call_count == 0, "upsert must not embed"
- assert len(storage._client) == 0, "nothing should be materialized yet"
- await storage.index_done_callback()
- assert embed.call_count == 1, "flush should embed in a single batch"
- assert sorted(embed.embedded_texts) == ["alpha", "beta"]
- assert len(storage._client) == 2
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_repeated_upserts_same_id_embed_once_per_flush(tmp_path):
- embed = _CountingEmbed()
- storage = _make_storage(tmp_path, embed)
- await storage.initialize()
- await storage.upsert({"id1": {"content": "v1"}})
- await storage.upsert({"id1": {"content": "v2"}})
- await storage.upsert({"id1": {"content": "v3"}})
- await storage.index_done_callback()
- assert embed.call_count == 1
- assert embed.embedded_texts == ["v3"], "only the latest content is embedded"
- assert len(storage._client) == 1
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_get_vectors_caches_and_flush_reuses(tmp_path):
- embed = _CountingEmbed()
- storage = _make_storage(tmp_path, embed)
- await storage.initialize()
- await storage.upsert({"id1": {"content": "alpha"}})
- vecs = await storage.get_vectors_by_ids(["id1"])
- assert "id1" in vecs and len(vecs["id1"]) == DIM
- assert embed.call_count == 1, "get_vectors_by_ids embeds pending lazily"
- # Flush must reuse the cached vector, not re-embed.
- await storage.index_done_callback()
- assert embed.call_count == 1, "flush should reuse the cached temp vector"
- assert len(storage._client) == 1
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_reupsert_after_get_vectors_clears_cached_vector(tmp_path):
- embed = _CountingEmbed()
- storage = _make_storage(tmp_path, embed)
- await storage.initialize()
- await storage.upsert({"id1": {"content": "old"}})
- await storage.get_vectors_by_ids(["id1"]) # caches a temp vector for "old"
- assert embed.call_count == 1
- # New content version must clear the cached vector and re-embed at flush.
- await storage.upsert({"id1": {"content": "new"}})
- await storage.index_done_callback()
- assert embed.call_count == 2
- assert embed.embedded_texts == ["old", "new"]
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_delete_cancels_pending_and_removes_materialized(tmp_path):
- embed = _CountingEmbed()
- storage = _make_storage(tmp_path, embed)
- await storage.initialize()
- # Materialize id1; leave id2 only as a pending (unflushed) upsert.
- await storage.upsert({"id1": {"content": "alpha"}})
- await storage.index_done_callback()
- await storage.upsert({"id2": {"content": "beta"}})
- await storage.delete(["id1", "id2"])
- assert "id2" not in storage._pending_upserts, "delete cancels pending upsert"
- assert len(storage._client) == 0, "delete removes the materialized row immediately"
- assert await storage.get_by_id("id1") is None
- assert await storage.get_by_id("id2") is None
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_stale_client_reload_still_flushes_pending_upsert(tmp_path):
- embed = _CountingEmbed()
- writer = _make_storage(tmp_path, embed)
- stale_writer = _make_storage(tmp_path, embed)
- await writer.initialize()
- await stale_writer.initialize()
- await writer.upsert({"id1": {"content": "alpha"}})
- assert await writer.index_done_callback() is True
- assert stale_writer.storage_updated.value is True
- await stale_writer.upsert({"id2": {"content": "beta"}})
- assert await stale_writer.index_done_callback() is True
- reader = _make_storage(tmp_path, embed)
- await reader.initialize()
- rows = await reader.get_by_ids(["id1", "id2"])
- assert [row["id"] for row in rows] == ["id1", "id2"]
- assert stale_writer._pending_upserts == {}
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_delete_reloads_stale_client_before_mutating(tmp_path):
- embed = _CountingEmbed()
- writer = _make_storage(tmp_path, embed)
- stale_deleter = _make_storage(tmp_path, embed)
- await writer.initialize()
- await stale_deleter.initialize()
- await writer.upsert({"id1": {"content": "alpha"}})
- assert await writer.index_done_callback() is True
- assert stale_deleter.storage_updated.value is True
- await stale_deleter.delete(["id1"])
- assert stale_deleter.storage_updated.value is False
- assert await stale_deleter.index_done_callback() is True
- reader = _make_storage(tmp_path, embed)
- await reader.initialize()
- assert await reader.get_by_id("id1") is None
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_finalize_reloads_stale_client_before_flushing(tmp_path):
- embed = _CountingEmbed()
- writer = _make_storage(tmp_path, embed)
- stale_finalizer = _make_storage(tmp_path, embed)
- await writer.initialize()
- await stale_finalizer.initialize()
- await writer.upsert({"id1": {"content": "alpha"}})
- assert await writer.index_done_callback() is True
- assert stale_finalizer.storage_updated.value is True
- await stale_finalizer.upsert({"id2": {"content": "beta"}})
- await stale_finalizer.finalize()
- reader = _make_storage(tmp_path, embed)
- await reader.initialize()
- rows = await reader.get_by_ids(["id1", "id2"])
- assert [row["id"] for row in rows] == ["id1", "id2"]
- assert stale_finalizer._pending_upserts == {}
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_read_your_writes_and_query_after_flush(tmp_path):
- embed = _CountingEmbed()
- storage = _make_storage(tmp_path, embed)
- await storage.initialize()
- await storage.upsert({"id1": {"content": "alpha"}})
- # Before flush: read paths see the pending row, query does not.
- hit = await storage.get_by_id("id1")
- assert hit is not None and hit["id"] == "id1" and hit["content"] == "alpha"
- by_ids = await storage.get_by_ids(["id1", "missing"])
- assert by_ids[0]["id"] == "id1" and by_ids[1] is None
- assert await storage.query("alpha", top_k=5) == [], "query ignores unflushed data"
- # After flush: query returns the row.
- await storage.index_done_callback()
- results = await storage.query("alpha", top_k=5)
- assert any(r["id"] == "id1" for r in results)
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_finalize_flushes_pending(tmp_path):
- embed = _CountingEmbed()
- storage = _make_storage(tmp_path, embed)
- await storage.initialize()
- await storage.upsert({"id1": {"content": "alpha"}})
- await storage.finalize()
- assert embed.call_count == 1
- assert storage._pending_upserts == {}
- assert len(storage._client) == 1
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_delete_entity_relation_cancels_pending(tmp_path):
- embed = _CountingEmbed()
- storage = NanoVectorDBStorage(
- namespace="test_relations",
- workspace="ws",
- global_config={
- "working_dir": str(tmp_path),
- "embedding_batch_num": 32,
- "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2},
- },
- embedding_func=EmbeddingFunc(embedding_dim=DIM, max_token_size=512, func=embed),
- meta_fields={"content", "src_id", "tgt_id"},
- )
- await storage.initialize()
- # Materialize r1 (A->B), leave r2 (A->C) and r3 (X->Y) as pending.
- await storage.upsert({"r1": {"content": "rel1", "src_id": "A", "tgt_id": "B"}})
- await storage.index_done_callback()
- await storage.upsert(
- {
- "r2": {"content": "rel2", "src_id": "A", "tgt_id": "C"},
- "r3": {"content": "rel3", "src_id": "X", "tgt_id": "Y"},
- }
- )
- await storage.delete_entity_relation("A")
- assert "r2" not in storage._pending_upserts, "incident pending entry cancelled"
- assert "r3" in storage._pending_upserts, "unrelated pending entry preserved"
- assert len(storage._client) == 0, "materialized A->B removed"
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_flush_embedding_failure_raises_and_keeps_pending(tmp_path):
- class _FailingEmbed:
- def __init__(self):
- self.call_count = 0
- async def __call__(self, texts, **kwargs):
- self.call_count += 1
- raise RuntimeError("embed boom")
- embed = _FailingEmbed()
- storage = NanoVectorDBStorage(
- namespace="test_vectors",
- workspace="ws",
- global_config={
- "working_dir": str(tmp_path),
- "embedding_batch_num": 32,
- "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2},
- },
- embedding_func=EmbeddingFunc(embedding_dim=DIM, max_token_size=512, func=embed),
- meta_fields={"content"},
- )
- await storage.initialize()
- await storage.upsert({"id1": {"content": "alpha"}})
- with pytest.raises(RuntimeError, match="embed boom"):
- await storage.index_done_callback()
- assert "id1" in storage._pending_upserts, "pending preserved for retry"
- assert len(storage._client) == 0, "nothing materialized on embed failure"
- # Embed failure happens before self._client.upsert in _flush_pending_locked,
- # so _client_dirty must NOT be set. (A save-stage failure would leave it True
- # — see test_finalize_retries_save_after_flush_failure.)
- assert storage._client_dirty is False
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_drop_discards_pending_without_embedding(tmp_path):
- embed = _CountingEmbed()
- storage = _make_storage(tmp_path, embed)
- await storage.initialize()
- await storage.upsert({"id1": {"content": "alpha"}})
- assert "id1" in storage._pending_upserts
- result = await storage.drop()
- assert result["status"] == "success"
- assert storage._pending_upserts == {}, "drop discards buffered upserts"
- assert embed.call_count == 0, "drop must not embed"
- assert storage._client_dirty is False
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_finalize_retries_save_after_flush_failure(tmp_path):
- embed = _CountingEmbed()
- storage = _make_storage(tmp_path, embed)
- await storage.initialize()
- await storage.upsert({"id1": {"content": "alpha"}})
- original_save = storage._save_to_disk_locked
- save_calls = 0
- def fail_once():
- nonlocal save_calls
- save_calls += 1
- if save_calls == 1:
- raise OSError("boom")
- original_save()
- storage._save_to_disk_locked = fail_once
- with pytest.raises(OSError, match="boom"):
- await storage.finalize()
- assert storage._pending_upserts == {}
- assert storage._client_dirty is True
- await storage.finalize()
- assert save_calls == 2
- assert storage._client_dirty is False
- reader = _make_storage(tmp_path, embed)
- await reader.initialize()
- hit = await reader.get_by_id("id1")
- assert hit is not None and hit["id"] == "id1"
|