test_zhipu_llm.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import importlib
  2. import sys
  3. from types import SimpleNamespace
  4. import numpy as np
  5. import pytest
  6. def _fake_embedding_vector(dim=1024):
  7. return [0.1] * dim
  8. def _fake_chat_response(content="", reasoning_content=""):
  9. message = SimpleNamespace(
  10. content=content,
  11. reasoning_content=reasoning_content,
  12. )
  13. return SimpleNamespace(choices=[SimpleNamespace(message=message)])
  14. def _load_zhipu_module(monkeypatch, client_factory):
  15. fake_pm = SimpleNamespace(
  16. is_installed=lambda name: True,
  17. install=lambda name: None,
  18. )
  19. fake_openai = SimpleNamespace(
  20. APIConnectionError=type("APIConnectionError", (Exception,), {}),
  21. RateLimitError=type("RateLimitError", (Exception,), {}),
  22. APITimeoutError=type("APITimeoutError", (Exception,), {}),
  23. )
  24. fake_zhipuai = SimpleNamespace(ZhipuAI=client_factory)
  25. monkeypatch.setitem(sys.modules, "pipmaster", fake_pm)
  26. monkeypatch.setitem(sys.modules, "openai", fake_openai)
  27. monkeypatch.setitem(sys.modules, "zhipuai", fake_zhipuai)
  28. sys.modules.pop("lightrag.llm.zhipu", None)
  29. return importlib.import_module("lightrag.llm.zhipu")
  30. @pytest.mark.offline
  31. @pytest.mark.asyncio
  32. async def test_zhipu_embedding_sends_dimensions_when_embedding_dim_provided(
  33. monkeypatch,
  34. ):
  35. captured_calls = []
  36. class FakeClient:
  37. def __init__(self, api_key=None):
  38. self.api_key = api_key
  39. self.embeddings = SimpleNamespace(create=self.create)
  40. def create(self, **kwargs):
  41. captured_calls.append(kwargs)
  42. return SimpleNamespace(
  43. data=[SimpleNamespace(embedding=_fake_embedding_vector())]
  44. )
  45. zhipu_module = _load_zhipu_module(monkeypatch, FakeClient)
  46. result = await zhipu_module.zhipu_embedding.func(
  47. ["hello"],
  48. api_key="test-key",
  49. embedding_dim=2048,
  50. )
  51. assert isinstance(result, np.ndarray)
  52. assert result.shape == (1, 1024)
  53. assert captured_calls == [
  54. {"model": "embedding-3", "input": ["hello"], "dimensions": 2048}
  55. ]
  56. @pytest.mark.offline
  57. @pytest.mark.asyncio
  58. async def test_zhipu_embedding_omits_dimensions_when_embedding_dim_not_provided(
  59. monkeypatch,
  60. ):
  61. captured_calls = []
  62. class FakeClient:
  63. def __init__(self, api_key=None):
  64. self.api_key = api_key
  65. self.embeddings = SimpleNamespace(create=self.create)
  66. def create(self, **kwargs):
  67. captured_calls.append(kwargs)
  68. return SimpleNamespace(
  69. data=[SimpleNamespace(embedding=_fake_embedding_vector())]
  70. )
  71. zhipu_module = _load_zhipu_module(monkeypatch, FakeClient)
  72. await zhipu_module.zhipu_embedding.func(["hello"], api_key="test-key")
  73. assert captured_calls == [{"model": "embedding-3", "input": ["hello"]}]
  74. @pytest.mark.offline
  75. @pytest.mark.asyncio
  76. async def test_zhipu_complete_forwards_official_thinking(monkeypatch):
  77. captured_calls = []
  78. class FakeClient:
  79. def __init__(self, api_key=None):
  80. self.api_key = api_key
  81. self.chat = SimpleNamespace(completions=SimpleNamespace(create=self.create))
  82. def create(self, **kwargs):
  83. captured_calls.append(kwargs)
  84. return _fake_chat_response(content="final answer")
  85. zhipu_module = _load_zhipu_module(monkeypatch, FakeClient)
  86. result = await zhipu_module.zhipu_complete_if_cache(
  87. prompt="hello",
  88. api_key="test-key",
  89. thinking={"type": "enabled"},
  90. )
  91. assert result == "final answer"
  92. assert captured_calls[0]["thinking"] == {"type": "enabled"}
  93. @pytest.mark.offline
  94. @pytest.mark.asyncio
  95. async def test_zhipu_complete_filters_reasoning_when_cot_disabled(monkeypatch):
  96. class FakeClient:
  97. def __init__(self, api_key=None):
  98. self.api_key = api_key
  99. self.chat = SimpleNamespace(completions=SimpleNamespace(create=self.create))
  100. def create(self, **kwargs):
  101. return _fake_chat_response(
  102. content="visible answer",
  103. reasoning_content="hidden chain of thought",
  104. )
  105. zhipu_module = _load_zhipu_module(monkeypatch, FakeClient)
  106. result = await zhipu_module.zhipu_complete_if_cache(
  107. prompt="hello",
  108. api_key="test-key",
  109. enable_cot=False,
  110. )
  111. assert result == "visible answer"
  112. @pytest.mark.offline
  113. @pytest.mark.asyncio
  114. async def test_zhipu_complete_includes_reasoning_when_cot_enabled(monkeypatch):
  115. class FakeClient:
  116. def __init__(self, api_key=None):
  117. self.api_key = api_key
  118. self.chat = SimpleNamespace(completions=SimpleNamespace(create=self.create))
  119. def create(self, **kwargs):
  120. return _fake_chat_response(
  121. content="visible answer",
  122. reasoning_content="hidden chain of thought",
  123. )
  124. zhipu_module = _load_zhipu_module(monkeypatch, FakeClient)
  125. result = await zhipu_module.zhipu_complete_if_cache(
  126. prompt="hello",
  127. api_key="test-key",
  128. enable_cot=True,
  129. )
  130. assert result == "<think>hidden chain of thought</think>visible answer"
  131. @pytest.mark.offline
  132. @pytest.mark.asyncio
  133. async def test_zhipu_keyword_extraction_ignores_reasoning_content(monkeypatch):
  134. class FakeClient:
  135. def __init__(self, api_key=None):
  136. self.api_key = api_key
  137. self.chat = SimpleNamespace(completions=SimpleNamespace(create=self.create))
  138. def create(self, **kwargs):
  139. return _fake_chat_response(
  140. content='{"high_level_keywords": ["AI"], "low_level_keywords": ["RAG"]}',
  141. reasoning_content="this should not be parsed",
  142. )
  143. zhipu_module = _load_zhipu_module(monkeypatch, FakeClient)
  144. with pytest.warns(DeprecationWarning):
  145. result = await zhipu_module.zhipu_complete(
  146. prompt="hello",
  147. api_key="test-key",
  148. keyword_extraction=True,
  149. enable_cot=True,
  150. )
  151. assert result == '{"high_level_keywords": ["AI"], "low_level_keywords": ["RAG"]}'
  152. @pytest.mark.offline
  153. @pytest.mark.asyncio
  154. async def test_zhipu_if_cache_entity_extraction_maps_to_json_object(monkeypatch):
  155. captured_calls = []
  156. class FakeClient:
  157. def __init__(self, api_key=None):
  158. self.api_key = api_key
  159. self.chat = SimpleNamespace(completions=SimpleNamespace(create=self.create))
  160. def create(self, **kwargs):
  161. captured_calls.append(kwargs)
  162. return _fake_chat_response(
  163. content='{"entities":[],"relationships":[]}',
  164. reasoning_content="this should not be parsed",
  165. )
  166. zhipu_module = _load_zhipu_module(monkeypatch, FakeClient)
  167. with pytest.warns(DeprecationWarning):
  168. result = await zhipu_module.zhipu_complete_if_cache(
  169. prompt="hello",
  170. api_key="test-key",
  171. entity_extraction=True,
  172. enable_cot=True,
  173. )
  174. assert result == '{"entities":[],"relationships":[]}'
  175. assert captured_calls[0]["response_format"] == {"type": "json_object"}
  176. assert "entity_extraction" not in captured_calls[0]
  177. @pytest.mark.offline
  178. @pytest.mark.asyncio
  179. async def test_zhipu_if_cache_structured_output_disables_cot(monkeypatch):
  180. class FakeClient:
  181. def __init__(self, api_key=None):
  182. self.api_key = api_key
  183. self.chat = SimpleNamespace(completions=SimpleNamespace(create=self.create))
  184. def create(self, **kwargs):
  185. return _fake_chat_response(
  186. content='{"answer":"ok"}',
  187. reasoning_content="this should not be included",
  188. )
  189. zhipu_module = _load_zhipu_module(monkeypatch, FakeClient)
  190. result = await zhipu_module.zhipu_complete_if_cache(
  191. prompt="hello",
  192. api_key="test-key",
  193. response_format={"type": "json_object"},
  194. enable_cot=True,
  195. )
  196. assert result == '{"answer":"ok"}'
  197. assert "<think>" not in result