test_agui_adapter.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. import dataclasses
  2. import json
  3. from types import SimpleNamespace
  4. from unittest.mock import MagicMock
  5. import pytest
  6. pytest.importorskip("ag_ui")
  7. from ag_ui.core import (
  8. AssistantMessage,
  9. CustomEvent,
  10. DeveloperMessage,
  11. EventType,
  12. FunctionCall,
  13. MessagesSnapshotEvent,
  14. RawEvent,
  15. TextMessageContentEvent,
  16. TextMessageEndEvent,
  17. TextMessageStartEvent,
  18. ToolCall,
  19. ToolCallArgsEvent,
  20. ToolCallEndEvent,
  21. ToolMessage,
  22. UserMessage,
  23. )
  24. from openai.types.responses import (
  25. ResponseCodeInterpreterToolCall,
  26. ResponseFileSearchToolCall,
  27. ResponseFunctionToolCall,
  28. ResponseOutputMessage,
  29. ResponseOutputText,
  30. )
  31. from openai.types.responses.response_file_search_tool_call import Result as FileSearchResult
  32. from openai.types.responses.response_function_call_arguments_delta_event import ResponseFunctionCallArgumentsDeltaEvent
  33. from openai.types.responses.response_output_item_added_event import ResponseOutputItemAddedEvent
  34. from openai.types.responses.response_output_item_done_event import ResponseOutputItemDoneEvent
  35. from openai.types.responses.response_output_text import AnnotationFileCitation
  36. from openai.types.responses.response_text_delta_event import Logprob, ResponseTextDeltaEvent
  37. from agency_swarm.ui.core.agui_adapter import AguiAdapter
  38. def make_raw_event(data):
  39. return SimpleNamespace(type="raw_response_event", data=data)
  40. def make_stream_event(name, item):
  41. return SimpleNamespace(type="run_item_stream_event", name=name, item=item)
  42. def test_agui_messages_to_chat_history_converts_roles_and_tool_calls():
  43. tool_call = ToolCall(id="call-1", type="function", function=FunctionCall(name="Weather", arguments="{}"))
  44. assistant_msg = AssistantMessage(id="a1", role="assistant", content="Hi", tool_calls=[tool_call])
  45. user_msg = UserMessage(id="u1", role="user", content="Hello")
  46. tool_msg = ToolMessage(id="t1", role="tool", content="Done", tool_call_id="call-1")
  47. dev_msg = DeveloperMessage(id="d1", role="developer", content="Dev note")
  48. history = AguiAdapter().agui_messages_to_chat_history([user_msg, assistant_msg, tool_msg, dev_msg])
  49. assert history[0] == {"role": "user", "content": "Hello"}
  50. assert history[1]["type"] == "function_call"
  51. assert history[1]["name"] == "Weather"
  52. assert history[2] == {"call_id": "call-1", "output": "Done", "type": "function_call_output"}
  53. assert history[3] == {"role": "system", "content": "Dev note"}
  54. def test_agui_messages_to_chat_history_handles_file_search_call():
  55. tool_call = ToolCall(
  56. id="call-99",
  57. type="function",
  58. function=FunctionCall(name="FileSearchTool", arguments='{"queries": ["foo"], "results": ["bar"]}'),
  59. )
  60. assistant_msg = AssistantMessage(id="a2", role="assistant", content="", tool_calls=[tool_call])
  61. history = AguiAdapter().agui_messages_to_chat_history([assistant_msg])
  62. assert history[0]["type"] == "file_search_call"
  63. assert history[0]["queries"] == ["foo"]
  64. assert history[0]["results"] == ["bar"]
  65. def test_agui_messages_to_chat_history_handles_code_interpreter_call():
  66. tool_call = ToolCall(
  67. id="ci-1",
  68. type="function",
  69. function=FunctionCall(
  70. name="CodeInterpreterTool",
  71. arguments='{"code": "print(1)", "container_id": "cid", "outputs": ["1"]}',
  72. ),
  73. )
  74. assistant_msg = AssistantMessage(id="a3", role="assistant", content="", tool_calls=[tool_call])
  75. history = AguiAdapter().agui_messages_to_chat_history([assistant_msg])
  76. assert history[0]["type"] == "code_interpreter_call"
  77. assert history[0]["code"] == "print(1)"
  78. assert history[0]["outputs"] == ["1"]
  79. def test_agui_messages_to_chat_history_handles_plain_assistant_message():
  80. assistant_msg = AssistantMessage(id="a4", role="assistant", content="Result ready", tool_calls=[])
  81. history = AguiAdapter().agui_messages_to_chat_history([assistant_msg])
  82. assert history == [{"role": "assistant", "content": "Result ready"}]
  83. def test_openai_events_emit_message_lifecycle():
  84. adapter = AguiAdapter()
  85. run_id = "run-1"
  86. message = ResponseOutputMessage(
  87. id="m-1",
  88. content=[ResponseOutputText(annotations=[], text="Hi", type="output_text")],
  89. role="assistant",
  90. status="completed",
  91. type="message",
  92. )
  93. start_event = make_raw_event(
  94. ResponseOutputItemAddedEvent(
  95. item=message,
  96. output_index=0,
  97. sequence_number=1,
  98. type="response.output_item.added",
  99. )
  100. )
  101. delta_event = make_raw_event(
  102. ResponseTextDeltaEvent(
  103. content_index=0,
  104. delta="Hi",
  105. item_id="m-1",
  106. logprobs=[],
  107. output_index=0,
  108. sequence_number=2,
  109. type="response.output_text.delta",
  110. )
  111. )
  112. done_event = make_raw_event(
  113. ResponseOutputItemDoneEvent(
  114. item=message,
  115. output_index=0,
  116. sequence_number=3,
  117. type="response.output_item.done",
  118. )
  119. )
  120. start = adapter.openai_to_agui_events(start_event, run_id=run_id)
  121. delta = adapter.openai_to_agui_events(delta_event, run_id=run_id)
  122. done = adapter.openai_to_agui_events(done_event, run_id=run_id)
  123. assert isinstance(start, TextMessageStartEvent)
  124. assert isinstance(delta, TextMessageContentEvent)
  125. assert isinstance(done, TextMessageEndEvent)
  126. assert delta.message_id == "m-1"
  127. def test_openai_events_track_tool_calls_and_arguments():
  128. adapter = AguiAdapter()
  129. run_id = "run-2"
  130. raw_tool = ResponseFunctionToolCall(
  131. arguments="{}",
  132. call_id="call-1",
  133. name="search",
  134. type="function_call",
  135. id="item-1",
  136. status="in_progress",
  137. )
  138. adapter.openai_to_agui_events(
  139. make_raw_event(
  140. ResponseOutputItemAddedEvent(
  141. item=raw_tool,
  142. output_index=0,
  143. sequence_number=1,
  144. type="response.output_item.added",
  145. )
  146. ),
  147. run_id=run_id,
  148. )
  149. args_event = adapter.openai_to_agui_events(
  150. make_raw_event(
  151. ResponseFunctionCallArgumentsDeltaEvent(
  152. item_id="item-1",
  153. delta='{"q": "',
  154. output_index=0,
  155. sequence_number=2,
  156. type="response.function_call_arguments.delta",
  157. )
  158. ),
  159. run_id=run_id,
  160. )
  161. done_events = adapter.openai_to_agui_events(
  162. make_raw_event(
  163. ResponseOutputItemDoneEvent(
  164. type="response.output_item.done",
  165. item=ResponseFunctionToolCall(
  166. arguments='{"q": "weather"}',
  167. call_id="call-1",
  168. name="search",
  169. type="function_call",
  170. id="item-1",
  171. status="completed",
  172. ),
  173. output_index=0,
  174. sequence_number=3,
  175. )
  176. ),
  177. run_id=run_id,
  178. )
  179. assert isinstance(args_event, ToolCallArgsEvent)
  180. assert args_event.tool_call_id == "call-1"
  181. assert isinstance(done_events, list)
  182. assert isinstance(done_events[0], ToolCallEndEvent)
  183. assert isinstance(done_events[1], MessagesSnapshotEvent)
  184. def test_openai_typed_events_emit_message_lifecycle():
  185. adapter = AguiAdapter()
  186. run_id = "typed-run"
  187. message = ResponseOutputMessage(
  188. id="msg-typed",
  189. content=[ResponseOutputText(annotations=[], text="Hello world", type="output_text")],
  190. role="assistant",
  191. status="completed",
  192. type="message",
  193. )
  194. start_event = ResponseOutputItemAddedEvent(
  195. item=message,
  196. output_index=0,
  197. sequence_number=1,
  198. type="response.output_item.added",
  199. )
  200. delta_event = ResponseTextDeltaEvent(
  201. content_index=0,
  202. delta="!",
  203. item_id="msg-typed",
  204. logprobs=[Logprob(token="!", logprob=0.0, top_logprobs=[])],
  205. output_index=0,
  206. sequence_number=2,
  207. type="response.output_text.delta",
  208. )
  209. done_event = ResponseOutputItemDoneEvent(
  210. item=message,
  211. output_index=0,
  212. sequence_number=3,
  213. type="response.output_item.done",
  214. )
  215. start = adapter.openai_to_agui_events(make_raw_event(start_event), run_id=run_id)
  216. delta = adapter.openai_to_agui_events(make_raw_event(delta_event), run_id=run_id)
  217. done = adapter.openai_to_agui_events(make_raw_event(done_event), run_id=run_id)
  218. assert isinstance(start, TextMessageStartEvent)
  219. assert isinstance(delta, TextMessageContentEvent)
  220. assert isinstance(done, TextMessageEndEvent)
  221. assert delta.message_id == "msg-typed"
  222. def test_openai_events_handles_exceptions_with_run_error():
  223. adapter = AguiAdapter()
  224. event = MagicMock()
  225. event.type = "raw_response_event"
  226. type(event).data = property(lambda self: (_ for _ in ()).throw(RuntimeError("boom")))
  227. result = adapter.openai_to_agui_events(event, run_id="oops")
  228. assert result.type == EventType.RUN_ERROR
  229. assert "boom" in result.message
  230. def test_openai_events_ignore_message_without_id():
  231. adapter = AguiAdapter()
  232. event = make_raw_event(
  233. SimpleNamespace(
  234. type="response.output_item.added",
  235. item=SimpleNamespace(type="message", role="assistant", id=None),
  236. )
  237. )
  238. result = adapter.openai_to_agui_events(event, run_id="missing-message")
  239. assert isinstance(result, RawEvent)
  240. assert result.type == EventType.RAW
  241. assert result.event["data"]["type"] == "response.output_item.added"
  242. def test_openai_events_ignore_tool_call_without_call_id():
  243. adapter = AguiAdapter()
  244. run_id = "missing-tool"
  245. tool = SimpleNamespace(type="function_call", id="item-99", call_id=None, name="search", arguments="{}")
  246. adapter.openai_to_agui_events(
  247. make_raw_event(SimpleNamespace(type="response.output_item.added", item=tool)),
  248. run_id=run_id,
  249. )
  250. args_event = adapter.openai_to_agui_events(
  251. make_raw_event(SimpleNamespace(type="response.function_call_arguments.delta", item_id="item-99", delta="{}")),
  252. run_id=run_id,
  253. )
  254. assert isinstance(args_event, RawEvent)
  255. assert args_event.type == EventType.RAW
  256. assert args_event.event["data"]["type"] == "response.function_call_arguments.delta"
  257. def test_openai_events_ignore_text_delta_without_item_id():
  258. adapter = AguiAdapter()
  259. event = make_raw_event(SimpleNamespace(type="response.output_text.delta", item_id=None, delta="Hi"))
  260. result = adapter.openai_to_agui_events(event, run_id="missing-delta-id")
  261. assert isinstance(result, RawEvent)
  262. assert result.type == EventType.RAW
  263. assert result.event["data"]["type"] == "response.output_text.delta"
  264. def test_openai_events_ignore_tool_done_without_call_id():
  265. adapter = AguiAdapter()
  266. raw_item = SimpleNamespace(type="function_call", id="item-9", call_id=None, name="search", arguments="{}")
  267. event = make_raw_event(SimpleNamespace(type="response.output_item.done", item=raw_item))
  268. result = adapter.openai_to_agui_events(event, run_id="tool-done-missing")
  269. assert isinstance(result, RawEvent)
  270. assert result.type == EventType.RAW
  271. assert result.event["data"]["type"] == "response.output_item.done"
  272. def test_run_item_stream_events_emit_snapshots():
  273. adapter = AguiAdapter()
  274. run_id = "run-3"
  275. output_content = ResponseOutputText(annotations=[], text="Answer", type="output_text")
  276. raw_item = ResponseOutputMessage(
  277. id="msg-1",
  278. content=[output_content],
  279. role="assistant",
  280. status="completed",
  281. type="message",
  282. )
  283. item = SimpleNamespace(raw_item=raw_item)
  284. events = adapter.openai_to_agui_events(make_stream_event("message_output_created", item), run_id=run_id)
  285. assert isinstance(events, list)
  286. assert all(isinstance(e, MessagesSnapshotEvent | CustomEvent) for e in events)
  287. assert any(isinstance(e, MessagesSnapshotEvent) for e in events)
  288. def test_run_item_stream_with_annotations_returns_custom_event():
  289. adapter = AguiAdapter()
  290. run_id = "annotated"
  291. annotation = AnnotationFileCitation(file_id="file-annot", filename="doc.pdf", index=1, type="file_citation")
  292. output_content = ResponseOutputText(annotations=[annotation], text="Answer", type="output_text")
  293. raw_item = ResponseOutputMessage(
  294. id="msg-annot",
  295. content=[output_content],
  296. role="assistant",
  297. status="completed",
  298. type="message",
  299. )
  300. item = SimpleNamespace(raw_item=raw_item)
  301. events = adapter.openai_to_agui_events(make_stream_event("message_output_created", item), run_id=run_id)
  302. assert isinstance(events, list)
  303. assert any(isinstance(e, CustomEvent) for e in events)
  304. custom = next(e for e in events if isinstance(e, CustomEvent))
  305. assert custom.value["annotations"] == [annotation.model_dump()]
  306. def test_run_item_stream_ignores_message_without_text():
  307. adapter = AguiAdapter()
  308. run_id = "missing-text"
  309. output_content = SimpleNamespace(text=None, annotations=None)
  310. raw_item = SimpleNamespace(id="msg-empty", content=[output_content])
  311. item = SimpleNamespace(raw_item=raw_item)
  312. result = adapter.openai_to_agui_events(make_stream_event("message_output_created", item), run_id=run_id)
  313. assert isinstance(result, RawEvent)
  314. assert result.type == EventType.RAW
  315. assert result.event["name"] == "message_output_created"
  316. def test_tool_output_stream_event_converts_to_tool_message():
  317. adapter = AguiAdapter()
  318. run_id = "run-4"
  319. item = SimpleNamespace(raw_item={"call_id": "call-7"}, call_id="call-7", output="done")
  320. event = adapter.openai_to_agui_events(make_stream_event("tool_output", item), run_id=run_id)
  321. assert isinstance(event, MessagesSnapshotEvent)
  322. message = event.messages[0]
  323. assert isinstance(message, ToolMessage)
  324. assert message.tool_call_id == "call-7"
  325. assert message.content == "done"
  326. def test_tool_output_without_call_id_is_ignored():
  327. adapter = AguiAdapter()
  328. item = SimpleNamespace(raw_item={}, call_id=None, output="done")
  329. result = adapter.openai_to_agui_events(make_stream_event("tool_output", item), run_id="tool-missing")
  330. assert isinstance(result, RawEvent)
  331. assert result.type == EventType.RAW
  332. assert result.event["name"] == "tool_output"
  333. def test_run_item_stream_unknown_event_is_returned_as_raw_event():
  334. adapter = AguiAdapter()
  335. run_id = "unknown-stream"
  336. unknown_event = make_stream_event("unhandled_event", None)
  337. result = adapter.openai_to_agui_events(unknown_event, run_id=run_id)
  338. assert isinstance(result, RawEvent)
  339. assert result.type == EventType.RAW
  340. assert result.event["name"] == "unhandled_event"
  341. assert result.event["type"] == "run_item_stream_event"
  342. def test_tool_meta_handles_non_function_tools():
  343. adapter = AguiAdapter()
  344. typed_file_search = ResponseFileSearchToolCall(
  345. id="file-1",
  346. queries=["foo"],
  347. status="completed",
  348. type="file_search_call",
  349. results=[FileSearchResult(file_id="doc", text="bar")],
  350. )
  351. typed_code_interpreter = ResponseCodeInterpreterToolCall(
  352. code="print(42)",
  353. container_id="cont",
  354. id="ci-7",
  355. outputs=[{"type": "logs", "logs": "42"}],
  356. type="code_interpreter_call",
  357. status="completed",
  358. )
  359. @dataclasses.dataclass
  360. class LegacyFileSearchCall:
  361. type: str
  362. id: str
  363. queries: list[str]
  364. results: list[dict]
  365. @dataclasses.dataclass
  366. class LegacyCodeInterpreterCall:
  367. type: str
  368. id: str
  369. code: str
  370. container_id: str
  371. outputs: list[dict]
  372. file_search = LegacyFileSearchCall(
  373. type="file_search_call",
  374. id=typed_file_search.id,
  375. queries=typed_file_search.queries or [],
  376. results=json.loads(typed_file_search.model_dump_json())["results"],
  377. )
  378. code_interpreter = LegacyCodeInterpreterCall(
  379. type="code_interpreter_call",
  380. id=typed_code_interpreter.id,
  381. code=typed_code_interpreter.code or "",
  382. container_id=typed_code_interpreter.container_id,
  383. outputs=[{"type": "logs", "logs": "42"}],
  384. )
  385. file_meta = adapter._tool_meta(file_search)
  386. code_meta = adapter._tool_meta(code_interpreter)
  387. assert file_meta[0] == "file-1"
  388. assert file_meta[1] == "FileSearchTool"
  389. assert json.loads(file_meta[2])["queries"] == ["foo"]
  390. assert code_meta[0] == "ci-7"
  391. assert code_meta[1] == "CodeInterpreterTool"
  392. assert json.loads(code_meta[2])["code"] == "print(42)"