test_fastapi_user_context.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. from __future__ import annotations
  2. import json
  3. import typing
  4. from copy import deepcopy
  5. from dataclasses import dataclass, field
  6. import pytest
  7. from agents.result import RunResult, RunResultStreaming
  8. from agents.run_context import RunContextWrapper
  9. from agents.usage import Usage
  10. from fastapi.testclient import TestClient
  11. from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
  12. from agency_swarm import Agency, Agent, run_fastapi
  13. from agency_swarm.agent.execution_stream_response import StreamingRunResponse
  14. from agency_swarm.context import MasterContext
  15. from agency_swarm.streaming.utils import StreamingContext
  16. from agency_swarm.utils.thread import ThreadManager
  17. class _HasMainAgentModel(typing.Protocol):
  18. _main_agent_model: str
  19. @dataclass
  20. class ContextTracker:
  21. """Keep the latest contexts observed by the test agent."""
  22. last_response_context: dict[str, str] | None = None
  23. last_stream_context: dict[str, str | StreamingContext] | None = None
  24. def reset(self) -> None:
  25. self.last_response_context = None
  26. self.last_stream_context = None
  27. def record_response(self, context: dict[str, str] | None) -> None:
  28. self.last_response_context = deepcopy(context) if context is not None else None
  29. def record_stream(self, context: dict[str, str | StreamingContext] | None) -> None:
  30. self.last_stream_context = deepcopy(context) if context is not None else None
  31. class TrackingAgent(Agent):
  32. """Agent subclass that records incoming context instead of calling the LLM."""
  33. def __init__(self, tracker: ContextTracker):
  34. super().__init__(name="TestAgent", instructions="Base instructions")
  35. self._tracker = tracker
  36. async def get_response(
  37. self,
  38. message,
  39. sender_name=None,
  40. context_override: dict[str, str] | None = None,
  41. **kwargs: str | int | float | bool | None | list[str],
  42. ):
  43. self._tracker.record_response(context_override)
  44. usage = Usage(
  45. requests=1,
  46. input_tokens=10,
  47. output_tokens=20,
  48. total_tokens=30,
  49. input_tokens_details=InputTokensDetails(cached_tokens=0),
  50. output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
  51. )
  52. thread_manager = ThreadManager()
  53. master_context = MasterContext(
  54. thread_manager=thread_manager,
  55. agents={self.name: self},
  56. user_context=context_override or {},
  57. agent_runtime_state={},
  58. current_agent_name=self.name,
  59. shared_instructions=None,
  60. )
  61. run_result = RunResult(
  62. input=str(message),
  63. new_items=[],
  64. raw_responses=[],
  65. final_output="Test response",
  66. input_guardrail_results=[],
  67. output_guardrail_results=[],
  68. tool_input_guardrail_results=[],
  69. tool_output_guardrail_results=[],
  70. context_wrapper=RunContextWrapper(context=master_context, usage=usage),
  71. _last_agent=self,
  72. )
  73. # Enables cost fallback calculation in calculate_usage_with_cost(...)
  74. typing.cast(_HasMainAgentModel, run_result)._main_agent_model = "gpt-5.4-mini"
  75. return run_result
  76. def get_response_stream(
  77. self,
  78. message,
  79. sender_name=None,
  80. context_override: dict[str, str | StreamingContext] | None = None,
  81. **kwargs: str | int | float | bool | None | list[str],
  82. ):
  83. self._tracker.record_stream(context_override)
  84. usage = Usage(
  85. requests=1,
  86. input_tokens=10,
  87. output_tokens=20,
  88. total_tokens=30,
  89. input_tokens_details=InputTokensDetails(cached_tokens=0),
  90. output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
  91. )
  92. thread_manager = ThreadManager()
  93. master_context = MasterContext(
  94. thread_manager=thread_manager,
  95. agents={self.name: self},
  96. user_context=(context_override or {}),
  97. agent_runtime_state={},
  98. current_agent_name=self.name,
  99. shared_instructions=None,
  100. )
  101. final_result = RunResultStreaming(
  102. input=str(message),
  103. new_items=[],
  104. raw_responses=[],
  105. final_output="Test response",
  106. input_guardrail_results=[],
  107. output_guardrail_results=[],
  108. tool_input_guardrail_results=[],
  109. tool_output_guardrail_results=[],
  110. context_wrapper=RunContextWrapper(context=master_context, usage=usage),
  111. current_agent=self,
  112. current_turn=1,
  113. max_turns=1,
  114. _current_agent_output_schema=None,
  115. trace=None,
  116. )
  117. typing.cast(_HasMainAgentModel, final_result)._main_agent_model = "gpt-5.4-mini"
  118. stream_ref: dict[str, StreamingRunResponse] = {}
  119. async def _generator():
  120. yield {"type": "text", "data": "Test"}
  121. # Make final_result available to the FastAPI endpoint handler.
  122. stream_ref["stream"]._resolve_final_result(final_result) # noqa: SLF001
  123. stream = StreamingRunResponse(_generator())
  124. stream_ref["stream"] = stream
  125. return stream
  126. @dataclass
  127. class RecordingAgencyFactory:
  128. """Factory that creates agencies with context-tracking agents."""
  129. tracker: ContextTracker = field(default_factory=ContextTracker)
  130. def __call__(self, load_threads_callback=None, save_threads_callback=None):
  131. self.tracker.reset()
  132. agent = TrackingAgent(self.tracker)
  133. return Agency(
  134. agent,
  135. load_threads_callback=load_threads_callback,
  136. save_threads_callback=save_threads_callback,
  137. )
  138. @pytest.fixture
  139. def recording_agency_factory() -> RecordingAgencyFactory:
  140. return RecordingAgencyFactory()
  141. def test_non_streaming_user_context(recording_agency_factory: RecordingAgencyFactory):
  142. """Ensure user_context is forwarded to non-streaming endpoint."""
  143. app = run_fastapi(agencies={"test_agency": recording_agency_factory}, return_app=True, app_token_env="")
  144. client = TestClient(app)
  145. response = client.post(
  146. "/test_agency/get_response",
  147. json={"message": "Hello", "user_context": {"plan": "pro", "user_id": "123"}},
  148. )
  149. assert response.status_code == 200
  150. payload = response.json()
  151. assert "usage" in payload
  152. usage = payload["usage"]
  153. assert usage["request_count"] == 1
  154. assert usage["input_tokens"] == 10
  155. assert usage["output_tokens"] == 20
  156. assert usage["total_tokens"] == 30
  157. assert isinstance(usage["total_cost"], int | float)
  158. assert recording_agency_factory.tracker.last_response_context == {"plan": "pro", "user_id": "123"}
  159. def test_streaming_user_context(recording_agency_factory: RecordingAgencyFactory):
  160. """Ensure user_context is forwarded to streaming endpoint."""
  161. app = run_fastapi(agencies={"test_agency": recording_agency_factory}, return_app=True, app_token_env="")
  162. client = TestClient(app)
  163. with client.stream(
  164. "POST",
  165. "/test_agency/get_response_stream",
  166. json={"message": "Hello", "user_context": {"plan": "pro"}},
  167. ) as response:
  168. assert response.status_code == 200
  169. lines = list(response.iter_lines())
  170. stream_context = recording_agency_factory.tracker.last_stream_context
  171. assert stream_context is not None
  172. assert {k: v for k, v in stream_context.items() if k != "streaming_context"} == {"plan": "pro"}
  173. assert "streaming_context" in stream_context
  174. messages_payload = _extract_last_messages_payload(lines)
  175. usage = messages_payload["usage"]
  176. assert usage["request_count"] == 1
  177. assert usage["input_tokens"] == 10
  178. assert usage["output_tokens"] == 20
  179. assert usage["total_tokens"] == 30
  180. assert isinstance(usage["total_cost"], int | float)
  181. def test_agui_user_context(recording_agency_factory: RecordingAgencyFactory):
  182. """Ensure AG-UI streaming endpoint forwards user_context."""
  183. app = run_fastapi(
  184. agencies={"test_agency": recording_agency_factory},
  185. return_app=True,
  186. app_token_env="",
  187. enable_agui=True,
  188. )
  189. client = TestClient(app)
  190. agui_payload = {
  191. "thread_id": "test_thread",
  192. "run_id": "test_run",
  193. "state": None,
  194. "messages": [{"id": "msg1", "role": "user", "content": "Hello"}],
  195. "tools": [],
  196. "context": [],
  197. "forwardedProps": None,
  198. "user_context": {"plan": "pro", "customer_tier": "gold"},
  199. }
  200. with client.stream("POST", "/test_agency/get_response_stream", json=agui_payload) as response:
  201. assert response.status_code == 200
  202. list(response.iter_lines())
  203. stream_context = recording_agency_factory.tracker.last_stream_context
  204. assert stream_context is not None
  205. assert {k: v for k, v in stream_context.items() if k != "streaming_context"} == {
  206. "plan": "pro",
  207. "customer_tier": "gold",
  208. }
  209. assert "streaming_context" in stream_context
  210. def test_user_context_defaults_to_none(recording_agency_factory: RecordingAgencyFactory):
  211. """Requests without user_context should not inject overrides."""
  212. app = run_fastapi(agencies={"test_agency": recording_agency_factory}, return_app=True, app_token_env="")
  213. client = TestClient(app)
  214. response = client.post("/test_agency/get_response", json={"message": "Hello"})
  215. assert response.status_code == 200
  216. assert recording_agency_factory.tracker.last_response_context is None
  217. def _extract_last_messages_payload(lines: list[bytes | str]) -> dict[str, object]:
  218. """Return the last SSE `event: messages` payload as a dict."""
  219. current_event: str | None = None
  220. messages_payloads: list[dict[str, object]] = []
  221. for raw in lines:
  222. if not raw:
  223. continue
  224. line = raw.decode("utf-8") if isinstance(raw, bytes | bytearray) else raw
  225. if line.startswith("event:"):
  226. current_event = line.split("event:", 1)[1].strip()
  227. continue
  228. if not line.startswith("data:"):
  229. continue
  230. data_str = line.split("data:", 1)[1].strip()
  231. if data_str == "[DONE]":
  232. continue
  233. if current_event != "messages":
  234. continue
  235. payload = json.loads(data_str)
  236. if isinstance(payload, dict):
  237. messages_payloads.append(payload)
  238. assert messages_payloads, "Expected a final 'messages' SSE event payload"
  239. return messages_payloads[-1]