test_fastapi_stream_cancellation.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import asyncio
  2. import json
  3. from collections.abc import AsyncGenerator
  4. from typing import Any
  5. import pytest
  6. from agency_swarm.agent.execution_stream_response import StreamingRunResponse
  7. from agency_swarm.integrations.fastapi_utils.endpoint_handlers import (
  8. ActiveRunRegistry,
  9. make_stream_endpoint,
  10. )
  11. from agency_swarm.integrations.fastapi_utils.request_models import BaseRequest
  12. class _StubRequest:
  13. def __init__(self) -> None:
  14. self._disconnected = False
  15. async def is_disconnected(self) -> bool:
  16. return self._disconnected
  17. def disconnect(self) -> None:
  18. self._disconnected = True
  19. class _StubThreadManager:
  20. def __init__(self) -> None:
  21. self._messages: list[dict[str, Any]] = []
  22. def get_all_messages(self) -> list[dict[str, Any]]:
  23. return list(self._messages)
  24. class _StubAgency:
  25. def __init__(self, stream: StreamingRunResponse) -> None:
  26. self.thread_manager = _StubThreadManager()
  27. self.mcp_servers: list[Any] = []
  28. self.agents = {}
  29. self.entry_points = []
  30. self._stream = stream
  31. def get_response_stream(self, **_kwargs: Any) -> StreamingRunResponse:
  32. return self._stream
  33. async def _simple_stream() -> AsyncGenerator[dict[str, Any]]:
  34. yield {"type": "delta", "content": "hello"}
  35. await asyncio.sleep(0)
  36. @pytest.mark.asyncio
  37. async def test_stream_endpoint_cleans_up_on_normal_completion(monkeypatch: pytest.MonkeyPatch) -> None:
  38. """Stream completing normally must remove the run from the registry."""
  39. async def _noop_attach(_agency: Any) -> None:
  40. return None
  41. monkeypatch.setattr(
  42. "agency_swarm.integrations.fastapi_utils.endpoint_handlers.attach_persistent_mcp_servers",
  43. _noop_attach,
  44. )
  45. stream = StreamingRunResponse(_simple_stream())
  46. agency = _StubAgency(stream)
  47. def agency_factory(**_kwargs: Any) -> _StubAgency:
  48. return agency
  49. run_registry = ActiveRunRegistry()
  50. handler = make_stream_endpoint(BaseRequest, agency_factory, lambda: None, run_registry)
  51. http_request = _StubRequest()
  52. request = BaseRequest(message="hi there")
  53. response = await handler(http_request=http_request, request=request, token=None)
  54. generator = response.body_iterator
  55. # Consume all events from the stream
  56. run_id = None
  57. async for event in generator:
  58. # The run_id comes in "event: meta\ndata: {...}" format
  59. for line in event.splitlines():
  60. if line.startswith("data: "):
  61. data_str = line[6:].strip()
  62. if data_str == "[DONE]":
  63. continue
  64. try:
  65. data = json.loads(data_str)
  66. if "run_id" in data:
  67. run_id = data["run_id"]
  68. except json.JSONDecodeError:
  69. pass
  70. assert run_id is not None, "run_id should have been received"
  71. remaining = await run_registry.get(run_id)
  72. assert remaining is None, "Active run registry entry must be removed after stream completes"
  73. @pytest.mark.asyncio
  74. async def test_stream_endpoint_cleans_up_on_disconnect(monkeypatch: pytest.MonkeyPatch) -> None:
  75. """Client disconnect must clean up registry and cancel the stream."""
  76. disconnect_triggered = asyncio.Event()
  77. stream_cancelled = asyncio.Event()
  78. async def _slow_stream() -> AsyncGenerator[dict[str, Any]]:
  79. """Stream that waits for disconnect signal before completing."""
  80. yield {"type": "delta", "content": "first"}
  81. # Wait until we're signaled to check for disconnect
  82. await disconnect_triggered.wait()
  83. # Yield one more to trigger the disconnect check
  84. yield {"type": "delta", "content": "second"}
  85. stream_cancelled.set()
  86. async def _noop_attach(_agency: Any) -> None:
  87. return None
  88. monkeypatch.setattr(
  89. "agency_swarm.integrations.fastapi_utils.endpoint_handlers.attach_persistent_mcp_servers",
  90. _noop_attach,
  91. )
  92. stream = StreamingRunResponse(_slow_stream())
  93. agency = _StubAgency(stream)
  94. def agency_factory(**_kwargs: Any) -> _StubAgency:
  95. return agency
  96. run_registry = ActiveRunRegistry()
  97. handler = make_stream_endpoint(BaseRequest, agency_factory, lambda: None, run_registry)
  98. http_request = _StubRequest()
  99. request = BaseRequest(message="hi there")
  100. response = await handler(http_request=http_request, request=request, token=None)
  101. generator = response.body_iterator
  102. # Get the meta event with run_id
  103. meta_event = await generator.__anext__()
  104. data_line = [line for line in meta_event.splitlines() if line.startswith("data: ")][0]
  105. run_id = json.loads(data_line.split("data: ", 1)[1])["run_id"]
  106. # Verify run is registered
  107. active_run = await run_registry.get(run_id)
  108. assert active_run is not None, "Run should be registered"
  109. # Simulate client disconnect
  110. http_request.disconnect()
  111. disconnect_triggered.set()
  112. # Consume remaining events until stream ends
  113. async for _ in generator:
  114. pass
  115. # Verify cleanup happened
  116. remaining = await run_registry.get(run_id)
  117. assert remaining is None, "Active run registry entry must be removed after disconnect"