deterministic_model.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. from __future__ import annotations
  2. import json
  3. import re
  4. import time
  5. import uuid
  6. from collections.abc import AsyncIterator
  7. from typing import Any
  8. from agents import Tool
  9. from agents.agent_output import AgentOutputSchemaBase
  10. from agents.handoffs import Handoff
  11. from agents.items import ModelResponse, TResponseInputItem, TResponseStreamEvent
  12. from agents.model_settings import ModelSettings
  13. from agents.models.interface import Model, ModelTracing
  14. from agents.usage import Usage
  15. from openai.types.responses import (
  16. Response,
  17. ResponseCompletedEvent,
  18. ResponseContentPartAddedEvent,
  19. ResponseContentPartDoneEvent,
  20. ResponseCreatedEvent,
  21. ResponseFunctionToolCall,
  22. ResponseOutputItemAddedEvent,
  23. ResponseOutputItemDoneEvent,
  24. ResponseOutputMessage,
  25. ResponseOutputText,
  26. ResponseTextDeltaEvent,
  27. ResponseTextDoneEvent,
  28. ResponseUsage,
  29. )
  30. from openai.types.responses.response_prompt_param import ResponsePromptParam
  31. from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
  32. _STORE_RE = re.compile(r"store\s+(?P<key>\w+)\s+with\s+value\s+(?P<value>\w+)", re.IGNORECASE)
  33. _GET_RE = re.compile(r"value\s+for\s+(?P<key>\w+)", re.IGNORECASE)
  34. _MESSAGE_RE = re.compile(r"message:\s*(?P<message>.+)$", re.IGNORECASE)
  35. _SECRET_RE = re.compile(r"secret code:\s*(?P<secret>[\w-]+)", re.IGNORECASE)
  36. _HANDLE_RE = re.compile(r"handle\s+(?P<task>[\w-]+)", re.IGNORECASE)
  37. _TASK_RE = re.compile(r"task\s+(?P<task>[\w-]+)", re.IGNORECASE)
  38. def _extract_text_from_content(content: Any) -> str | None:
  39. if isinstance(content, str):
  40. return content
  41. if isinstance(content, list):
  42. parts: list[str] = []
  43. for part in content:
  44. if not isinstance(part, dict):
  45. continue
  46. text_value = part.get("text")
  47. if isinstance(text_value, str):
  48. parts.append(text_value)
  49. if parts:
  50. return "".join(parts)
  51. return None
  52. def _extract_last_user_text(items: str | list[TResponseInputItem]) -> str | None:
  53. if isinstance(items, str):
  54. return items
  55. for item in reversed(items):
  56. if not isinstance(item, dict):
  57. continue
  58. if item.get("role") != "user":
  59. continue
  60. text = _extract_text_from_content(item.get("content"))
  61. if isinstance(text, str):
  62. return text
  63. return None
  64. def _extract_last_tool_output(items: str | list[TResponseInputItem]) -> str | None:
  65. if isinstance(items, str):
  66. return None
  67. for item in reversed(items):
  68. if not isinstance(item, dict):
  69. continue
  70. if item.get("role") == "user":
  71. return None
  72. if item.get("type") not in {"function_call_output", "tool_call_output_item"}:
  73. continue
  74. output = item.get("output")
  75. if isinstance(output, str):
  76. return output
  77. if output is not None:
  78. return json.dumps(output)
  79. return None
  80. def _extract_secret_from_history(items: list[TResponseInputItem]) -> str | None:
  81. for item in reversed(items):
  82. if not isinstance(item, dict):
  83. continue
  84. content = item.get("content")
  85. text = _extract_text_from_content(content)
  86. if not isinstance(text, str):
  87. continue
  88. match = _SECRET_RE.search(text)
  89. if match:
  90. return match.group("secret")
  91. return None
  92. def _select_recipient(user_text: str, recipients: list[str]) -> str | None:
  93. lower = user_text.lower()
  94. matches = [recipient for recipient in recipients if recipient.lower() in lower]
  95. if matches:
  96. return max(matches, key=len)
  97. return None
  98. def _extract_relay_message(user_text: str) -> str:
  99. lower = user_text.lower()
  100. if "remember" in lower or "recall" in lower or "secret code" in lower:
  101. return user_text.strip()
  102. match = _MESSAGE_RE.search(user_text)
  103. if match:
  104. return match.group("message").strip()
  105. match = _SECRET_RE.search(user_text)
  106. if match:
  107. return match.group("secret").strip()
  108. match = _HANDLE_RE.search(user_text)
  109. if match:
  110. return match.group("task").strip()
  111. match = _TASK_RE.search(user_text)
  112. if match:
  113. return match.group("task").strip()
  114. return user_text.strip()
  115. def _build_message_response(text: str, model_name: str) -> ModelResponse:
  116. tokens = max(1, len(text.split()))
  117. usage = Usage(
  118. requests=1,
  119. input_tokens=0,
  120. output_tokens=tokens,
  121. total_tokens=tokens,
  122. input_tokens_details=InputTokensDetails(cached_tokens=0),
  123. output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
  124. )
  125. message = ResponseOutputMessage(
  126. id=f"msg_{uuid.uuid4().hex}",
  127. content=[ResponseOutputText(text=text, type="output_text", annotations=[], logprobs=[])],
  128. role="assistant",
  129. status="completed",
  130. type="message",
  131. )
  132. return ModelResponse(output=[message], usage=usage, response_id=f"resp_{uuid.uuid4().hex}")
  133. def _build_tool_call_response(tool_name: str, arguments: dict[str, Any]) -> ModelResponse:
  134. call_id = f"call_{uuid.uuid4().hex}"
  135. tool_call = ResponseFunctionToolCall(
  136. arguments=json.dumps(arguments),
  137. call_id=call_id,
  138. name=tool_name,
  139. type="function_call",
  140. id=f"fc_{uuid.uuid4().hex}",
  141. status="completed",
  142. )
  143. usage = Usage(
  144. requests=1,
  145. input_tokens=0,
  146. output_tokens=0,
  147. total_tokens=0,
  148. input_tokens_details=InputTokensDetails(cached_tokens=0),
  149. output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
  150. )
  151. return ModelResponse(output=[tool_call], usage=usage, response_id=f"resp_{uuid.uuid4().hex}")
  152. async def _stream_text_events(text: str, model_name: str) -> AsyncIterator[TResponseStreamEvent]:
  153. response_id = f"resp_{uuid.uuid4().hex}"
  154. message_id = f"msg_{uuid.uuid4().hex}"
  155. created_at = int(time.time())
  156. sequence_number = 0
  157. created_response = Response(
  158. id=response_id,
  159. created_at=created_at,
  160. model=model_name,
  161. object="response",
  162. output=[],
  163. tool_choice="none",
  164. tools=[],
  165. parallel_tool_calls=False,
  166. usage=None,
  167. )
  168. yield ResponseCreatedEvent(
  169. response=created_response,
  170. sequence_number=sequence_number,
  171. type="response.created",
  172. )
  173. sequence_number += 1
  174. start_message = ResponseOutputMessage(
  175. id=message_id,
  176. content=[],
  177. role="assistant",
  178. status="in_progress",
  179. type="message",
  180. )
  181. yield ResponseOutputItemAddedEvent(
  182. item=start_message,
  183. output_index=0,
  184. sequence_number=sequence_number,
  185. type="response.output_item.added",
  186. )
  187. sequence_number += 1
  188. content_part = ResponseOutputText(
  189. text="",
  190. type="output_text",
  191. annotations=[],
  192. logprobs=[],
  193. )
  194. yield ResponseContentPartAddedEvent(
  195. content_index=0,
  196. item_id=message_id,
  197. output_index=0,
  198. part=content_part,
  199. sequence_number=sequence_number,
  200. type="response.content_part.added",
  201. )
  202. sequence_number += 1
  203. yield ResponseTextDeltaEvent(
  204. content_index=0,
  205. delta=text,
  206. item_id=message_id,
  207. logprobs=[],
  208. output_index=0,
  209. sequence_number=sequence_number,
  210. type="response.output_text.delta",
  211. )
  212. sequence_number += 1
  213. yield ResponseTextDoneEvent(
  214. content_index=0,
  215. item_id=message_id,
  216. logprobs=[],
  217. output_index=0,
  218. sequence_number=sequence_number,
  219. text=text,
  220. type="response.output_text.done",
  221. )
  222. sequence_number += 1
  223. final_content = ResponseOutputText(
  224. text=text,
  225. type="output_text",
  226. annotations=[],
  227. logprobs=[],
  228. )
  229. yield ResponseContentPartDoneEvent(
  230. content_index=0,
  231. item_id=message_id,
  232. output_index=0,
  233. part=final_content,
  234. sequence_number=sequence_number,
  235. type="response.content_part.done",
  236. )
  237. sequence_number += 1
  238. completed_message = ResponseOutputMessage(
  239. id=message_id,
  240. content=[final_content],
  241. role="assistant",
  242. status="completed",
  243. type="message",
  244. )
  245. yield ResponseOutputItemDoneEvent(
  246. item=completed_message,
  247. output_index=0,
  248. sequence_number=sequence_number,
  249. type="response.output_item.done",
  250. )
  251. sequence_number += 1
  252. tokens = max(1, len(text.split()))
  253. usage = ResponseUsage(
  254. input_tokens=0,
  255. input_tokens_details=InputTokensDetails(cached_tokens=0),
  256. output_tokens=tokens,
  257. output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
  258. total_tokens=tokens,
  259. )
  260. completed_response = Response(
  261. id=response_id,
  262. created_at=created_at,
  263. model=model_name,
  264. object="response",
  265. output=[completed_message],
  266. tool_choice="none",
  267. tools=[],
  268. parallel_tool_calls=False,
  269. usage=usage,
  270. )
  271. yield ResponseCompletedEvent(
  272. response=completed_response,
  273. sequence_number=sequence_number,
  274. type="response.completed",
  275. )
  276. class DeterministicModel(Model):
  277. def __init__(self, model: str = "test-deterministic", default_response: str = "OK") -> None:
  278. self.model = model
  279. self._default_response = default_response
  280. async def get_response(
  281. self,
  282. system_instructions: str | None,
  283. input: str | list[TResponseInputItem],
  284. model_settings: ModelSettings,
  285. tools: list[Tool],
  286. output_schema: AgentOutputSchemaBase | None,
  287. handoffs: list[Handoff],
  288. tracing: ModelTracing,
  289. *,
  290. previous_response_id: str | None,
  291. conversation_id: str | None,
  292. prompt: ResponsePromptParam | None,
  293. ) -> ModelResponse:
  294. tool_output = _extract_last_tool_output(input)
  295. if tool_output is not None:
  296. return _build_message_response(tool_output, self.model)
  297. user_text = _extract_last_user_text(input) or ""
  298. lower = user_text.lower()
  299. tool_map = {tool.name: tool for tool in tools}
  300. if "store_data" in tool_map:
  301. store_match = _STORE_RE.search(user_text)
  302. if store_match:
  303. return _build_tool_call_response(
  304. "store_data",
  305. {"key": store_match.group("key"), "value": store_match.group("value")},
  306. )
  307. if "get_data" in tool_map:
  308. get_match = _GET_RE.search(user_text)
  309. if get_match:
  310. return _build_tool_call_response("get_data", {"key": get_match.group("key")})
  311. if "send_message" in tool_map:
  312. schema = tool_map["send_message"].params_json_schema
  313. recipients = schema.get("properties", {}).get("recipient_agent", {}).get("enum", [])
  314. recipients = [r for r in recipients if isinstance(r, str)]
  315. recipient = _select_recipient(user_text, recipients)
  316. if recipient:
  317. return _build_tool_call_response(
  318. "send_message",
  319. {
  320. "recipient_agent": recipient,
  321. "message": _extract_relay_message(user_text),
  322. "additional_instructions": "",
  323. },
  324. )
  325. if handoffs:
  326. for handoff in handoffs:
  327. if handoff.agent_name.lower() in lower:
  328. return _build_tool_call_response(handoff.tool_name, {"recipient_agent": handoff.agent_name})
  329. if "data agent" in lower or "name and age" in lower:
  330. handoff = handoffs[0]
  331. return _build_tool_call_response(handoff.tool_name, {"recipient_agent": handoff.agent_name})
  332. if "remember" in lower:
  333. secret = _extract_secret_from_history(input if isinstance(input, list) else [])
  334. if secret:
  335. return _build_message_response(f"REMEMBERED: {secret}", self.model)
  336. if "recall" in lower or "secret code" in lower:
  337. secret = _extract_secret_from_history(input if isinstance(input, list) else [])
  338. if secret:
  339. return _build_message_response(f"RECALLED: {secret}", self.model)
  340. if any(word in lower for word in ("task", "handle")):
  341. return _build_message_response(f"TASK_COMPLETED: {user_text}", self.model)
  342. return _build_message_response(self._default_response, self.model)
  343. def stream_response(
  344. self,
  345. system_instructions: str | None,
  346. input: str | list[TResponseInputItem],
  347. model_settings: ModelSettings,
  348. tools: list[Tool],
  349. output_schema: AgentOutputSchemaBase | None,
  350. handoffs: list[Handoff],
  351. tracing: ModelTracing,
  352. *,
  353. previous_response_id: str | None,
  354. conversation_id: str | None,
  355. prompt: ResponsePromptParam | None,
  356. ) -> AsyncIterator[TResponseStreamEvent]:
  357. text = self._default_response
  358. return _stream_text_events(text, self.model)