test_streaming_final_result.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import asyncio
  2. import threading
  3. from unittest.mock import MagicMock
  4. import pytest
  5. from agency_swarm import Agent
  6. def _build_simple_agent(name: str = "TestAgent") -> Agent:
  7. return Agent(
  8. name=name,
  9. instructions="Return deterministic streaming output for test validation.",
  10. description="Minimal agent used to exercise streaming wrappers.",
  11. )
  12. def test_wait_final_result_without_event_loop(monkeypatch):
  13. """Streaming wrapper must resolve to None when created before an event loop starts."""
  14. from agency_swarm.agent.execution_streaming import StreamingRunResponse
  15. async def _empty_stream():
  16. if False: # pragma: no cover
  17. yield
  18. def _stubbed_run_stream(**_kwargs):
  19. wrapper = StreamingRunResponse(_empty_stream())
  20. wrapper._resolve_final_result(None)
  21. return wrapper
  22. monkeypatch.setattr("agency_swarm.agent.execution.run_stream_with_guardrails", _stubbed_run_stream)
  23. agent = _build_simple_agent()
  24. stream = agent.get_response_stream("Trigger stream")
  25. async def _drive_stream() -> None:
  26. async for _ in stream:
  27. pass
  28. async def _await_result() -> None:
  29. result = await asyncio.wait_for(stream.wait_final_result(), timeout=0.5)
  30. assert result is None
  31. async def _run() -> None:
  32. await asyncio.gather(_drive_stream(), _await_result())
  33. asyncio.run(_run())
  34. @pytest.mark.asyncio
  35. async def test_wait_final_result_before_adoption(monkeypatch):
  36. """Awaiting wait_final_result before iterating events must resolve once the inner stream finishes."""
  37. from agency_swarm.agent.execution_streaming import StreamingRunResponse
  38. final_result = MagicMock()
  39. async def _single_event_stream():
  40. yield {"type": "test_event"}
  41. def _stubbed_run_stream(**_kwargs):
  42. wrapper = StreamingRunResponse(_single_event_stream())
  43. wrapper._resolve_final_result(final_result)
  44. return wrapper
  45. monkeypatch.setattr("agency_swarm.agent.execution.run_stream_with_guardrails", _stubbed_run_stream)
  46. agent = _build_simple_agent("PreAdoptionAgent")
  47. stream = agent.get_response_stream("Trigger stream")
  48. wait_task = asyncio.create_task(asyncio.wait_for(stream.wait_final_result(), timeout=0.5))
  49. events = []
  50. async for event in stream:
  51. events.append(event)
  52. result = await wait_task
  53. assert events == [{"type": "test_event"}]
  54. assert result is final_result
  55. @pytest.mark.asyncio
  56. async def test_adopt_stream_syncs_futures_across_event_loops():
  57. """Adopting a stream must safely synchronize final futures from another loop."""
  58. from agency_swarm.agent.execution_streaming import StreamingRunResponse
  59. async def _empty_stream():
  60. if False: # pragma: no cover
  61. yield
  62. external_loop_ready = threading.Event()
  63. external_loop = asyncio.new_event_loop()
  64. def _loop_runner() -> None:
  65. asyncio.set_event_loop(external_loop)
  66. external_loop_ready.set()
  67. external_loop.run_forever()
  68. runner_thread = threading.Thread(target=_loop_runner, daemon=True)
  69. runner_thread.start()
  70. external_loop_ready.wait()
  71. async def _create_future() -> asyncio.Future[object | None]:
  72. loop = asyncio.get_running_loop()
  73. return loop.create_future()
  74. try:
  75. outer_future = asyncio.run_coroutine_threadsafe(_create_future(), external_loop).result(timeout=1)
  76. outer_wrapper = StreamingRunResponse(_empty_stream())
  77. outer_wrapper._final_future = outer_future
  78. inner_wrapper = StreamingRunResponse(_empty_stream())
  79. inner_wrapper._final_future = asyncio.get_running_loop().create_future()
  80. outer_wrapper._adopt_stream(inner_wrapper)
  81. completion = threading.Event()
  82. external_loop.call_soon_threadsafe(outer_future.add_done_callback, lambda _fut: completion.set())
  83. sentinel = object()
  84. inner_wrapper._final_future.set_result(sentinel)
  85. assert await asyncio.to_thread(completion.wait, 1.0)
  86. async def _await_external() -> object | None:
  87. return await outer_future
  88. result = asyncio.run_coroutine_threadsafe(_await_external(), external_loop).result(timeout=1)
  89. assert result is sentinel
  90. finally:
  91. external_loop.call_soon_threadsafe(external_loop.stop)
  92. runner_thread.join(timeout=1)
  93. external_loop.close()