test_stream_id_normalization.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. from collections.abc import AsyncGenerator
  2. from typing import Any, cast
  3. import pytest
  4. from agents.items import MessageOutputItem
  5. from agents.models.fake_id import FAKE_RESPONSES_ID
  6. from agents.stream_events import RawResponsesStreamEvent, RunItemStreamEvent
  7. from openai.types.responses import ResponseFunctionToolCall, ResponseOutputMessage, ResponseOutputText
  8. from openai.types.responses.response_function_call_arguments_delta_event import (
  9. ResponseFunctionCallArgumentsDeltaEvent,
  10. )
  11. from openai.types.responses.response_output_item_added_event import ResponseOutputItemAddedEvent
  12. from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent
  13. from agency_swarm import Agent
  14. @pytest.mark.asyncio
  15. async def test_agent_stream_rewrites_fake_ids_in_raw_and_run_item_events(
  16. monkeypatch: pytest.MonkeyPatch,
  17. ) -> None:
  18. """Streaming must not expose `id/item_id=__fake_id__` on LiteLLM/ChatCompletions surfaces."""
  19. async def dummy_stream_events(agent: Agent) -> AsyncGenerator[Any]:
  20. yield RawResponsesStreamEvent(
  21. data=ResponseTextDeltaEvent(
  22. content_index=0,
  23. delta="A",
  24. item_id=FAKE_RESPONSES_ID,
  25. logprobs=[],
  26. output_index=0,
  27. sequence_number=1,
  28. type="response.output_text.delta",
  29. )
  30. )
  31. yield RunItemStreamEvent(
  32. name="message_output_created",
  33. item=MessageOutputItem(
  34. agent=agent,
  35. raw_item=ResponseOutputMessage(
  36. id=FAKE_RESPONSES_ID,
  37. content=[ResponseOutputText(text="A", type="output_text", annotations=[])],
  38. role="assistant",
  39. status="completed",
  40. type="message",
  41. ),
  42. ),
  43. )
  44. class DummyStreamedResult:
  45. def __init__(self, agent: Agent) -> None:
  46. self._agent = agent
  47. def stream_events(self):
  48. return dummy_stream_events(self._agent)
  49. def run_streamed_stub(*_args: Any, **kwargs: Any) -> DummyStreamedResult:
  50. return DummyStreamedResult(cast(Agent, kwargs["starting_agent"]))
  51. monkeypatch.setattr("agents.Runner.run_streamed", run_streamed_stub)
  52. agent = Agent(name="TestAgent", instructions="noop")
  53. events = [event async for event in agent.get_response_stream("hi")]
  54. raw_event = events[0]
  55. assert getattr(raw_event, "type", None) == "raw_response_event"
  56. assert hasattr(raw_event, "agent") and raw_event.agent == "TestAgent"
  57. assert raw_event.data.item_id != FAKE_RESPONSES_ID
  58. stable_id = raw_event.data.item_id
  59. assert getattr(raw_event, "item_id", None) == stable_id
  60. run_item_event = events[1]
  61. assert getattr(run_item_event, "type", None) == "run_item_stream_event"
  62. assert run_item_event.name == "message_output_created"
  63. assert hasattr(run_item_event, "agent") and run_item_event.agent == "TestAgent"
  64. assert run_item_event.item.raw_item.id == stable_id
  65. assert getattr(run_item_event, "item_id", None) == stable_id
  66. @pytest.mark.asyncio
  67. async def test_agent_stream_rewrites_tool_argument_delta_item_id_to_call_id(
  68. monkeypatch: pytest.MonkeyPatch,
  69. ) -> None:
  70. """Tool arg deltas must correlate via call_id, not the placeholder item_id."""
  71. async def dummy_stream_events() -> AsyncGenerator[Any]:
  72. tool_call = ResponseFunctionToolCall(
  73. arguments="{}",
  74. call_id="call_1",
  75. name="Tool",
  76. type="function_call",
  77. id=FAKE_RESPONSES_ID,
  78. status="in_progress",
  79. )
  80. yield RawResponsesStreamEvent(
  81. data=ResponseOutputItemAddedEvent(
  82. item=tool_call,
  83. output_index=0,
  84. sequence_number=1,
  85. type="response.output_item.added",
  86. )
  87. )
  88. yield RawResponsesStreamEvent(
  89. data=ResponseFunctionCallArgumentsDeltaEvent(
  90. item_id=FAKE_RESPONSES_ID,
  91. delta='{"x": 1}',
  92. output_index=0,
  93. sequence_number=2,
  94. type="response.function_call_arguments.delta",
  95. )
  96. )
  97. class DummyStreamedResult:
  98. def stream_events(self):
  99. return dummy_stream_events()
  100. monkeypatch.setattr("agents.Runner.run_streamed", lambda *_a, **_k: DummyStreamedResult())
  101. agent = Agent(name="ToolAgent", instructions="noop")
  102. events = [event async for event in agent.get_response_stream("hi")]
  103. output_item = events[0].data.item
  104. assert output_item.id == "call_1"
  105. args_delta = events[1].data
  106. assert args_delta.item_id == "call_1"
  107. assert getattr(events[1], "item_id", None) == "call_1"