| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514 |
- """Unit tests for MilvusVectorDBStorage's deferred-embedding flush pipeline.
- All tests use mocks — no running Milvus instance required.
- Mirrors the structure of tests/kg/opensearch_impl/test_opensearch_storage.py's
- TestVectorStorageBatching to keep behaviour aligned across backends.
- """
- import asyncio
- import os
- import numpy as np
- import pytest
- from unittest.mock import MagicMock, patch
- from lightrag.kg.milvus_impl import MilvusVectorDBStorage
- pytestmark = pytest.mark.offline
- # ---------------------------------------------------------------------------
- # Fixtures
- # ---------------------------------------------------------------------------
- class MockEmbeddingFunc:
- """Mock embedding function that returns random vectors."""
- def __init__(self, dim=8):
- self.embedding_dim = dim
- self.max_token_size = 512
- self.model_name = "mock-embed"
- async def __call__(self, texts, **kwargs):
- return np.random.rand(len(texts), self.embedding_dim).astype(np.float32)
- class CountingEmbeddingFunc(MockEmbeddingFunc):
- """Embedding test double that records calls and can fail a fixed number of times."""
- def __init__(self, dim=8, fail_times=0):
- super().__init__(dim=dim)
- self.fail_times = fail_times
- self.call_count = 0
- self.batches: list[list[str]] = []
- self.texts: list[str] = []
- async def __call__(self, texts, **kwargs):
- self.call_count += 1
- batch = list(texts)
- self.batches.append(batch)
- self.texts.extend(batch)
- if self.fail_times > 0:
- self.fail_times -= 1
- raise RuntimeError("embedding failed")
- return await super().__call__(texts, **kwargs)
- @pytest.fixture(autouse=True)
- def patch_namespace_lock():
- """Cache real asyncio.Locks per (namespace, workspace) for shared semantics.
- Two storage instances whose ``final_namespace`` matches must observe the
- same Lock instance — this fixture lets us assert that and also exercises
- real serialization between concurrent flush/upsert coroutines.
- """
- cache: dict[tuple[str, str | None], asyncio.Lock] = {}
- def factory(namespace, workspace=None, enable_logging=False):
- key = (namespace, workspace or "")
- lock = cache.get(key)
- if lock is None:
- lock = asyncio.Lock()
- cache[key] = lock
- return lock
- with patch("lightrag.kg.milvus_impl.get_namespace_lock", side_effect=factory):
- yield cache
- def _make_storage(
- embed_func,
- *,
- namespace="entities",
- workspace="test",
- meta_fields=None,
- ):
- """Build a MilvusVectorDBStorage skipping `initialize()` (no real client)."""
- if meta_fields is None:
- meta_fields = {"content", "entity_name", "src_id", "tgt_id"}
- storage = MilvusVectorDBStorage(
- namespace=namespace,
- workspace=workspace,
- global_config={
- "embedding_batch_num": 10,
- "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2},
- },
- embedding_func=embed_func,
- meta_fields=meta_fields,
- )
- # Bypass real Milvus client; manually wire the bits initialize() would set.
- # The flush lock is already constructed in __post_init__ via the patched
- # get_namespace_lock factory, so no manual lock wiring is needed here.
- storage._client = MagicMock()
- storage._client.has_collection.return_value = True
- storage._client.upsert = MagicMock(return_value={"upsert_count": 0})
- storage._client.delete = MagicMock(return_value={"delete_count": 0})
- storage._client.query = MagicMock(return_value=[])
- storage._client.load_collection = MagicMock()
- storage._initialized = True
- return storage
- # ---------------------------------------------------------------------------
- # Tests: deferred embedding + batched flush
- # ---------------------------------------------------------------------------
- @pytest.mark.asyncio
- async def test_upsert_buffers_without_embedding():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.upsert({"v1": {"content": "hello"}, "v2": {"content": "world"}})
- assert embed.call_count == 0
- assert set(s._pending_vector_docs.keys()) == {"v1", "v2"}
- assert s._pending_vector_docs["v1"].vector is None
- assert s._pending_vector_docs["v2"].vector is None
- s._client.upsert.assert_not_called()
- @pytest.mark.asyncio
- async def test_index_done_callback_triggers_flush():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.upsert({"v1": {"content": "hello"}, "v2": {"content": "world"}})
- await s.index_done_callback()
- assert embed.call_count == 1
- s._client.upsert.assert_called_once()
- call_kwargs = s._client.upsert.call_args.kwargs
- assert call_kwargs["collection_name"] == s.final_namespace
- upserted = call_kwargs["data"]
- assert {row["id"] for row in upserted} == {"v1", "v2"}
- assert all("vector" in row for row in upserted)
- # Buffers cleared after a successful flush.
- assert s._pending_vector_docs == {}
- assert s._pending_vector_deletes == set()
- @pytest.mark.asyncio
- async def test_repeated_upsert_same_id_embeds_once():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.upsert({"v1": {"content": "first"}})
- await s.upsert({"v1": {"content": "second"}})
- await s.upsert({"v1": {"content": "third"}})
- await s.index_done_callback()
- assert embed.call_count == 1
- # Only the latest content survives and was embedded.
- assert embed.texts == ["third"]
- s._client.upsert.assert_called_once()
- @pytest.mark.asyncio
- async def test_deferred_embeddings_respect_batch_size():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- s._max_batch_size = 2
- await s.upsert({f"v{i}": {"content": f"doc {i}"} for i in range(5)})
- await s.index_done_callback()
- # 5 docs / batch 2 → 3 batches → 3 embedding calls
- assert embed.call_count == 3
- assert [len(b) for b in embed.batches] == [2, 2, 1]
- @pytest.mark.asyncio
- async def test_get_vectors_by_ids_lazy_embed_then_reuse_in_flush():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.upsert({"v1": {"content": "hello"}})
- vectors = await s.get_vectors_by_ids(["v1"])
- assert "v1" in vectors
- assert embed.call_count == 1 # lazy embed inside get_vectors_by_ids
- # The lazy-embedded vector is cached on the pending doc.
- assert s._pending_vector_docs["v1"].vector is not None
- await s.index_done_callback()
- # Flush reused the cached vector — no extra embedding call.
- assert embed.call_count == 1
- s._client.upsert.assert_called_once()
- @pytest.mark.asyncio
- async def test_flush_failure_keeps_buffer_and_no_double_embed_on_retry():
- embed = CountingEmbeddingFunc(fail_times=1) # first flush raises
- s = _make_storage(embed)
- await s.upsert({"v1": {"content": "hello"}})
- with pytest.raises(RuntimeError, match="embedding failed"):
- await s.index_done_callback()
- # Buffer must remain so the next flush can retry.
- assert "v1" in s._pending_vector_docs
- assert s._pending_vector_docs["v1"].vector is None
- s._client.upsert.assert_not_called()
- # Second attempt succeeds; total embed calls is 2 (one failed + one ok),
- # not 3 — the same content was retried exactly once.
- await s.index_done_callback()
- assert embed.call_count == 2
- s._client.upsert.assert_called_once()
- assert s._pending_vector_docs == {}
- @pytest.mark.asyncio
- async def test_server_upsert_failure_keeps_buffer():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- s._client.upsert.side_effect = RuntimeError("milvus down")
- await s.upsert({"v1": {"content": "hello"}})
- with pytest.raises(RuntimeError, match="milvus down"):
- await s.index_done_callback()
- # Embedding ran but server write failed; buffer must remain populated.
- assert "v1" in s._pending_vector_docs
- # Vector should be cached so retry doesn't re-embed.
- assert s._pending_vector_docs["v1"].vector is not None
- # On retry, no further embedding call; only the server write is reattempted.
- s._client.upsert.side_effect = None
- s._client.upsert.return_value = {"upsert_count": 1}
- await s.index_done_callback()
- assert embed.call_count == 1
- assert s._pending_vector_docs == {}
- @pytest.mark.asyncio
- async def test_finalize_raises_when_buffer_unflushed():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- s._client.upsert.side_effect = RuntimeError("transient milvus error")
- await s.upsert({"v1": {"content": "hello"}})
- with pytest.raises(RuntimeError, match="finalize.*flush raised"):
- await s.finalize()
- # Buffer still populated — caller knows data was lost.
- assert "v1" in s._pending_vector_docs
- @pytest.mark.asyncio
- async def test_delete_then_upsert_same_id_keeps_upsert():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.delete(["v1"])
- assert "v1" in s._pending_vector_deletes
- await s.upsert({"v1": {"content": "hello"}})
- assert "v1" in s._pending_vector_docs
- assert "v1" not in s._pending_vector_deletes
- await s.index_done_callback()
- s._client.upsert.assert_called_once()
- s._client.delete.assert_not_called()
- @pytest.mark.asyncio
- async def test_upsert_then_delete_same_id_keeps_delete():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.upsert({"v1": {"content": "hello"}})
- await s.delete(["v1"])
- assert "v1" not in s._pending_vector_docs
- assert "v1" in s._pending_vector_deletes
- await s.index_done_callback()
- # No upsert payload, only the delete batch.
- s._client.upsert.assert_not_called()
- s._client.delete.assert_called_once()
- assert s._client.delete.call_args.kwargs["pks"] == ["v1"]
- @pytest.mark.asyncio
- async def test_delete_entity_relation_raises_on_server_failure():
- """Server-side failure must bubble up — no log-and-swallow."""
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- s._client.query.return_value = [{"id": "rel1"}, {"id": "rel2"}]
- s._client.delete.side_effect = RuntimeError("milvus delete failed")
- with pytest.raises(RuntimeError, match="milvus delete failed"):
- await s.delete_entity_relation("X")
- @pytest.mark.asyncio
- async def test_delete_entity_relation_prunes_pending_buffer():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.upsert(
- {
- "rel-A-B": {"content": "A → B", "src_id": "A", "tgt_id": "B"},
- "rel-C-D": {"content": "C → D", "src_id": "C", "tgt_id": "D"},
- }
- )
- s._client.query.return_value = [] # no server-side hits
- await s.delete_entity_relation("A")
- # Pending doc whose src_id == A is pruned, the other survives.
- assert "rel-A-B" not in s._pending_vector_docs
- assert "rel-C-D" in s._pending_vector_docs
- @pytest.mark.asyncio
- async def test_delete_entity_relation_diverges_when_buffer_overwrites_persisted():
- """Pins the deferred ↔ eager semantic divergence documented on
- ``delete_entity_relation``.
- Scenario: a persisted row ``rel-X-Y`` has ``src_id="X" / tgt_id="Y"``,
- and a pending upsert is about to rewrite that same id so it would
- instead carry ``src_id="A" / tgt_id="B"``. A call to
- ``delete_entity_relation("A")`` arrives before the buffer is flushed.
- Expected (deferred mode, current implementation):
- * server-side filter ``src_id == "A" or tgt_id == "A"`` does NOT
- match the persisted row (its src/tgt are still X/Y), so the
- server-side delete is a no-op;
- * the buffered upsert IS pruned (its buffered src/tgt match);
- * net effect: persisted ``rel-X-Y`` (old values) survives and the
- pending overwrite is lost.
- Under eager ordering (upsert → flush → delete) the persisted row
- would have been rewritten first and then matched by the filter, so
- the final state would have been a deleted ``rel-X-Y``. This test
- locks in the divergence so a future refactor can't silently change
- it without touching the docstring.
- """
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- # Buffered upsert rewriting an (assumed) already-persisted rel-X-Y
- # so that its new src/tgt would match entity "A".
- await s.upsert({"rel-X-Y": {"content": "A → B", "src_id": "A", "tgt_id": "B"}})
- assert "rel-X-Y" in s._pending_vector_docs
- # Server still sees the OLD persisted row (src_id="X" / tgt_id="Y"),
- # so a filter on entity "A" returns nothing.
- s._client.query.return_value = []
- await s.delete_entity_relation("A")
- # Buffered overwrite is pruned (matches buffered src/tgt view) …
- assert "rel-X-Y" not in s._pending_vector_docs
- # … but the server-side delete is not issued, because the filter
- # didn't match the persisted row's actual src/tgt.
- s._client.delete.assert_not_called()
- @pytest.mark.asyncio
- async def test_delete_entity_relation_eager_ordering_matches_persisted():
- """Counterpart to the divergence test: if the caller flushes before
- invoking ``delete_entity_relation``, the persisted row reflects the
- buffered overwrite and the server-side filter catches it.
- This documents the recommended workaround called out in the
- ``delete_entity_relation`` docstring: ``index_done_callback()`` first
- when eager-equivalent semantics are required.
- """
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.upsert({"rel-X-Y": {"content": "A → B", "src_id": "A", "tgt_id": "B"}})
- await s.index_done_callback() # buffered upsert is now persisted
- assert s._pending_vector_docs == {}
- s._client.upsert.assert_called_once()
- # With the row persisted, the server filter on entity "A" now hits.
- s._client.query.return_value = [{"id": "rel-X-Y"}]
- await s.delete_entity_relation("A")
- s._client.delete.assert_called_once()
- assert s._client.delete.call_args.kwargs["pks"] == ["rel-X-Y"]
- @pytest.mark.asyncio
- async def test_get_by_id_reads_pending_buffer_without_vector():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.upsert({"v1": {"content": "hello", "entity_name": "E1"}})
- doc = await s.get_by_id("v1")
- assert doc is not None
- assert doc["id"] == "v1"
- assert doc.get("entity_name") == "E1"
- assert "vector" not in doc
- # Server was not queried because the buffer answered the read.
- s._client.query.assert_not_called()
- @pytest.mark.asyncio
- async def test_get_by_id_returns_none_for_pending_delete():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.delete(["v1"])
- assert await s.get_by_id("v1") is None
- s._client.query.assert_not_called()
- @pytest.mark.asyncio
- async def test_env_workspace_override_shares_flush_lock(patch_namespace_lock):
- """Two instances whose final_namespace collides must share the flush lock."""
- cache = patch_namespace_lock
- embed = CountingEmbeddingFunc()
- with patch.dict(os.environ, {"MILVUS_WORKSPACE": "shared_ws"}, clear=False):
- a = _make_storage(embed, workspace="caller_a")
- b = _make_storage(embed, workspace="caller_b")
- assert a.final_namespace == b.final_namespace == "shared_ws_entities"
- assert a._flush_lock is b._flush_lock
- # Sanity: only one lock object was cached for that final_namespace.
- assert len([k for k in cache if k[0] == "shared_ws_entities"]) == 1
- @pytest.mark.asyncio
- async def test_distinct_namespaces_get_independent_locks(patch_namespace_lock):
- """Different final_namespaces must NOT share a lock."""
- embed = CountingEmbeddingFunc()
- # Two instances with no env override and different workspaces produce
- # different final_namespaces ("a_entities" vs "b_entities").
- a = _make_storage(embed, workspace="a")
- b = _make_storage(embed, workspace="b")
- assert a.final_namespace != b.final_namespace
- assert a._flush_lock is not b._flush_lock
- @pytest.mark.asyncio
- async def test_mixed_upsert_and_delete_in_single_flush():
- """A flush carrying both pending upserts and pending deletes (on disjoint
- ids) must dispatch one server upsert and one server delete in a single
- pass, then clear both buffers."""
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.upsert({"a": {"content": "alpha"}})
- await s.delete(["b"])
- assert set(s._pending_vector_docs.keys()) == {"a"}
- assert s._pending_vector_deletes == {"b"}
- await s.index_done_callback()
- s._client.upsert.assert_called_once()
- upsert_kwargs = s._client.upsert.call_args.kwargs
- assert {row["id"] for row in upsert_kwargs["data"]} == {"a"}
- s._client.delete.assert_called_once()
- assert s._client.delete.call_args.kwargs["pks"] == ["b"]
- # Both buffers cleared after a successful flush.
- assert s._pending_vector_docs == {}
- assert s._pending_vector_deletes == set()
- @pytest.mark.asyncio
- async def test_finalize_clean_flush_no_raise():
- """Happy-path counterpart to test_finalize_raises_when_buffer_unflushed:
- a successful flush during finalize() must leave both buffers empty and
- must not raise."""
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.upsert({"v1": {"content": "hello"}})
- await s.delete(["v2"])
- await s.finalize() # must not raise
- s._client.upsert.assert_called_once()
- s._client.delete.assert_called_once()
- assert s._pending_vector_docs == {}
- assert s._pending_vector_deletes == set()
- @pytest.mark.asyncio
- async def test_drop_clears_pending_buffers():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- s._client.has_collection.return_value = False # skip drop_collection call
- # Stub out _create_collection_if_not_exist to avoid hitting MilvusIndexConfig logic.
- with patch.object(s, "_create_collection_if_not_exist"):
- await s.upsert({"v1": {"content": "hello"}})
- await s.delete(["v2"])
- assert s._pending_vector_docs and s._pending_vector_deletes
- result = await s.drop()
- assert result["status"] == "success"
- assert s._pending_vector_docs == {}
- assert s._pending_vector_deletes == set()
|