utils.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791
  1. #!/usr/bin/env python3
  2. """
  3. ABOUTME: Shared token estimation utilities for audit scripts
  4. ABOUTME: XML sanitization helpers for document processing
  5. """
  6. import json
  7. import os
  8. import re
  9. try:
  10. from google import genai
  11. from google.genai import types
  12. HAS_GEMINI = True
  13. except ImportError: # pragma: no cover - optional dependency
  14. genai = None
  15. types = None
  16. HAS_GEMINI = False
  17. try:
  18. import openai
  19. HAS_OPENAI = True
  20. except ImportError: # pragma: no cover - optional dependency
  21. openai = None
  22. HAS_OPENAI = False
  23. def estimate_tokens(text: str) -> int:
  24. """
  25. Estimate token count for LLM context management.
  26. Uses a weighted formula based on character types:
  27. - Chinese characters: ~0.75 tokens per character (subword tokenization)
  28. - JSON structural characters (brackets, quotes, commas): ~1 tokens per character
  29. - Other characters (English, numbers, symbols): ~0.4 tokens per character (~3 chars/token)
  30. Includes 5% buffer and safety offset for special formatting and system prompt overhead.
  31. Args:
  32. text: Input text to estimate tokens for
  33. Returns:
  34. int: Estimated token count
  35. """
  36. if not text:
  37. return 0
  38. chinese_count = len(re.findall(r"[\u4e00-\u9fa5]", text))
  39. json_chars_count = len(re.findall(r'[\[\]",{}]', text))
  40. other_count = len(text) - chinese_count - json_chars_count
  41. base_estimate = (
  42. (chinese_count * 0.75) + (json_chars_count * 1) + (other_count * 0.4)
  43. )
  44. final_tokens = int(base_estimate * 1.05) + 2
  45. return final_tokens
  46. def sanitize_xml_string(text: str) -> str:
  47. """
  48. Remove control characters that are illegal in XML 1.0.
  49. XML 1.0 allows: #x9 (tab), #xA (LF), #xD (CR), and #x20-#xD7FF, #xE000-#xFFFD, #x10000-#x10FFFF
  50. This function removes all other control characters (0x00-0x08, 0x0B, 0x0C, 0x0E-0x1F).
  51. Args:
  52. text: Text that may contain control characters
  53. Returns:
  54. Sanitized text safe for XML. Returns input unchanged if not a non-empty string.
  55. """
  56. if not text or not isinstance(text, str):
  57. return text
  58. # Build a translation table to remove illegal control characters
  59. # Keep: \t (0x09), \n (0x0A), \r (0x0D)
  60. # Remove: 0x00-0x08, 0x0B, 0x0C, 0x0E-0x1F
  61. illegal_chars = "".join(chr(c) for c in range(0x20) if c not in (0x09, 0x0A, 0x0D))
  62. return text.translate(str.maketrans("", "", illegal_chars))
  63. def is_vertex_ai_mode() -> bool:
  64. """
  65. Check if Vertex AI mode is enabled via environment variable.
  66. Returns:
  67. True if GOOGLE_GENAI_USE_VERTEXAI is set to 'true', False otherwise
  68. """
  69. return os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
  70. def create_gemini_client(use_async: bool = False):
  71. """
  72. Create Gemini client for AI Studio or Vertex AI.
  73. Supports two modes:
  74. - AI Studio (default): Uses GOOGLE_API_KEY for authentication
  75. - Vertex AI: Uses ADC (GOOGLE_APPLICATION_CREDENTIALS or gcloud auth)
  76. Environment variables for Vertex AI mode:
  77. - GOOGLE_GENAI_USE_VERTEXAI: Set to 'true' to enable Vertex AI mode
  78. - GOOGLE_CLOUD_PROJECT: Required GCP project ID
  79. - GOOGLE_CLOUD_LOCATION: Optional region (default: us-central1)
  80. - GOOGLE_VERTEX_BASE_URL: Optional custom API endpoint (for API gateway proxies)
  81. - GOOGLE_APPLICATION_CREDENTIALS: Path to service account JSON (or use gcloud auth)
  82. Args:
  83. use_async: If True, return the async client (.aio), otherwise return sync client
  84. Returns:
  85. Gemini client instance (sync or async based on use_async parameter)
  86. Raises:
  87. ValueError: If required environment variables are not set
  88. """
  89. use_vertex = is_vertex_ai_mode()
  90. if use_vertex:
  91. # Vertex AI mode - uses ADC (GOOGLE_APPLICATION_CREDENTIALS or gcloud auth)
  92. project = os.getenv("GOOGLE_CLOUD_PROJECT")
  93. location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1")
  94. base_url = os.getenv("GOOGLE_VERTEX_BASE_URL")
  95. if not project:
  96. raise ValueError(
  97. "GOOGLE_CLOUD_PROJECT is required for Vertex AI mode. "
  98. "Set GOOGLE_GENAI_USE_VERTEXAI=false to use AI Studio mode instead."
  99. )
  100. # Build http_options only if custom base_url is specified
  101. http_options = None
  102. if base_url:
  103. http_options = {"base_url": base_url}
  104. # Note: ADC handles authentication automatically
  105. # via GOOGLE_APPLICATION_CREDENTIALS env var or gcloud auth
  106. client = genai.Client(
  107. vertexai=True, project=project, location=location, http_options=http_options
  108. )
  109. else:
  110. # AI Studio mode - requires API key
  111. api_key = os.getenv("GOOGLE_API_KEY")
  112. if not api_key:
  113. raise ValueError(
  114. "GOOGLE_API_KEY is required for AI Studio mode. "
  115. "Set GOOGLE_GENAI_USE_VERTEXAI=true and configure GCP credentials for Vertex AI mode."
  116. )
  117. client = genai.Client(api_key=api_key)
  118. # Return async or sync client based on parameter
  119. return client.aio if use_async else client
  120. def get_gemini_provider_name() -> str:
  121. """
  122. Get the Gemini provider name based on current mode.
  123. Returns:
  124. Provider name string for display purposes
  125. """
  126. if is_vertex_ai_mode():
  127. project = os.getenv("GOOGLE_CLOUD_PROJECT", "unknown")
  128. location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1")
  129. return f"Google Gemini (Vertex AI: {project}/{location})"
  130. return "Google Gemini (AI Studio)"
  131. def create_openai_client(use_async: bool = True):
  132. """
  133. Create OpenAI client with optional custom base URL.
  134. Environment variables:
  135. - OPENAI_API_KEY: Required API key
  136. - OPENAI_BASE_URL: Optional custom API endpoint (for proxies, Azure, etc.)
  137. Args:
  138. use_async: If True, return AsyncOpenAI, otherwise return OpenAI
  139. Returns:
  140. OpenAI client instance (async or sync based on use_async parameter)
  141. Raises:
  142. ValueError: If OPENAI_API_KEY is not set
  143. """
  144. if not HAS_OPENAI:
  145. raise ValueError("openai library is not installed.")
  146. api_key = os.getenv("OPENAI_API_KEY")
  147. if not api_key:
  148. raise ValueError("OPENAI_API_KEY is required for OpenAI mode.")
  149. base_url = os.getenv("OPENAI_BASE_URL")
  150. if use_async:
  151. return openai.AsyncOpenAI(base_url=base_url)
  152. return openai.OpenAI(base_url=base_url)
  153. def get_openai_provider_name() -> str:
  154. """
  155. Get the OpenAI provider name, including custom endpoint if configured.
  156. Returns:
  157. Provider name string for display purposes
  158. """
  159. base_url = os.getenv("OPENAI_BASE_URL")
  160. if base_url:
  161. return f"OpenAI (Custom: {base_url})"
  162. return "OpenAI"
  163. def is_openai_reasoning_model(model_name: str) -> bool:
  164. """
  165. Check if the OpenAI model supports reasoning_effort parameter.
  166. Models that support reasoning_effort:
  167. - o-series: o1, o3, o4 and their variants (o1-mini, o1-2024-12-17, etc.)
  168. - gpt-5 series: gpt-5, gpt-5.2, gpt-5-turbo, etc.
  169. Non-reasoning models like gpt-4.1, gpt-4o, etc. will reject this parameter.
  170. Handles proxy/router prefixes like "openai/o1-mini" or "openrouter/gpt-5.2".
  171. Args:
  172. model_name: The OpenAI model name (may include path prefix)
  173. Returns:
  174. True if the model supports reasoning_effort, False otherwise
  175. """
  176. model_lower = model_name.lower()
  177. # Handle proxy/router prefixes like "openai/o1-mini", "openrouter/gpt-5.2"
  178. # Extract the base model name after the last "/"
  179. if "/" in model_lower:
  180. model_lower = model_lower.rsplit("/", 1)[-1]
  181. # Match o-series and gpt-5 series
  182. return model_lower.startswith(("o1", "o3", "o4", "gpt-5"))
  183. def is_openai_retryable(error: Exception) -> bool:
  184. """
  185. Determine if an OpenAI error should be retried.
  186. Non-retryable errors:
  187. - AuthenticationError (401): Invalid API key
  188. - PermissionDeniedError (403): No access to resource
  189. - BadRequestError (400): Invalid request format
  190. - NotFoundError (404): Model or resource not found
  191. Retryable errors:
  192. - RateLimitError (429): Rate limit exceeded
  193. - APIConnectionError: Network issues
  194. - InternalServerError (500): Server errors
  195. - APIStatusError with 502, 503, 504: Gateway/service errors
  196. Args:
  197. error: The exception from OpenAI API call
  198. Returns:
  199. True if the error should be retried, False otherwise
  200. """
  201. if not HAS_OPENAI:
  202. return True
  203. # Authentication error - invalid API key (401)
  204. if isinstance(error, openai.AuthenticationError):
  205. return False
  206. # Permission denied - no access to resource (403)
  207. if isinstance(error, openai.PermissionDeniedError):
  208. return False
  209. # Bad request - invalid request format (400)
  210. if isinstance(error, openai.BadRequestError):
  211. return False
  212. # Not found - model or resource doesn't exist (404)
  213. if isinstance(error, openai.NotFoundError):
  214. return False
  215. # Rate limit exceeded - should retry with backoff (429)
  216. if isinstance(error, openai.RateLimitError):
  217. return True
  218. # API connection error - network issues, should retry
  219. if isinstance(error, openai.APIConnectionError):
  220. return True
  221. # Internal server error - should retry (500)
  222. if isinstance(error, openai.InternalServerError):
  223. return True
  224. # For other APIStatusError, check HTTP status code
  225. if isinstance(error, openai.APIStatusError):
  226. # Retryable server-side errors
  227. return error.status_code in (429, 500, 502, 503, 504)
  228. # For unknown errors, default to retry (network issues, timeouts, etc.)
  229. return True
  230. def is_gemini_retryable(error: Exception) -> bool:
  231. """
  232. Determine if a Gemini error should be retried.
  233. Uses string matching on error messages since google-genai may not have
  234. well-defined exception types for all error cases.
  235. Non-retryable errors:
  236. - API key errors
  237. - Authentication/permission errors
  238. - Invalid request errors
  239. - Model not found errors
  240. - Billing/quota permanently exceeded
  241. Retryable errors:
  242. - Rate limit (429)
  243. - Server errors (500, 502, 503, 504)
  244. - Timeout/connection errors
  245. Args:
  246. error: The exception from Gemini API call
  247. Returns:
  248. True if the error should be retried, False otherwise
  249. """
  250. error_str = str(error).lower()
  251. # API key / authentication errors - do not retry
  252. if "api_key" in error_str or "api key" in error_str:
  253. return False
  254. if "authentication" in error_str or "authenticate" in error_str:
  255. return False
  256. if "invalid_api_key" in error_str or "invalid api key" in error_str:
  257. return False
  258. # Permission / forbidden errors - do not retry
  259. if "permission" in error_str and "denied" in error_str:
  260. return False
  261. if "forbidden" in error_str or "403" in error_str:
  262. return False
  263. # Invalid request errors - do not retry
  264. if "invalid" in error_str and ("request" in error_str or "argument" in error_str):
  265. return False
  266. if "400" in error_str and "bad request" in error_str:
  267. return False
  268. # Model not found - do not retry
  269. if "model" in error_str and ("not found" in error_str or "not exist" in error_str):
  270. return False
  271. if "404" in error_str:
  272. return False
  273. # Billing / permanent quota errors - do not retry
  274. if "billing" in error_str:
  275. return False
  276. if "quota" in error_str and ("exceeded" in error_str or "exhausted" in error_str):
  277. # Check if it mentions billing which indicates permanent quota issue
  278. if "billing" in error_str or "payment" in error_str:
  279. return False
  280. # Temporary quota (rate limit) - should retry
  281. return True
  282. # Rate limit errors - should retry (429)
  283. if "rate" in error_str and "limit" in error_str:
  284. return True
  285. if "429" in error_str or "resource_exhausted" in error_str:
  286. return True
  287. # Server errors - should retry (500, 502, 503, 504)
  288. if any(code in error_str for code in ["500", "502", "503", "504"]):
  289. return True
  290. if "internal" in error_str and ("error" in error_str or "server" in error_str):
  291. return True
  292. if "service" in error_str and "unavailable" in error_str:
  293. return True
  294. if "gateway" in error_str:
  295. return True
  296. # Timeout / connection errors - should retry
  297. if "timeout" in error_str or "timed out" in error_str:
  298. return True
  299. if "connection" in error_str:
  300. return True
  301. if "network" in error_str:
  302. return True
  303. # Unknown errors - default to retry with limited attempts
  304. return True
  305. # JSON Schema for LLM structured output
  306. AUDIT_RESULT_SCHEMA = {
  307. "type": "object",
  308. "additionalProperties": False,
  309. "properties": {
  310. "is_violation": {
  311. "type": "boolean",
  312. "description": "Whether any violations were found",
  313. },
  314. "violations": {
  315. "type": "array",
  316. "description": "List of violations found",
  317. "items": {
  318. "type": "object",
  319. "additionalProperties": False,
  320. "properties": {
  321. "rule_id": {
  322. "type": "string",
  323. "description": "ID of the violated rule (e.g., R001)",
  324. },
  325. "violation_text": {
  326. "type": "string",
  327. "description": "The problematic text directly verbatim quote from the source content, and not span multiple cells",
  328. },
  329. "violation_reason": {
  330. "type": "string",
  331. "description": "Explanation of why this violates the rule",
  332. },
  333. "fix_action": {
  334. "type": "string",
  335. "enum": ["replace", "manual"],
  336. "description": "Action type: replace substitutes text (including deletion-via-replace), manual requires human review",
  337. },
  338. "revised_text": {
  339. "type": "string",
  340. "description": "For replace: complete replacement text (including deletion-via-replace). For manual: additional guidance for human reviewer",
  341. },
  342. },
  343. "required": [
  344. "rule_id",
  345. "violation_text",
  346. "violation_reason",
  347. "fix_action",
  348. "revised_text",
  349. ],
  350. },
  351. },
  352. },
  353. "required": ["is_violation", "violations"],
  354. }
  355. # JSON Schema for global extraction output
  356. GLOBAL_EXTRACT_SCHEMA = {
  357. "type": "object",
  358. "additionalProperties": False,
  359. "properties": {
  360. "results": {
  361. "type": "array",
  362. "items": {
  363. "type": "object",
  364. "additionalProperties": False,
  365. "properties": {
  366. "rule_id": {"type": "string"},
  367. "extracted_results": {
  368. "type": "array",
  369. "items": {
  370. "type": "object",
  371. "additionalProperties": False,
  372. "properties": {
  373. "entity": {"type": "string"},
  374. "fields": {
  375. "type": "array",
  376. "items": {
  377. "type": "object",
  378. "additionalProperties": False,
  379. "properties": {
  380. "name": {"type": "string"},
  381. "value": {"type": "string"},
  382. "evidence": {"type": "string"},
  383. },
  384. "required": ["name", "value", "evidence"],
  385. },
  386. },
  387. },
  388. "required": ["entity", "fields"],
  389. },
  390. },
  391. },
  392. "required": ["rule_id", "extracted_results"],
  393. },
  394. }
  395. },
  396. "required": ["results"],
  397. }
  398. # JSON Schema for global verification output
  399. GLOBAL_VERIFY_SCHEMA = {
  400. "type": "object",
  401. "additionalProperties": False,
  402. "properties": {
  403. "violations": {
  404. "type": "array",
  405. "items": {
  406. "type": "object",
  407. "additionalProperties": False,
  408. "properties": {
  409. "rule_id": {"type": "string"},
  410. "uuid": {"type": "string"},
  411. "uuid_end": {"type": "string"},
  412. "violation_text": {"type": "string"},
  413. "violation_reason": {"type": "string"},
  414. "fix_action": {"type": "string", "enum": ["replace", "manual"]},
  415. "revised_text": {"type": "string"},
  416. },
  417. "required": [
  418. "rule_id",
  419. "uuid",
  420. "uuid_end",
  421. "violation_text",
  422. "violation_reason",
  423. "fix_action",
  424. "revised_text",
  425. ],
  426. },
  427. }
  428. },
  429. "required": ["violations"],
  430. }
  431. async def global_extract_gemini_async(
  432. user_prompt: str,
  433. system_prompt: str,
  434. model_name: str,
  435. client,
  436. thinking_level: str = None,
  437. thinking_budget: int = None,
  438. ) -> dict:
  439. thinking_config = None
  440. if thinking_level and thinking_level.upper() in (
  441. "MINIMAL",
  442. "LOW",
  443. "MEDIUM",
  444. "HIGH",
  445. ):
  446. level_map = {
  447. "MINIMAL": types.ThinkingLevel.MINIMAL,
  448. "LOW": types.ThinkingLevel.LOW,
  449. "MEDIUM": types.ThinkingLevel.MEDIUM,
  450. "HIGH": types.ThinkingLevel.HIGH,
  451. }
  452. thinking_config = types.ThinkingConfig(
  453. thinking_level=level_map[thinking_level.upper()]
  454. )
  455. elif thinking_budget is not None:
  456. thinking_config = types.ThinkingConfig(thinking_budget=int(thinking_budget))
  457. config_params = {
  458. "system_instruction": system_prompt,
  459. "response_mime_type": "application/json",
  460. "response_schema": GLOBAL_EXTRACT_SCHEMA,
  461. }
  462. if thinking_config:
  463. config_params["thinking_config"] = thinking_config
  464. response = await client.models.generate_content(
  465. model=model_name,
  466. contents=user_prompt,
  467. config=types.GenerateContentConfig(**config_params),
  468. )
  469. return json.loads(response.text)
  470. async def global_extract_openai_async(
  471. user_prompt: str,
  472. system_prompt: str,
  473. model_name: str,
  474. client,
  475. reasoning_effort: str = None,
  476. ) -> dict:
  477. request_params = {
  478. "model": model_name,
  479. "messages": [
  480. {"role": "system", "content": system_prompt},
  481. {"role": "user", "content": user_prompt},
  482. ],
  483. "response_format": {
  484. "type": "json_schema",
  485. "json_schema": {
  486. "name": "global_extract",
  487. "strict": True,
  488. "schema": GLOBAL_EXTRACT_SCHEMA,
  489. },
  490. },
  491. }
  492. if (
  493. reasoning_effort
  494. and reasoning_effort.lower() in ("low", "medium", "high")
  495. and is_openai_reasoning_model(model_name)
  496. ):
  497. request_params["reasoning_effort"] = reasoning_effort.lower()
  498. response = await client.chat.completions.create(**request_params)
  499. return json.loads(response.choices[0].message.content)
  500. async def global_verify_gemini_async(
  501. user_prompt: str,
  502. system_prompt: str,
  503. model_name: str,
  504. client,
  505. thinking_level: str = None,
  506. thinking_budget: int = None,
  507. ) -> dict:
  508. thinking_config = None
  509. if thinking_level and thinking_level.upper() in (
  510. "MINIMAL",
  511. "LOW",
  512. "MEDIUM",
  513. "HIGH",
  514. ):
  515. level_map = {
  516. "MINIMAL": types.ThinkingLevel.MINIMAL,
  517. "LOW": types.ThinkingLevel.LOW,
  518. "MEDIUM": types.ThinkingLevel.MEDIUM,
  519. "HIGH": types.ThinkingLevel.HIGH,
  520. }
  521. thinking_config = types.ThinkingConfig(
  522. thinking_level=level_map[thinking_level.upper()]
  523. )
  524. elif thinking_budget is not None:
  525. thinking_config = types.ThinkingConfig(thinking_budget=int(thinking_budget))
  526. config_params = {
  527. "system_instruction": system_prompt,
  528. "response_mime_type": "application/json",
  529. "response_schema": GLOBAL_VERIFY_SCHEMA,
  530. }
  531. if thinking_config:
  532. config_params["thinking_config"] = thinking_config
  533. response = await client.models.generate_content(
  534. model=model_name,
  535. contents=user_prompt,
  536. config=types.GenerateContentConfig(**config_params),
  537. )
  538. return json.loads(response.text)
  539. async def global_verify_openai_async(
  540. user_prompt: str,
  541. system_prompt: str,
  542. model_name: str,
  543. client,
  544. reasoning_effort: str = None,
  545. ) -> dict:
  546. request_params = {
  547. "model": model_name,
  548. "messages": [
  549. {"role": "system", "content": system_prompt},
  550. {"role": "user", "content": user_prompt},
  551. ],
  552. "response_format": {
  553. "type": "json_schema",
  554. "json_schema": {
  555. "name": "global_verify",
  556. "strict": True,
  557. "schema": GLOBAL_VERIFY_SCHEMA,
  558. },
  559. },
  560. }
  561. if (
  562. reasoning_effort
  563. and reasoning_effort.lower() in ("low", "medium", "high")
  564. and is_openai_reasoning_model(model_name)
  565. ):
  566. request_params["reasoning_effort"] = reasoning_effort.lower()
  567. response = await client.chat.completions.create(**request_params)
  568. return json.loads(response.choices[0].message.content)
  569. async def audit_block_gemini_async(
  570. user_prompt: str,
  571. system_prompt: str,
  572. model_name: str,
  573. client,
  574. thinking_level: str = None,
  575. thinking_budget: int = None,
  576. ) -> dict:
  577. """
  578. Audit a text block using Google Gemini with strict JSON mode (async version).
  579. Args:
  580. user_prompt: User prompt to audit
  581. system_prompt: Cached system prompt with rules and instructions
  582. model_name: Gemini model to use
  583. client: Gemini async client instance (client.aio)
  584. thinking_level: Thinking level for Gemini 3 models (MINIMAL, LOW, MEDIUM, HIGH)
  585. thinking_budget: Thinking token budget for Gemini 2.5 models (integer)
  586. Returns:
  587. Audit result dictionary
  588. """
  589. # Build thinking config based on model and parameters
  590. thinking_config = None
  591. if thinking_level and thinking_level.upper() in (
  592. "MINIMAL",
  593. "LOW",
  594. "MEDIUM",
  595. "HIGH",
  596. ):
  597. # For Gemini 3 models
  598. level_map = {
  599. "MINIMAL": types.ThinkingLevel.MINIMAL,
  600. "LOW": types.ThinkingLevel.LOW,
  601. "MEDIUM": types.ThinkingLevel.MEDIUM,
  602. "HIGH": types.ThinkingLevel.HIGH,
  603. }
  604. thinking_config = types.ThinkingConfig(
  605. thinking_level=level_map[thinking_level.upper()]
  606. )
  607. elif thinking_budget is not None:
  608. # For Gemini 2.5 models
  609. thinking_config = types.ThinkingConfig(thinking_budget=int(thinking_budget))
  610. config_params = {
  611. "system_instruction": system_prompt,
  612. "response_mime_type": "application/json",
  613. "response_schema": AUDIT_RESULT_SCHEMA,
  614. }
  615. # Only add thinking_config if it's configured
  616. if thinking_config:
  617. config_params["thinking_config"] = thinking_config
  618. response = await client.models.generate_content(
  619. model=model_name,
  620. contents=user_prompt,
  621. config=types.GenerateContentConfig(**config_params),
  622. )
  623. # With structured output, response is guaranteed to be valid JSON
  624. result = json.loads(response.text)
  625. return result
  626. async def audit_block_openai_async(
  627. user_prompt: str,
  628. system_prompt: str,
  629. model_name: str,
  630. client,
  631. reasoning_effort: str = None,
  632. ) -> dict:
  633. """
  634. Audit a text block using OpenAI with strict JSON mode (async version).
  635. Args:
  636. user_prompt: User prompt to audit
  637. system_prompt: Cached system prompt with rules and instructions
  638. model_name: OpenAI model to use
  639. client: AsyncOpenAI client instance
  640. reasoning_effort: Reasoning effort for o-series models (low, medium, high)
  641. Returns:
  642. Audit result dictionary
  643. """
  644. request_params = {
  645. "model": model_name,
  646. "messages": [
  647. {"role": "system", "content": system_prompt},
  648. {"role": "user", "content": user_prompt},
  649. ],
  650. "response_format": {
  651. "type": "json_schema",
  652. "json_schema": {
  653. "name": "audit_result",
  654. "strict": True,
  655. "schema": AUDIT_RESULT_SCHEMA,
  656. },
  657. },
  658. }
  659. # Add reasoning_effort only for o-series models that support it
  660. if (
  661. reasoning_effort
  662. and reasoning_effort.lower() in ("low", "medium", "high")
  663. and is_openai_reasoning_model(model_name)
  664. ):
  665. request_params["reasoning_effort"] = reasoning_effort.lower()
  666. response = await client.chat.completions.create(**request_params)
  667. # With structured output, response is guaranteed to be valid JSON
  668. result = json.loads(response.choices[0].message.content)
  669. return result