test_voyageai_embed.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. """Unit tests for lightrag.llm.voyageai.
  2. These tests mock voyageai.AsyncClient so they run fully offline.
  3. """
  4. from __future__ import annotations
  5. from unittest.mock import MagicMock, patch
  6. import numpy as np
  7. import pytest
  8. @pytest.fixture
  9. def fake_voyage_response():
  10. """Build a fake VoyageAI embed response with N rows of fixed-dim vectors."""
  11. def _make(n: int, dim: int = 8) -> MagicMock:
  12. rng = np.linspace(0.0, 1.0, num=dim, dtype=np.float32)
  13. rows = [rng.tolist() for _ in range(n)]
  14. resp = MagicMock()
  15. resp.embeddings = rows
  16. return resp
  17. return _make
  18. @pytest.fixture
  19. def patched_async_client(fake_voyage_response):
  20. """Patch voyageai.AsyncClient so each call returns a recorded response."""
  21. captured: list[dict] = []
  22. async def fake_embed(**kwargs):
  23. captured.append(kwargs)
  24. return fake_voyage_response(len(kwargs["texts"]))
  25. fake_client = MagicMock()
  26. fake_client.embed = fake_embed
  27. with patch(
  28. "lightrag.llm.voyageai.voyageai.AsyncClient", return_value=fake_client
  29. ) as m:
  30. yield captured, m
  31. @pytest.mark.asyncio
  32. async def test_voyageai_embed_passes_model(patched_async_client):
  33. """The function should forward the model parameter to the SDK."""
  34. captured, _ = patched_async_client
  35. from lightrag.llm.voyageai import voyageai_embed
  36. out = await voyageai_embed.func(
  37. texts=["hello", "world"], model="voyage-3-lite", api_key="fake"
  38. )
  39. assert isinstance(out, np.ndarray)
  40. assert out.shape[0] == 2
  41. assert len(captured) == 1
  42. assert captured[0]["model"] == "voyage-3-lite"
  43. @pytest.mark.asyncio
  44. async def test_voyageai_embed_accepts_legacy_voyage_api_key(
  45. patched_async_client, monkeypatch
  46. ):
  47. """Setting only VOYAGE_API_KEY (the SDK's name) must work for backward compat."""
  48. captured, _ = patched_async_client
  49. monkeypatch.delenv("VOYAGEAI_API_KEY", raising=False)
  50. monkeypatch.setenv("VOYAGE_API_KEY", "key-from-legacy-name")
  51. from lightrag.llm.voyageai import voyageai_embed
  52. await voyageai_embed.func(texts=["x"], model="voyage-3")
  53. assert len(captured) == 1
  54. @pytest.mark.asyncio
  55. async def test_voyageai_embed_accepts_voyageai_api_key(
  56. patched_async_client, monkeypatch
  57. ):
  58. """The newer VOYAGEAI_API_KEY name must also still work."""
  59. captured, _ = patched_async_client
  60. monkeypatch.delenv("VOYAGE_API_KEY", raising=False)
  61. monkeypatch.setenv("VOYAGEAI_API_KEY", "key-from-new-name")
  62. from lightrag.llm.voyageai import voyageai_embed
  63. await voyageai_embed.func(texts=["x"], model="voyage-3")
  64. assert len(captured) == 1
  65. @pytest.mark.asyncio
  66. async def test_voyageai_embed_raises_when_no_api_key(monkeypatch):
  67. """Without any API key configured the call should raise ValueError."""
  68. monkeypatch.delenv("VOYAGE_API_KEY", raising=False)
  69. monkeypatch.delenv("VOYAGEAI_API_KEY", raising=False)
  70. from lightrag.llm.voyageai import voyageai_embed
  71. with pytest.raises(ValueError, match="VOYAGE_API_KEY"):
  72. await voyageai_embed.func(texts=["x"])
  73. @pytest.mark.asyncio
  74. async def test_voyageai_embed_forwards_input_type(patched_async_client):
  75. """input_type kwarg must reach the SDK so callers can drive query/document selection."""
  76. captured, _ = patched_async_client
  77. from lightrag.llm.voyageai import voyageai_embed
  78. await voyageai_embed.func(texts=["q"], api_key="fake", input_type="query")
  79. await voyageai_embed.func(texts=["d"], api_key="fake", input_type="document")
  80. assert captured[0]["input_type"] == "query"
  81. assert captured[1]["input_type"] == "document"
  82. @pytest.mark.asyncio
  83. async def test_voyageai_embed_maps_context_to_input_type(patched_async_client):
  84. """LightRAG query/document context should drive VoyageAI's input_type."""
  85. captured, _ = patched_async_client
  86. from lightrag.llm.voyageai import voyageai_embed
  87. await voyageai_embed.func(texts=["q"], api_key="fake", context="query")
  88. await voyageai_embed.func(texts=["d"], api_key="fake", context="document")
  89. assert captured[0]["input_type"] == "query"
  90. assert captured[1]["input_type"] == "document"
  91. @pytest.mark.asyncio
  92. async def test_voyageai_embed_explicit_input_type_overrides_context(
  93. patched_async_client,
  94. ):
  95. """Explicit input_type must keep direct callers backward compatible."""
  96. captured, _ = patched_async_client
  97. from lightrag.llm.voyageai import voyageai_embed
  98. await voyageai_embed.func(
  99. texts=["x"], api_key="fake", input_type="document", context="query"
  100. )
  101. assert captured[0]["input_type"] == "document"
  102. def test_voyageai_embed_declares_asymmetric_support():
  103. from lightrag.llm.voyageai import voyageai_embed
  104. assert voyageai_embed.supports_asymmetric is True
  105. def test_anthropic_embed_deprecation_shim():
  106. """``anthropic_embed`` must remain importable and emit DeprecationWarning."""
  107. import warnings
  108. from lightrag.llm.anthropic import anthropic_embed # must not ImportError
  109. with warnings.catch_warnings(record=True) as caught:
  110. warnings.simplefilter("always")
  111. coro = anthropic_embed(texts=["x"], api_key="ignored-mock")
  112. # Close the coroutine to silence "never awaited" runtime warnings;
  113. # we only care that the deprecation warning fired at call time.
  114. if hasattr(coro, "close"):
  115. coro.close()
  116. assert any(
  117. issubclass(w.category, DeprecationWarning) for w in caught
  118. ), "anthropic_embed should warn DeprecationWarning"