test_fastapi_additional_instructions.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. """Integration tests to verify additional_instructions handling in FastAPI endpoints."""
  2. import pytest
  3. from agents.result import RunResult
  4. from agents.run_context import RunContextWrapper
  5. from agents.usage import Usage
  6. from fastapi.testclient import TestClient
  7. from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
  8. from agency_swarm import Agency, Agent, run_fastapi
  9. from agency_swarm.context import MasterContext
  10. from agency_swarm.utils.thread import ThreadManager
  11. def _make_fake_run_result(*, agent: Agent, message: str, final_output: str) -> RunResult:
  12. usage = Usage(
  13. requests=0,
  14. input_tokens=0,
  15. output_tokens=0,
  16. total_tokens=0,
  17. input_tokens_details=InputTokensDetails(cached_tokens=0),
  18. output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
  19. )
  20. wrapper = RunContextWrapper(
  21. context=MasterContext(
  22. thread_manager=ThreadManager(),
  23. agents={agent.name: agent},
  24. user_context={},
  25. agent_runtime_state={},
  26. current_agent_name=agent.name,
  27. shared_instructions=None,
  28. ),
  29. usage=usage,
  30. )
  31. return RunResult(
  32. input=message,
  33. new_items=[],
  34. raw_responses=[],
  35. final_output=final_output,
  36. input_guardrail_results=[],
  37. output_guardrail_results=[],
  38. tool_input_guardrail_results=[],
  39. tool_output_guardrail_results=[],
  40. context_wrapper=wrapper,
  41. _last_agent=agent,
  42. )
  43. def test_non_streaming_additional_instructions(monkeypatch, agency_factory):
  44. """Test that additional_instructions are passed to non-streaming endpoint."""
  45. captured_params = {}
  46. async def fake_get_response(self, message, additional_instructions=None, **kwargs):
  47. captured_params["additional_instructions"] = additional_instructions
  48. return _make_fake_run_result(agent=self, message=message, final_output="Test response")
  49. monkeypatch.setattr(Agent, "get_response", fake_get_response)
  50. app = run_fastapi(agencies={"test_agency": agency_factory}, return_app=True, app_token_env="")
  51. client = TestClient(app)
  52. response = client.post(
  53. "/test_agency/get_response",
  54. json={"message": "Hello", "additional_instructions": "Be very brief"},
  55. )
  56. assert response.status_code == 200
  57. assert captured_params["additional_instructions"] == "Be very brief"
  58. def test_streaming_additional_instructions(monkeypatch, agency_factory):
  59. """Test that additional_instructions are passed to streaming endpoint."""
  60. captured_params = {}
  61. async def fake_get_response_stream(self, message, additional_instructions=None, **kwargs):
  62. captured_params["additional_instructions"] = additional_instructions
  63. # Yield at least one event
  64. yield {"type": "text", "data": "Test"}
  65. monkeypatch.setattr(Agent, "get_response_stream", fake_get_response_stream)
  66. app = run_fastapi(agencies={"test_agency": agency_factory}, return_app=True, app_token_env="")
  67. client = TestClient(app)
  68. with client.stream(
  69. "POST",
  70. "/test_agency/get_response_stream",
  71. json={"message": "Hello", "additional_instructions": "Be very brief"},
  72. ) as response:
  73. assert response.status_code == 200
  74. # Consume the stream
  75. list(response.iter_lines())
  76. assert captured_params["additional_instructions"] == "Be very brief"
  77. def test_agui_additional_instructions(monkeypatch, agency_factory):
  78. """Test that additional_instructions are passed to AG-UI endpoint."""
  79. captured_params = {}
  80. async def fake_get_response_stream(self, message, additional_instructions=None, **kwargs):
  81. captured_params["additional_instructions"] = additional_instructions
  82. # Yield at least one event
  83. yield {"type": "text", "data": "Test"}
  84. monkeypatch.setattr(Agent, "get_response_stream", fake_get_response_stream)
  85. app = run_fastapi(
  86. agencies={"test_agency": agency_factory},
  87. return_app=True,
  88. app_token_env="",
  89. enable_agui=True,
  90. )
  91. client = TestClient(app)
  92. agui_payload = {
  93. "thread_id": "test_thread",
  94. "run_id": "test_run",
  95. "state": None,
  96. "messages": [{"id": "msg1", "role": "user", "content": "Hello"}],
  97. "tools": [],
  98. "context": [],
  99. "forwardedProps": None,
  100. "additional_instructions": "Be very brief",
  101. }
  102. with client.stream("POST", "/test_agency/get_response_stream", json=agui_payload) as response:
  103. assert response.status_code == 200
  104. # Consume the stream
  105. list(response.iter_lines())
  106. assert captured_params["additional_instructions"] == "Be very brief"
  107. def test_agui_chat_history_additional_instructions(monkeypatch, agency_factory):
  108. """Test that chat_history works with additional_instructions in AG-UI endpoint."""
  109. captured_params = {}
  110. async def fake_get_response_stream(self, message, additional_instructions=None, **kwargs):
  111. captured_params["additional_instructions"] = additional_instructions
  112. yield {"type": "text", "data": "Test"}
  113. monkeypatch.setattr(Agent, "get_response_stream", fake_get_response_stream)
  114. app = run_fastapi(
  115. agencies={"test_agency": agency_factory},
  116. return_app=True,
  117. app_token_env="",
  118. enable_agui=True,
  119. )
  120. client = TestClient(app)
  121. agui_payload = {
  122. "thread_id": "test_thread",
  123. "run_id": "test_run",
  124. "state": None,
  125. "messages": [{"id": "msg1", "role": "user", "content": "Hello"}],
  126. "chat_history": [
  127. {
  128. "agent": "TestAgent",
  129. "callerAgent": None,
  130. "timestamp": 0,
  131. "role": "user",
  132. "content": "Hello",
  133. }
  134. ],
  135. "tools": [],
  136. "context": [],
  137. "forwardedProps": None,
  138. "additional_instructions": "Be very brief",
  139. }
  140. with client.stream("POST", "/test_agency/get_response_stream", json=agui_payload) as response:
  141. assert response.status_code == 200
  142. list(response.iter_lines())
  143. assert captured_params["additional_instructions"] == "Be very brief"
  144. def test_additional_instructions_none_handling(monkeypatch, agency_factory):
  145. """Test that None additional_instructions are handled properly."""
  146. captured_params = {}
  147. async def fake_get_response(self, message, additional_instructions=None, **kwargs):
  148. captured_params["additional_instructions"] = additional_instructions
  149. return _make_fake_run_result(agent=self, message=message, final_output="Test response")
  150. monkeypatch.setattr(Agent, "get_response", fake_get_response)
  151. app = run_fastapi(agencies={"test_agency": agency_factory}, return_app=True, app_token_env="")
  152. client = TestClient(app)
  153. # Test without additional_instructions field
  154. response = client.post("/test_agency/get_response", json={"message": "Hello"})
  155. assert response.status_code == 200
  156. assert captured_params["additional_instructions"] is None
  157. @pytest.mark.asyncio
  158. async def test_additional_instructions_real_integration(agency_factory):
  159. """Test with a real agency instance (without mocking) to ensure end-to-end functionality."""
  160. agent = Agent(
  161. name="TestAgent",
  162. instructions="You are a test agent. Follow any additional instructions carefully.",
  163. )
  164. agency = Agency(agent)
  165. # Test that additional_instructions don't break the real call
  166. response = await agency.get_response(message="Say hello", additional_instructions="Keep it under 10 words")
  167. # Verify we get a response (even if we can't verify the LLM actually followed the instructions)
  168. assert response.final_output is not None
  169. assert isinstance(response.final_output, str)
  170. assert len(response.final_output) > 0
  171. if __name__ == "__main__":
  172. # Allow direct execution for debugging
  173. pytest.main([__file__, "-v"])