| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534 |
- """Unit tests for MongoVectorDBStorage's deferred-embedding flush pipeline.
- All tests use mocks — no running MongoDB instance required.
- Mirrors the structure of tests/kg/opensearch_impl/test_opensearch_storage.py's
- TestVectorStorageBatching to keep behaviour aligned across backends.
- """
- import asyncio
- import os
- import numpy as np
- import pytest
- from unittest.mock import AsyncMock, MagicMock, patch
- pytest.importorskip(
- "pymongo",
- reason="pymongo is required for Mongo storage tests",
- )
- from pymongo import UpdateOne, DeleteOne # type: ignore
- from lightrag.kg.mongo_impl import MongoVectorDBStorage
- pytestmark = pytest.mark.offline
- # ---------------------------------------------------------------------------
- # Fixtures and helpers
- # ---------------------------------------------------------------------------
- class MockEmbeddingFunc:
- def __init__(self, dim=8):
- self.embedding_dim = dim
- self.max_token_size = 512
- self.model_name = "mock-embed"
- async def __call__(self, texts, **kwargs):
- return np.random.rand(len(texts), self.embedding_dim).astype(np.float32)
- class CountingEmbeddingFunc(MockEmbeddingFunc):
- def __init__(self, dim=8, fail_times=0):
- super().__init__(dim=dim)
- self.fail_times = fail_times
- self.call_count = 0
- self.batches: list[list[str]] = []
- self.texts: list[str] = []
- async def __call__(self, texts, **kwargs):
- self.call_count += 1
- batch = list(texts)
- self.batches.append(batch)
- self.texts.extend(batch)
- if self.fail_times > 0:
- self.fail_times -= 1
- raise RuntimeError("embedding failed")
- return await super().__call__(texts, **kwargs)
- class _AsyncCursor:
- def __init__(self, docs):
- self._docs = list(docs)
- async def to_list(self, length=None):
- return list(self._docs)
- @pytest.fixture(autouse=True)
- def patch_namespace_lock(monkeypatch):
- """Cache real asyncio.Locks per (namespace, workspace) for shared semantics.
- Also unconditionally clears ``MONGODB_WORKSPACE`` so tests are insulated
- from shell-level env leakage: ``final_namespace`` depends on this var,
- and a leaked value (e.g. ``space2``) silently collapses distinct
- workspaces into one namespace. Tests that need an override set it
- explicitly via ``patch.dict(os.environ, ...)``.
- """
- monkeypatch.delenv("MONGODB_WORKSPACE", raising=False)
- cache: dict[tuple[str, str | None], asyncio.Lock] = {}
- def factory(namespace, workspace=None, enable_logging=False):
- key = (namespace, workspace or "")
- lock = cache.get(key)
- if lock is None:
- lock = asyncio.Lock()
- cache[key] = lock
- return lock
- with patch("lightrag.kg.mongo_impl.get_namespace_lock", side_effect=factory):
- yield cache
- def _make_storage(
- embed_func,
- *,
- namespace="entities",
- workspace="test",
- meta_fields=None,
- ):
- if meta_fields is None:
- meta_fields = {"content", "entity_name", "src_id", "tgt_id"}
- storage = MongoVectorDBStorage(
- namespace=namespace,
- workspace=workspace,
- global_config={
- "embedding_batch_num": 10,
- "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2},
- },
- embedding_func=embed_func,
- meta_fields=meta_fields,
- )
- # Wire a fake AsyncCollection (the only Mongo surface our code touches).
- storage._data = MagicMock()
- storage._data.bulk_write = AsyncMock()
- storage._data.delete_many = AsyncMock(return_value=MagicMock(deleted_count=0))
- storage._data.find_one = AsyncMock(return_value=None)
- storage._data.find = MagicMock(return_value=_AsyncCursor([]))
- storage.db = MagicMock() # non-None so finalize releases it
- from lightrag.kg.mongo_impl import get_namespace_lock
- storage._flush_lock = get_namespace_lock(
- namespace=storage.final_namespace, workspace=""
- )
- return storage
- # ---------------------------------------------------------------------------
- # Tests
- # ---------------------------------------------------------------------------
- @pytest.mark.asyncio
- async def test_upsert_buffers_without_embedding():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.upsert({"v1": {"content": "hello"}, "v2": {"content": "world"}})
- assert embed.call_count == 0
- assert set(s._pending_vector_docs.keys()) == {"v1", "v2"}
- assert s._pending_vector_docs["v1"].vector is None
- s._data.bulk_write.assert_not_called()
- @pytest.mark.asyncio
- async def test_index_done_callback_triggers_flush():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.upsert({"v1": {"content": "hello"}, "v2": {"content": "world"}})
- await s.index_done_callback()
- assert embed.call_count == 1
- s._data.bulk_write.assert_called_once()
- ops, kwargs = (
- s._data.bulk_write.call_args.args[0],
- s._data.bulk_write.call_args.kwargs,
- )
- assert kwargs.get("ordered") is False
- assert all(isinstance(op, UpdateOne) for op in ops)
- assert len(ops) == 2
- assert s._pending_vector_docs == {}
- @pytest.mark.asyncio
- async def test_repeated_upsert_same_id_embeds_once():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.upsert({"v1": {"content": "first"}})
- await s.upsert({"v1": {"content": "second"}})
- await s.upsert({"v1": {"content": "third"}})
- await s.index_done_callback()
- assert embed.call_count == 1
- assert embed.texts == ["third"]
- s._data.bulk_write.assert_called_once()
- @pytest.mark.asyncio
- async def test_deferred_embeddings_respect_batch_size():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- s._max_batch_size = 2
- await s.upsert({f"v{i}": {"content": f"doc {i}"} for i in range(5)})
- await s.index_done_callback()
- assert embed.call_count == 3
- assert [len(b) for b in embed.batches] == [2, 2, 1]
- @pytest.mark.asyncio
- async def test_get_vectors_by_ids_lazy_embed_then_reuse_in_flush():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.upsert({"v1": {"content": "hello"}})
- vectors = await s.get_vectors_by_ids(["v1"])
- assert "v1" in vectors
- assert embed.call_count == 1
- assert s._pending_vector_docs["v1"].vector is not None
- await s.index_done_callback()
- assert embed.call_count == 1
- s._data.bulk_write.assert_called_once()
- @pytest.mark.asyncio
- async def test_flush_failure_keeps_buffer_no_double_embed_on_retry():
- embed = CountingEmbeddingFunc(fail_times=1)
- s = _make_storage(embed)
- await s.upsert({"v1": {"content": "hello"}})
- with pytest.raises(RuntimeError, match="embedding failed"):
- await s.index_done_callback()
- assert "v1" in s._pending_vector_docs
- assert s._pending_vector_docs["v1"].vector is None
- s._data.bulk_write.assert_not_called()
- await s.index_done_callback()
- assert embed.call_count == 2
- s._data.bulk_write.assert_called_once()
- assert s._pending_vector_docs == {}
- @pytest.mark.asyncio
- async def test_server_write_failure_keeps_buffer():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- s._data.bulk_write.side_effect = RuntimeError("mongo down")
- await s.upsert({"v1": {"content": "hello"}})
- with pytest.raises(RuntimeError, match="mongo down"):
- await s.index_done_callback()
- assert "v1" in s._pending_vector_docs
- assert s._pending_vector_docs["v1"].vector is not None
- s._data.bulk_write.side_effect = None
- await s.index_done_callback()
- assert embed.call_count == 1
- assert s._pending_vector_docs == {}
- @pytest.mark.asyncio
- async def test_finalize_raises_when_buffer_unflushed_and_still_releases_client():
- """finalize() must release the Mongo client even when the flush fails."""
- from lightrag.kg.mongo_impl import ClientManager
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- s._data.bulk_write.side_effect = RuntimeError("mongo down")
- await s.upsert({"v1": {"content": "hello"}})
- with patch.object(ClientManager, "release_client", new=AsyncMock()) as rel:
- with pytest.raises(RuntimeError, match="finalize.*flush raised") as exc_info:
- await s.finalize()
- rel.assert_awaited_once()
- # Operator-diagnostic counts must appear in the message so the buffered
- # data loss is auditable from the log alone (1 upsert pre-loaded, 0 deletes).
- msg = str(exc_info.value)
- assert "1 pending upserts" in msg
- assert "0 pending deletes" in msg
- # Client references cleared so a second finalize doesn't release twice.
- assert s.db is None
- @pytest.mark.asyncio
- async def test_delete_then_upsert_same_id_keeps_upsert():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.delete(["v1"])
- assert "v1" in s._pending_vector_deletes
- await s.upsert({"v1": {"content": "hello"}})
- assert "v1" in s._pending_vector_docs
- assert "v1" not in s._pending_vector_deletes
- await s.index_done_callback()
- ops = s._data.bulk_write.call_args.args[0]
- assert all(isinstance(op, UpdateOne) for op in ops)
- @pytest.mark.asyncio
- async def test_upsert_then_delete_same_id_keeps_delete():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.upsert({"v1": {"content": "hello"}})
- await s.delete(["v1"])
- assert "v1" not in s._pending_vector_docs
- assert "v1" in s._pending_vector_deletes
- await s.index_done_callback()
- ops = s._data.bulk_write.call_args.args[0]
- assert len(ops) == 1
- assert isinstance(ops[0], DeleteOne)
- @pytest.mark.asyncio
- async def test_bulk_write_uses_update_one_and_delete_one_mix():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.upsert({"u1": {"content": "u1"}, "u2": {"content": "u2"}})
- await s.delete(["d1", "d2"])
- await s.index_done_callback()
- ops = s._data.bulk_write.call_args.args[0]
- op_types = {type(op).__name__ for op in ops}
- assert op_types == {"UpdateOne", "DeleteOne"}
- assert sum(isinstance(op, UpdateOne) for op in ops) == 2
- assert sum(isinstance(op, DeleteOne) for op in ops) == 2
- @pytest.mark.asyncio
- async def test_delete_entity_relation_raises_on_server_failure():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- s._data.find = MagicMock(
- return_value=_AsyncCursor([{"_id": "rel1"}, {"_id": "rel2"}])
- )
- s._data.delete_many = AsyncMock(side_effect=RuntimeError("mongo delete failed"))
- with pytest.raises(RuntimeError, match="mongo delete failed"):
- await s.delete_entity_relation("X")
- @pytest.mark.asyncio
- async def test_delete_entity_relation_prunes_pending_buffer():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.upsert(
- {
- "rel-A-B": {"content": "A → B", "src_id": "A", "tgt_id": "B"},
- "rel-C-D": {"content": "C → D", "src_id": "C", "tgt_id": "D"},
- }
- )
- s._data.find = MagicMock(return_value=_AsyncCursor([]))
- await s.delete_entity_relation("A")
- assert "rel-A-B" not in s._pending_vector_docs
- assert "rel-C-D" in s._pending_vector_docs
- @pytest.mark.asyncio
- async def test_get_by_id_buffer_excludes_vector():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.upsert({"v1": {"content": "hello", "entity_name": "E1"}})
- doc = await s.get_by_id("v1")
- assert doc is not None
- assert doc["id"] == "v1"
- assert doc.get("entity_name") == "E1"
- assert "vector" not in doc
- s._data.find_one.assert_not_called()
- @pytest.mark.asyncio
- async def test_get_by_id_fallback_projects_out_vector():
- """Server-side find_one must request projection={'vector': 0}."""
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- s._data.find_one = AsyncMock(
- return_value={"_id": "v9", "entity_name": "X", "created_at": 0}
- )
- doc = await s.get_by_id("v9")
- assert doc is not None
- assert "vector" not in doc
- args, kwargs = s._data.find_one.call_args.args, s._data.find_one.call_args.kwargs
- # projection is positional arg #2 in Mongo's API.
- projection = args[1] if len(args) > 1 else kwargs.get("projection")
- assert projection == {"vector": 0}
- @pytest.mark.asyncio
- async def test_get_by_ids_fallback_projects_out_vector():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- s._data.find = MagicMock(
- return_value=_AsyncCursor(
- [{"_id": "a", "entity_name": "A"}, {"_id": "b", "entity_name": "B"}]
- )
- )
- docs = await s.get_by_ids(["a", "b"])
- assert len(docs) == 2
- assert all("vector" not in d for d in docs if d)
- args, kwargs = s._data.find.call_args.args, s._data.find.call_args.kwargs
- projection = args[1] if len(args) > 1 else kwargs.get("projection")
- assert projection == {"vector": 0}
- @pytest.mark.asyncio
- async def test_get_by_id_returns_none_for_pending_delete():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- await s.delete(["v1"])
- assert await s.get_by_id("v1") is None
- s._data.find_one.assert_not_called()
- @pytest.mark.asyncio
- async def test_env_workspace_override_shares_flush_lock(patch_namespace_lock):
- cache = patch_namespace_lock
- embed = CountingEmbeddingFunc()
- with patch.dict(os.environ, {"MONGODB_WORKSPACE": "shared_ws"}, clear=False):
- a = _make_storage(embed, workspace="caller_a")
- b = _make_storage(embed, workspace="caller_b")
- assert a.final_namespace == b.final_namespace == "shared_ws_entities"
- assert a._flush_lock is b._flush_lock
- assert len([k for k in cache if k[0] == "shared_ws_entities"]) == 1
- @pytest.mark.asyncio
- async def test_same_workspace_param_shares_flush_lock(patch_namespace_lock):
- """Plain ctor path (no MONGODB_WORKSPACE env): same workspace → shared lock.
- Companion to ``test_env_workspace_override_shares_flush_lock``; together
- they cover both ways two instances can land on the same final_namespace.
- The autouse fixture clears MONGODB_WORKSPACE so this exercises the plain
- constructor path, not the env-override path.
- """
- cache = patch_namespace_lock
- embed = CountingEmbeddingFunc()
- a = _make_storage(embed, workspace="caller")
- b = _make_storage(embed, workspace="caller")
- assert a.final_namespace == b.final_namespace == "caller_entities"
- assert a._flush_lock is b._flush_lock
- assert len([k for k in cache if k[0] == "caller_entities"]) == 1
- @pytest.mark.asyncio
- async def test_distinct_namespaces_get_independent_locks():
- embed = CountingEmbeddingFunc()
- a = _make_storage(embed, workspace="a")
- b = _make_storage(embed, workspace="b")
- assert a.final_namespace != b.final_namespace
- assert a._flush_lock is not b._flush_lock
- @pytest.mark.asyncio
- async def test_concurrent_upsert_and_flush_serialize_on_lock():
- """upsert() and index_done_callback() racing on the same namespace must
- not corrupt the buffer or split a single doc across two embed calls.
- Drives a slow embed (asyncio.Event-gated) so the flush genuinely holds
- the lock while a second coroutine attempts upsert mid-flight. Asserts:
- - the late upsert lands in the buffer (not silently dropped)
- - it is *not* embedded by the in-flight flush (still pending after)
- - a follow-up flush picks it up cleanly with exactly one extra embed
- Note: we replace ``s.embedding_func`` directly (not ``patch.object`` on
- ``embed.__call__``) because Python dispatches ``embed(...)`` through
- ``type(embed).__call__``, bypassing any instance-level patch.
- """
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- embed_gate = asyncio.Event()
- flush_entered = asyncio.Event()
- original_embed = s.embedding_func
- async def gated_embed(texts, **kwargs):
- flush_entered.set()
- await embed_gate.wait()
- return await original_embed(texts, **kwargs)
- await s.upsert({"v1": {"content": "first"}})
- s.embedding_func = gated_embed
- try:
- flush_task = asyncio.create_task(s.index_done_callback())
- await flush_entered.wait() # flush is now holding _flush_lock
- # This upsert must wait on _flush_lock; schedule it concurrently.
- late_upsert = asyncio.create_task(s.upsert({"v2": {"content": "late"}}))
- # Give the event loop a chance to actually start late_upsert and
- # confirm it is blocked on the lock (still no v2 in buffer).
- for _ in range(5):
- await asyncio.sleep(0)
- assert "v2" not in s._pending_vector_docs
- assert not late_upsert.done()
- embed_gate.set()
- await flush_task
- await late_upsert
- finally:
- s.embedding_func = original_embed
- # Flush embedded v1 only; v2 arrived after the docs_to_embed snapshot.
- assert embed.call_count == 1
- assert embed.batches == [["first"]]
- assert "v1" not in s._pending_vector_docs # flushed
- assert "v2" in s._pending_vector_docs # still buffered
- s._data.bulk_write.assert_called_once()
- # Next flush picks up the late upsert without re-embedding v1.
- await s.index_done_callback()
- assert embed.call_count == 2
- assert embed.batches[-1] == ["late"]
- assert s._pending_vector_docs == {}
- assert s._data.bulk_write.call_count == 2
- @pytest.mark.asyncio
- async def test_drop_clears_pending_buffers():
- embed = CountingEmbeddingFunc()
- s = _make_storage(embed)
- with patch.object(s, "create_vector_index_if_not_exists", new=AsyncMock()):
- await s.upsert({"v1": {"content": "hello"}})
- await s.delete(["v2"])
- assert s._pending_vector_docs and s._pending_vector_deletes
- result = await s.drop()
- assert result["status"] == "success"
- assert s._pending_vector_docs == {}
- assert s._pending_vector_deletes == set()
|