| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223 |
- """
- Tests for batch embedding pre-computation in _perform_kg_search().
- Verifies that kg_query batches all needed embeddings (query, ll_keywords,
- hl_keywords) into a single embedding API call instead of 3 sequential calls.
- """
- from unittest.mock import AsyncMock, MagicMock
- import numpy as np
- import pytest
- from lightrag.base import QueryParam
- def _make_mock_embedding_func(dim=1536):
- """Create a mock async embedding function that returns distinct vectors per input."""
- async def _embed(texts, **kwargs):
- return np.array(
- [np.full(dim, i + 1, dtype=np.float32) for i in range(len(texts))]
- )
- mock = AsyncMock(side_effect=_embed)
- return mock
- def _make_mock_kv_storage(embedding_func, global_config=None):
- mock = MagicMock()
- mock.embedding_func = embedding_func
- mock.global_config = global_config or {"kg_chunk_pick_method": "VECTOR"}
- return mock
- def _make_mock_vdb():
- """Create a mock VDB whose query() records the query_embedding it receives."""
- mock = AsyncMock()
- mock.query = AsyncMock(return_value=[])
- mock.cosine_better_than_threshold = 0.2
- return mock
- def _make_mock_graph():
- mock = AsyncMock()
- return mock
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_hybrid_mode_batches_embeddings():
- """In hybrid mode with both keywords, embedding_func should be called exactly once."""
- from lightrag.operate import _perform_kg_search
- embed_func = _make_mock_embedding_func()
- text_chunks_db = _make_mock_kv_storage(embed_func)
- entities_vdb = _make_mock_vdb()
- relationships_vdb = _make_mock_vdb()
- knowledge_graph = _make_mock_graph()
- query_param = QueryParam(mode="hybrid", top_k=5)
- await _perform_kg_search(
- query="test query",
- ll_keywords="entity1, entity2",
- hl_keywords="theme1, theme2",
- knowledge_graph_inst=knowledge_graph,
- entities_vdb=entities_vdb,
- relationships_vdb=relationships_vdb,
- text_chunks_db=text_chunks_db,
- query_param=query_param,
- )
- # The embedding function should be called exactly once with all 3 texts batched
- assert (
- embed_func.call_count == 1
- ), f"Expected 1 batched embedding call, got {embed_func.call_count}"
- call_args = embed_func.call_args[0][0]
- assert len(call_args) == 3, f"Expected 3 texts in batch, got {len(call_args)}"
- assert call_args == ["test query", "entity1, entity2", "theme1, theme2"]
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_hybrid_mode_passes_embeddings_to_vdbs():
- """Pre-computed embeddings should be forwarded to entities and relationships VDB queries."""
- from lightrag.operate import _perform_kg_search
- embed_func = _make_mock_embedding_func()
- text_chunks_db = _make_mock_kv_storage(embed_func)
- entities_vdb = _make_mock_vdb()
- relationships_vdb = _make_mock_vdb()
- knowledge_graph = _make_mock_graph()
- query_param = QueryParam(mode="hybrid", top_k=5)
- await _perform_kg_search(
- query="test query",
- ll_keywords="entity keywords",
- hl_keywords="theme keywords",
- knowledge_graph_inst=knowledge_graph,
- entities_vdb=entities_vdb,
- relationships_vdb=relationships_vdb,
- text_chunks_db=text_chunks_db,
- query_param=query_param,
- )
- # entities_vdb.query should receive ll_embedding (index 1 → all 2s)
- entities_call = entities_vdb.query.call_args
- assert entities_call is not None, "entities_vdb.query was not called"
- ll_embedding = entities_call.kwargs.get("query_embedding")
- assert ll_embedding is not None, "ll_embedding was not passed to entities_vdb.query"
- assert np.all(
- ll_embedding == 2.0
- ), f"Expected ll_embedding=[2,2,...], got {ll_embedding[:3]}"
- # relationships_vdb.query should receive hl_embedding (index 2 → all 3s)
- rel_call = relationships_vdb.query.call_args
- assert rel_call is not None, "relationships_vdb.query was not called"
- hl_embedding = rel_call.kwargs.get("query_embedding")
- assert (
- hl_embedding is not None
- ), "hl_embedding was not passed to relationships_vdb.query"
- assert np.all(
- hl_embedding == 3.0
- ), f"Expected hl_embedding=[3,3,...], got {hl_embedding[:3]}"
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_local_mode_skips_hl_keywords():
- """In local mode, should only embed query + ll_keywords (skip hl_keywords)."""
- from lightrag.operate import _perform_kg_search
- embed_func = _make_mock_embedding_func()
- text_chunks_db = _make_mock_kv_storage(embed_func)
- entities_vdb = _make_mock_vdb()
- relationships_vdb = _make_mock_vdb()
- knowledge_graph = _make_mock_graph()
- query_param = QueryParam(mode="local", top_k=5)
- await _perform_kg_search(
- query="test query",
- ll_keywords="entity keywords",
- hl_keywords="theme keywords",
- knowledge_graph_inst=knowledge_graph,
- entities_vdb=entities_vdb,
- relationships_vdb=relationships_vdb,
- text_chunks_db=text_chunks_db,
- query_param=query_param,
- )
- assert embed_func.call_count == 1
- call_args = embed_func.call_args[0][0]
- assert len(call_args) == 2, f"Expected 2 texts (query + ll), got {len(call_args)}"
- assert "theme keywords" not in call_args
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_global_mode_skips_ll_keywords():
- """In global mode, should only embed query + hl_keywords (skip ll_keywords)."""
- from lightrag.operate import _perform_kg_search
- embed_func = _make_mock_embedding_func()
- text_chunks_db = _make_mock_kv_storage(embed_func)
- entities_vdb = _make_mock_vdb()
- relationships_vdb = _make_mock_vdb()
- knowledge_graph = _make_mock_graph()
- query_param = QueryParam(mode="global", top_k=5)
- await _perform_kg_search(
- query="test query",
- ll_keywords="entity keywords",
- hl_keywords="theme keywords",
- knowledge_graph_inst=knowledge_graph,
- entities_vdb=entities_vdb,
- relationships_vdb=relationships_vdb,
- text_chunks_db=text_chunks_db,
- query_param=query_param,
- )
- assert embed_func.call_count == 1
- call_args = embed_func.call_args[0][0]
- assert len(call_args) == 2, f"Expected 2 texts (query + hl), got {len(call_args)}"
- assert "entity keywords" not in call_args
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_embedding_failure_falls_back_gracefully():
- """If batch embedding fails, VDB queries should still work (fallback to individual calls)."""
- from lightrag.operate import _perform_kg_search
- embed_func = AsyncMock(side_effect=RuntimeError("API error"))
- text_chunks_db = _make_mock_kv_storage(embed_func)
- entities_vdb = _make_mock_vdb()
- relationships_vdb = _make_mock_vdb()
- knowledge_graph = _make_mock_graph()
- query_param = QueryParam(mode="hybrid", top_k=5)
- # Should not raise — graceful degradation
- await _perform_kg_search(
- query="test query",
- ll_keywords="entity keywords",
- hl_keywords="theme keywords",
- knowledge_graph_inst=knowledge_graph,
- entities_vdb=entities_vdb,
- relationships_vdb=relationships_vdb,
- text_chunks_db=text_chunks_db,
- query_param=query_param,
- )
- # VDB queries should still be called (with query_embedding=None fallback)
- entities_call = entities_vdb.query.call_args
- assert entities_call is not None
- assert entities_call.kwargs.get("query_embedding") is None
- rel_call = relationships_vdb.query.call_args
- assert rel_call is not None
- assert rel_call.kwargs.get("query_embedding") is None
|