test_batch_embeddings.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. """
  2. Tests for batch embedding pre-computation in _perform_kg_search().
  3. Verifies that kg_query batches all needed embeddings (query, ll_keywords,
  4. hl_keywords) into a single embedding API call instead of 3 sequential calls.
  5. """
  6. from unittest.mock import AsyncMock, MagicMock
  7. import numpy as np
  8. import pytest
  9. from lightrag.base import QueryParam
  10. def _make_mock_embedding_func(dim=1536):
  11. """Create a mock async embedding function that returns distinct vectors per input."""
  12. async def _embed(texts, **kwargs):
  13. return np.array(
  14. [np.full(dim, i + 1, dtype=np.float32) for i in range(len(texts))]
  15. )
  16. mock = AsyncMock(side_effect=_embed)
  17. return mock
  18. def _make_mock_kv_storage(embedding_func, global_config=None):
  19. mock = MagicMock()
  20. mock.embedding_func = embedding_func
  21. mock.global_config = global_config or {"kg_chunk_pick_method": "VECTOR"}
  22. return mock
  23. def _make_mock_vdb():
  24. """Create a mock VDB whose query() records the query_embedding it receives."""
  25. mock = AsyncMock()
  26. mock.query = AsyncMock(return_value=[])
  27. mock.cosine_better_than_threshold = 0.2
  28. return mock
  29. def _make_mock_graph():
  30. mock = AsyncMock()
  31. return mock
  32. @pytest.mark.offline
  33. @pytest.mark.asyncio
  34. async def test_hybrid_mode_batches_embeddings():
  35. """In hybrid mode with both keywords, embedding_func should be called exactly once."""
  36. from lightrag.operate import _perform_kg_search
  37. embed_func = _make_mock_embedding_func()
  38. text_chunks_db = _make_mock_kv_storage(embed_func)
  39. entities_vdb = _make_mock_vdb()
  40. relationships_vdb = _make_mock_vdb()
  41. knowledge_graph = _make_mock_graph()
  42. query_param = QueryParam(mode="hybrid", top_k=5)
  43. await _perform_kg_search(
  44. query="test query",
  45. ll_keywords="entity1, entity2",
  46. hl_keywords="theme1, theme2",
  47. knowledge_graph_inst=knowledge_graph,
  48. entities_vdb=entities_vdb,
  49. relationships_vdb=relationships_vdb,
  50. text_chunks_db=text_chunks_db,
  51. query_param=query_param,
  52. )
  53. # The embedding function should be called exactly once with all 3 texts batched
  54. assert (
  55. embed_func.call_count == 1
  56. ), f"Expected 1 batched embedding call, got {embed_func.call_count}"
  57. call_args = embed_func.call_args[0][0]
  58. assert len(call_args) == 3, f"Expected 3 texts in batch, got {len(call_args)}"
  59. assert call_args == ["test query", "entity1, entity2", "theme1, theme2"]
  60. @pytest.mark.offline
  61. @pytest.mark.asyncio
  62. async def test_hybrid_mode_passes_embeddings_to_vdbs():
  63. """Pre-computed embeddings should be forwarded to entities and relationships VDB queries."""
  64. from lightrag.operate import _perform_kg_search
  65. embed_func = _make_mock_embedding_func()
  66. text_chunks_db = _make_mock_kv_storage(embed_func)
  67. entities_vdb = _make_mock_vdb()
  68. relationships_vdb = _make_mock_vdb()
  69. knowledge_graph = _make_mock_graph()
  70. query_param = QueryParam(mode="hybrid", top_k=5)
  71. await _perform_kg_search(
  72. query="test query",
  73. ll_keywords="entity keywords",
  74. hl_keywords="theme keywords",
  75. knowledge_graph_inst=knowledge_graph,
  76. entities_vdb=entities_vdb,
  77. relationships_vdb=relationships_vdb,
  78. text_chunks_db=text_chunks_db,
  79. query_param=query_param,
  80. )
  81. # entities_vdb.query should receive ll_embedding (index 1 → all 2s)
  82. entities_call = entities_vdb.query.call_args
  83. assert entities_call is not None, "entities_vdb.query was not called"
  84. ll_embedding = entities_call.kwargs.get("query_embedding")
  85. assert ll_embedding is not None, "ll_embedding was not passed to entities_vdb.query"
  86. assert np.all(
  87. ll_embedding == 2.0
  88. ), f"Expected ll_embedding=[2,2,...], got {ll_embedding[:3]}"
  89. # relationships_vdb.query should receive hl_embedding (index 2 → all 3s)
  90. rel_call = relationships_vdb.query.call_args
  91. assert rel_call is not None, "relationships_vdb.query was not called"
  92. hl_embedding = rel_call.kwargs.get("query_embedding")
  93. assert (
  94. hl_embedding is not None
  95. ), "hl_embedding was not passed to relationships_vdb.query"
  96. assert np.all(
  97. hl_embedding == 3.0
  98. ), f"Expected hl_embedding=[3,3,...], got {hl_embedding[:3]}"
  99. @pytest.mark.offline
  100. @pytest.mark.asyncio
  101. async def test_local_mode_skips_hl_keywords():
  102. """In local mode, should only embed query + ll_keywords (skip hl_keywords)."""
  103. from lightrag.operate import _perform_kg_search
  104. embed_func = _make_mock_embedding_func()
  105. text_chunks_db = _make_mock_kv_storage(embed_func)
  106. entities_vdb = _make_mock_vdb()
  107. relationships_vdb = _make_mock_vdb()
  108. knowledge_graph = _make_mock_graph()
  109. query_param = QueryParam(mode="local", top_k=5)
  110. await _perform_kg_search(
  111. query="test query",
  112. ll_keywords="entity keywords",
  113. hl_keywords="theme keywords",
  114. knowledge_graph_inst=knowledge_graph,
  115. entities_vdb=entities_vdb,
  116. relationships_vdb=relationships_vdb,
  117. text_chunks_db=text_chunks_db,
  118. query_param=query_param,
  119. )
  120. assert embed_func.call_count == 1
  121. call_args = embed_func.call_args[0][0]
  122. assert len(call_args) == 2, f"Expected 2 texts (query + ll), got {len(call_args)}"
  123. assert "theme keywords" not in call_args
  124. @pytest.mark.offline
  125. @pytest.mark.asyncio
  126. async def test_global_mode_skips_ll_keywords():
  127. """In global mode, should only embed query + hl_keywords (skip ll_keywords)."""
  128. from lightrag.operate import _perform_kg_search
  129. embed_func = _make_mock_embedding_func()
  130. text_chunks_db = _make_mock_kv_storage(embed_func)
  131. entities_vdb = _make_mock_vdb()
  132. relationships_vdb = _make_mock_vdb()
  133. knowledge_graph = _make_mock_graph()
  134. query_param = QueryParam(mode="global", top_k=5)
  135. await _perform_kg_search(
  136. query="test query",
  137. ll_keywords="entity keywords",
  138. hl_keywords="theme keywords",
  139. knowledge_graph_inst=knowledge_graph,
  140. entities_vdb=entities_vdb,
  141. relationships_vdb=relationships_vdb,
  142. text_chunks_db=text_chunks_db,
  143. query_param=query_param,
  144. )
  145. assert embed_func.call_count == 1
  146. call_args = embed_func.call_args[0][0]
  147. assert len(call_args) == 2, f"Expected 2 texts (query + hl), got {len(call_args)}"
  148. assert "entity keywords" not in call_args
  149. @pytest.mark.offline
  150. @pytest.mark.asyncio
  151. async def test_embedding_failure_falls_back_gracefully():
  152. """If batch embedding fails, VDB queries should still work (fallback to individual calls)."""
  153. from lightrag.operate import _perform_kg_search
  154. embed_func = AsyncMock(side_effect=RuntimeError("API error"))
  155. text_chunks_db = _make_mock_kv_storage(embed_func)
  156. entities_vdb = _make_mock_vdb()
  157. relationships_vdb = _make_mock_vdb()
  158. knowledge_graph = _make_mock_graph()
  159. query_param = QueryParam(mode="hybrid", top_k=5)
  160. # Should not raise — graceful degradation
  161. await _perform_kg_search(
  162. query="test query",
  163. ll_keywords="entity keywords",
  164. hl_keywords="theme keywords",
  165. knowledge_graph_inst=knowledge_graph,
  166. entities_vdb=entities_vdb,
  167. relationships_vdb=relationships_vdb,
  168. text_chunks_db=text_chunks_db,
  169. query_param=query_param,
  170. )
  171. # VDB queries should still be called (with query_embedding=None fallback)
  172. entities_call = entities_vdb.query.call_args
  173. assert entities_call is not None
  174. assert entities_call.kwargs.get("query_embedding") is None
  175. rel_call = relationships_vdb.query.call_args
  176. assert rel_call is not None
  177. assert rel_call.kwargs.get("query_embedding") is None