test_gemini_llm.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import importlib
  2. import sys
  3. from types import ModuleType, SimpleNamespace
  4. import pytest
  5. def _load_gemini_module(monkeypatch, request):
  6. fake_pm = SimpleNamespace(
  7. is_installed=lambda name: True,
  8. install=lambda name: None,
  9. )
  10. class FakeGenerateContentConfig:
  11. def __init__(self, **kwargs):
  12. self.kwargs = kwargs
  13. class FakeHttpOptions:
  14. def __init__(self, **kwargs):
  15. self.kwargs = kwargs
  16. fake_types = SimpleNamespace(
  17. GenerateContentConfig=FakeGenerateContentConfig,
  18. HttpOptions=FakeHttpOptions,
  19. )
  20. fake_genai = SimpleNamespace(Client=lambda **kwargs: SimpleNamespace(kwargs=kwargs))
  21. fake_google_module = ModuleType("google")
  22. fake_google_module.genai = fake_genai
  23. fake_api_exceptions = SimpleNamespace(
  24. InternalServerError=type("InternalServerError", (Exception,), {}),
  25. ServiceUnavailable=type("ServiceUnavailable", (Exception,), {}),
  26. ResourceExhausted=type("ResourceExhausted", (Exception,), {}),
  27. GatewayTimeout=type("GatewayTimeout", (Exception,), {}),
  28. BadGateway=type("BadGateway", (Exception,), {}),
  29. DeadlineExceeded=type("DeadlineExceeded", (Exception,), {}),
  30. Aborted=type("Aborted", (Exception,), {}),
  31. Unknown=type("Unknown", (Exception,), {}),
  32. )
  33. fake_google_api_core = ModuleType("google.api_core")
  34. fake_google_api_core.exceptions = fake_api_exceptions
  35. monkeypatch.setitem(sys.modules, "pipmaster", fake_pm)
  36. monkeypatch.setitem(sys.modules, "google", fake_google_module)
  37. monkeypatch.setitem(sys.modules, "google.genai", SimpleNamespace(types=fake_types))
  38. monkeypatch.setitem(sys.modules, "google.api_core", fake_google_api_core)
  39. monkeypatch.setitem(sys.modules, "google.api_core.exceptions", fake_api_exceptions)
  40. # Force a fresh import of lightrag.llm.gemini against the fakes above,
  41. # and restore the original module (or absence) on teardown — otherwise
  42. # subsequent tests (e.g. tests/llm/test_asymmetric_embedding.py) inherit
  43. # this stubbed `genai.types` namespace and break with AttributeError on
  44. # types.EmbedContentConfig. Note: clearing sys.modules alone is not
  45. # enough — Python also caches the submodule as an attribute on the parent
  46. # package, and `from lightrag.llm import gemini` resolves via that
  47. # attribute. Both pointers must be cleared.
  48. parent = sys.modules.get("lightrag.llm")
  49. original_gemini = sys.modules.get("lightrag.llm.gemini")
  50. original_parent_attr = getattr(parent, "gemini", None) if parent else None
  51. sys.modules.pop("lightrag.llm.gemini", None)
  52. if parent is not None and hasattr(parent, "gemini"):
  53. delattr(parent, "gemini")
  54. def _restore_gemini():
  55. if original_gemini is not None:
  56. sys.modules["lightrag.llm.gemini"] = original_gemini
  57. else:
  58. sys.modules.pop("lightrag.llm.gemini", None)
  59. if parent is not None:
  60. if original_parent_attr is not None:
  61. parent.gemini = original_parent_attr
  62. elif hasattr(parent, "gemini"):
  63. delattr(parent, "gemini")
  64. request.addfinalizer(_restore_gemini)
  65. return importlib.import_module("lightrag.llm.gemini")
  66. def _make_fake_gemini_response(regular_text="", thought_text=""):
  67. parts = []
  68. if thought_text:
  69. parts.append(SimpleNamespace(text=thought_text, thought=True))
  70. if regular_text:
  71. parts.append(SimpleNamespace(text=regular_text, thought=False))
  72. return SimpleNamespace(
  73. candidates=[
  74. SimpleNamespace(content=SimpleNamespace(parts=parts)),
  75. ],
  76. usage_metadata=SimpleNamespace(
  77. prompt_token_count=1,
  78. candidates_token_count=2,
  79. total_token_count=3,
  80. ),
  81. )
  82. @pytest.mark.offline
  83. def test_gemini_maps_schema_response_format_to_response_json_schema(
  84. monkeypatch, request
  85. ):
  86. gemini_module = _load_gemini_module(monkeypatch, request)
  87. schema = {
  88. "type": "object",
  89. "properties": {"answer": {"type": "string"}},
  90. "required": ["answer"],
  91. }
  92. config = gemini_module._build_generation_config(
  93. base_config=None,
  94. system_prompt=None,
  95. response_format=schema,
  96. )
  97. assert config.kwargs["response_mime_type"] == "application/json"
  98. assert config.kwargs["response_json_schema"] == schema
  99. assert "response_schema" not in config.kwargs
  100. @pytest.mark.offline
  101. def test_gemini_unwraps_openai_json_schema_wrapper(monkeypatch, request):
  102. gemini_module = _load_gemini_module(monkeypatch, request)
  103. schema = {
  104. "type": "object",
  105. "properties": {"answer": {"type": "string"}},
  106. "required": ["answer"],
  107. }
  108. response_format = {
  109. "type": "json_schema",
  110. "json_schema": {
  111. "name": "answer_payload",
  112. "schema": schema,
  113. },
  114. }
  115. config = gemini_module._build_generation_config(
  116. base_config=None,
  117. system_prompt=None,
  118. response_format=response_format,
  119. )
  120. assert config.kwargs["response_mime_type"] == "application/json"
  121. assert config.kwargs["response_json_schema"] == schema
  122. @pytest.mark.offline
  123. def test_gemini_rejects_typed_response_format(monkeypatch, request):
  124. gemini_module = _load_gemini_module(monkeypatch, request)
  125. class FakeSchemaModel:
  126. pass
  127. with pytest.raises(TypeError, match="typed/Pydantic"):
  128. gemini_module._validate_gemini_response_format(FakeSchemaModel)
  129. @pytest.mark.offline
  130. def test_gemini_default_service_root_is_not_treated_as_custom_base_url(
  131. monkeypatch, request
  132. ):
  133. gemini_module = _load_gemini_module(monkeypatch, request)
  134. gemini_module._get_gemini_client.cache_clear()
  135. monkeypatch.delenv("GOOGLE_GENAI_USE_VERTEXAI", raising=False)
  136. client = gemini_module._get_gemini_client(
  137. "test-key",
  138. "https://generativelanguage.googleapis.com",
  139. 1234,
  140. )
  141. assert client.kwargs["api_key"] == "test-key"
  142. assert "http_options" in client.kwargs
  143. assert client.kwargs["http_options"].kwargs == {"timeout": 1234}
  144. @pytest.mark.offline
  145. def test_gemini_custom_base_url_is_preserved(monkeypatch, request):
  146. gemini_module = _load_gemini_module(monkeypatch, request)
  147. gemini_module._get_gemini_client.cache_clear()
  148. monkeypatch.delenv("GOOGLE_GENAI_USE_VERTEXAI", raising=False)
  149. client = gemini_module._get_gemini_client(
  150. "test-key",
  151. "https://proxy.example.com",
  152. 1234,
  153. )
  154. assert client.kwargs["http_options"].kwargs == {
  155. "base_url": "https://proxy.example.com",
  156. "timeout": 1234,
  157. }
  158. @pytest.mark.offline
  159. @pytest.mark.asyncio
  160. async def test_gemini_streaming_structured_output_disables_cot(monkeypatch, request):
  161. gemini_module = _load_gemini_module(monkeypatch, request)
  162. fake_stream_response = _make_fake_gemini_response(
  163. regular_text='{"answer":"ok"}',
  164. thought_text="this should not be included",
  165. )
  166. async def _single_chunk_stream(response):
  167. yield response
  168. async def _fake_generate_content_stream(**kwargs):
  169. return _single_chunk_stream(fake_stream_response)
  170. fake_client = SimpleNamespace(
  171. aio=SimpleNamespace(
  172. models=SimpleNamespace(
  173. generate_content_stream=_fake_generate_content_stream
  174. )
  175. )
  176. )
  177. monkeypatch.setattr(gemini_module, "_get_gemini_client", lambda *args: fake_client)
  178. stream = await gemini_module.gemini_complete_if_cache(
  179. model="gemini-model",
  180. prompt="hello",
  181. stream=True,
  182. enable_cot=True,
  183. response_format={"type": "json_object"},
  184. api_key="test-key",
  185. )
  186. chunks = []
  187. async for chunk in stream:
  188. chunks.append(chunk)
  189. assert "".join(chunks) == '{"answer":"ok"}'