zhipu.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. import sys
  2. import warnings
  3. from ..utils import verbose_debug
  4. if sys.version_info < (3, 9):
  5. pass
  6. else:
  7. pass
  8. import pipmaster as pm # Pipmaster for dynamic library install
  9. # install specific modules
  10. if not pm.is_installed("zhipuai"):
  11. pm.install("zhipuai")
  12. from openai import (
  13. APIConnectionError,
  14. RateLimitError,
  15. APITimeoutError,
  16. )
  17. from tenacity import (
  18. retry,
  19. stop_after_attempt,
  20. wait_exponential,
  21. retry_if_exception_type,
  22. )
  23. from lightrag.utils import (
  24. wrap_embedding_func_with_attrs,
  25. logger,
  26. )
  27. import numpy as np
  28. from typing import Union, List, Optional, Dict
  29. @retry(
  30. stop=stop_after_attempt(3),
  31. wait=wait_exponential(multiplier=1, min=4, max=10),
  32. retry=retry_if_exception_type(
  33. (RateLimitError, APIConnectionError, APITimeoutError)
  34. ),
  35. )
  36. async def zhipu_complete_if_cache(
  37. prompt: Union[str, List[Dict[str, str]]],
  38. model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
  39. api_key: Optional[str] = None,
  40. system_prompt: Optional[str] = None,
  41. history_messages: List[Dict[str, str]] = [],
  42. enable_cot: bool = False, # LightRAG output switch: include reasoning_content as <think>...</think>
  43. thinking: Optional[
  44. Dict[str, object]
  45. ] = None, # Zhipu request param: use {"type": "enabled"} to enable thinking
  46. **kwargs,
  47. ) -> str:
  48. """Call Zhipu chat completions with optional official thinking support.
  49. Parameter roles:
  50. - `thinking`: forwarded to the Zhipu API as-is. To enable thinking output,
  51. pass a config such as `{"type": "enabled"}`.
  52. - `enable_cot`: LightRAG-only formatting switch. When True and the API
  53. returns `reasoning_content`, it is preserved in the final string as
  54. `<think>...</think>`.
  55. - `response_format`: forwarded as Zhipu's OpenAI-compatible structured
  56. output parameter when supplied by callers.
  57. - Deprecated `keyword_extraction` and `entity_extraction` booleans are
  58. compatibility shims; when no explicit `response_format` is supplied,
  59. they are mapped to `{"type": "json_object"}`.
  60. """
  61. # dynamically load ZhipuAI
  62. try:
  63. from zhipuai import ZhipuAI
  64. except ImportError:
  65. raise ImportError("Please install zhipuai before initialize zhipuai backend.")
  66. if api_key:
  67. client = ZhipuAI(api_key=api_key)
  68. else:
  69. # please set ZHIPUAI_API_KEY in your environment
  70. # os.environ["ZHIPUAI_API_KEY"]
  71. client = ZhipuAI()
  72. messages = []
  73. if not system_prompt:
  74. system_prompt = "You are a helpful assistant. Note that sensitive words in the content should be replaced with ***"
  75. # Add system prompt if provided
  76. if system_prompt:
  77. messages.append({"role": "system", "content": system_prompt})
  78. messages.extend(history_messages)
  79. messages.append({"role": "user", "content": prompt})
  80. # Add debug logging
  81. logger.debug("===== Query Input to LLM =====")
  82. logger.debug(f"Query: {prompt}")
  83. verbose_debug(f"System prompt: {system_prompt}")
  84. # Deprecation shims: map legacy extraction booleans to response_format only
  85. # when an explicit response_format was not supplied by the caller. The
  86. # legacy path also forces enable_cot=False so reasoning_content cannot
  87. # corrupt the JSON payload expected by callers relying on it.
  88. keyword_extraction = kwargs.pop("keyword_extraction", False)
  89. entity_extraction = kwargs.pop("entity_extraction", False)
  90. if kwargs.get("response_format") is None:
  91. if entity_extraction:
  92. warnings.warn(
  93. "zhipu_complete_if_cache(entity_extraction=True) is deprecated; "
  94. "pass response_format={'type': 'json_object'} instead.",
  95. DeprecationWarning,
  96. stacklevel=2,
  97. )
  98. kwargs["response_format"] = {"type": "json_object"}
  99. enable_cot = False
  100. elif keyword_extraction:
  101. warnings.warn(
  102. "zhipu_complete_if_cache(keyword_extraction=True) is deprecated; "
  103. "pass response_format={'type': 'json_object'} instead.",
  104. DeprecationWarning,
  105. stacklevel=2,
  106. )
  107. kwargs["response_format"] = {"type": "json_object"}
  108. enable_cot = False
  109. # Structured output and COT are mutually exclusive here because
  110. # reasoning_content would corrupt the JSON payload expected by callers.
  111. if kwargs.get("response_format") is not None:
  112. enable_cot = False
  113. # Remove unsupported kwargs
  114. kwargs = {
  115. k: v
  116. for k, v in kwargs.items()
  117. if k not in ["hashing_kv", "keyword_extraction", "entity_extraction"]
  118. }
  119. # `thinking` is an official Zhipu request field. Example:
  120. # {"type": "enabled"} enables reasoning output on supported models.
  121. if thinking is not None:
  122. kwargs["thinking"] = thinking
  123. response = client.chat.completions.create(model=model, messages=messages, **kwargs)
  124. if not response.choices or response.choices[0].message is None:
  125. return ""
  126. message = response.choices[0].message
  127. content = message.content or ""
  128. reasoning_content = getattr(message, "reasoning_content", "") or ""
  129. if enable_cot and reasoning_content.strip():
  130. if content:
  131. return f"<think>{reasoning_content}</think>{content}"
  132. return f"<think>{reasoning_content}</think>"
  133. return content
  134. async def zhipu_complete(
  135. prompt,
  136. system_prompt=None,
  137. history_messages=[],
  138. keyword_extraction=False,
  139. entity_extraction=False,
  140. enable_cot: bool = False,
  141. **kwargs,
  142. ):
  143. """Zhipu completion wrapper with LightRAG structured-output shims.
  144. Structured output note:
  145. - This adapter accepts OpenAI-style ``response_format`` and forwards it to
  146. Zhipu's compatible chat-completions API.
  147. - Deprecated ``keyword_extraction`` and ``entity_extraction`` booleans are
  148. compatibility shims; when no explicit ``response_format`` is supplied,
  149. they are mapped to ``{"type": "json_object"}``.
  150. """
  151. # Pop legacy extraction flags from kwargs to avoid passing them downstream.
  152. keyword_extraction = kwargs.pop("keyword_extraction", keyword_extraction)
  153. entity_extraction = kwargs.pop("entity_extraction", entity_extraction)
  154. # Deprecation shims: map legacy boolean flags to response_format only when
  155. # an explicit response_format was not supplied by the caller. The legacy
  156. # path also forces enable_cot=False so that reasoning_content cannot
  157. # corrupt the JSON payload expected by callers that were relying on it.
  158. if kwargs.get("response_format") is None:
  159. if entity_extraction:
  160. warnings.warn(
  161. "zhipu_complete(entity_extraction=True) is deprecated; "
  162. "pass response_format={'type': 'json_object'} instead.",
  163. DeprecationWarning,
  164. stacklevel=2,
  165. )
  166. kwargs["response_format"] = {"type": "json_object"}
  167. enable_cot = False
  168. elif keyword_extraction:
  169. warnings.warn(
  170. "zhipu_complete(keyword_extraction=True) is deprecated; "
  171. "pass response_format={'type': 'json_object'} instead.",
  172. DeprecationWarning,
  173. stacklevel=2,
  174. )
  175. kwargs["response_format"] = {"type": "json_object"}
  176. enable_cot = False
  177. return await zhipu_complete_if_cache(
  178. prompt=prompt,
  179. system_prompt=system_prompt,
  180. history_messages=history_messages,
  181. enable_cot=enable_cot,
  182. **kwargs,
  183. )
  184. @wrap_embedding_func_with_attrs(
  185. embedding_dim=1024, max_token_size=8192, model_name="embedding-3"
  186. )
  187. @retry(
  188. stop=stop_after_attempt(3),
  189. wait=wait_exponential(multiplier=1, min=4, max=60),
  190. retry=retry_if_exception_type(
  191. (RateLimitError, APIConnectionError, APITimeoutError)
  192. ),
  193. )
  194. async def zhipu_embedding(
  195. texts: list[str],
  196. model: str = "embedding-3",
  197. api_key: str = None,
  198. embedding_dim: int | None = None,
  199. **kwargs,
  200. ) -> np.ndarray:
  201. # dynamically load ZhipuAI
  202. try:
  203. from zhipuai import ZhipuAI
  204. except ImportError:
  205. raise ImportError("Please install zhipuai before initialize zhipuai backend.")
  206. if api_key:
  207. client = ZhipuAI(api_key=api_key)
  208. else:
  209. # please set ZHIPUAI_API_KEY in your environment
  210. # os.environ["ZHIPUAI_API_KEY"]
  211. client = ZhipuAI()
  212. # Convert single text to list if needed
  213. if isinstance(texts, str):
  214. texts = [texts]
  215. embeddings = []
  216. for text in texts:
  217. try:
  218. request_kwargs = dict(kwargs)
  219. if embedding_dim is not None:
  220. request_kwargs["dimensions"] = embedding_dim
  221. response = client.embeddings.create(
  222. model=model, input=[text], **request_kwargs
  223. )
  224. embeddings.append(response.data[0].embedding)
  225. except Exception as e:
  226. raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
  227. return np.array(embeddings)