| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401 |
- """Unit tests for the task-aware (asymmetric) embedding feature.
- Covers:
- * ``wrap_embedding_func_with_attrs`` auto-detects ``supports_asymmetric``
- from the wrapped function's signature so users can't accidentally
- silently disable the feature by forgetting the flag.
- * ``EmbeddingFunc.__call__`` strips the ``context`` kwarg when the wrapped
- function does not declare ``supports_asymmetric=True`` (legacy back-compat).
- * ``jina_embed`` selects the right ``task`` from ``context`` when the caller
- leaves the new ``task=None`` default in place.
- * ``gemini_embed`` selects the right ``task_type`` from ``context``.
- * ``voyageai_embed`` selects the right ``input_type`` from ``context``.
- All tests are fully mocked; no live API calls.
- """
- from __future__ import annotations
- import base64
- from unittest.mock import MagicMock, patch
- import numpy as np
- import pytest
- from lightrag.api import config as api_config
- # ---------------------------------------------------------------------------
- # wrap_embedding_func_with_attrs auto-detection
- # ---------------------------------------------------------------------------
- def test_wrap_auto_detects_supports_asymmetric_when_context_present():
- """If the wrapped function takes ``context``, supports_asymmetric should be True."""
- from lightrag.utils import wrap_embedding_func_with_attrs
- @wrap_embedding_func_with_attrs(embedding_dim=4, max_token_size=64)
- async def my_embed(texts, context="document"):
- return np.zeros((len(texts), 4), dtype=np.float32)
- assert my_embed.supports_asymmetric is True
- def test_wrap_auto_detects_no_supports_asymmetric_for_legacy_func():
- """Legacy embed without ``context`` should default to supports_asymmetric=False."""
- from lightrag.utils import wrap_embedding_func_with_attrs
- @wrap_embedding_func_with_attrs(embedding_dim=4, max_token_size=64)
- async def legacy_embed(texts):
- return np.zeros((len(texts), 4), dtype=np.float32)
- assert legacy_embed.supports_asymmetric is False
- def test_wrap_explicit_supports_asymmetric_overrides_auto_detect():
- """Explicit kwarg must win over signature inspection."""
- from lightrag.utils import wrap_embedding_func_with_attrs
- @wrap_embedding_func_with_attrs(
- embedding_dim=4, max_token_size=64, supports_asymmetric=False
- )
- async def my_embed(texts, context="document"):
- return np.zeros((len(texts), 4), dtype=np.float32)
- assert my_embed.supports_asymmetric is False
- def test_wrap_auto_detects_per_function_when_decorator_reused():
- """Reusing a decorator must not share auto-detected support between functions."""
- from lightrag.utils import wrap_embedding_func_with_attrs
- decorator = wrap_embedding_func_with_attrs(embedding_dim=4, max_token_size=64)
- @decorator
- async def legacy_embed(texts):
- return np.zeros((len(texts), 4), dtype=np.float32)
- @decorator
- async def aware_embed(texts, context="document"):
- return np.zeros((len(texts), 4), dtype=np.float32)
- assert legacy_embed.supports_asymmetric is False
- assert aware_embed.supports_asymmetric is True
- # ---------------------------------------------------------------------------
- # EmbeddingFunc.__call__ strips context for legacy embeds
- # ---------------------------------------------------------------------------
- @pytest.mark.asyncio
- async def test_embedding_func_strips_context_for_legacy_func():
- """Legacy func that doesn't accept ``context`` must not see it (no TypeError)."""
- from lightrag.utils import EmbeddingFunc
- received_kwargs: list[dict] = []
- async def legacy_embed(texts):
- # If `context` were still in kwargs we'd never get here -- the call
- # would raise TypeError. So just record what we did receive.
- received_kwargs.append({"texts": texts})
- return np.zeros((len(texts), 4), dtype=np.float32)
- func = EmbeddingFunc(
- embedding_dim=4, max_token_size=64, supports_asymmetric=False, func=legacy_embed
- )
- out = await func(["a", "b"], context="query")
- assert out.shape == (2, 4)
- assert received_kwargs[0] == {"texts": ["a", "b"]}
- @pytest.mark.asyncio
- async def test_embedding_func_forwards_context_when_supported():
- from lightrag.utils import EmbeddingFunc
- received: list[str] = []
- async def aware_embed(texts, context="document"):
- received.append(context)
- return np.zeros((len(texts), 4), dtype=np.float32)
- func = EmbeddingFunc(
- embedding_dim=4, max_token_size=64, supports_asymmetric=True, func=aware_embed
- )
- await func(["a"], context="query")
- await func(["b"], context="document")
- assert received == ["query", "document"]
- # ---------------------------------------------------------------------------
- # API asymmetric opt-in resolution
- # ---------------------------------------------------------------------------
- def test_asymmetric_opt_in_is_off_when_toggle_is_unset_even_with_prefixes():
- assert (
- api_config.resolve_asymmetric_embedding_opt_in(
- binding="ollama",
- embedding_asymmetric=False,
- embedding_asymmetric_configured=False,
- query_prefix="search_query: ",
- query_prefix_configured=True,
- document_prefix=None,
- document_prefix_configured=False,
- )
- is False
- )
- def test_asymmetric_opt_in_explicit_false_disables_even_with_prefixes():
- assert (
- api_config.resolve_asymmetric_embedding_opt_in(
- binding="ollama",
- embedding_asymmetric=False,
- embedding_asymmetric_configured=True,
- query_prefix="search_query: ",
- query_prefix_configured=True,
- document_prefix=None,
- document_prefix_configured=False,
- )
- is False
- )
- @pytest.mark.parametrize("binding", ["jina", "gemini", "voyageai"])
- def test_asymmetric_opt_in_explicit_true_allows_provider_level_bindings(binding):
- assert (
- api_config.resolve_asymmetric_embedding_opt_in(
- binding=binding,
- embedding_asymmetric=True,
- embedding_asymmetric_configured=True,
- query_prefix=None,
- query_prefix_configured=False,
- document_prefix=None,
- document_prefix_configured=False,
- )
- is True
- )
- def test_asymmetric_opt_in_explicit_true_ignores_provider_prefixes():
- assert (
- api_config.resolve_asymmetric_embedding_opt_in(
- binding="jina",
- embedding_asymmetric=True,
- embedding_asymmetric_configured=True,
- query_prefix="search_query: ",
- query_prefix_configured=True,
- document_prefix=None,
- document_prefix_configured=False,
- )
- is True
- )
- def test_asymmetric_opt_in_explicit_true_requires_both_prefix_settings():
- with pytest.raises(ValueError, match="requires both"):
- api_config.resolve_asymmetric_embedding_opt_in(
- binding="ollama",
- embedding_asymmetric=True,
- embedding_asymmetric_configured=True,
- query_prefix="search_query: ",
- query_prefix_configured=True,
- document_prefix=None,
- document_prefix_configured=False,
- )
- def test_asymmetric_opt_in_explicit_true_accepts_no_prefix_sentinel_side():
- assert (
- api_config.resolve_asymmetric_embedding_opt_in(
- binding="ollama",
- embedding_asymmetric=True,
- embedding_asymmetric_configured=True,
- query_prefix="search_query: ",
- query_prefix_configured=True,
- document_prefix="",
- document_prefix_configured=True,
- )
- is True
- )
- def test_asymmetric_opt_in_explicit_true_rejects_both_sides_no_prefix():
- with pytest.raises(ValueError, match="At least one"):
- api_config.resolve_asymmetric_embedding_opt_in(
- binding="ollama",
- embedding_asymmetric=True,
- embedding_asymmetric_configured=True,
- query_prefix="",
- query_prefix_configured=True,
- document_prefix="",
- document_prefix_configured=True,
- )
- def test_get_embedding_prefix_config_uses_no_prefix_sentinel(monkeypatch):
- monkeypatch.setenv("EMBEDDING_DOCUMENT_PREFIX", api_config.NO_PREFIX_SENTINEL)
- assert api_config.get_embedding_prefix_config("EMBEDDING_DOCUMENT_PREFIX") == (
- "",
- True,
- )
- def test_get_embedding_prefix_config_rejects_empty_env_value(monkeypatch):
- monkeypatch.setenv("EMBEDDING_DOCUMENT_PREFIX", "")
- with pytest.raises(ValueError, match=api_config.NO_PREFIX_SENTINEL):
- api_config.get_embedding_prefix_config("EMBEDDING_DOCUMENT_PREFIX")
- # ---------------------------------------------------------------------------
- # jina_embed: task auto-selection from context
- # ---------------------------------------------------------------------------
- def _fake_jina_response(num: int, dim: int = 4) -> list[dict]:
- arr = np.zeros((num, dim), dtype=np.float32)
- return [
- {"embedding": base64.b64encode(arr[i].tobytes()).decode()} for i in range(num)
- ]
- @pytest.mark.asyncio
- async def test_jina_default_task_is_query_when_context_query(monkeypatch):
- """Default ``task=None`` + ``context='query'`` must produce ``retrieval.query``."""
- monkeypatch.setenv("JINA_API_KEY", "fake")
- from lightrag.llm import jina as jina_mod
- captured: list[dict] = []
- async def fake_fetch(url, headers, data):
- captured.append(data)
- return _fake_jina_response(len(data["input"]))
- with patch.object(jina_mod, "fetch_data", side_effect=fake_fetch):
- await jina_mod.jina_embed.func(texts=["q1"], context="query")
- assert captured[0]["task"] == "retrieval.query"
- @pytest.mark.asyncio
- async def test_jina_default_task_is_passage_when_context_document(monkeypatch):
- monkeypatch.setenv("JINA_API_KEY", "fake")
- from lightrag.llm import jina as jina_mod
- captured: list[dict] = []
- async def fake_fetch(url, headers, data):
- captured.append(data)
- return _fake_jina_response(len(data["input"]))
- with patch.object(jina_mod, "fetch_data", side_effect=fake_fetch):
- await jina_mod.jina_embed.func(texts=["d1", "d2"], context="document")
- assert captured[0]["task"] == "retrieval.passage"
- @pytest.mark.asyncio
- async def test_jina_explicit_task_overrides_context(monkeypatch):
- monkeypatch.setenv("JINA_API_KEY", "fake")
- from lightrag.llm import jina as jina_mod
- captured: list[dict] = []
- async def fake_fetch(url, headers, data):
- captured.append(data)
- return _fake_jina_response(len(data["input"]))
- with patch.object(jina_mod, "fetch_data", side_effect=fake_fetch):
- await jina_mod.jina_embed.func(
- texts=["x"], context="query", task="text-matching"
- )
- assert captured[0]["task"] == "text-matching"
- # ---------------------------------------------------------------------------
- # gemini_embed: task_type auto-selection from context
- # ---------------------------------------------------------------------------
- @pytest.fixture
- def gemini_client_cache_cleared():
- """gemini.py caches its Client via lru_cache; clear it between tests."""
- pytest.importorskip("google.genai")
- from lightrag.llm import gemini as gemini_mod
- gemini_mod._get_gemini_client.cache_clear()
- yield
- gemini_mod._get_gemini_client.cache_clear()
- @pytest.mark.asyncio
- async def test_gemini_task_type_query_for_query_context(gemini_client_cache_cleared):
- pytest.importorskip("google.genai")
- from lightrag.llm import gemini as gemini_mod
- captured: list[dict] = []
- async def fake_embed_content(*, model, contents, config):
- captured.append({"task_type": getattr(config, "task_type", None)})
- resp = MagicMock()
- resp.embeddings = [MagicMock(values=[0.1] * 4) for _ in contents]
- return resp
- fake_client = MagicMock()
- fake_client.aio.models.embed_content = fake_embed_content
- with patch.object(gemini_mod.genai, "Client", return_value=fake_client):
- await gemini_mod.gemini_embed.func(
- texts=["q"], api_key="fake", context="query", task_type=None
- )
- assert captured[0]["task_type"] == "RETRIEVAL_QUERY"
- @pytest.mark.asyncio
- async def test_gemini_task_type_document_for_document_context(
- gemini_client_cache_cleared,
- ):
- pytest.importorskip("google.genai")
- from lightrag.llm import gemini as gemini_mod
- captured: list[dict] = []
- async def fake_embed_content(*, model, contents, config):
- captured.append({"task_type": getattr(config, "task_type", None)})
- resp = MagicMock()
- resp.embeddings = [MagicMock(values=[0.1] * 4) for _ in contents]
- return resp
- fake_client = MagicMock()
- fake_client.aio.models.embed_content = fake_embed_content
- with patch.object(gemini_mod.genai, "Client", return_value=fake_client):
- await gemini_mod.gemini_embed.func(
- texts=["d"], api_key="fake", context="document", task_type=None
- )
- assert captured[0]["task_type"] == "RETRIEVAL_DOCUMENT"
- @pytest.mark.asyncio
- async def test_gemini_explicit_task_type_overrides_context(gemini_client_cache_cleared):
- pytest.importorskip("google.genai")
- from lightrag.llm import gemini as gemini_mod
- captured: list[dict] = []
- async def fake_embed_content(*, model, contents, config):
- captured.append({"task_type": getattr(config, "task_type", None)})
- resp = MagicMock()
- resp.embeddings = [MagicMock(values=[0.1] * 4) for _ in contents]
- return resp
- fake_client = MagicMock()
- fake_client.aio.models.embed_content = fake_embed_content
- with patch.object(gemini_mod.genai, "Client", return_value=fake_client):
- await gemini_mod.gemini_embed.func(
- texts=["x"],
- api_key="fake",
- context="query",
- task_type="CLASSIFICATION",
- )
- assert captured[0]["task_type"] == "CLASSIFICATION"
|