openai_util.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. """OpenAI 调用公共工具。"""
  2. import json
  3. import logging
  4. import uuid
  5. from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List
  6. import openai
  7. from pydantic import BaseModel, ValidationError
  8. from ..config import settings
  9. from ..utils.config_manager import config_manager
  10. from ..utils.errors import AppError
  11. from .prompts.json_repair_prompts import build_json_repair_messages
  12. logger = logging.getLogger(__name__)
  13. ProgressCallback = Callable[[str], Awaitable[None]]
  14. JsonValidator = Callable[[Dict[str, Any]], None]
  15. class OpenAIUtil:
  16. """封装 OpenAI SDK 调用、日志与 JSON 修复能力。"""
  17. def __init__(self):
  18. config = config_manager.load_config()
  19. self.api_key = config.get("api_key", "")
  20. self.base_url = config.get("base_url", "")
  21. self.model_name = config.get("model_name", "gpt-3.5-turbo")
  22. if not self.api_key:
  23. raise AppError("请先配置OpenAI API密钥", status_code=400)
  24. self.client = openai.AsyncOpenAI(
  25. api_key=self.api_key,
  26. base_url=self.base_url or None,
  27. )
  28. def _chat_endpoint_url(self) -> str:
  29. """获取聊天完成接口地址。"""
  30. base_url = (self.base_url or "https://api.openai.com/v1").rstrip("/")
  31. return f"{base_url}/chat/completions"
  32. def _log_ai_request(
  33. self,
  34. request_id: str,
  35. messages: list[dict[str, str]],
  36. temperature: float,
  37. response_format: dict | None,
  38. ) -> None:
  39. """记录 AI 请求日志。"""
  40. if not settings.enable_file_logging:
  41. return
  42. logger.debug(
  43. "AI_REQUEST %s",
  44. json.dumps(
  45. {
  46. "request_id": request_id,
  47. "url": self._chat_endpoint_url(),
  48. "model": self.model_name,
  49. "temperature": temperature,
  50. "response_format": response_format,
  51. "messages": messages,
  52. },
  53. ensure_ascii=False,
  54. ),
  55. )
  56. def _log_ai_response(self, request_id: str, content: str) -> None:
  57. """记录 AI 响应日志。"""
  58. if not settings.enable_file_logging:
  59. return
  60. logger.debug(
  61. "AI_RESPONSE %s",
  62. json.dumps(
  63. {
  64. "request_id": request_id,
  65. "url": self._chat_endpoint_url(),
  66. "model": self.model_name,
  67. "content": content,
  68. },
  69. ensure_ascii=False,
  70. ),
  71. )
  72. def _log_ai_raw_response(
  73. self,
  74. request_id: str,
  75. raw_chunks: list[dict[str, Any]],
  76. content: str,
  77. ) -> None:
  78. """记录 AI 接口原始响应日志。"""
  79. if not settings.enable_file_logging:
  80. return
  81. logger.debug(
  82. "AI_RAW_RESPONSE %s",
  83. json.dumps(
  84. {
  85. "request_id": request_id,
  86. "url": self._chat_endpoint_url(),
  87. "model": self.model_name,
  88. "raw_chunks": raw_chunks,
  89. "content": content,
  90. },
  91. ensure_ascii=False,
  92. default=str,
  93. ),
  94. )
  95. def _log_ai_error(
  96. self,
  97. request_id: str,
  98. messages: list[dict[str, str]],
  99. temperature: float,
  100. response_format: dict | None,
  101. partial_content: str,
  102. raw_chunks: list[dict[str, Any]],
  103. error: Exception,
  104. ) -> None:
  105. """记录 AI 异常日志。"""
  106. if not settings.enable_file_logging:
  107. return
  108. logger.debug(
  109. "AI_ERROR %s",
  110. json.dumps(
  111. {
  112. "request_id": request_id,
  113. "url": self._chat_endpoint_url(),
  114. "model": self.model_name,
  115. "temperature": temperature,
  116. "response_format": response_format,
  117. "messages": messages,
  118. "partial_content": partial_content,
  119. "raw_chunks": raw_chunks,
  120. "error": str(error),
  121. },
  122. ensure_ascii=False,
  123. default=str,
  124. ),
  125. )
  126. @staticmethod
  127. def _dump_chunk(chunk: Any) -> dict[str, Any]:
  128. """序列化 OpenAI SDK 返回的 chunk。"""
  129. if hasattr(chunk, "model_dump"):
  130. return chunk.model_dump(mode="json")
  131. return {"raw": str(chunk)}
  132. @staticmethod
  133. def _extract_json_content(content: str) -> str:
  134. """提取模型响应中的 JSON 内容,兼容 Markdown 代码块包裹。"""
  135. normalized = content.strip()
  136. if not normalized.startswith("```"):
  137. return normalized
  138. lines = normalized.splitlines()
  139. if not lines:
  140. return normalized
  141. first_line = lines[0].strip().lower()
  142. last_line = lines[-1].strip()
  143. if not last_line.startswith("```"):
  144. return normalized
  145. if first_line in {"```", "```json", "```javascript", "```js"}:
  146. return "\n".join(lines[1:-1]).strip()
  147. return normalized
  148. @staticmethod
  149. def _is_response_format_unsupported_error(message: str) -> bool:
  150. """判断当前错误是否表示模型不支持 response_format。"""
  151. normalized = message.lower()
  152. if "response_format" not in normalized:
  153. return False
  154. return any(
  155. marker in normalized
  156. for marker in (
  157. "not supported",
  158. "does not support",
  159. "not support",
  160. "unsupported",
  161. "unknown parameter",
  162. "invalid parameter",
  163. )
  164. )
  165. @staticmethod
  166. async def emit_progress(
  167. progress_callback: ProgressCallback | None,
  168. message: str,
  169. ) -> None:
  170. """发送进度消息。"""
  171. if progress_callback is None:
  172. return
  173. await progress_callback(message)
  174. async def get_available_models(self) -> List[str]:
  175. """获取可用模型列表。"""
  176. try:
  177. models = await self.client.models.list()
  178. except Exception as exc:
  179. raise AppError(f"获取模型列表失败: {exc}", status_code=502) from exc
  180. chat_models: list[str] = []
  181. for model in models.data:
  182. model_id = model.id.lower()
  183. if any(
  184. keyword in model_id
  185. for keyword in ["gpt", "claude", "chat", "llama", "qwen", "deepseek"]
  186. ):
  187. chat_models.append(model.id)
  188. return sorted(set(chat_models))
  189. async def stream_chat_completion(
  190. self,
  191. messages: list[dict[str, str]],
  192. temperature: float = 0.7,
  193. response_format: dict | None = None,
  194. ) -> AsyncGenerator[str, None]:
  195. """流式调用聊天完成接口。"""
  196. request_id = uuid.uuid4().hex
  197. parts: list[str] = []
  198. raw_chunks: list[dict[str, Any]] = []
  199. self._log_ai_request(request_id, messages, temperature, response_format)
  200. try:
  201. stream = await self.client.chat.completions.create(
  202. model=self.model_name,
  203. messages=messages,
  204. temperature=temperature,
  205. stream=True,
  206. **(
  207. {"response_format": response_format}
  208. if response_format is not None
  209. else {}
  210. ),
  211. )
  212. except Exception as exc:
  213. self._log_ai_error(
  214. request_id,
  215. messages,
  216. temperature,
  217. response_format,
  218. "",
  219. raw_chunks,
  220. exc,
  221. )
  222. raise AppError(f"模型调用失败: {exc}", status_code=502) from exc
  223. try:
  224. async for chunk in stream:
  225. raw_chunks.append(self._dump_chunk(chunk))
  226. if not chunk.choices:
  227. continue
  228. content = chunk.choices[0].delta.content
  229. if content is not None:
  230. parts.append(content)
  231. yield content
  232. except Exception as exc:
  233. self._log_ai_error(
  234. request_id,
  235. messages,
  236. temperature,
  237. response_format,
  238. "".join(parts),
  239. raw_chunks,
  240. exc,
  241. )
  242. raise AppError(f"模型调用失败: {exc}", status_code=502) from exc
  243. self._log_ai_response(request_id, "".join(parts))
  244. self._log_ai_raw_response(request_id, raw_chunks, "".join(parts))
  245. async def collect_chat_completion(
  246. self,
  247. messages: list[dict[str, str]],
  248. temperature: float = 0.7,
  249. response_format: dict | None = None,
  250. ) -> str:
  251. """收集流式输出并拼接为完整文本。"""
  252. parts: list[str] = []
  253. async for chunk in self.stream_chat_completion(
  254. messages,
  255. temperature=temperature,
  256. response_format=response_format,
  257. ):
  258. parts.append(chunk)
  259. return "".join(parts)
  260. async def _collect_chat_completion_with_json_mode_fallback(
  261. self,
  262. messages: list[dict[str, str]],
  263. temperature: float,
  264. use_response_format: bool,
  265. progress_callback: ProgressCallback | None = None,
  266. ) -> tuple[str, bool]:
  267. """优先使用 JSON 模式请求,不支持时自动降级为普通请求。"""
  268. try:
  269. content = await self.collect_chat_completion(
  270. messages,
  271. temperature=temperature,
  272. response_format={"type": "json_object"}
  273. if use_response_format
  274. else None,
  275. )
  276. return content, use_response_format
  277. except AppError as exc:
  278. if (
  279. not use_response_format
  280. or not self._is_response_format_unsupported_error(exc.message)
  281. ):
  282. raise
  283. await self.emit_progress(
  284. progress_callback,
  285. "当前模型不支持结构化 JSON 响应,已降级为普通请求解析。",
  286. )
  287. content = await self.collect_chat_completion(
  288. messages,
  289. temperature=temperature,
  290. response_format=None,
  291. )
  292. return content, False
  293. @staticmethod
  294. def _normalize_json_response(
  295. content: str,
  296. schema: type[BaseModel] | None = None,
  297. validator: JsonValidator | None = None,
  298. ) -> Dict[str, Any]:
  299. """解析、校验并标准化 JSON 响应。"""
  300. json_content = OpenAIUtil._extract_json_content(content)
  301. parsed = json.loads(json_content)
  302. if schema is None:
  303. normalized = parsed
  304. else:
  305. validated = schema.model_validate(parsed)
  306. normalized = validated.model_dump(exclude_none=True)
  307. if validator is not None:
  308. validator(normalized)
  309. return normalized
  310. @staticmethod
  311. def _format_json_issues(error: Exception) -> list[str]:
  312. """格式化 JSON 解析或校验问题。"""
  313. if isinstance(error, json.JSONDecodeError):
  314. return [
  315. f"JSON 语法错误:第 {error.lineno} 行第 {error.colno} 列附近 {error.msg}。"
  316. ]
  317. if isinstance(error, ValidationError):
  318. issues: list[str] = []
  319. for item in error.errors():
  320. location = ".".join(str(part) for part in item.get("loc", [])) or "root"
  321. message = item.get("msg", "字段校验失败")
  322. issues.append(f"{location}: {message}")
  323. return issues or [str(error)]
  324. return [str(error)]
  325. async def _repair_json_response(
  326. self,
  327. invalid_content: str,
  328. issues: list[str],
  329. temperature: float,
  330. use_response_format: bool,
  331. progress_callback: ProgressCallback | None,
  332. progress_label: str,
  333. ) -> tuple[str, bool]:
  334. """基于当前结果发起一次定向 JSON 修复。"""
  335. await self.emit_progress(
  336. progress_callback,
  337. f"{progress_label}格式校验失败,正在基于当前结果进行修复。",
  338. )
  339. repair_messages = build_json_repair_messages(
  340. invalid_content=invalid_content,
  341. issues=issues,
  342. target_description=progress_label,
  343. )
  344. return await self._collect_chat_completion_with_json_mode_fallback(
  345. messages=repair_messages,
  346. temperature=temperature,
  347. use_response_format=use_response_format,
  348. progress_callback=progress_callback,
  349. )
  350. async def collect_json_response(
  351. self,
  352. messages: list[dict[str, str]],
  353. temperature: float = 0.7,
  354. schema: type[BaseModel] | None = None,
  355. validator: JsonValidator | None = None,
  356. progress_callback: ProgressCallback | None = None,
  357. progress_label: str = "JSON结果",
  358. failure_message: str = "模型返回的 JSON 数据格式无效",
  359. ) -> Dict[str, Any]:
  360. """收集并校验 JSON 响应。"""
  361. max_retries = 2
  362. total_attempts = max_retries + 1
  363. use_response_format = True
  364. for attempt in range(total_attempts):
  365. try:
  366. (
  367. content,
  368. use_response_format,
  369. ) = await self._collect_chat_completion_with_json_mode_fallback(
  370. messages=messages,
  371. temperature=temperature,
  372. use_response_format=use_response_format,
  373. progress_callback=progress_callback,
  374. )
  375. normalized = self._normalize_json_response(
  376. content,
  377. schema=schema,
  378. validator=validator,
  379. )
  380. return normalized
  381. except (json.JSONDecodeError, ValidationError, ValueError) as exc:
  382. issues = self._format_json_issues(exc)
  383. logger.warning(
  384. "模型返回非法 JSON,第 %s/%s 次尝试: %s;问题: %s",
  385. attempt + 1,
  386. total_attempts,
  387. content,
  388. " | ".join(issues),
  389. )
  390. try:
  391. (
  392. repaired_content,
  393. use_response_format,
  394. ) = await self._repair_json_response(
  395. invalid_content=content,
  396. issues=issues,
  397. temperature=temperature,
  398. use_response_format=use_response_format,
  399. progress_callback=progress_callback,
  400. progress_label=progress_label,
  401. )
  402. normalized = self._normalize_json_response(
  403. repaired_content,
  404. schema=schema,
  405. validator=validator,
  406. )
  407. return normalized
  408. except AppError as repair_error:
  409. logger.warning(
  410. "JSON 修复请求失败,第 %s/%s 次尝试: %s",
  411. attempt + 1,
  412. total_attempts,
  413. repair_error.message,
  414. )
  415. exc = repair_error
  416. except (
  417. json.JSONDecodeError,
  418. ValidationError,
  419. ValueError,
  420. ) as repair_error:
  421. logger.warning(
  422. "JSON 修复后仍校验失败,第 %s/%s 次尝试: %s;问题: %s",
  423. attempt + 1,
  424. total_attempts,
  425. repaired_content,
  426. " | ".join(self._format_json_issues(repair_error)),
  427. )
  428. exc = repair_error
  429. if attempt == max_retries:
  430. await self.emit_progress(
  431. progress_callback,
  432. f"{progress_label}连续 {total_attempts} 次校验失败。",
  433. )
  434. raise AppError(failure_message, status_code=502) from exc
  435. await self.emit_progress(
  436. progress_callback,
  437. f"{progress_label}第 {attempt + 1}/{total_attempts} 次校验失败,正在重试。",
  438. )
  439. raise AppError(failure_message, status_code=502)