test_asymmetric_embedding.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. """Unit tests for the task-aware (asymmetric) embedding feature.
  2. Covers:
  3. * ``wrap_embedding_func_with_attrs`` auto-detects ``supports_asymmetric``
  4. from the wrapped function's signature so users can't accidentally
  5. silently disable the feature by forgetting the flag.
  6. * ``EmbeddingFunc.__call__`` strips the ``context`` kwarg when the wrapped
  7. function does not declare ``supports_asymmetric=True`` (legacy back-compat).
  8. * ``jina_embed`` selects the right ``task`` from ``context`` when the caller
  9. leaves the new ``task=None`` default in place.
  10. * ``gemini_embed`` selects the right ``task_type`` from ``context``.
  11. * ``voyageai_embed`` selects the right ``input_type`` from ``context``.
  12. All tests are fully mocked; no live API calls.
  13. """
  14. from __future__ import annotations
  15. import base64
  16. from unittest.mock import MagicMock, patch
  17. import numpy as np
  18. import pytest
  19. from lightrag.api import config as api_config
  20. # ---------------------------------------------------------------------------
  21. # wrap_embedding_func_with_attrs auto-detection
  22. # ---------------------------------------------------------------------------
  23. def test_wrap_auto_detects_supports_asymmetric_when_context_present():
  24. """If the wrapped function takes ``context``, supports_asymmetric should be True."""
  25. from lightrag.utils import wrap_embedding_func_with_attrs
  26. @wrap_embedding_func_with_attrs(embedding_dim=4, max_token_size=64)
  27. async def my_embed(texts, context="document"):
  28. return np.zeros((len(texts), 4), dtype=np.float32)
  29. assert my_embed.supports_asymmetric is True
  30. def test_wrap_auto_detects_no_supports_asymmetric_for_legacy_func():
  31. """Legacy embed without ``context`` should default to supports_asymmetric=False."""
  32. from lightrag.utils import wrap_embedding_func_with_attrs
  33. @wrap_embedding_func_with_attrs(embedding_dim=4, max_token_size=64)
  34. async def legacy_embed(texts):
  35. return np.zeros((len(texts), 4), dtype=np.float32)
  36. assert legacy_embed.supports_asymmetric is False
  37. def test_wrap_explicit_supports_asymmetric_overrides_auto_detect():
  38. """Explicit kwarg must win over signature inspection."""
  39. from lightrag.utils import wrap_embedding_func_with_attrs
  40. @wrap_embedding_func_with_attrs(
  41. embedding_dim=4, max_token_size=64, supports_asymmetric=False
  42. )
  43. async def my_embed(texts, context="document"):
  44. return np.zeros((len(texts), 4), dtype=np.float32)
  45. assert my_embed.supports_asymmetric is False
  46. def test_wrap_auto_detects_per_function_when_decorator_reused():
  47. """Reusing a decorator must not share auto-detected support between functions."""
  48. from lightrag.utils import wrap_embedding_func_with_attrs
  49. decorator = wrap_embedding_func_with_attrs(embedding_dim=4, max_token_size=64)
  50. @decorator
  51. async def legacy_embed(texts):
  52. return np.zeros((len(texts), 4), dtype=np.float32)
  53. @decorator
  54. async def aware_embed(texts, context="document"):
  55. return np.zeros((len(texts), 4), dtype=np.float32)
  56. assert legacy_embed.supports_asymmetric is False
  57. assert aware_embed.supports_asymmetric is True
  58. # ---------------------------------------------------------------------------
  59. # EmbeddingFunc.__call__ strips context for legacy embeds
  60. # ---------------------------------------------------------------------------
  61. @pytest.mark.asyncio
  62. async def test_embedding_func_strips_context_for_legacy_func():
  63. """Legacy func that doesn't accept ``context`` must not see it (no TypeError)."""
  64. from lightrag.utils import EmbeddingFunc
  65. received_kwargs: list[dict] = []
  66. async def legacy_embed(texts):
  67. # If `context` were still in kwargs we'd never get here -- the call
  68. # would raise TypeError. So just record what we did receive.
  69. received_kwargs.append({"texts": texts})
  70. return np.zeros((len(texts), 4), dtype=np.float32)
  71. func = EmbeddingFunc(
  72. embedding_dim=4, max_token_size=64, supports_asymmetric=False, func=legacy_embed
  73. )
  74. out = await func(["a", "b"], context="query")
  75. assert out.shape == (2, 4)
  76. assert received_kwargs[0] == {"texts": ["a", "b"]}
  77. @pytest.mark.asyncio
  78. async def test_embedding_func_forwards_context_when_supported():
  79. from lightrag.utils import EmbeddingFunc
  80. received: list[str] = []
  81. async def aware_embed(texts, context="document"):
  82. received.append(context)
  83. return np.zeros((len(texts), 4), dtype=np.float32)
  84. func = EmbeddingFunc(
  85. embedding_dim=4, max_token_size=64, supports_asymmetric=True, func=aware_embed
  86. )
  87. await func(["a"], context="query")
  88. await func(["b"], context="document")
  89. assert received == ["query", "document"]
  90. # ---------------------------------------------------------------------------
  91. # API asymmetric opt-in resolution
  92. # ---------------------------------------------------------------------------
  93. def test_asymmetric_opt_in_is_off_when_toggle_is_unset_even_with_prefixes():
  94. assert (
  95. api_config.resolve_asymmetric_embedding_opt_in(
  96. binding="ollama",
  97. embedding_asymmetric=False,
  98. embedding_asymmetric_configured=False,
  99. query_prefix="search_query: ",
  100. query_prefix_configured=True,
  101. document_prefix=None,
  102. document_prefix_configured=False,
  103. )
  104. is False
  105. )
  106. def test_asymmetric_opt_in_explicit_false_disables_even_with_prefixes():
  107. assert (
  108. api_config.resolve_asymmetric_embedding_opt_in(
  109. binding="ollama",
  110. embedding_asymmetric=False,
  111. embedding_asymmetric_configured=True,
  112. query_prefix="search_query: ",
  113. query_prefix_configured=True,
  114. document_prefix=None,
  115. document_prefix_configured=False,
  116. )
  117. is False
  118. )
  119. @pytest.mark.parametrize("binding", ["jina", "gemini", "voyageai"])
  120. def test_asymmetric_opt_in_explicit_true_allows_provider_level_bindings(binding):
  121. assert (
  122. api_config.resolve_asymmetric_embedding_opt_in(
  123. binding=binding,
  124. embedding_asymmetric=True,
  125. embedding_asymmetric_configured=True,
  126. query_prefix=None,
  127. query_prefix_configured=False,
  128. document_prefix=None,
  129. document_prefix_configured=False,
  130. )
  131. is True
  132. )
  133. def test_asymmetric_opt_in_explicit_true_ignores_provider_prefixes():
  134. assert (
  135. api_config.resolve_asymmetric_embedding_opt_in(
  136. binding="jina",
  137. embedding_asymmetric=True,
  138. embedding_asymmetric_configured=True,
  139. query_prefix="search_query: ",
  140. query_prefix_configured=True,
  141. document_prefix=None,
  142. document_prefix_configured=False,
  143. )
  144. is True
  145. )
  146. def test_asymmetric_opt_in_explicit_true_requires_both_prefix_settings():
  147. with pytest.raises(ValueError, match="requires both"):
  148. api_config.resolve_asymmetric_embedding_opt_in(
  149. binding="ollama",
  150. embedding_asymmetric=True,
  151. embedding_asymmetric_configured=True,
  152. query_prefix="search_query: ",
  153. query_prefix_configured=True,
  154. document_prefix=None,
  155. document_prefix_configured=False,
  156. )
  157. def test_asymmetric_opt_in_explicit_true_accepts_no_prefix_sentinel_side():
  158. assert (
  159. api_config.resolve_asymmetric_embedding_opt_in(
  160. binding="ollama",
  161. embedding_asymmetric=True,
  162. embedding_asymmetric_configured=True,
  163. query_prefix="search_query: ",
  164. query_prefix_configured=True,
  165. document_prefix="",
  166. document_prefix_configured=True,
  167. )
  168. is True
  169. )
  170. def test_asymmetric_opt_in_explicit_true_rejects_both_sides_no_prefix():
  171. with pytest.raises(ValueError, match="At least one"):
  172. api_config.resolve_asymmetric_embedding_opt_in(
  173. binding="ollama",
  174. embedding_asymmetric=True,
  175. embedding_asymmetric_configured=True,
  176. query_prefix="",
  177. query_prefix_configured=True,
  178. document_prefix="",
  179. document_prefix_configured=True,
  180. )
  181. def test_get_embedding_prefix_config_uses_no_prefix_sentinel(monkeypatch):
  182. monkeypatch.setenv("EMBEDDING_DOCUMENT_PREFIX", api_config.NO_PREFIX_SENTINEL)
  183. assert api_config.get_embedding_prefix_config("EMBEDDING_DOCUMENT_PREFIX") == (
  184. "",
  185. True,
  186. )
  187. def test_get_embedding_prefix_config_rejects_empty_env_value(monkeypatch):
  188. monkeypatch.setenv("EMBEDDING_DOCUMENT_PREFIX", "")
  189. with pytest.raises(ValueError, match=api_config.NO_PREFIX_SENTINEL):
  190. api_config.get_embedding_prefix_config("EMBEDDING_DOCUMENT_PREFIX")
  191. # ---------------------------------------------------------------------------
  192. # jina_embed: task auto-selection from context
  193. # ---------------------------------------------------------------------------
  194. def _fake_jina_response(num: int, dim: int = 4) -> list[dict]:
  195. arr = np.zeros((num, dim), dtype=np.float32)
  196. return [
  197. {"embedding": base64.b64encode(arr[i].tobytes()).decode()} for i in range(num)
  198. ]
  199. @pytest.mark.asyncio
  200. async def test_jina_default_task_is_query_when_context_query(monkeypatch):
  201. """Default ``task=None`` + ``context='query'`` must produce ``retrieval.query``."""
  202. monkeypatch.setenv("JINA_API_KEY", "fake")
  203. from lightrag.llm import jina as jina_mod
  204. captured: list[dict] = []
  205. async def fake_fetch(url, headers, data):
  206. captured.append(data)
  207. return _fake_jina_response(len(data["input"]))
  208. with patch.object(jina_mod, "fetch_data", side_effect=fake_fetch):
  209. await jina_mod.jina_embed.func(texts=["q1"], context="query")
  210. assert captured[0]["task"] == "retrieval.query"
  211. @pytest.mark.asyncio
  212. async def test_jina_default_task_is_passage_when_context_document(monkeypatch):
  213. monkeypatch.setenv("JINA_API_KEY", "fake")
  214. from lightrag.llm import jina as jina_mod
  215. captured: list[dict] = []
  216. async def fake_fetch(url, headers, data):
  217. captured.append(data)
  218. return _fake_jina_response(len(data["input"]))
  219. with patch.object(jina_mod, "fetch_data", side_effect=fake_fetch):
  220. await jina_mod.jina_embed.func(texts=["d1", "d2"], context="document")
  221. assert captured[0]["task"] == "retrieval.passage"
  222. @pytest.mark.asyncio
  223. async def test_jina_explicit_task_overrides_context(monkeypatch):
  224. monkeypatch.setenv("JINA_API_KEY", "fake")
  225. from lightrag.llm import jina as jina_mod
  226. captured: list[dict] = []
  227. async def fake_fetch(url, headers, data):
  228. captured.append(data)
  229. return _fake_jina_response(len(data["input"]))
  230. with patch.object(jina_mod, "fetch_data", side_effect=fake_fetch):
  231. await jina_mod.jina_embed.func(
  232. texts=["x"], context="query", task="text-matching"
  233. )
  234. assert captured[0]["task"] == "text-matching"
  235. # ---------------------------------------------------------------------------
  236. # gemini_embed: task_type auto-selection from context
  237. # ---------------------------------------------------------------------------
  238. @pytest.fixture
  239. def gemini_client_cache_cleared():
  240. """gemini.py caches its Client via lru_cache; clear it between tests."""
  241. pytest.importorskip("google.genai")
  242. from lightrag.llm import gemini as gemini_mod
  243. gemini_mod._get_gemini_client.cache_clear()
  244. yield
  245. gemini_mod._get_gemini_client.cache_clear()
  246. @pytest.mark.asyncio
  247. async def test_gemini_task_type_query_for_query_context(gemini_client_cache_cleared):
  248. pytest.importorskip("google.genai")
  249. from lightrag.llm import gemini as gemini_mod
  250. captured: list[dict] = []
  251. async def fake_embed_content(*, model, contents, config):
  252. captured.append({"task_type": getattr(config, "task_type", None)})
  253. resp = MagicMock()
  254. resp.embeddings = [MagicMock(values=[0.1] * 4) for _ in contents]
  255. return resp
  256. fake_client = MagicMock()
  257. fake_client.aio.models.embed_content = fake_embed_content
  258. with patch.object(gemini_mod.genai, "Client", return_value=fake_client):
  259. await gemini_mod.gemini_embed.func(
  260. texts=["q"], api_key="fake", context="query", task_type=None
  261. )
  262. assert captured[0]["task_type"] == "RETRIEVAL_QUERY"
  263. @pytest.mark.asyncio
  264. async def test_gemini_task_type_document_for_document_context(
  265. gemini_client_cache_cleared,
  266. ):
  267. pytest.importorskip("google.genai")
  268. from lightrag.llm import gemini as gemini_mod
  269. captured: list[dict] = []
  270. async def fake_embed_content(*, model, contents, config):
  271. captured.append({"task_type": getattr(config, "task_type", None)})
  272. resp = MagicMock()
  273. resp.embeddings = [MagicMock(values=[0.1] * 4) for _ in contents]
  274. return resp
  275. fake_client = MagicMock()
  276. fake_client.aio.models.embed_content = fake_embed_content
  277. with patch.object(gemini_mod.genai, "Client", return_value=fake_client):
  278. await gemini_mod.gemini_embed.func(
  279. texts=["d"], api_key="fake", context="document", task_type=None
  280. )
  281. assert captured[0]["task_type"] == "RETRIEVAL_DOCUMENT"
  282. @pytest.mark.asyncio
  283. async def test_gemini_explicit_task_type_overrides_context(gemini_client_cache_cleared):
  284. pytest.importorskip("google.genai")
  285. from lightrag.llm import gemini as gemini_mod
  286. captured: list[dict] = []
  287. async def fake_embed_content(*, model, contents, config):
  288. captured.append({"task_type": getattr(config, "task_type", None)})
  289. resp = MagicMock()
  290. resp.embeddings = [MagicMock(values=[0.1] * 4) for _ in contents]
  291. return resp
  292. fake_client = MagicMock()
  293. fake_client.aio.models.embed_content = fake_embed_content
  294. with patch.object(gemini_mod.genai, "Client", return_value=fake_client):
  295. await gemini_mod.gemini_embed.func(
  296. texts=["x"],
  297. api_key="fake",
  298. context="query",
  299. task_type="CLASSIFICATION",
  300. )
  301. assert captured[0]["task_type"] == "CLASSIFICATION"