""" Unit tests for OpenSearch storage implementations. All tests use mocks — no running OpenSearch instance required. Run with: pytest tests/kg/opensearch_impl/test_opensearch_storage.py -v """ import asyncio import math import pytest from contextlib import asynccontextmanager from unittest.mock import AsyncMock, patch import numpy as np pytest.importorskip( "opensearchpy", reason="opensearchpy is required for OpenSearch storage tests", ) from opensearchpy.exceptions import NotFoundError, OpenSearchException # type: ignore from lightrag.kg.opensearch_impl import ( OpenSearchKVStorage, OpenSearchDocStatusStorage, OpenSearchGraphStorage, OpenSearchVectorDBStorage, ClientManager, _build_index_name, _resolve_workspace, _sanitize_index_name, _verify_mirrored_id_mapping, ) from lightrag.base import DocStatus, DocProcessingStatus pytestmark = pytest.mark.offline # --------------------------------------------------------------------------- # Mock the shared storage lock so tests don't need full LightRAG init # --------------------------------------------------------------------------- @asynccontextmanager async def _mock_lock(): yield def _mock_lock_factory(): return _mock_lock() def _missing_index_error() -> NotFoundError: return NotFoundError(404, "index_not_found_exception", "no such index") @pytest.fixture(autouse=True) def patch_data_init_lock(): """Patch get_data_init_lock globally so initialize() works without shared storage.""" with patch( "lightrag.kg.opensearch_impl.get_data_init_lock", side_effect=_mock_lock_factory ): yield @pytest.fixture(autouse=True) def patch_namespace_lock(): """Patch get_namespace_lock to return real asyncio.Lock instances. Returning a real Lock (not a no-op) preserves the in-process blocking semantics the storage relies on, so concurrent flush / read / write tests can observe actual serialization. Locks are cached per (namespace, workspace) tuple so multiple calls from the same storage pick up the same Lock instance. """ cache: dict[tuple[str, str | None], asyncio.Lock] = {} def factory(namespace, workspace=None, enable_logging=False): key = (namespace, workspace or "") lock = cache.get(key) if lock is None: lock = asyncio.Lock() cache[key] = lock return lock with patch("lightrag.kg.opensearch_impl.get_namespace_lock", side_effect=factory): yield @pytest.fixture(autouse=True) def patch_shard_doc_supported(): """Default tests to OpenSearch >= 3.3.0 so the __mirrored_id verification is a no-op. Tests covering the < 3.3.0 fallback should override this with their own patch. """ with patch("lightrag.kg.opensearch_impl._shard_doc_supported", True): yield # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- class MockEmbeddingFunc: """Mock embedding function that returns random vectors.""" def __init__(self, dim=128): self.embedding_dim = dim self.max_token_size = 512 self.model_name = "mock-embed" async def __call__(self, texts, **kwargs): return np.random.rand(len(texts), self.embedding_dim).astype(np.float32) class CountingEmbeddingFunc(MockEmbeddingFunc): """Embedding test double that records calls and can fail a fixed number of times.""" def __init__(self, dim=128, fail_times=0): super().__init__(dim=dim) self.fail_times = fail_times self.call_count = 0 self.batches: list[list[str]] = [] self.texts: list[str] = [] async def __call__(self, texts, **kwargs): self.call_count += 1 batch = list(texts) self.batches.append(batch) self.texts.extend(batch) if self.fail_times > 0: self.fail_times -= 1 raise RuntimeError("embedding failed") return await super().__call__(texts, **kwargs) @pytest.fixture def global_config(): """Standard global config fixture for all storage tests.""" return { "embedding_batch_num": 10, "max_graph_nodes": 1000, "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2}, } @pytest.fixture def embed_func(): """Mock embedding function fixture.""" return MockEmbeddingFunc() def _make_client(): """Create a fully-mocked AsyncOpenSearch client with spec validation.""" from opensearchpy import AsyncOpenSearch client = AsyncMock(spec=AsyncOpenSearch) # indices sub-client client.indices = AsyncMock() client.indices.exists = AsyncMock(return_value=False) client.indices.create = AsyncMock() client.indices.delete = AsyncMock() client.indices.refresh = AsyncMock() client.indices.get_mapping = AsyncMock(return_value={}) # transport for PPL client.transport = AsyncMock() client.transport.perform_request = AsyncMock( side_effect=Exception("PPL not available") ) # document operations client.exists = AsyncMock(return_value=False) client.index = AsyncMock() client.delete = AsyncMock() client.delete_by_query = AsyncMock() client.get = AsyncMock( return_value={ "_id": "doc1", "_source": {"content": "hello", "create_time": 0, "update_time": 0}, } ) client.mget = AsyncMock( return_value={ "docs": [ {"_id": "id1", "found": True, "_source": {"content": "c1"}}, {"_id": "id2", "found": True, "_source": {"content": "c2"}}, ] } ) client.count = AsyncMock(return_value={"count": 5}) client.search = AsyncMock( return_value={ "hits": {"hits": [], "total": {"value": 0}}, "aggregations": { "status_counts": {"buckets": []}, "src": {"buckets": []}, "tgt": {"buckets": []}, "source_degrees": {"buckets": []}, "target_degrees": {"buckets": []}, }, } ) # PIT operations client.create_pit = AsyncMock(return_value={"pit_id": "mock_pit_id_123"}) client.delete_pit = AsyncMock() return client @pytest.fixture def mock_client(): """Fully-mocked AsyncOpenSearch client fixture.""" return _make_client() # --------------------------------------------------------------------------- # Helper utilities # --------------------------------------------------------------------------- class TestHelpers: """Tests for module-level helper functions (_build_index_name, _resolve_workspace, _sanitize_index_name).""" def test_build_index_name_with_workspace(self): ws, ns, idx = _build_index_name("myws", "text_chunks") assert ws == "myws" assert ns == "myws_text_chunks" assert idx == _sanitize_index_name("myws_text_chunks") def test_build_index_name_no_workspace(self): ws, ns, idx = _build_index_name("", "chunks") assert ws == "" assert idx == _sanitize_index_name("chunks") def test_resolve_workspace_env_override(self): with patch.dict("os.environ", {"OPENSEARCH_WORKSPACE": "forced"}): assert _resolve_workspace("original", "ns") == "forced" def test_resolve_workspace_fallback(self): with patch.dict("os.environ", {}, clear=True): assert _resolve_workspace("original", "ns") == "original" def test_sanitize_index_name(self): assert _sanitize_index_name("Hello_World") == "hello_world" assert _sanitize_index_name("-bad") == "x-bad" assert _sanitize_index_name("a.b/c") == "a_b_c" # --------------------------------------------------------------------------- # ClientManager # --------------------------------------------------------------------------- class TestClientManager: """Tests for ClientManager singleton pattern and reference counting.""" @staticmethod def _stub_client(version: str = "3.3.0") -> AsyncMock: """Build an AsyncMock client with a concrete .info() payload. Without this stub, _detect_shard_doc_support's chained .get(...) calls on an AsyncMock would leak un-awaited coroutines. """ client = AsyncMock() client.info = AsyncMock(return_value={"version": {"number": version}}) return client @pytest.mark.asyncio async def test_singleton_and_refcount(self): ClientManager._instances = {"client": None, "ref_count": 0} with patch("lightrag.kg.opensearch_impl.AsyncOpenSearch") as mock_cls: mock_cls.return_value = self._stub_client() c1 = await ClientManager.get_client() c2 = await ClientManager.get_client() assert c1 is c2 assert ClientManager._instances["ref_count"] == 2 await ClientManager.release_client(c1) assert ClientManager._instances["ref_count"] == 1 await ClientManager.release_client(c2) assert ClientManager._instances["ref_count"] == 0 assert ClientManager._instances["client"] is None @pytest.mark.asyncio async def test_close_called_on_last_release(self): ClientManager._instances = {"client": None, "ref_count": 0} with patch("lightrag.kg.opensearch_impl.AsyncOpenSearch") as mock_cls: inner = self._stub_client() mock_cls.return_value = inner c = await ClientManager.get_client() await ClientManager.release_client(c) inner.close.assert_awaited_once() # --------------------------------------------------------------------------- # _verify_mirrored_id_mapping helper # --------------------------------------------------------------------------- class TestMirroredIdVerification: """Tests for the _verify_mirrored_id_mapping fail-fast helper.""" @pytest.mark.asyncio async def test_skipped_on_modern_cluster(self, mock_client): """On OpenSearch >= 3.3.0 the mapping check is short-circuited.""" # _shard_doc_supported is True via autouse fixture. await _verify_mirrored_id_mapping(mock_client, "any_index") mock_client.indices.get_mapping.assert_not_awaited() @pytest.mark.asyncio async def test_passes_when_mapping_present(self, mock_client): """On OpenSearch < 3.3.0 a mapping containing __mirrored_id is accepted.""" mock_client.indices.get_mapping = AsyncMock( return_value={ "my_index": { "mappings": {"properties": {"__mirrored_id": {"type": "keyword"}}} } } ) with patch("lightrag.kg.opensearch_impl._shard_doc_supported", False): await _verify_mirrored_id_mapping(mock_client, "my_index") @pytest.mark.asyncio async def test_fails_fast_when_mapping_missing(self, mock_client): """On OpenSearch < 3.3.0 a legacy index without __mirrored_id raises.""" mock_client.indices.get_mapping = AsyncMock( return_value={ "my_index": { "mappings": {"properties": {"other_field": {"type": "text"}}} } } ) with patch("lightrag.kg.opensearch_impl._shard_doc_supported", False): with pytest.raises(RuntimeError, match="__mirrored_id"): await _verify_mirrored_id_mapping(mock_client, "my_index") @pytest.mark.asyncio async def test_swallows_get_mapping_error(self, mock_client): """Mapping-fetch failures should not block initialization.""" mock_client.indices.get_mapping = AsyncMock( side_effect=OpenSearchException("transport error") ) with patch("lightrag.kg.opensearch_impl._shard_doc_supported", False): await _verify_mirrored_id_mapping(mock_client, "my_index") # --------------------------------------------------------------------------- # KV Storage # --------------------------------------------------------------------------- class TestKVStorage: """Tests for OpenSearchKVStorage CRUD operations, timestamps, refresh behavior.""" def _make(self, global_config, embed_func, workspace="test"): return OpenSearchKVStorage( namespace="text_chunks", global_config=global_config, embedding_func=embed_func, workspace=workspace, ) @pytest.mark.asyncio async def test_index_name(self, global_config, embed_func): s = self._make(global_config, embed_func, workspace="proj_a") assert s._index_name == "proj_a_text_chunks" @pytest.mark.asyncio async def test_initialize_creates_index( self, global_config, embed_func, mock_client ): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() mock_client.indices.exists.assert_awaited_once() mock_client.indices.create.assert_awaited_once() @pytest.mark.asyncio async def test_initialize_skips_existing_index( self, global_config, embed_func, mock_client ): mock_client.indices.exists = AsyncMock(return_value=True) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() mock_client.indices.create.assert_not_awaited() @pytest.mark.asyncio async def test_initialize_fails_on_legacy_index_without_mirrored_id( self, global_config, embed_func, mock_client ): """On OpenSearch < 3.3.0, an existing index lacking __mirrored_id must fail-fast.""" mock_client.indices.exists = AsyncMock(return_value=True) mock_client.indices.get_mapping = AsyncMock( return_value={ "test_text_chunks": { "mappings": {"properties": {"content": {"type": "text"}}} } } ) with ( patch.object(ClientManager, "get_client", return_value=mock_client), patch("lightrag.kg.opensearch_impl._shard_doc_supported", False), ): s = self._make(global_config, embed_func) with pytest.raises(RuntimeError, match="__mirrored_id"): await s.initialize() mock_client.indices.create.assert_not_awaited() @pytest.mark.asyncio async def test_get_by_id(self, global_config, embed_func, mock_client): mock_client.mget = AsyncMock( return_value={ "docs": [ { "_id": "doc1", "found": True, "_source": { "content": "hello", "create_time": 0, "update_time": 0, }, } ] } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() doc = await s.get_by_id("doc1") assert doc is not None assert doc["content"] == "hello" assert doc["_id"] == "doc1" mock_client.mget.assert_awaited_once_with( index=s._index_name, body={"ids": ["doc1"]} ) @pytest.mark.asyncio async def test_get_by_id_not_found(self, global_config, embed_func, mock_client): mock_client.mget = AsyncMock( return_value={"docs": [{"_id": "missing", "found": False}]} ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert await s.get_by_id("missing") is None mock_client.get.assert_not_awaited() @pytest.mark.asyncio async def test_get_by_ids_preserves_order( self, global_config, embed_func, mock_client ): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() docs = await s.get_by_ids(["id1", "id2"]) assert docs[0]["content"] == "c1" assert docs[1]["content"] == "c2" @pytest.mark.asyncio async def test_filter_keys(self, global_config, embed_func, mock_client): mock_client.mget = AsyncMock( return_value={ "docs": [ {"_id": "a", "found": True}, {"_id": "b", "found": False}, ] } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() result = await s.filter_keys({"a", "b"}) assert result == {"b"} @pytest.mark.asyncio async def test_upsert_no_per_operation_refresh( self, global_config, embed_func, mock_client ): """The flush (during index_done_callback) must not request per-op refresh.""" with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"k1": {"content": "v1"}}) # upsert buffers; bulk fires on flush. mock_bulk.assert_not_awaited() await s.index_done_callback() _, kwargs = mock_bulk.call_args assert "refresh" not in kwargs @pytest.mark.asyncio async def test_upsert_sets_timestamps(self, global_config, embed_func, mock_client): """Buffered docs carry create_time / update_time set eagerly during upsert.""" with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"k1": {"content": "v1"}}) # Timestamps are visible in the pending buffer immediately. assert "create_time" in s._pending_upserts["k1"] assert "update_time" in s._pending_upserts["k1"] await s.index_done_callback() actions = mock_bulk.call_args[0][1] src = actions[0]["_source"] assert "create_time" in src assert "update_time" in src @pytest.mark.asyncio async def test_is_empty(self, global_config, embed_func, mock_client): mock_client.count = AsyncMock(return_value={"count": 0}) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert await s.is_empty() is True @pytest.mark.asyncio async def test_delete(self, global_config, embed_func, mock_client): """delete() buffers tombstones; the bulk delete fires on flush.""" with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (2, []) s = self._make(global_config, embed_func) await s.initialize() await s.delete(["a", "b"]) mock_bulk.assert_not_awaited() assert s._pending_kv_deletes == {"a", "b"} await s.index_done_callback() actions = mock_bulk.call_args[0][1] assert len(actions) == 2 assert all(a["_op_type"] == "delete" for a in actions) @pytest.mark.asyncio async def test_drop(self, global_config, embed_func, mock_client): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() result = await s.drop() assert result["status"] == "success" mock_client.indices.delete.assert_awaited_once() @pytest.mark.asyncio async def test_drop_error_marks_index_not_ready_and_next_upsert_recreates_index( self, global_config, embed_func, mock_client ): mock_client.indices.delete = AsyncMock( side_effect=OpenSearchException("drop failed") ) with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() with patch.object( s, "_create_index_if_not_exists", new_callable=AsyncMock ) as mock_create: result = await s.drop() assert result["status"] == "error" assert s._index_ready is False await s.upsert({"k1": {"content": "v1"}}) mock_create.assert_awaited_once() @pytest.mark.asyncio async def test_upsert_after_drop_recreates_index( self, global_config, embed_func, mock_client ): with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() with patch.object( s, "_create_index_if_not_exists", new_callable=AsyncMock ) as mock_create: await s.drop() await s.upsert({"k1": {"content": "v1"}}) mock_create.assert_awaited_once() @pytest.mark.asyncio async def test_reads_short_circuit_after_drop( self, global_config, embed_func, mock_client ): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.drop() assert await s.get_by_id("doc1") is None assert await s.get_by_ids(["doc1", "doc2"]) == [None, None] assert await s.is_empty() is True mock_client.mget.assert_not_awaited() mock_client.count.assert_not_awaited() @pytest.mark.asyncio async def test_read_missing_index_demotes_readiness( self, global_config, embed_func, mock_client ): mock_client.mget = AsyncMock(side_effect=_missing_index_error()) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert await s.get_by_id("doc1") is None assert await s.get_by_id("doc1") is None assert s._index_ready is False assert mock_client.mget.await_count == 1 @pytest.mark.asyncio async def test_iter_raw_docs_uses_pit_and_search_after( self, global_config, embed_func, mock_client ): mock_client.search = AsyncMock( side_effect=[ { "hits": { "hits": [ {"_id": "d1", "_source": {"content": "a"}, "sort": [1]}, {"_id": "d2", "_source": {"content": "b"}, "sort": [2]}, ] } }, { "hits": { "hits": [ {"_id": "d3", "_source": {"content": "c"}, "sort": [3]} ] } }, ] ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() batches = [batch async for batch in s._iter_raw_docs(batch_size=2)] assert [[doc["_id"] for doc in batch] for batch in batches] == [ ["d1", "d2"], ["d3"], ] assert ( "search_after" not in mock_client.search.await_args_list[0].kwargs["body"] ) assert mock_client.search.await_args_list[1].kwargs["body"][ "search_after" ] == [2] mock_client.create_pit.assert_awaited_once() mock_client.delete_pit.assert_awaited_once() @pytest.mark.asyncio async def test_iter_raw_docs_missing_index_demotes_readiness( self, global_config, embed_func, mock_client ): mock_client.search = AsyncMock(side_effect=_missing_index_error()) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() batches = [batch async for batch in s._iter_raw_docs(batch_size=2)] assert batches == [] assert s._index_ready is False mock_client.create_pit.assert_awaited_once() mock_client.delete_pit.assert_awaited_once() @pytest.mark.asyncio async def test_finalize(self, global_config, embed_func, mock_client): with patch.object(ClientManager, "get_client", return_value=mock_client): with patch.object( ClientManager, "release_client", new_callable=AsyncMock ) as mock_release: s = self._make(global_config, embed_func) await s.initialize() await s.finalize() mock_release.assert_awaited_once() assert s.client is None # --------------------------------------------------------------------------- # KV storage write batching (derived from issue #2785 / PR #2822) # --------------------------------------------------------------------------- class TestKVStorageBatching: """Tests for the buffered upsert/delete + flush behaviour.""" def _make(self, global_config, embed_func, workspace="test"): return OpenSearchKVStorage( namespace="text_chunks", global_config=global_config, embedding_func=embed_func, workspace=workspace, ) @pytest.mark.asyncio async def test_repeated_kv_upserts_flush_in_single_bulk_call( self, global_config, embed_func, mock_client ): """Many small upsert() calls collapse to one async_bulk on flush.""" with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (5, []) s = self._make(global_config, embed_func) await s.initialize() for i in range(5): await s.upsert({f"k{i}": {"content": f"doc {i}"}}) mock_bulk.assert_not_awaited() await s.index_done_callback() mock_bulk.assert_awaited_once() actions = mock_bulk.call_args[0][1] assert len(actions) == 5 assert {a["_id"] for a in actions} == {f"k{i}" for i in range(5)} @pytest.mark.asyncio async def test_kv_upsert_overwrites_pending_doc_for_same_id( self, global_config, embed_func, mock_client ): """Upserting the same id twice keeps only the latest payload.""" with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"k1": {"content": "first"}}) await s.upsert({"k1": {"content": "second"}}) await s.index_done_callback() actions = mock_bulk.call_args[0][1] assert len(actions) == 1 assert actions[0]["_source"]["content"] == "second" @pytest.mark.asyncio async def test_kv_delete_cancels_pending_upsert( self, global_config, embed_func, mock_client ): """A delete after a buffered upsert removes the upsert from the buffer. Without this, the flush would re-index the doc and silently resurrect a logically-deleted key. """ with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"k1": {"content": "doomed"}}) await s.delete(["k1"]) assert "k1" not in s._pending_upserts assert "k1" in s._pending_kv_deletes await s.index_done_callback() actions = mock_bulk.call_args[0][1] assert len(actions) == 1 assert actions[0]["_op_type"] == "delete" @pytest.mark.asyncio async def test_kv_upsert_cancels_pending_delete( self, global_config, embed_func, mock_client ): """An upsert after a buffered delete removes the tombstone.""" with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() await s.delete(["k1"]) await s.upsert({"k1": {"content": "resurrected"}}) assert "k1" not in s._pending_kv_deletes assert "k1" in s._pending_upserts await s.index_done_callback() actions = mock_bulk.call_args[0][1] assert len(actions) == 1 assert actions[0]["_op_type"] == "index" @pytest.mark.asyncio async def test_kv_delete_works_when_index_not_ready( self, global_config, embed_func, mock_client ): """delete() must invalidate pending upserts even if the index has been marked missing -- otherwise the next flush would resurrect the logically-deleted key. """ with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"k1": {"content": "x"}}) s._mark_index_missing() await s.delete(["k1"]) # Buffer invariants hold regardless of _index_ready. assert "k1" not in s._pending_upserts assert "k1" in s._pending_kv_deletes @pytest.mark.asyncio async def test_kv_get_by_id_reads_pending_buffer( self, global_config, embed_func, mock_client ): """Buffered upserts are visible to get_by_id without hitting OpenSearch.""" with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"k1": {"content": "buffered"}}) doc = await s.get_by_id("k1") assert doc is not None assert doc["_id"] == "k1" assert doc["content"] == "buffered" mock_client.mget.assert_not_awaited() @pytest.mark.asyncio async def test_kv_get_by_id_returns_none_for_pending_delete( self, global_config, embed_func, mock_client ): """A pending tombstone shadows any persisted doc, without mget RTT.""" mock_client.mget = AsyncMock() with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.delete(["k1"]) assert await s.get_by_id("k1") is None mock_client.mget.assert_not_awaited() @pytest.mark.asyncio async def test_kv_get_by_id_strips_mirrored_id_from_buffer_path( self, global_config, embed_func, mock_client ): """Buffered docs internally carry __mirrored_id (used for PIT sort); the returned dict must NOT expose it, matching the mget read path.""" with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"k1": {"content": "x"}}) # Sanity: the buffer entry itself carries __mirrored_id. assert s._pending_upserts["k1"]["__mirrored_id"] == "k1" doc = await s.get_by_id("k1") assert doc is not None assert "__mirrored_id" not in doc assert doc["_id"] == "k1" @pytest.mark.asyncio async def test_kv_get_by_ids_merges_buffer_and_mget( self, global_config, embed_func, mock_client ): """get_by_ids returns buffered docs and falls back to mget for the rest.""" mock_client.mget = AsyncMock( return_value={ "docs": [ { "_id": "k2", "found": True, "_source": {"content": "from_index"}, }, ] } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"k1": {"content": "buffered"}}) docs = await s.get_by_ids(["k1", "k2"]) assert docs[0]["content"] == "buffered" assert "__mirrored_id" not in docs[0] assert docs[1]["content"] == "from_index" mock_client.mget.assert_awaited_once_with( index=s._index_name, body={"ids": ["k2"]} ) @pytest.mark.asyncio async def test_kv_filter_keys_excludes_buffered_upserts( self, global_config, embed_func, mock_client ): """Buffered upserts shadow OpenSearch: filter_keys treats them as existing and never queries them via mget.""" mock_client.mget = AsyncMock( return_value={"docs": [{"_id": "k2", "found": False}]} ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"k1": {"content": "x"}}) missing = await s.filter_keys({"k1", "k2"}) assert missing == {"k2"} # Only the unbuffered id is queried server-side. ((_, kwargs),) = mock_client.mget.await_args_list[0:1] assert kwargs["body"] == {"ids": ["k2"]} @pytest.mark.asyncio async def test_kv_filter_keys_treats_buffered_deletes_as_missing( self, global_config, embed_func, mock_client ): """A persisted-but-pending-delete key must be reported as missing AND must NOT be looked up via mget (otherwise the still-persisted row would be misclassified as existing).""" mock_client.mget = AsyncMock( return_value={"docs": [{"_id": "k3", "found": True}]} ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.delete(["k1"]) # tombstone missing = await s.filter_keys({"k1", "k3"}) assert "k1" in missing # tombstoned key counts as missing assert "k3" not in missing # exists on server # The tombstone id was NOT sent to mget. mget_kwargs = mock_client.mget.await_args_list[0].kwargs assert mget_kwargs["body"] == {"ids": ["k3"]} @pytest.mark.asyncio async def test_kv_is_empty_returns_false_with_pending_upsert( self, global_config, embed_func, mock_client ): """is_empty short-circuits to False when the buffer has pending upserts -- avoiding the counterintuitive "I just upserted but is_empty returned True" outcome.""" with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"k1": {"content": "x"}}) assert await s.is_empty() is False mock_client.count.assert_not_awaited() @pytest.mark.asyncio async def test_kv_finalize_flushes_pending( self, global_config, embed_func, mock_client ): """finalize() flushes the buffer before releasing the client.""" with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"k1": {"content": "to flush"}}) await s.finalize() mock_bulk.assert_awaited_once() assert s.client is None @pytest.mark.asyncio async def test_kv_finalize_raises_when_retryable_buffer_remains( self, global_config, embed_func, mock_client ): """finalize() must surface a RuntimeError when retryable bulk failures left rows buffered, otherwise the upstream finalize_storages() call would log the storage as successfully finalized while writes are silently lost. The client is still released so we don't leak a connection on shutdown. """ with patch.object(ClientManager, "get_client", return_value=mock_client): with patch.object( ClientManager, "release_client", new_callable=AsyncMock ) as mock_release: with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock, ) as mock_bulk: # 503 is retryable; flush keeps it in the buffer. mock_bulk.return_value = ( 0, [{"index": {"_id": "k1", "status": 503, "error": "down"}}], ) s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"k1": {"content": "stuck"}}) with pytest.raises(RuntimeError, match="pending upserts"): await s.finalize() # Client released regardless of the failure. mock_release.assert_awaited_once() assert s.client is None @pytest.mark.asyncio async def test_kv_finalize_propagates_flush_exception( self, global_config, embed_func, mock_client ): """If async_bulk itself raises, finalize() still releases the client and wraps the original error in a RuntimeError that names the unflushed buffer counts. """ with patch.object(ClientManager, "get_client", return_value=mock_client): with patch.object( ClientManager, "release_client", new_callable=AsyncMock ) as mock_release: with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock, ) as mock_bulk: mock_bulk.side_effect = OpenSearchException("connection reset") s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"k1": {"content": "stuck"}}) with pytest.raises(RuntimeError) as exc_info: await s.finalize() # Wrapped: cause is the original OpenSearchException. assert isinstance(exc_info.value.__cause__, OpenSearchException) mock_release.assert_awaited_once() assert s.client is None @pytest.mark.asyncio async def test_kv_finalize_propagates_cancellation( self, global_config, embed_func, mock_client ): """asyncio.CancelledError raised during the final flush must propagate UN-wrapped so the shutdown sequence honours the cancellation signal. The client is still released (finally block) before the cancellation continues. """ with patch.object(ClientManager, "get_client", return_value=mock_client): with patch.object( ClientManager, "release_client", new_callable=AsyncMock ) as mock_release: with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock, ) as mock_bulk: mock_bulk.side_effect = asyncio.CancelledError() s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"k1": {"content": "stuck"}}) with pytest.raises(asyncio.CancelledError): await s.finalize() # finally block still released the client. mock_release.assert_awaited_once() assert s.client is None @pytest.mark.asyncio async def test_kv_drop_discards_buffers_and_serialises_with_flush( self, global_config, embed_func, mock_client ): """drop() drops both buffers and is serialised with any in-flight flush so indices.delete cannot land mid-bulk.""" flush_started = asyncio.Event() flush_can_finish = asyncio.Event() drop_delete_started = asyncio.Event() async def slow_bulk(client, actions, raise_on_error=False, **kwargs): flush_started.set() await flush_can_finish.wait() return (len(actions), []) async def watch_indices_delete(**kwargs): drop_delete_started.set() mock_client.indices.delete = AsyncMock(side_effect=watch_indices_delete) with patch.object(ClientManager, "get_client", return_value=mock_client): with patch("lightrag.kg.opensearch_impl.helpers.async_bulk", new=slow_bulk): s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"k1": {"content": "x"}}) await s.delete(["k2"]) flush_task = asyncio.create_task(s.index_done_callback()) await flush_started.wait() drop_task = asyncio.create_task(s.drop()) for _ in range(5): await asyncio.sleep(0) assert ( not drop_delete_started.is_set() ), "indices.delete should be blocked behind the flush lock" assert not drop_task.done() flush_can_finish.set() await flush_task await drop_task assert drop_delete_started.is_set() # Even though flush flushed k1/k2, drop() then cleared the # buffer state (no-op here because flush already drained # them, but the assertion confirms drop() does not crash # against the now-empty buffer). assert s._pending_upserts == {} assert s._pending_kv_deletes == set() @pytest.mark.asyncio async def test_kv_failed_flush_retains_retryable( self, global_config, embed_func, mock_client ): """Transient (5xx) per-doc failures stay buffered for the next flush.""" with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = ( 1, [{"index": {"_id": "k2", "status": 503, "error": "down"}}], ) s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"k1": {"content": "ok"}, "k2": {"content": "boom"}}) await s.index_done_callback() assert "k1" not in s._pending_upserts assert "k2" in s._pending_upserts @pytest.mark.asyncio async def test_kv_failed_flush_drops_non_retryable( self, global_config, embed_func, mock_client ): """Permanent (4xx, e.g. mapping error) failures are cleared from the buffer rather than retried forever.""" with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = ( 0, [ { "index": { "_id": "k1", "status": 400, "error": { "type": "mapper_parsing_exception", "reason": "bad", }, } }, {"index": {"_id": "k2", "status": 503, "error": "down"}}, ], ) s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"k1": {"content": "x"}, "k2": {"content": "y"}}) await s.index_done_callback() assert "k1" not in s._pending_upserts assert "k2" in s._pending_upserts @pytest.mark.asyncio async def test_kv_concurrent_upsert_during_flush_blocked( self, global_config, embed_func, mock_client ): """A concurrent upsert that lands while async_bulk is in flight is blocked by the namespace lock and lands in the buffer only after the flush completes.""" flush_started = asyncio.Event() flush_can_finish = asyncio.Event() async def slow_bulk(client, actions, raise_on_error=False, **kwargs): flush_started.set() await flush_can_finish.wait() return (len(actions), []) with patch.object(ClientManager, "get_client", return_value=mock_client): with patch("lightrag.kg.opensearch_impl.helpers.async_bulk", new=slow_bulk): s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"k1": {"content": "first"}}) flush_task = asyncio.create_task(s.index_done_callback()) await flush_started.wait() concurrent_task = asyncio.create_task( s.upsert({"k2": {"content": "concurrent"}}) ) for _ in range(5): await asyncio.sleep(0) assert ( not concurrent_task.done() ), "concurrent upsert should be blocked by the flush lock" assert "k2" not in s._pending_upserts flush_can_finish.set() await flush_task await concurrent_task # k1 flushed and cleared; k2 added after flush released. assert "k1" not in s._pending_upserts assert "k2" in s._pending_upserts # --------------------------------------------------------------------------- # DocStatus Storage # --------------------------------------------------------------------------- class TestDocStatusStorage: """Tests for OpenSearchDocStatusStorage including aggregations, pagination, and data normalization.""" def _make(self, global_config, embed_func, workspace="test"): return OpenSearchDocStatusStorage( namespace="doc_status", global_config=global_config, embedding_func=embed_func, workspace=workspace, ) @pytest.mark.asyncio async def test_index_name(self, global_config, embed_func): s = self._make(global_config, embed_func) assert s._index_name == "test_doc_status" @pytest.mark.asyncio async def test_initialize_creates_index( self, global_config, embed_func, mock_client ): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() mock_client.indices.create.assert_awaited_once() @pytest.mark.asyncio async def test_get_by_id(self, global_config, embed_func, mock_client): mock_client.mget = AsyncMock( return_value={ "docs": [ { "_id": "doc-abc", "found": True, "_source": {"status": "processed", "file_path": "/a.txt"}, } ] } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() doc = await s.get_by_id("doc-abc") assert doc["status"] == "processed" assert doc["_id"] == "doc-abc" mock_client.mget.assert_awaited_once_with( index=s._index_name, body={"ids": ["doc-abc"]} ) @pytest.mark.asyncio async def test_get_by_id_not_found(self, global_config, embed_func, mock_client): mock_client.mget = AsyncMock( return_value={"docs": [{"_id": "missing", "found": False}]} ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert await s.get_by_id("missing") is None mock_client.get.assert_not_awaited() @pytest.mark.asyncio async def test_upsert_sets_chunks_list_default( self, global_config, embed_func, mock_client ): with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"d1": {"status": "pending"}}) actions = mock_bulk.call_args[0][1] assert actions[0]["_source"]["chunks_list"] == [] @pytest.mark.asyncio async def test_get_status_counts(self, global_config, embed_func, mock_client): mock_client.search = AsyncMock( return_value={ "hits": {"hits": [], "total": {"value": 0}}, "aggregations": { "status_counts": { "buckets": [ {"key": "processed", "doc_count": 3}, {"key": "pending", "doc_count": 1}, ] } }, } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() counts = await s.get_status_counts() assert counts == {"processed": 3, "pending": 1} @pytest.mark.asyncio async def test_get_all_status_counts_includes_all( self, global_config, embed_func, mock_client ): mock_client.search = AsyncMock( return_value={ "hits": {"hits": [], "total": {"value": 0}}, "aggregations": { "status_counts": { "buckets": [ {"key": "processed", "doc_count": 5}, {"key": "failed", "doc_count": 2}, ] } }, } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() counts = await s.get_all_status_counts() assert counts["all"] == 7 assert counts["processed"] == 5 @pytest.mark.asyncio async def test_get_docs_by_status(self, global_config, embed_func, mock_client): mock_client.search = AsyncMock( return_value={ "hits": { "hits": [ { "_id": "d1", "_source": { "status": "processed", "file_path": "/a.txt", "content_summary": "s", "content_length": 10, "chunks_count": 1, "created_at": 100, "updated_at": 200, }, "sort": ["d1"], }, ], "total": {"value": 1}, }, } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() result = await s.get_docs_by_status(DocStatus.PROCESSED) assert "d1" in result assert isinstance(result["d1"], DocProcessingStatus) @pytest.mark.asyncio async def test_get_docs_paginated(self, global_config, embed_func, mock_client): """Page 1 returns results directly without search_after.""" mock_client.count = AsyncMock(return_value={"count": 50}) mock_client.search = AsyncMock( return_value={ "hits": { "hits": [ { "_id": "d1", "_source": { "status": "processed", "file_path": "/a.txt", "content_summary": "s", "content_length": 10, "chunks_count": 1, "created_at": 100, "updated_at": 200, }, "sort": [200, "d1"], }, ], "total": {"value": 50}, }, } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() docs, total = await s.get_docs_paginated(page=1, page_size=10) assert total == 50 assert len(docs) == 1 assert docs[0][0] == "d1" # Page 1: no search_after needed, single search call assert mock_client.search.await_count == 1 body = mock_client.search.call_args.kwargs.get( "body" ) or mock_client.search.call_args[1].get("body", {}) assert "search_after" not in body @pytest.mark.asyncio async def test_get_docs_paginated_page2_uses_search_after( self, global_config, embed_func, mock_client ): """Page 2 skips page 1 results via search_after.""" mock_client.count = AsyncMock(return_value={"count": 50}) call_count = {"n": 0} async def search_side_effect(*args, **kwargs): call_count["n"] += 1 body = kwargs.get("body", {}) if "search_after" not in body: # First call: skip batch return { "hits": { "hits": [ { "_id": f"skip{i}", "_source": { "status": "processed", "file_path": f"/{i}.txt", "content_summary": "s", "content_length": 1, "chunks_count": 1, "created_at": 100, "updated_at": 100 + i, }, "sort": [100 + i, f"skip{i}"], } for i in range(10) ], "total": {"value": 50}, } } else: # Second call: actual page return { "hits": { "hits": [ { "_id": "page2_doc", "_source": { "status": "pending", "file_path": "/p2.txt", "content_summary": "s", "content_length": 1, "chunks_count": 1, "created_at": 200, "updated_at": 300, }, "sort": [300, "page2_doc"], } ], "total": {"value": 50}, } } mock_client.search = AsyncMock(side_effect=search_side_effect) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() docs, total = await s.get_docs_paginated(page=2, page_size=10) assert total == 50 assert len(docs) == 1 assert docs[0][0] == "page2_doc" # 2 search calls: 1 skip + 1 fetch assert mock_client.search.await_count == 2 @pytest.mark.asyncio async def test_get_docs_paginated_empty_index( self, global_config, embed_func, mock_client ): """Empty index returns empty list with total 0.""" mock_client.count = AsyncMock(return_value={"count": 0}) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() docs, total = await s.get_docs_paginated(page=1, page_size=10) assert total == 0 assert docs == [] mock_client.search.assert_not_awaited() @pytest.mark.asyncio async def test_get_docs_paginated_page_beyond_total( self, global_config, embed_func, mock_client ): """Requesting a page beyond total docs returns empty list.""" mock_client.count = AsyncMock(return_value={"count": 5}) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() docs, total = await s.get_docs_paginated(page=100, page_size=10) assert total == 5 assert docs == [] @pytest.mark.asyncio async def test_get_docs_paginated_with_status_filter( self, global_config, embed_func, mock_client ): """Status filter is passed as term query.""" mock_client.count = AsyncMock(return_value={"count": 3}) mock_client.search = AsyncMock( return_value={ "hits": {"hits": [], "total": {"value": 3}}, } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() docs, total = await s.get_docs_paginated( status_filter=DocStatus.PROCESSED, page=1, page_size=10 ) assert total == 3 # Verify count query used the status filter count_body = mock_client.count.call_args.kwargs.get("body", {}) assert count_body["query"] == {"term": {"status": "processed"}} @pytest.mark.asyncio async def test_get_docs_paginated_with_status_filters( self, global_config, embed_func, mock_client ): """Multi-status filters are passed as terms query and override status_filter.""" mock_client.count = AsyncMock(return_value={"count": 2}) mock_client.search = AsyncMock( return_value={ "hits": {"hits": [], "total": {"value": 2}}, } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() docs, total = await s.get_docs_paginated( status_filter=DocStatus.PROCESSED, status_filters=[DocStatus.PARSING, DocStatus.ANALYZING], page=1, page_size=10, ) assert total == 2 assert docs == [] count_body = mock_client.count.call_args.kwargs.get("body", {}) assert count_body["query"] == { "terms": {"status": ["analyzing", "parsing"]} } @pytest.mark.asyncio async def test_get_doc_by_file_path(self, global_config, embed_func, mock_client): mock_client.search = AsyncMock( return_value={ "hits": { "hits": [ { "_id": "d1", "_source": { "file_path": "/test.txt", "status": "processed", }, }, ], "total": {"value": 1}, }, } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() doc = await s.get_doc_by_file_path("/test.txt") assert doc is not None assert doc["_id"] == "d1" @pytest.mark.asyncio async def test_get_doc_by_file_path_not_found( self, global_config, embed_func, mock_client ): mock_client.search = AsyncMock( return_value={ "hits": {"hits": [], "total": {"value": 0}}, } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert await s.get_doc_by_file_path("/nope.txt") is None @pytest.mark.asyncio async def test_get_doc_by_file_basename_returns_tuple_on_hit( self, global_config, embed_func, mock_client ): mock_client.search = AsyncMock( return_value={ "hits": { "hits": [ { "_id": "doc-1", "_source": { "file_path": "report.pdf", "status": "processed", }, }, ], "total": {"value": 1}, }, } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() result = await s.get_doc_by_file_basename("report.pdf") assert result is not None doc_id, doc = result assert doc_id == "doc-1" assert doc["file_path"] == "report.pdf" body = mock_client.search.call_args.kwargs.get( "body" ) or mock_client.search.call_args[1].get("body", {}) assert body["query"] == {"term": {"file_path": "report.pdf"}} @pytest.mark.asyncio async def test_get_doc_by_file_basename_empty_short_circuits( self, global_config, embed_func, mock_client ): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() mock_client.search.reset_mock() assert await s.get_doc_by_file_basename("") is None mock_client.search.assert_not_awaited() @pytest.mark.asyncio async def test_get_doc_by_file_basename_unknown_source_sentinel( self, global_config, embed_func, mock_client ): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() mock_client.search.reset_mock() assert await s.get_doc_by_file_basename("unknown_source") is None mock_client.search.assert_not_awaited() @pytest.mark.asyncio async def test_get_doc_by_file_basename_miss_returns_none( self, global_config, embed_func, mock_client ): mock_client.search = AsyncMock( return_value={"hits": {"hits": [], "total": {"value": 0}}} ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert await s.get_doc_by_file_basename("missing.pdf") is None @pytest.mark.asyncio async def test_get_doc_by_content_hash_returns_tuple_on_hit( self, global_config, embed_func, mock_client ): mock_client.search = AsyncMock( return_value={ "hits": { "hits": [ { "_id": "doc-1", "_source": { "file_path": "report.pdf", "content_hash": "abc123", "status": "processed", }, }, ], "total": {"value": 1}, }, } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() result = await s.get_doc_by_content_hash("abc123") assert result is not None doc_id, doc = result assert doc_id == "doc-1" assert doc["content_hash"] == "abc123" body = mock_client.search.call_args.kwargs.get( "body" ) or mock_client.search.call_args[1].get("body", {}) assert body["query"] == {"term": {"content_hash": "abc123"}} @pytest.mark.asyncio async def test_get_doc_by_content_hash_empty_short_circuits( self, global_config, embed_func, mock_client ): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() mock_client.search.reset_mock() assert await s.get_doc_by_content_hash("") is None mock_client.search.assert_not_awaited() @pytest.mark.asyncio async def test_get_doc_by_content_hash_miss_returns_none( self, global_config, embed_func, mock_client ): mock_client.search = AsyncMock( return_value={"hits": {"hits": [], "total": {"value": 0}}} ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert await s.get_doc_by_content_hash("zzz999") is None @pytest.mark.asyncio async def test_ensure_content_hash_mapping_added_when_missing( self, global_config, embed_func, mock_client ): """Pre-existing indices without content_hash mapping should get one added.""" mock_client.indices.exists = AsyncMock(return_value=True) mock_client.indices.get_mapping = AsyncMock( return_value={ "test_doc_status": { "mappings": { "properties": { "__mirrored_id": {"type": "keyword"}, "status": {"type": "keyword"}, "file_path": {"type": "keyword"}, } } } } ) mock_client.indices.put_mapping = AsyncMock() with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() mock_client.indices.put_mapping.assert_awaited_once() kwargs = mock_client.indices.put_mapping.call_args.kwargs assert kwargs["body"] == { "properties": {"content_hash": {"type": "keyword"}} } @pytest.mark.asyncio async def test_ensure_content_hash_mapping_skipped_when_present( self, global_config, embed_func, mock_client ): """Indices that already have content_hash mapping should not be touched.""" mock_client.indices.exists = AsyncMock(return_value=True) mock_client.indices.get_mapping = AsyncMock( return_value={ "test_doc_status": { "mappings": { "properties": { "__mirrored_id": {"type": "keyword"}, "content_hash": {"type": "keyword"}, } } } } ) mock_client.indices.put_mapping = AsyncMock() with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() mock_client.indices.put_mapping.assert_not_awaited() @pytest.mark.asyncio async def test_prepare_doc_status_data(self, global_config, embed_func): s = self._make(global_config, embed_func) raw = {"_id": "x", "status": "processed", "error": "oops"} data = s._prepare_doc_status_data(raw) assert "_id" not in data assert data["error_msg"] == "oops" assert "error" not in data assert data["file_path"] == "no-file-path" assert data["metadata"] == {} @pytest.mark.asyncio async def test_drop_error_marks_index_not_ready_and_next_upsert_recreates_index( self, global_config, embed_func, mock_client ): mock_client.indices.delete = AsyncMock( side_effect=OpenSearchException("drop failed") ) with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() with patch.object( s, "_create_index_if_not_exists", new_callable=AsyncMock ) as mock_create: result = await s.drop() assert result["status"] == "error" assert s._index_ready is False await s.upsert({"d1": {"status": "pending"}}) mock_create.assert_awaited_once() @pytest.mark.asyncio async def test_upsert_after_drop_recreates_index( self, global_config, embed_func, mock_client ): with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() with patch.object( s, "_create_index_if_not_exists", new_callable=AsyncMock ) as mock_create: await s.drop() await s.upsert({"d1": {"status": "pending"}}) mock_create.assert_awaited_once() @pytest.mark.asyncio async def test_reads_short_circuit_after_drop( self, global_config, embed_func, mock_client ): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.drop() assert await s.get_all_status_counts() == {} assert await s.get_docs_paginated(page=1, page_size=10) == ([], 0) assert await s.get_doc_by_file_path("/a.txt") is None assert await s.get_docs_by_status(DocStatus.PROCESSED) == {} mock_client.count.assert_not_awaited() mock_client.search.assert_not_awaited() mock_client.create_pit.assert_not_awaited() @pytest.mark.asyncio async def test_read_missing_index_demotes_readiness( self, global_config, embed_func, mock_client ): mock_client.search = AsyncMock(side_effect=_missing_index_error()) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert await s.get_all_status_counts() == {} assert await s.get_all_status_counts() == {} assert s._index_ready is False assert mock_client.search.await_count == 1 # --------------------------------------------------------------------------- # Graph Storage # --------------------------------------------------------------------------- class TestGraphStorage: """Tests for OpenSearchGraphStorage node/edge CRUD, batch ops, BFS, and label queries.""" def _make(self, global_config, embed_func, workspace="test"): return OpenSearchGraphStorage( namespace="chunk_entity_relation", global_config=global_config, embedding_func=embed_func, workspace=workspace, ) @pytest.mark.asyncio async def test_index_names(self, global_config, embed_func): s = self._make(global_config, embed_func) assert s._nodes_index == "test_chunk_entity_relation-nodes" assert s._edges_index == "test_chunk_entity_relation-edges" @pytest.mark.asyncio async def test_initialize_creates_both_indices( self, global_config, embed_func, mock_client ): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert mock_client.indices.create.await_count == 2 @pytest.mark.asyncio async def test_has_node_true(self, global_config, embed_func, mock_client): mock_client.exists = AsyncMock(return_value=True) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert await s.has_node("Alice") is True @pytest.mark.asyncio async def test_has_node_false(self, global_config, embed_func, mock_client): mock_client.exists = AsyncMock(return_value=False) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert await s.has_node("Nobody") is False @pytest.mark.asyncio async def test_has_edge(self, global_config, embed_func, mock_client): mock_client.search = AsyncMock( return_value={ "hits": {"hits": [], "total": {"value": 1}}, "aggregations": { "status_counts": {"buckets": []}, "src": {"buckets": []}, "tgt": {"buckets": []}, "source_degrees": {"buckets": []}, "target_degrees": {"buckets": []}, }, } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert await s.has_edge("A", "B") is True @pytest.mark.asyncio async def test_node_degree(self, global_config, embed_func, mock_client): mock_client.count = AsyncMock(return_value={"count": 3}) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert await s.node_degree("A") == 3 @pytest.mark.asyncio async def test_get_node(self, global_config, embed_func, mock_client): mock_client.mget = AsyncMock( return_value={ "docs": [ { "_id": "Alice", "found": True, "_source": { "entity_type": "person", "description": "A researcher", }, } ] } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() node = await s.get_node("Alice") assert node["entity_type"] == "person" assert node["_id"] == "Alice" mock_client.mget.assert_awaited_once_with( index=s._nodes_index, body={"ids": ["Alice"]} ) @pytest.mark.asyncio async def test_get_node_not_found(self, global_config, embed_func, mock_client): mock_client.mget = AsyncMock( return_value={"docs": [{"_id": "Nobody", "found": False}]} ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert await s.get_node("Nobody") is None mock_client.get.assert_not_awaited() @pytest.mark.asyncio async def test_get_edge(self, global_config, embed_func, mock_client): # get_edge now uses mget (translog real-time) instead of search. mock_client.mget = AsyncMock( return_value={ "docs": [ { "_id": "e1", "found": True, "_source": { "source_node_id": "A", "target_node_id": "B", "weight": 1.0, }, }, { "_id": "e2", "found": False, }, ] } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() edge = await s.get_edge("A", "B") assert edge is not None assert edge["weight"] == 1.0 @pytest.mark.asyncio async def test_get_node_edges(self, global_config, embed_func, mock_client): mock_client.search = AsyncMock( return_value={ "hits": { "hits": [ { "_id": "e1", "_source": {"source_node_id": "A", "target_node_id": "B"}, "sort": [1], }, { "_id": "e2", "_source": {"source_node_id": "C", "target_node_id": "A"}, "sort": [2], }, ], "total": {"value": 2}, }, } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() edges = await s.get_node_edges("A") assert len(edges) == 2 assert ("A", "B") in edges @pytest.mark.asyncio async def test_get_nodes_batch(self, global_config, embed_func, mock_client): mock_client.mget = AsyncMock( return_value={ "docs": [ {"_id": "A", "found": True, "_source": {"entity_type": "person"}}, {"_id": "B", "found": False}, ] } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() result = await s.get_nodes_batch(["A", "B"]) assert "A" in result assert "B" not in result @pytest.mark.asyncio async def test_node_degrees_batch(self, global_config, embed_func, mock_client): mock_client.search = AsyncMock( return_value={ "hits": {"hits": [], "total": {"value": 0}}, "aggregations": { "source_degrees": {"buckets": [{"key": "A", "doc_count": 2}]}, "target_degrees": { "buckets": [ {"key": "A", "doc_count": 1}, {"key": "B", "doc_count": 3}, ] }, "status_counts": {"buckets": []}, "src": {"buckets": []}, "tgt": {"buckets": []}, }, } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() degrees = await s.node_degrees_batch(["A", "B"]) assert degrees["A"] == 3 # 2 + 1 assert degrees["B"] == 3 @pytest.mark.asyncio async def test_upsert_node(self, global_config, embed_func, mock_client): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.upsert_node( "Alice", {"entity_type": "person", "source_id": "c1c2"} ) mock_client.index.assert_awaited() call_kwargs = mock_client.index.call_args assert call_kwargs.kwargs["id"] == "Alice" body = call_kwargs.kwargs["body"] assert body["source_ids"] == ["c1", "c2"] assert body["entity_id"] == "Alice" @pytest.mark.asyncio async def test_upsert_edge(self, global_config, embed_func, mock_client): mock_client.exists = AsyncMock(return_value=False) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.upsert_edge("A", "B", {"weight": "1.0", "description": "knows"}) # Should call index twice: once for ensuring source node, once for edge assert mock_client.index.await_count == 2 @pytest.mark.asyncio async def test_upsert_edges_batch_reuses_id_for_reciprocal_edges( self, global_config, embed_func, mock_client ): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() bulk_calls = [] async def capture_bulk(_client, actions, *args, **kwargs): bulk_calls.append(list(actions)) return (len(bulk_calls[-1]), []) mock_client.mget = AsyncMock( side_effect=[ {"docs": []}, {"docs": [{"_id": "edge-ba", "found": False}] * 2}, ] ) with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new=AsyncMock(side_effect=capture_bulk), ): await s.upsert_edges_batch( [ ("A", "B", {"weight": "1.0"}), ("B", "A", {"weight": "2.0"}), ] ) edge_actions = bulk_calls[-1] assert len(edge_actions) == 2 assert edge_actions[0]["_id"] == edge_actions[1]["_id"] @pytest.mark.asyncio async def test_upsert_after_drop_recreates_indices( self, global_config, embed_func, mock_client ): mock_client.exists = AsyncMock(return_value=False) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) with patch.object( s, "_create_indices_if_not_exist", new_callable=AsyncMock ) as mock_create: await s.initialize() mock_create.reset_mock() await s.drop() await s.upsert_edge("A", "B", {"weight": "1.0"}) mock_create.assert_awaited_once() assert mock_client.index.await_count == 2 @pytest.mark.asyncio async def test_reads_short_circuit_after_drop( self, global_config, embed_func, mock_client ): mock_client.transport = AsyncMock() mock_client.transport.perform_request = AsyncMock( side_effect=Exception("PPL not available") ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.drop() graph = await s.get_knowledge_graph("A", max_depth=2) assert await s.get_node("A") is None assert await s.get_all_labels() == [] assert await s.has_edge("A", "B") is False assert await s.node_degree("A") == 0 assert len(graph.nodes) == 0 assert len(graph.edges) == 0 mock_client.mget.assert_not_awaited() mock_client.search.assert_not_awaited() mock_client.create_pit.assert_not_awaited() mock_client.count.assert_not_awaited() @pytest.mark.asyncio async def test_read_missing_index_demotes_readiness( self, global_config, embed_func, mock_client ): mock_client.transport = AsyncMock() mock_client.transport.perform_request = AsyncMock( side_effect=Exception("PPL not available") ) mock_client.mget = AsyncMock(side_effect=_missing_index_error()) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert await s.get_node("A") is None assert await s.get_node("A") is None assert s._indices_ready is False assert mock_client.mget.await_count == 1 @pytest.mark.asyncio async def test_delete_node(self, global_config, embed_func, mock_client): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.delete_node("Alice") mock_client.delete_by_query.assert_awaited_once() mock_client.delete.assert_awaited_once() @pytest.mark.asyncio async def test_remove_nodes(self, global_config, embed_func, mock_client): with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (2, []) s = self._make(global_config, embed_func) await s.initialize() await s.remove_nodes(["A", "B"]) mock_client.delete_by_query.assert_awaited_once() mock_bulk.assert_awaited_once() @pytest.mark.asyncio async def test_remove_edges(self, global_config, embed_func, mock_client): # remove_edges now uses bulk delete with deterministic IDs instead of # delete_by_query, so mock bulk as AsyncMock. mock_client.bulk = AsyncMock(return_value={"errors": False, "items": []}) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.remove_edges([("A", "B"), ("C", "D")]) # 2 edges × 2 candidate directions = 4 delete actions in one bulk call mock_client.bulk.assert_awaited_once() call_body = mock_client.bulk.call_args.kwargs["body"] assert len(call_body) == 4 @pytest.mark.asyncio async def test_get_all_labels(self, global_config, embed_func, mock_client): mock_client.search = AsyncMock( return_value={ "hits": { "hits": [ {"_id": "Alice", "sort": ["Alice"]}, {"_id": "Bob", "sort": ["Bob"]}, ], "total": {"value": 2}, }, } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() labels = await s.get_all_labels() assert labels == ["Alice", "Bob"] @pytest.mark.asyncio async def test_get_popular_labels(self, global_config, embed_func, mock_client): mock_client.search = AsyncMock( return_value={ "hits": {"hits": [], "total": {"value": 0}}, "aggregations": { "src": { "buckets": [ {"key": "A", "doc_count": 5}, {"key": "B", "doc_count": 2}, ] }, "tgt": {"buckets": [{"key": "A", "doc_count": 3}]}, "status_counts": {"buckets": []}, }, } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() labels = await s.get_popular_labels(limit=10) assert labels[0] == "A" # degree 8 > B degree 2 @pytest.mark.asyncio async def test_get_knowledge_graph_all_backfills_isolated_nodes_when_truncated( self, global_config, embed_func, mock_client ): mock_client.count = AsyncMock(return_value={"count": 5}) mock_client.search = AsyncMock( side_effect=[ { "hits": {"hits": [], "total": {"value": 1}}, "aggregations": { "src": {"buckets": [{"key": "A", "doc_count": 1}]}, "tgt": {"buckets": [{"key": "B", "doc_count": 1}]}, "status_counts": {"buckets": []}, }, }, { "hits": { "hits": [ {"_id": "A", "sort": [1]}, {"_id": "B", "sort": [2]}, {"_id": "C", "sort": [3]}, {"_id": "D", "sort": [4]}, {"_id": "E", "sort": [5]}, ], "total": {"value": 5}, } }, { "hits": { "hits": [ { "_id": "edge-ab", "_source": { "source_node_id": "A", "target_node_id": "B", "relationship": "knows", }, } ], "total": {"value": 1}, } }, ] ) mock_client.mget = AsyncMock( return_value={ "docs": [ {"_id": "A", "found": True, "_source": {"entity_type": "person"}}, {"_id": "B", "found": True, "_source": {"entity_type": "person"}}, {"_id": "C", "found": True, "_source": {"entity_type": "person"}}, {"_id": "D", "found": True, "_source": {"entity_type": "person"}}, ] } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() result = await s.get_knowledge_graph("*", max_nodes=4) assert result.is_truncated is True assert [node.id for node in result.nodes] == ["A", "B", "C", "D"] assert len(result.edges) == 1 assert result.edges[0].source == "A" assert result.edges[0].target == "B" assert mock_client.create_pit.await_count == 2 @pytest.mark.asyncio async def test_get_knowledge_graph_all_paginates_edges_between_selected_nodes( self, global_config, embed_func, mock_client ): mock_client.count = AsyncMock(return_value={"count": 2}) first_edge_page = [ { "_id": f"edge-{i}", "_source": { "source_node_id": "A", "target_node_id": "B", "relationship": "knows", }, "sort": [i], } for i in range(10000) ] mock_client.search = AsyncMock( side_effect=[ { "hits": { "hits": [ {"_id": "A"}, {"_id": "B"}, ], "total": {"value": 2}, } }, {"hits": {"hits": first_edge_page, "total": {"value": 10001}}}, { "hits": { "hits": [ { "_id": "edge-last", "_source": { "source_node_id": "B", "target_node_id": "A", "relationship": "knows", }, "sort": [10000], } ], "total": {"value": 10001}, } }, ] ) mock_client.mget = AsyncMock( return_value={ "docs": [ {"_id": "A", "found": True, "_source": {"entity_type": "person"}}, {"_id": "B", "found": True, "_source": {"entity_type": "person"}}, ] } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() result = await s.get_knowledge_graph("*", max_nodes=2) assert len(result.nodes) == 2 assert len(result.edges) == 2 assert {(edge.source, edge.target) for edge in result.edges} == { ("A", "B"), ("B", "A"), } assert mock_client.search.await_count == 3 @pytest.mark.asyncio async def test_search_labels_empty_query( self, global_config, embed_func, mock_client ): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert await s.search_labels("") == [] @pytest.mark.asyncio async def test_drop(self, global_config, embed_func, mock_client): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() result = await s.drop() assert result["status"] == "success" assert mock_client.indices.delete.await_count == 2 @pytest.mark.asyncio async def test_drop_partial_error_marks_indices_not_ready_and_next_upsert_recreates_indices( self, global_config, embed_func, mock_client ): mock_client.exists = AsyncMock(return_value=False) mock_client.indices.delete = AsyncMock( side_effect=[None, OpenSearchException("edges drop failed")] ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() with patch.object( s, "_create_indices_if_not_exist", new_callable=AsyncMock ) as mock_create: result = await s.drop() assert result["status"] == "error" assert "edges drop failed" in result["message"] assert s._indices_ready is False await s.upsert_edge("A", "B", {"weight": "1.0"}) mock_create.assert_awaited_once() @pytest.mark.asyncio async def test_drop_treats_missing_graph_indices_as_success( self, global_config, embed_func, mock_client ): mock_client.indices.delete = AsyncMock( side_effect=[_missing_index_error(), _missing_index_error()] ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() result = await s.drop() assert result["status"] == "success" assert s._indices_ready is False @pytest.mark.asyncio async def test_construct_graph_node(self, global_config, embed_func): s = self._make(global_config, embed_func) node = s._construct_graph_node( "Alice", { "entity_type": "person", "description": "A researcher", "_id": "Alice", "entity_id": "Alice", }, ) assert node.id == "Alice" assert "entity_type" in node.properties assert "_id" not in node.properties assert "entity_id" not in node.properties @pytest.mark.asyncio async def test_construct_graph_edge(self, global_config, embed_func): s = self._make(global_config, embed_func) edge = s._construct_graph_edge( "e1", { "source_node_id": "A", "target_node_id": "B", "relationship": "knows", "weight": 1.0, }, ) assert edge.source == "A" assert edge.target == "B" assert edge.type == "knows" assert "source_node_id" not in edge.properties @pytest.mark.asyncio async def test_bfs_subgraph_start_not_found( self, global_config, embed_func, mock_client ): mock_client.mget = AsyncMock( return_value={"docs": [{"_id": "NonExistent", "found": False}]} ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() result = await s.get_knowledge_graph("NonExistent", max_depth=2) assert len(result.nodes) == 0 assert len(result.edges) == 0 class TestGraphPPLDetection: """Tests for PPL graphlookup detection and server-side BFS.""" def _make(self, global_config, embed_func, workspace="test"): return OpenSearchGraphStorage( namespace="chunk_entity_relation", global_config=global_config, embedding_func=embed_func, workspace=workspace, ) @pytest.mark.asyncio async def test_ppl_detected_when_available( self, global_config, embed_func, mock_client ): """When PPL endpoint responds successfully, graphlookup should be detected.""" mock_client.transport = AsyncMock() mock_client.transport.perform_request = AsyncMock( return_value={"datarows": [], "schema": []} ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert s._ppl_graphlookup_available is True @pytest.mark.asyncio async def test_ppl_not_detected_when_endpoint_fails( self, global_config, embed_func, mock_client ): """When PPL endpoint fails, should fall back to client-side BFS.""" mock_client.transport = AsyncMock() mock_client.transport.perform_request = AsyncMock( side_effect=Exception("PPL not supported") ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert s._ppl_graphlookup_available is False @pytest.mark.asyncio async def test_env_override_true(self, global_config, embed_func, mock_client): with patch.dict("os.environ", {"OPENSEARCH_USE_PPL_GRAPHLOOKUP": "true"}): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert s._ppl_graphlookup_available is True # Should NOT have called transport.perform_request for detection mock_client.transport.perform_request.assert_not_awaited() @pytest.mark.asyncio async def test_env_override_false(self, global_config, embed_func, mock_client): mock_client.transport = AsyncMock() mock_client.transport.perform_request = AsyncMock( return_value={"datarows": [], "schema": []} ) with patch.dict("os.environ", {"OPENSEARCH_USE_PPL_GRAPHLOOKUP": "false"}): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert s._ppl_graphlookup_available is False @pytest.mark.asyncio async def test_ppl_bfs_calls_ppl_endpoint( self, global_config, embed_func, mock_client ): """When PPL is available, get_knowledge_graph should use PPL endpoint.""" mock_client.transport = AsyncMock() # PPL response: connected_edges contains dicts with source_node_id/target_node_id ppl_response = { "schema": [ {"name": "entity_id", "type": "string"}, {"name": "connected_edges", "type": "struct"}, ], "datarows": [ [ "A", [ # connected_edges array { "source_node_id": "A", "target_node_id": "B", "weight": 1.0, "_depth": 0, }, { "source_node_id": "B", "target_node_id": "C", "weight": 0.5, "_depth": 1, }, ], ] ], } mock_client.transport.perform_request = AsyncMock(return_value=ppl_response) # get_node for start node verification mock_client.get = AsyncMock( return_value={ "_id": "A", "_source": {"entity_type": "person", "description": "Node A"}, } ) # mget for batch node fetch (only B and C, A is already added) mock_client.mget = AsyncMock( return_value={ "docs": [ {"_id": "B", "found": True, "_source": {"entity_type": "person"}}, {"_id": "C", "found": True, "_source": {"entity_type": "person"}}, ] } ) # search for final edge fetch mock_client.search = AsyncMock( return_value={ "hits": { "hits": [ { "_id": "e1", "_source": { "source_node_id": "A", "target_node_id": "B", "relationship": "knows", }, }, { "_id": "e2", "_source": { "source_node_id": "B", "target_node_id": "C", "relationship": "knows", }, }, ], "total": {"value": 2}, }, "aggregations": { "status_counts": {"buckets": []}, "src": {"buckets": []}, "tgt": {"buckets": []}, }, } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert s._ppl_graphlookup_available is True result = await s.get_knowledge_graph("A", max_depth=2) assert len(result.nodes) == 3 assert len(result.edges) == 2 # Verify PPL was called (2 for detection + 1 for actual query) assert mock_client.transport.perform_request.await_count == 3 # Verify the PPL query uses nodes index as source actual_query = mock_client.transport.perform_request.call_args_list[2] ppl_body = actual_query.kwargs.get("body") or actual_query[1].get( "body", {} ) if isinstance(ppl_body, dict): assert s._nodes_index in ppl_body.get("query", "") @pytest.mark.asyncio async def test_ppl_bfs_falls_back_on_query_failure( self, global_config, embed_func, mock_client ): """If PPL query fails at runtime, should fall back to client-side BFS.""" call_count = {"n": 0} async def ppl_side_effect(*args, **kwargs): call_count["n"] += 1 if call_count["n"] <= 2: # Detection calls succeed return {"datarows": [], "schema": []} # Actual query fails raise Exception("PPL query timeout") mock_client.transport = AsyncMock() mock_client.transport.perform_request = AsyncMock(side_effect=ppl_side_effect) mock_client.mget = AsyncMock( return_value={"docs": [{"_id": "A", "found": False}]} ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert s._ppl_graphlookup_available is True # Should fall back to _bfs_subgraph, which returns empty (node not found) result = await s.get_knowledge_graph("A", max_depth=2) assert len(result.nodes) == 0 @pytest.mark.asyncio async def test_escape_ppl(self, global_config, embed_func): s = self._make(global_config, embed_func) assert s._escape_ppl("it's") == "it\\'s" assert s._escape_ppl("normal") == "normal" assert s._escape_ppl("back\\slash") == "back\\\\slash" assert s._escape_ppl("both\\and'quote") == "both\\\\and\\'quote" @pytest.mark.asyncio async def test_ppl_bfs_depth_zero_returns_start_only( self, global_config, embed_func, mock_client ): """max_depth=0 should return only the start node without PPL query.""" mock_client.transport = AsyncMock() mock_client.transport.perform_request = AsyncMock( return_value={"datarows": [], "schema": []} ) mock_client.get = AsyncMock( return_value={"_id": "A", "_source": {"entity_type": "person"}} ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert s._ppl_graphlookup_available is True result = await s.get_knowledge_graph("A", max_depth=0) assert len(result.nodes) == 1 assert result.nodes[0].id == "A" assert len(result.edges) == 0 # PPL query should NOT have been called for the actual traversal (only 2 detection calls) assert mock_client.transport.perform_request.await_count == 2 @pytest.mark.asyncio async def test_ppl_bfs_empty_connected_edges( self, global_config, embed_func, mock_client ): """PPL returns no connected edges — should return only start node.""" mock_client.transport = AsyncMock() ppl_response = { "schema": [ {"name": "entity_id", "type": "string"}, {"name": "connected_edges", "type": "struct"}, ], "datarows": [["A", []]], } mock_client.transport.perform_request = AsyncMock(return_value=ppl_response) mock_client.get = AsyncMock( return_value={"_id": "A", "_source": {"entity_type": "person"}} ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() result = await s.get_knowledge_graph("A", max_depth=2) assert len(result.nodes) == 1 assert result.nodes[0].id == "A" @pytest.mark.asyncio async def test_ppl_bfs_truncates_nodes_by_depth_then_weight( self, global_config, embed_func, mock_client ): mock_client.transport = AsyncMock() ppl_response = { "schema": [ {"name": "entity_id", "type": "string"}, {"name": "connected_edges", "type": "struct"}, ], "datarows": [ [ "A", [ { "source_node_id": "A", "target_node_id": "C", "weight": 1.0, "_depth": 1, }, { "source_node_id": "B", "target_node_id": "D", "weight": 10.0, "_depth": 1, }, { "source_node_id": "A", "target_node_id": "B", "weight": 1.0, "_depth": 0, }, ], ] ], } mock_client.transport.perform_request = AsyncMock(return_value=ppl_response) mock_client.mget = AsyncMock( side_effect=[ { "docs": [ { "_id": "A", "found": True, "_source": {"entity_type": "person"}, } ] }, { "docs": [ { "_id": "B", "found": True, "_source": {"entity_type": "person"}, }, { "_id": "D", "found": True, "_source": {"entity_type": "person"}, }, ] }, ] ) mock_client.search = AsyncMock( return_value={ "hits": { "hits": [ { "_id": "e1", "_source": { "source_node_id": "A", "target_node_id": "B", "relationship": "knows", }, "sort": [1], }, { "_id": "e2", "_source": { "source_node_id": "B", "target_node_id": "D", "relationship": "knows", }, "sort": [2], }, ], "total": {"value": 2}, } } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() result = await s.get_knowledge_graph("A", max_depth=2, max_nodes=3) assert [node.id for node in result.nodes] == ["A", "B", "D"] assert result.is_truncated is True assert {(edge.source, edge.target) for edge in result.edges} == { ("A", "B"), ("B", "D"), } @pytest.mark.asyncio async def test_upsert_node_adds_entity_id( self, global_config, embed_func, mock_client ): """upsert_node should always include entity_id field for PPL compatibility.""" with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.upsert_node("TestNode", {"description": "test"}) body = mock_client.index.call_args.kwargs["body"] assert body["entity_id"] == "TestNode" assert body["description"] == "test" @pytest.mark.asyncio async def test_node_degree_uses_count_api( self, global_config, embed_func, mock_client ): """node_degree should use the count API, not search.""" mock_client.count = AsyncMock(return_value={"count": 7}) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() degree = await s.node_degree("X") assert degree == 7 # Verify count was called on the edges index mock_client.count.assert_awaited() call_kwargs = mock_client.count.call_args assert s._edges_index in str(call_kwargs) # --------------------------------------------------------------------------- # Vector Storage # --------------------------------------------------------------------------- class TestVectorStorage: """Tests for OpenSearchVectorDBStorage k-NN index, embeddings, cosine conversion, and entity deletion.""" def _make(self, global_config, embed_func, workspace="test"): return OpenSearchVectorDBStorage( namespace="entities", global_config=global_config, embedding_func=embed_func, workspace=workspace, meta_fields={"content", "entity_name", "src_id", "tgt_id"}, ) @pytest.mark.asyncio async def test_index_name(self, global_config, embed_func): s = self._make(global_config, embed_func) assert s._index_name == "test_entities" @pytest.mark.asyncio async def test_cosine_threshold_required(self, embed_func): with pytest.raises(ValueError, match="cosine_better_than_threshold"): OpenSearchVectorDBStorage( namespace="v", global_config={ "embedding_batch_num": 10, "vector_db_storage_cls_kwargs": {}, }, embedding_func=embed_func, ) @pytest.mark.asyncio async def test_initialize_creates_knn_index( self, global_config, embed_func, mock_client ): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() mock_client.indices.create.assert_awaited_once() body = mock_client.indices.create.call_args.kwargs["body"] assert body["settings"]["index"]["knn"] is True assert body["mappings"]["properties"]["vector"]["dimension"] == 128 assert ( body["mappings"]["properties"]["vector"]["method"]["engine"] == "lucene" ) @pytest.mark.asyncio async def test_upsert_generates_embeddings( self, global_config, embed_func, mock_client ): """Embeddings are deferred until flush; upsert only buffers payloads.""" embed_func = CountingEmbeddingFunc() with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (2, []) s = self._make(global_config, embed_func) await s.initialize() await s.upsert( { "v1": {"content": "hello"}, "v2": {"content": "world"}, } ) # Upsert buffers; no bulk write yet. mock_bulk.assert_not_awaited() assert embed_func.call_count == 0 assert set(s._pending_vector_docs.keys()) == {"v1", "v2"} assert s._pending_vector_docs["v1"].vector is None # Flush embeds and triggers a single bulk call with both docs. await s.index_done_callback() assert embed_func.call_count == 1 mock_bulk.assert_awaited_once() actions = mock_bulk.call_args[0][1] assert len(actions) == 2 assert all(a["_op_type"] == "index" for a in actions) assert all("vector" in a["_source"] for a in actions) @pytest.mark.asyncio async def test_query_cosine_score_conversion( self, global_config, embed_func, mock_client ): """Test that scores are used directly and threshold filtering works.""" mock_client.search = AsyncMock( return_value={ "hits": { "hits": [ { "_id": "v1", "_score": 0.85, "_source": {"content": "match", "entity_name": "E1"}, }, ], "total": {"value": 1}, }, "aggregations": { "status_counts": {"buckets": []}, "src": {"buckets": []}, "tgt": {"buckets": []}, }, } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() results = await s.query("test", top_k=5) assert len(results) == 1 assert results[0]["distance"] == 0.85 @pytest.mark.asyncio async def test_query_filters_below_threshold( self, global_config, embed_func, mock_client ): """Low scores should be filtered out.""" # score 0.15 < threshold 0.2 mock_client.search = AsyncMock( return_value={ "hits": { "hits": [ { "_id": "v1", "_score": 0.15, "_source": {"content": "weak match"}, }, ], "total": {"value": 1}, }, "aggregations": { "status_counts": {"buckets": []}, "src": {"buckets": []}, "tgt": {"buckets": []}, }, } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() results = await s.query("test", top_k=5) assert len(results) == 0 @pytest.mark.asyncio async def test_query_with_provided_embedding( self, global_config, embed_func, mock_client ): mock_client.search = AsyncMock( return_value={ "hits": { "hits": [ {"_id": "v1", "_score": 1.0, "_source": {"content": "exact"}}, ], "total": {"value": 1}, }, "aggregations": { "status_counts": {"buckets": []}, "src": {"buckets": []}, "tgt": {"buckets": []}, }, } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() vec = np.random.rand(128).astype(np.float32) results = await s.query("test", top_k=5, query_embedding=vec) assert len(results) == 1 assert results[0]["distance"] == 1.0 @pytest.mark.asyncio async def test_get_by_id(self, global_config, embed_func, mock_client): mock_client.mget = AsyncMock( return_value={ "docs": [ { "_id": "v1", "found": True, "_source": {"content": "hello", "vector": [0.1] * 128}, } ] } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() doc = await s.get_by_id("v1") assert doc["id"] == "v1" assert doc["content"] == "hello" # vector field is stripped on the mget path to match NanoVectorDB assert "vector" not in doc mock_client.mget.assert_awaited_once_with( index=s._index_name, body={"ids": ["v1"]}, _source_excludes=["vector"], ) @pytest.mark.asyncio async def test_get_by_id_not_found(self, global_config, embed_func, mock_client): mock_client.mget = AsyncMock( return_value={"docs": [{"_id": "missing", "found": False}]} ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert await s.get_by_id("missing") is None mock_client.get.assert_not_awaited() @pytest.mark.asyncio async def test_get_by_ids(self, global_config, embed_func, mock_client): mock_client.mget = AsyncMock( return_value={ "docs": [ {"_id": "v1", "found": True, "_source": {"content": "a"}}, {"_id": "v2", "found": True, "_source": {"content": "b"}}, ] } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() docs = await s.get_by_ids(["v1", "v2"]) assert docs[0]["id"] == "v1" assert docs[1]["id"] == "v2" @pytest.mark.asyncio async def test_get_vectors_by_ids(self, global_config, embed_func, mock_client): vec = [0.1] * 128 mock_client.mget = AsyncMock( return_value={ "docs": [ {"_id": "v1", "found": True, "_source": {"vector": vec}}, {"_id": "v2", "found": False}, ] } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() result = await s.get_vectors_by_ids(["v1", "v2"]) assert "v1" in result assert "v2" not in result assert result["v1"] == vec @pytest.mark.asyncio async def test_delete(self, global_config, embed_func, mock_client): """delete() buffers ids; the actual bulk delete fires on flush.""" with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (2, []) s = self._make(global_config, embed_func) await s.initialize() await s.delete(["v1", "v2"]) mock_bulk.assert_not_awaited() assert s._pending_vector_deletes == {"v1", "v2"} await s.index_done_callback() mock_bulk.assert_awaited_once() actions = mock_bulk.call_args[0][1] assert len(actions) == 2 assert all(a["_op_type"] == "delete" for a in actions) @pytest.mark.asyncio async def test_delete_entity(self, global_config, embed_func, mock_client): """delete_entity buffers a tombstone for the computed mdhash id.""" with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.delete_entity("Alice") # No direct client.delete call -- delete is buffered for batched flush. mock_client.delete.assert_not_awaited() assert len(s._pending_vector_deletes) == 1 @pytest.mark.asyncio async def test_delete_entity_relation(self, global_config, embed_func, mock_client): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.delete_entity_relation("Alice") mock_client.delete_by_query.assert_awaited_once() @pytest.mark.asyncio async def test_drop_recreates_index(self, global_config, embed_func, mock_client): # After drop, _create_knn_index_if_not_exists is called again. # First call (init): exists=False -> create. Second call (after drop): exists=False -> create again. mock_client.indices.exists = AsyncMock(return_value=False) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() result = await s.drop() assert result["status"] == "success" mock_client.indices.delete.assert_awaited_once() # create called twice: once during init, once during drop recreate assert mock_client.indices.create.await_count == 2 @pytest.mark.asyncio async def test_drop_delete_error_marks_index_not_ready( self, global_config, embed_func, mock_client ): mock_client.indices.delete = AsyncMock( side_effect=OpenSearchException("delete failed") ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() result = await s.drop() assert result["status"] == "error" assert s._index_ready is False @pytest.mark.asyncio async def test_drop_recreate_error_marks_index_not_ready( self, global_config, embed_func, mock_client ): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() with patch.object( s, "_create_knn_index_if_not_exists", new=AsyncMock(side_effect=OpenSearchException("recreate failed")), ): result = await s.drop() assert result["status"] == "error" assert s._index_ready is False @pytest.mark.asyncio async def test_drop_recreates_index_when_missing( self, global_config, embed_func, mock_client ): mock_client.indices.exists = AsyncMock(return_value=False) mock_client.indices.delete = AsyncMock( side_effect=NotFoundError(404, "not found") ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() result = await s.drop() assert result["status"] == "success" assert mock_client.indices.create.await_count == 2 @pytest.mark.asyncio async def test_reads_short_circuit_when_index_not_ready( self, global_config, embed_func, mock_client ): with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() s._index_ready = False assert await s.query("test", top_k=5) == [] assert await s.get_by_id("v1") is None assert await s.get_vectors_by_ids(["v1"]) == {} mock_client.search.assert_not_awaited() mock_client.mget.assert_not_awaited() @pytest.mark.asyncio async def test_read_missing_index_demotes_readiness( self, global_config, embed_func, mock_client ): mock_client.search = AsyncMock(side_effect=_missing_index_error()) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() assert await s.query("test", top_k=5) == [] assert await s.query("test", top_k=5) == [] assert s._index_ready is False assert mock_client.search.await_count == 1 # --------------------------------------------------------------------------- # Vector storage write batching (issue #2785) # --------------------------------------------------------------------------- class TestVectorStorageBatching: """Tests for the buffered upsert/delete + flush behaviour added for #2785.""" def _make(self, global_config, embed_func, workspace="test"): return OpenSearchVectorDBStorage( namespace="entities", global_config=global_config, embedding_func=embed_func, workspace=workspace, meta_fields={"content", "entity_name", "src_id", "tgt_id"}, ) @pytest.mark.asyncio async def test_repeated_upserts_flush_in_single_bulk_call( self, global_config, embed_func, mock_client ): """Many small upsert() calls collapse to one async_bulk on flush.""" embed_func = CountingEmbeddingFunc() with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (5, []) s = self._make(global_config, embed_func) await s.initialize() for i in range(5): await s.upsert({f"v{i}": {"content": f"doc {i}"}}) mock_bulk.assert_not_awaited() assert embed_func.call_count == 0 await s.index_done_callback() assert embed_func.call_count == 1 assert embed_func.batches == [[f"doc {i}" for i in range(5)]] mock_bulk.assert_awaited_once() actions = mock_bulk.call_args[0][1] assert len(actions) == 5 assert {a["_id"] for a in actions} == {f"v{i}" for i in range(5)} @pytest.mark.asyncio async def test_deferred_embeddings_respect_batch_size( self, global_config, embed_func, mock_client ): """Flush batches deferred embeddings by embedding_batch_num.""" embed_func = CountingEmbeddingFunc() config = {**global_config, "embedding_batch_num": 2} with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (5, []) s = self._make(config, embed_func) await s.initialize() for i in range(5): await s.upsert({f"v{i}": {"content": f"doc {i}"}}) await s.index_done_callback() assert embed_func.batches == [ ["doc 0", "doc 1"], ["doc 2", "doc 3"], ["doc 4"], ] mock_bulk.assert_awaited_once() @pytest.mark.asyncio async def test_upsert_overwrites_pending_doc_for_same_id( self, global_config, embed_func, mock_client ): """Upserting the same id twice keeps only the latest payload.""" embed_func = CountingEmbeddingFunc() with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"v1": {"content": "first"}}) await s.upsert({"v1": {"content": "second"}}) await s.index_done_callback() assert embed_func.call_count == 1 assert embed_func.texts == ["second"] actions = mock_bulk.call_args[0][1] assert len(actions) == 1 assert actions[0]["_source"]["content"] == "second" @pytest.mark.asyncio async def test_delete_cancels_pending_upsert( self, global_config, embed_func, mock_client ): """A delete after a buffered upsert removes the upsert from the buffer.""" with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"v1": {"content": "doomed"}}) await s.delete(["v1"]) assert "v1" not in s._pending_vector_docs assert "v1" in s._pending_vector_deletes await s.index_done_callback() actions = mock_bulk.call_args[0][1] assert len(actions) == 1 assert actions[0]["_op_type"] == "delete" @pytest.mark.asyncio async def test_upsert_cancels_pending_delete( self, global_config, embed_func, mock_client ): """An upsert after a buffered delete removes the tombstone.""" with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() await s.delete(["v1"]) await s.upsert({"v1": {"content": "resurrected"}}) assert "v1" not in s._pending_vector_deletes assert "v1" in s._pending_vector_docs await s.index_done_callback() actions = mock_bulk.call_args[0][1] assert len(actions) == 1 assert actions[0]["_op_type"] == "index" @pytest.mark.asyncio async def test_get_by_id_reads_pending_buffer( self, global_config, embed_func, mock_client ): """Buffered upserts are visible to get_by_id without hitting OpenSearch.""" with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"v1": {"content": "buffered"}}) doc = await s.get_by_id("v1") assert doc is not None assert doc["id"] == "v1" assert doc["content"] == "buffered" # Vector field is hidden from get_by_id results, mirroring the # _source excludes used by query(). assert "vector" not in doc mock_client.mget.assert_not_awaited() @pytest.mark.asyncio async def test_get_by_id_returns_none_for_pending_delete( self, global_config, embed_func, mock_client ): """A pending tombstone shadows any persisted doc.""" mock_client.mget = AsyncMock() # would be wrong to invoke with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.delete(["v1"]) assert await s.get_by_id("v1") is None mock_client.mget.assert_not_awaited() @pytest.mark.asyncio async def test_get_by_ids_merges_buffer_and_index( self, global_config, embed_func, mock_client ): """get_by_ids returns buffered docs and falls back to mget for the rest.""" mock_client.mget = AsyncMock( return_value={ "docs": [ {"_id": "v2", "found": True, "_source": {"content": "from_index"}}, ] } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"v1": {"content": "buffered"}}) docs = await s.get_by_ids(["v1", "v2"]) assert docs[0]["content"] == "buffered" assert docs[1]["content"] == "from_index" # Only the unbuffered id is requested from OpenSearch, # and vector is excluded server-side. mock_client.mget.assert_awaited_once_with( index=s._index_name, body={"ids": ["v2"]}, _source_excludes=["vector"], ) @pytest.mark.asyncio async def test_get_vectors_by_ids_uses_buffer( self, global_config, embed_func, mock_client ): """get_vectors_by_ids returns buffered embeddings without an mget roundtrip.""" embed_func = CountingEmbeddingFunc() with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"v1": {"content": "x"}}) assert embed_func.call_count == 0 vecs = await s.get_vectors_by_ids(["v1"]) assert "v1" in vecs assert len(vecs["v1"]) == 128 assert embed_func.call_count == 1 assert s._pending_vector_docs["v1"].vector == vecs["v1"] mock_client.mget.assert_not_awaited() @pytest.mark.asyncio async def test_lazy_get_vectors_cache_is_reused_by_flush( self, global_config, embed_func, mock_client ): """A lazy pending-vector read should not force a second embedding during flush.""" embed_func = CountingEmbeddingFunc() with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"v1": {"content": "x"}}) vecs = await s.get_vectors_by_ids(["v1"]) await s.index_done_callback() assert embed_func.call_count == 1 actions = mock_bulk.call_args[0][1] assert actions[0]["_source"]["vector"] == vecs["v1"] @pytest.mark.asyncio async def test_finalize_flushes_pending_ops( self, global_config, embed_func, mock_client ): """finalize() flushes buffered writes before releasing the client.""" with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"v1": {"content": "to flush"}}) await s.finalize() mock_bulk.assert_awaited_once() assert s.client is None @pytest.mark.asyncio async def test_vector_finalize_raises_when_retryable_buffer_remains( self, global_config, embed_func, mock_client ): """finalize() must surface a RuntimeError when retryable bulk failures left vector rows buffered, otherwise the upstream finalize_storages() call would log the storage as successfully finalized while writes are silently lost. The client is still released regardless to avoid connection leak. """ with patch.object(ClientManager, "get_client", return_value=mock_client): with patch.object( ClientManager, "release_client", new_callable=AsyncMock ) as mock_release: with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock, ) as mock_bulk: mock_bulk.return_value = ( 0, [{"index": {"_id": "v1", "status": 503, "error": "down"}}], ) s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"v1": {"content": "stuck"}}) with pytest.raises(RuntimeError, match="pending upserts"): await s.finalize() mock_release.assert_awaited_once() assert s.client is None @pytest.mark.asyncio async def test_vector_finalize_propagates_flush_exception( self, global_config, embed_func, mock_client ): """If async_bulk raises during the final flush, finalize() still releases the client and wraps the original error in a RuntimeError that names the unflushed buffer counts. """ with patch.object(ClientManager, "get_client", return_value=mock_client): with patch.object( ClientManager, "release_client", new_callable=AsyncMock ) as mock_release: with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock, ) as mock_bulk: mock_bulk.side_effect = OpenSearchException("connection reset") s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"v1": {"content": "stuck"}}) with pytest.raises(RuntimeError) as exc_info: await s.finalize() assert isinstance(exc_info.value.__cause__, OpenSearchException) mock_release.assert_awaited_once() assert s.client is None @pytest.mark.asyncio async def test_vector_finalize_propagates_cancellation( self, global_config, embed_func, mock_client ): """asyncio.CancelledError raised during the final flush must propagate UN-wrapped so the shutdown sequence honours the cancellation signal. The client is still released (finally block) before the cancellation continues. """ with patch.object(ClientManager, "get_client", return_value=mock_client): with patch.object( ClientManager, "release_client", new_callable=AsyncMock ) as mock_release: with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock, ) as mock_bulk: mock_bulk.side_effect = asyncio.CancelledError() s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"v1": {"content": "stuck"}}) with pytest.raises(asyncio.CancelledError): await s.finalize() mock_release.assert_awaited_once() assert s.client is None @pytest.mark.asyncio async def test_drop_discards_pending_buffers( self, global_config, embed_func, mock_client ): """drop() throws away pending writes; nothing is flushed to a deleted index.""" with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"v1": {"content": "doomed"}}) await s.delete(["v2"]) await s.drop() assert s._pending_vector_docs == {} assert s._pending_vector_deletes == set() mock_bulk.assert_not_awaited() @pytest.mark.asyncio async def test_failed_flush_entries_retained_for_retry( self, global_config, embed_func, mock_client ): """Transient (5xx) per-doc failures stay buffered for the next flush.""" embed_func = CountingEmbeddingFunc() with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: # First flush: v1 succeeds, v2 fails with 503 (retryable). mock_bulk.side_effect = [ ( 1, [{"index": {"_id": "v2", "status": 503, "error": "down"}}], ), (1, []), ] s = self._make(global_config, embed_func) await s.initialize() await s.upsert( { "v1": {"content": "ok"}, "v2": {"content": "boom"}, } ) await s.index_done_callback() # v1 cleared, v2 retained for retry. assert "v1" not in s._pending_vector_docs assert "v2" in s._pending_vector_docs assert s._pending_vector_docs["v2"].vector is not None assert embed_func.call_count == 1 await s.index_done_callback() assert "v2" not in s._pending_vector_docs assert embed_func.call_count == 1 assert mock_bulk.await_count == 2 @pytest.mark.asyncio async def test_embedding_failure_leaves_pending_for_retry( self, global_config, embed_func, mock_client ): """Embedding failures behave like flush failures: buffers stay intact.""" embed_func = CountingEmbeddingFunc(fail_times=1) with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"v1": {"content": "retry me"}}) with pytest.raises(RuntimeError, match="embedding failed"): await s.index_done_callback() mock_bulk.assert_not_awaited() assert "v1" in s._pending_vector_docs assert s._pending_vector_docs["v1"].vector is None await s.index_done_callback() mock_bulk.assert_awaited_once() assert "v1" not in s._pending_vector_docs assert embed_func.call_count == 2 @pytest.mark.asyncio async def test_finalize_wraps_embedding_failure( self, global_config, embed_func, mock_client ): """finalize() reports pending buffers when deferred embedding fails.""" embed_func = CountingEmbeddingFunc(fail_times=1) with patch.object(ClientManager, "get_client", return_value=mock_client): with patch.object( ClientManager, "release_client", new_callable=AsyncMock ) as mock_release: with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock, ) as mock_bulk: s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"v1": {"content": "stuck"}}) with pytest.raises(RuntimeError, match="pending upserts"): await s.finalize() mock_bulk.assert_not_awaited() mock_release.assert_awaited_once() assert s.client is None assert "v1" in s._pending_vector_docs assert s._pending_vector_docs["v1"].vector is None @pytest.mark.asyncio async def test_delete_entity_relation_prunes_pending_buffer( self, global_config, embed_func, mock_client ): """Pending docs whose src_id/tgt_id match the entity are dropped before delete_by_query.""" with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() await s.upsert( { "rel-1": { "content": "Alice -> Bob", "src_id": "Alice", "tgt_id": "Bob", }, "rel-2": { "content": "Carol -> Dave", "src_id": "Carol", "tgt_id": "Dave", }, } ) await s.delete_entity_relation("Alice") assert "rel-1" not in s._pending_vector_docs assert "rel-2" in s._pending_vector_docs mock_client.delete_by_query.assert_awaited_once() def test_extract_bulk_failed_ids_classifies_by_status(self): from lightrag.kg.opensearch_impl import _extract_bulk_failed_ids # No failures -> empty containers. retryable, non_retryable = _extract_bulk_failed_ids(None) assert retryable == set() assert non_retryable == [] retryable, non_retryable = _extract_bulk_failed_ids([]) assert retryable == set() assert non_retryable == [] retryable, non_retryable = _extract_bulk_failed_ids( [ # Retryable: 5xx server error. {"index": {"_id": "r-500", "status": 500}}, # Retryable: rate-limited. {"index": {"_id": "r-429", "status": 429}}, # Retryable: missing status (network / parse failure). {"create": {"_id": "r-none"}}, # Non-retryable: bad request with dict-shape error. { "index": { "_id": "n-400", "status": 400, "error": { "type": "mapper_parsing_exception", "reason": "vector must be array", }, } }, # Non-retryable: not found on update (doc disappeared). {"update": {"_id": "n-404", "status": 404, "error": "not found"}}, # Special case: delete of missing doc -> dropped from BOTH # sets, since the row is already gone. {"delete": {"_id": "drop-404", "status": 404}}, # Malformed entries are skipped silently. "garbage", {"update": {}}, ] ) assert retryable == {"r-500", "r-429", "r-none"} non_retryable_ids = {op.doc_id for op in non_retryable} assert non_retryable_ids == {"n-400", "n-404"} by_id = {op.doc_id: op for op in non_retryable} # dict-shape error is summarised via "reason" assert by_id["n-400"].op == "index" assert by_id["n-400"].status == 400 assert "vector must be array" in by_id["n-400"].error # string-shape error is passed through assert by_id["n-404"].op == "update" assert by_id["n-404"].status == 404 assert by_id["n-404"].error == "not found" def test_extract_bulk_failed_ids_truncates_long_errors(self): from lightrag.kg.opensearch_impl import ( _extract_bulk_failed_ids, _BULK_ERROR_SUMMARY_MAX_LEN, ) long_reason = "x" * 1000 _, non_retryable = _extract_bulk_failed_ids( [ { "index": { "_id": "n-400", "status": 400, "error": {"reason": long_reason}, } } ] ) assert len(non_retryable) == 1 assert len(non_retryable[0].error) <= _BULK_ERROR_SUMMARY_MAX_LEN assert non_retryable[0].error.endswith("...") @pytest.mark.asyncio async def test_failed_flush_drops_non_retryable_entries( self, global_config, embed_func, mock_client ): """4xx (non-429) failures are dropped, not perpetually retried.""" with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: # v1 fails permanently (400 mapping error); v2 fails # transiently (503). mock_bulk.return_value = ( 0, [ {"index": {"_id": "v1", "status": 400, "error": "bad mapping"}}, {"index": {"_id": "v2", "status": 503, "error": "down"}}, ], ) s = self._make(global_config, embed_func) await s.initialize() await s.upsert( {"v1": {"content": "bad"}, "v2": {"content": "transient"}} ) await s.index_done_callback() # v1 is dropped (non-retryable), v2 is retained (retryable). assert "v1" not in s._pending_vector_docs assert "v2" in s._pending_vector_docs @pytest.mark.asyncio async def test_concurrent_writes_during_flush_are_serialised( self, global_config, embed_func, mock_client ): """All buffer writes acquire the namespace lock, so an upsert issued while a flush is in flight is blocked until the flush completes and then lands in the live buffer for the next flush. """ flush_started = asyncio.Event() flush_can_finish = asyncio.Event() async def slow_bulk(client, actions, raise_on_error=False, **kwargs): flush_started.set() await flush_can_finish.wait() return (len(actions), []) with patch.object(ClientManager, "get_client", return_value=mock_client): with patch("lightrag.kg.opensearch_impl.helpers.async_bulk", new=slow_bulk): s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"v1": {"content": "first"}}) flush_task = asyncio.create_task(s.index_done_callback()) await flush_started.wait() # The flush is holding the lock and awaiting async_bulk. # Issue a concurrent upsert via create_task so we can # assert it is blocked (a direct await would deadlock the # single-threaded event loop on the lock acquisition). concurrent_task = asyncio.create_task( s.upsert({"v2": {"content": "concurrent"}}) ) # Yield so the concurrent task gets a chance to start its # embedding computation and arrive at the lock. for _ in range(5): await asyncio.sleep(0) assert ( not concurrent_task.done() ), "concurrent upsert should be blocked by the flush lock" # v2 must not be visible in the buffer yet. assert "v2" not in s._pending_vector_docs # Release the bulk call; flush completes and the concurrent # upsert then finally writes v2 into the (now-empty) buffer. flush_can_finish.set() await flush_task await concurrent_task assert "v1" not in s._pending_vector_docs assert "v2" in s._pending_vector_docs @pytest.mark.asyncio async def test_concurrent_delete_during_flush_supersedes_retried_upsert( self, global_config, embed_func, mock_client ): """A delete that lands after a flush retains a transient failure wins over the retried upsert for the same id. Under the lock-everywhere model the delete runs strictly after the flush; the merge-back of the retryable v1 upsert is then cancelled by the delete in a single, sequential pass. """ flush_started = asyncio.Event() flush_can_finish = asyncio.Event() async def slow_bulk(client, actions, raise_on_error=False, **kwargs): flush_started.set() await flush_can_finish.wait() # Report v1's upsert as a transient failure so the flush # leaves it in the buffer for retry. return ( 0, [{"index": {"_id": "v1", "status": 503, "error": "down"}}], ) with patch.object(ClientManager, "get_client", return_value=mock_client): with patch("lightrag.kg.opensearch_impl.helpers.async_bulk", new=slow_bulk): s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"v1": {"content": "first"}}) flush_task = asyncio.create_task(s.index_done_callback()) await flush_started.wait() # Issue the concurrent delete; it queues behind the lock. delete_task = asyncio.create_task(s.delete(["v1"])) for _ in range(5): await asyncio.sleep(0) assert ( not delete_task.done() ), "concurrent delete should be blocked by the flush lock" flush_can_finish.set() await flush_task await delete_task # The retry left v1 in the docs buffer; the subsequent # delete then cancelled that upsert and replaced it with a # tombstone. assert "v1" not in s._pending_vector_docs assert "v1" in s._pending_vector_deletes @pytest.mark.asyncio async def test_get_by_id_strips_vector_from_mget_path( self, global_config, embed_func, mock_client ): """The mget fallback path returns the same shape as NanoVectorDB: no ``vector`` key, and the server-side _source_excludes is set so the embedding never crosses the wire in the first place. """ mock_client.mget = AsyncMock( return_value={ "docs": [ { "_id": "v1", "found": True, # defensive: server-side excludes might be ignored # in misconfigured indices; we still pop client-side. "_source": {"content": "from_index", "vector": [0.1] * 128}, } ] } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() # No upsert: buffer empty, falls through to mget. doc = await s.get_by_id("v1") assert doc is not None assert doc["id"] == "v1" assert doc["content"] == "from_index" assert "vector" not in doc mock_client.mget.assert_awaited_once_with( index=s._index_name, body={"ids": ["v1"]}, _source_excludes=["vector"], ) @pytest.mark.asyncio async def test_get_by_ids_strips_vector_from_mget_path( self, global_config, embed_func, mock_client ): """get_by_ids strips vector on the fallback path and forwards _source_excludes to mget.""" mock_client.mget = AsyncMock( return_value={ "docs": [ { "_id": "v1", "found": True, "_source": {"content": "a", "vector": [0.1] * 128}, }, { "_id": "v2", "found": True, "_source": {"content": "b", "vector": [0.2] * 128}, }, ] } ) with patch.object(ClientManager, "get_client", return_value=mock_client): s = self._make(global_config, embed_func) await s.initialize() docs = await s.get_by_ids(["v1", "v2"]) assert all(d is not None for d in docs) assert all("vector" not in d for d in docs) assert docs[0]["content"] == "a" assert docs[1]["content"] == "b" mock_client.mget.assert_awaited_once_with( index=s._index_name, body={"ids": ["v1", "v2"]}, _source_excludes=["vector"], ) @pytest.mark.asyncio async def test_non_retryable_logs_sample_ids( self, global_config, embed_func, mock_client, caplog ): """Non-retryable bulk failures log a sample with id/status/error.""" import logging as _logging failed = [ { "index": { "_id": f"v{i}", "status": 400, "error": { "type": "mapper_parsing_exception", "reason": f"bad field {i}", }, } } for i in range(6) ] # lightrag logger has propagate=False, so caplog's root handler # would miss these records. Re-enable propagation just for this # test so caplog can capture the warning we emit. lightrag_logger = _logging.getLogger("lightrag") original_propagate = lightrag_logger.propagate lightrag_logger.propagate = True try: with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock, ) as mock_bulk: mock_bulk.return_value = (0, failed) s = self._make(global_config, embed_func) await s.initialize() await s.upsert({f"v{i}": {"content": f"d{i}"} for i in range(6)}) with caplog.at_level("WARNING", logger="lightrag"): await s.index_done_callback() finally: lightrag_logger.propagate = original_propagate warning_text = "\n".join( rec.message for rec in caplog.records if rec.levelname == "WARNING" ) # Sample contains the first 5 ids with op/status/reason text. for i in range(5): assert f"v{i}" in warning_text assert "status=400" in warning_text assert "bad field" in warning_text # 6 permanent failures reported in aggregate. assert "6 vector ops" in warning_text @pytest.mark.asyncio async def test_index_done_callback_flushes_when_index_recreated( self, global_config, embed_func, mock_client ): """If the index was marked missing after writes were buffered, the callback must still flush — _flush_pending_vector_ops recreates the index via _ensure_index_ready before issuing the bulk call. """ # Sequence the indices.exists results so the second _create # invocation actually creates the index again. exists_responses = [False, False] mock_client.indices.exists = AsyncMock( side_effect=lambda **kw: exists_responses.pop(0) if exists_responses else False ) with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"v1": {"content": "ok"}}) # Simulate the index disappearing (e.g. via a read 404) # AFTER the write was buffered. s._mark_index_missing() await s.index_done_callback() # The buffer was flushed, even though _index_ready was # False at callback entry. mock_bulk.assert_awaited_once() assert s._pending_vector_docs == {} # The index was recreated as part of flush. assert mock_client.indices.create.await_count >= 2 @pytest.mark.asyncio async def test_delete_entity_relation_serialised_with_flush( self, global_config, embed_func, mock_client ): """delete_entity_relation runs entirely under the flush lock, so it cannot race with an in-flight bulk indexing operation.""" flush_started = asyncio.Event() flush_can_finish = asyncio.Event() delete_started = asyncio.Event() async def slow_bulk(client, actions, raise_on_error=False, **kwargs): flush_started.set() await flush_can_finish.wait() return (len(actions), []) async def watch_delete_by_query(**kwargs): delete_started.set() return {"deleted": 0} mock_client.delete_by_query = AsyncMock(side_effect=watch_delete_by_query) with patch.object(ClientManager, "get_client", return_value=mock_client): with patch("lightrag.kg.opensearch_impl.helpers.async_bulk", new=slow_bulk): s = self._make(global_config, embed_func) await s.initialize() await s.upsert( { "rel-1": { "content": "X", "src_id": "Alice", "tgt_id": "Bob", } } ) flush_task = asyncio.create_task(s.index_done_callback()) await flush_started.wait() # delete_by_query must NOT fire while bulk is still in flight. rel_task = asyncio.create_task(s.delete_entity_relation("Alice")) for _ in range(5): await asyncio.sleep(0) assert ( not delete_started.is_set() ), "delete_by_query should be blocked behind the flush lock" assert not rel_task.done() flush_can_finish.set() await flush_task await rel_task assert delete_started.is_set() @pytest.mark.asyncio async def test_drop_serialised_with_flush( self, global_config, embed_func, mock_client ): """drop must serialise with an in-flight flush; the index delete cannot land while bulk indexing is mid-request. """ flush_started = asyncio.Event() flush_can_finish = asyncio.Event() drop_delete_started = asyncio.Event() async def slow_bulk(client, actions, raise_on_error=False, **kwargs): flush_started.set() await flush_can_finish.wait() return (len(actions), []) async def watch_indices_delete(**kwargs): drop_delete_started.set() mock_client.indices.delete = AsyncMock(side_effect=watch_indices_delete) with patch.object(ClientManager, "get_client", return_value=mock_client): with patch("lightrag.kg.opensearch_impl.helpers.async_bulk", new=slow_bulk): s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"v1": {"content": "x"}}) flush_task = asyncio.create_task(s.index_done_callback()) await flush_started.wait() drop_task = asyncio.create_task(s.drop()) for _ in range(5): await asyncio.sleep(0) assert ( not drop_delete_started.is_set() ), "indices.delete should be blocked behind the flush lock" assert not drop_task.done() flush_can_finish.set() await flush_task await drop_task assert drop_delete_started.is_set() @pytest.mark.asyncio async def test_drop_serialised_with_flush_embedding_phase( self, global_config, mock_client ): """drop must also wait while deferred embedding runs under the flush lock.""" embedding_started = asyncio.Event() embedding_can_finish = asyncio.Event() drop_delete_started = asyncio.Event() class GatedEmbeddingFunc(MockEmbeddingFunc): async def __call__(self, texts, **kwargs): embedding_started.set() await embedding_can_finish.wait() return await super().__call__(texts, **kwargs) async def watch_indices_delete(**kwargs): drop_delete_started.set() mock_client.indices.delete = AsyncMock(side_effect=watch_indices_delete) embed_func = GatedEmbeddingFunc() with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock ) as mock_bulk: mock_bulk.return_value = (1, []) s = self._make(global_config, embed_func) await s.initialize() await s.upsert({"v1": {"content": "x"}}) flush_task = asyncio.create_task(s.index_done_callback()) await embedding_started.wait() drop_task = asyncio.create_task(s.drop()) for _ in range(5): await asyncio.sleep(0) assert ( not drop_delete_started.is_set() ), "indices.delete should be blocked during deferred embedding" assert not drop_task.done() embedding_can_finish.set() await flush_task await drop_task assert drop_delete_started.is_set() # --------------------------------------------------------------------------- # Cosine score edge cases # --------------------------------------------------------------------------- class TestScoreThreshold: """Verify that raw OpenSearch scores are compared directly against threshold.""" def test_above_threshold(self): assert 0.85 >= 0.2 def test_below_threshold(self): assert 0.15 < 0.2 def test_exact_threshold(self): assert 0.2 >= 0.2 # --------------------------------------------------------------------------- # Why raising EMBEDDING_BATCH_NUM does not lower the embedding call count # --------------------------------------------------------------------------- class TestEmbeddingBatchNumDiagnosis: """Pin down why bumping EMBEDDING_BATCH_NUM leaves the embedding call count (get_embedding_queue_status -> submitted_total) unchanged for entities/relations. ``merge_nodes_and_edges`` upserts entities/relations ONE id at a time: ``_merge_nodes_then_upsert`` calls ``entity_vdb.upsert({single})`` and ``_merge_edges_then_upsert`` calls ``relationships_vdb.upsert({single})`` (lightrag/operate.py). ``EMBEDDING_BATCH_NUM`` only slices the items *within one embedding pass* (``contents[i:i+batch]``). So the call count is governed by how many items reach a single embedding pass, not by the batch size -- raising the batch size only helps once >= 2 items are embedded together. """ def _make(self, batch_num, embed_func, workspace="diag"): config = { "embedding_batch_num": batch_num, "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2}, } return OpenSearchVectorDBStorage( namespace="entities", global_config=config, embedding_func=embed_func, workspace=workspace, meta_fields={"content", "entity_name"}, ) @staticmethod def _fake_bulk(_client, actions, *_args, **_kwargs): # async_bulk(raise_on_error=False) -> (success_count, failed_list). # Empty failed list = every buffered action persisted. return (len(actions), []) async def _run_per_item(self, batch_num, *, flush_each, n=100): """Upsert ``n`` entities one-at-a-time, mirroring the merge path. flush_each=True -> embed right after each single-item upsert, so every embedding pass sees exactly 1 item. This is the pre-defer / eager behaviour where ``upsert`` embeds inline. flush_each=False -> buffer every single-item upsert and flush once, i.e. the deferred-embedding design on this branch. """ embed = CountingEmbeddingFunc() mock_client = _make_client() with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock, ) as mock_bulk: mock_bulk.side_effect = self._fake_bulk s = self._make(batch_num, embed) await s.initialize() for i in range(n): await s.upsert( {f"ent-{i}": {"content": f"entity {i}", "entity_name": f"E{i}"}} ) if flush_each: await s.index_done_callback() if not flush_each: await s.index_done_callback() return embed @pytest.mark.asyncio async def test_per_item_embedding_makes_batch_num_a_noop(self): """Eager pattern: embedding happens once per single-item upsert. Reproduces the billing observation -- every embedding call carries exactly ONE item (~one entity's tokens) -- and bumping EMBEDDING_BATCH_NUM from 16 to 32 changes nothing. """ embed16 = await self._run_per_item(16, flush_each=True) embed32 = await self._run_per_item(32, flush_each=True) assert embed16.call_count == 100 assert embed32.call_count == 100 # Each embedding pass saw exactly one item, regardless of batch size. assert all(len(b) == 1 for b in embed16.batches) assert all(len(b) == 1 for b in embed32.batches) # The crux: raising the batch size did not reduce the call count. assert embed16.call_count == embed32.call_count @pytest.mark.asyncio async def test_deferred_flush_makes_batch_num_effective(self): """Deferred pattern: buffer all single-item upserts, flush once. Now EMBEDDING_BATCH_NUM finally governs the count: ceil(100/16)=7 vs ceil(100/32)=4. """ embed16 = await self._run_per_item(16, flush_each=False) embed32 = await self._run_per_item(32, flush_each=False) assert embed16.call_count == math.ceil(100 / 16) == 7 assert embed32.call_count == math.ceil(100 / 32) == 4 assert embed16.call_count != embed32.call_count # Every flushed batch respects the configured cap, and nothing is lost. assert all(len(b) <= 16 for b in embed16.batches) assert all(len(b) <= 32 for b in embed32.batches) assert len(embed16.texts) == 100 assert len(embed32.texts) == 100 @pytest.mark.asyncio async def test_single_multiitem_upsert_is_batched_like_chunks_vdb(self): """Contrast: chunks_vdb upserts a whole document's chunks in ONE call. When many items arrive in a single upsert/embedding pass, EMBEDDING_BATCH_NUM works as expected even with an immediate flush -- proving the determining factor is items-per-embedding-pass, not the storage backend. This is why batch_num visibly affects chunks but not per-id entity/relation upserts. """ embed16 = CountingEmbeddingFunc() embed32 = CountingEmbeddingFunc() for batch_num, embed in ((16, embed16), (32, embed32)): mock_client = _make_client() with patch.object(ClientManager, "get_client", return_value=mock_client): with patch( "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock, ) as mock_bulk: mock_bulk.side_effect = self._fake_bulk s = self._make(batch_num, embed) await s.initialize() # chunks_vdb.upsert(chunks): one call carrying 100 items. await s.upsert( { f"chunk-{i}": {"content": f"chunk {i}", "entity_name": ""} for i in range(100) } ) await s.index_done_callback() assert embed16.call_count == math.ceil(100 / 16) == 7 assert embed32.call_count == math.ceil(100 / 32) == 4 assert embed16.call_count != embed32.call_count