| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- import asyncio
- import json
- from collections.abc import AsyncGenerator
- from typing import Any
- import pytest
- from agency_swarm.agent.execution_stream_response import StreamingRunResponse
- from agency_swarm.integrations.fastapi_utils.endpoint_handlers import (
- ActiveRunRegistry,
- make_stream_endpoint,
- )
- from agency_swarm.integrations.fastapi_utils.request_models import BaseRequest
- class _StubRequest:
- def __init__(self) -> None:
- self._disconnected = False
- async def is_disconnected(self) -> bool:
- return self._disconnected
- def disconnect(self) -> None:
- self._disconnected = True
- class _StubThreadManager:
- def __init__(self) -> None:
- self._messages: list[dict[str, Any]] = []
- def get_all_messages(self) -> list[dict[str, Any]]:
- return list(self._messages)
- class _StubAgency:
- def __init__(self, stream: StreamingRunResponse) -> None:
- self.thread_manager = _StubThreadManager()
- self.mcp_servers: list[Any] = []
- self.agents = {}
- self.entry_points = []
- self._stream = stream
- def get_response_stream(self, **_kwargs: Any) -> StreamingRunResponse:
- return self._stream
- async def _simple_stream() -> AsyncGenerator[dict[str, Any]]:
- yield {"type": "delta", "content": "hello"}
- await asyncio.sleep(0)
- @pytest.mark.asyncio
- async def test_stream_endpoint_cleans_up_on_normal_completion(monkeypatch: pytest.MonkeyPatch) -> None:
- """Stream completing normally must remove the run from the registry."""
- async def _noop_attach(_agency: Any) -> None:
- return None
- monkeypatch.setattr(
- "agency_swarm.integrations.fastapi_utils.endpoint_handlers.attach_persistent_mcp_servers",
- _noop_attach,
- )
- stream = StreamingRunResponse(_simple_stream())
- agency = _StubAgency(stream)
- def agency_factory(**_kwargs: Any) -> _StubAgency:
- return agency
- run_registry = ActiveRunRegistry()
- handler = make_stream_endpoint(BaseRequest, agency_factory, lambda: None, run_registry)
- http_request = _StubRequest()
- request = BaseRequest(message="hi there")
- response = await handler(http_request=http_request, request=request, token=None)
- generator = response.body_iterator
- # Consume all events from the stream
- run_id = None
- async for event in generator:
- # The run_id comes in "event: meta\ndata: {...}" format
- for line in event.splitlines():
- if line.startswith("data: "):
- data_str = line[6:].strip()
- if data_str == "[DONE]":
- continue
- try:
- data = json.loads(data_str)
- if "run_id" in data:
- run_id = data["run_id"]
- except json.JSONDecodeError:
- pass
- assert run_id is not None, "run_id should have been received"
- remaining = await run_registry.get(run_id)
- assert remaining is None, "Active run registry entry must be removed after stream completes"
- @pytest.mark.asyncio
- async def test_stream_endpoint_cleans_up_on_disconnect(monkeypatch: pytest.MonkeyPatch) -> None:
- """Client disconnect must clean up registry and cancel the stream."""
- disconnect_triggered = asyncio.Event()
- stream_cancelled = asyncio.Event()
- async def _slow_stream() -> AsyncGenerator[dict[str, Any]]:
- """Stream that waits for disconnect signal before completing."""
- yield {"type": "delta", "content": "first"}
- # Wait until we're signaled to check for disconnect
- await disconnect_triggered.wait()
- # Yield one more to trigger the disconnect check
- yield {"type": "delta", "content": "second"}
- stream_cancelled.set()
- async def _noop_attach(_agency: Any) -> None:
- return None
- monkeypatch.setattr(
- "agency_swarm.integrations.fastapi_utils.endpoint_handlers.attach_persistent_mcp_servers",
- _noop_attach,
- )
- stream = StreamingRunResponse(_slow_stream())
- agency = _StubAgency(stream)
- def agency_factory(**_kwargs: Any) -> _StubAgency:
- return agency
- run_registry = ActiveRunRegistry()
- handler = make_stream_endpoint(BaseRequest, agency_factory, lambda: None, run_registry)
- http_request = _StubRequest()
- request = BaseRequest(message="hi there")
- response = await handler(http_request=http_request, request=request, token=None)
- generator = response.body_iterator
- # Get the meta event with run_id
- meta_event = await generator.__anext__()
- data_line = [line for line in meta_event.splitlines() if line.startswith("data: ")][0]
- run_id = json.loads(data_line.split("data: ", 1)[1])["run_id"]
- # Verify run is registered
- active_run = await run_registry.get(run_id)
- assert active_run is not None, "Run should be registered"
- # Simulate client disconnect
- http_request.disconnect()
- disconnect_triggered.set()
- # Consume remaining events until stream ends
- async for _ in generator:
- pass
- # Verify cleanup happened
- remaining = await run_registry.get(run_id)
- assert remaining is None, "Active run registry entry must be removed after disconnect"
|