| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221 |
- """Integration tests to verify additional_instructions handling in FastAPI endpoints."""
- import pytest
- from agents.result import RunResult
- from agents.run_context import RunContextWrapper
- from agents.usage import Usage
- from fastapi.testclient import TestClient
- from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
- from agency_swarm import Agency, Agent, run_fastapi
- from agency_swarm.context import MasterContext
- from agency_swarm.utils.thread import ThreadManager
- def _make_fake_run_result(*, agent: Agent, message: str, final_output: str) -> RunResult:
- usage = Usage(
- requests=0,
- input_tokens=0,
- output_tokens=0,
- total_tokens=0,
- input_tokens_details=InputTokensDetails(cached_tokens=0),
- output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
- )
- wrapper = RunContextWrapper(
- context=MasterContext(
- thread_manager=ThreadManager(),
- agents={agent.name: agent},
- user_context={},
- agent_runtime_state={},
- current_agent_name=agent.name,
- shared_instructions=None,
- ),
- usage=usage,
- )
- return RunResult(
- input=message,
- new_items=[],
- raw_responses=[],
- final_output=final_output,
- input_guardrail_results=[],
- output_guardrail_results=[],
- tool_input_guardrail_results=[],
- tool_output_guardrail_results=[],
- context_wrapper=wrapper,
- _last_agent=agent,
- )
- def test_non_streaming_additional_instructions(monkeypatch, agency_factory):
- """Test that additional_instructions are passed to non-streaming endpoint."""
- captured_params = {}
- async def fake_get_response(self, message, additional_instructions=None, **kwargs):
- captured_params["additional_instructions"] = additional_instructions
- return _make_fake_run_result(agent=self, message=message, final_output="Test response")
- monkeypatch.setattr(Agent, "get_response", fake_get_response)
- app = run_fastapi(agencies={"test_agency": agency_factory}, return_app=True, app_token_env="")
- client = TestClient(app)
- response = client.post(
- "/test_agency/get_response",
- json={"message": "Hello", "additional_instructions": "Be very brief"},
- )
- assert response.status_code == 200
- assert captured_params["additional_instructions"] == "Be very brief"
- def test_streaming_additional_instructions(monkeypatch, agency_factory):
- """Test that additional_instructions are passed to streaming endpoint."""
- captured_params = {}
- async def fake_get_response_stream(self, message, additional_instructions=None, **kwargs):
- captured_params["additional_instructions"] = additional_instructions
- # Yield at least one event
- yield {"type": "text", "data": "Test"}
- monkeypatch.setattr(Agent, "get_response_stream", fake_get_response_stream)
- app = run_fastapi(agencies={"test_agency": agency_factory}, return_app=True, app_token_env="")
- client = TestClient(app)
- with client.stream(
- "POST",
- "/test_agency/get_response_stream",
- json={"message": "Hello", "additional_instructions": "Be very brief"},
- ) as response:
- assert response.status_code == 200
- # Consume the stream
- list(response.iter_lines())
- assert captured_params["additional_instructions"] == "Be very brief"
- def test_agui_additional_instructions(monkeypatch, agency_factory):
- """Test that additional_instructions are passed to AG-UI endpoint."""
- captured_params = {}
- async def fake_get_response_stream(self, message, additional_instructions=None, **kwargs):
- captured_params["additional_instructions"] = additional_instructions
- # Yield at least one event
- yield {"type": "text", "data": "Test"}
- monkeypatch.setattr(Agent, "get_response_stream", fake_get_response_stream)
- app = run_fastapi(
- agencies={"test_agency": agency_factory},
- return_app=True,
- app_token_env="",
- enable_agui=True,
- )
- client = TestClient(app)
- agui_payload = {
- "thread_id": "test_thread",
- "run_id": "test_run",
- "state": None,
- "messages": [{"id": "msg1", "role": "user", "content": "Hello"}],
- "tools": [],
- "context": [],
- "forwardedProps": None,
- "additional_instructions": "Be very brief",
- }
- with client.stream("POST", "/test_agency/get_response_stream", json=agui_payload) as response:
- assert response.status_code == 200
- # Consume the stream
- list(response.iter_lines())
- assert captured_params["additional_instructions"] == "Be very brief"
- def test_agui_chat_history_additional_instructions(monkeypatch, agency_factory):
- """Test that chat_history works with additional_instructions in AG-UI endpoint."""
- captured_params = {}
- async def fake_get_response_stream(self, message, additional_instructions=None, **kwargs):
- captured_params["additional_instructions"] = additional_instructions
- yield {"type": "text", "data": "Test"}
- monkeypatch.setattr(Agent, "get_response_stream", fake_get_response_stream)
- app = run_fastapi(
- agencies={"test_agency": agency_factory},
- return_app=True,
- app_token_env="",
- enable_agui=True,
- )
- client = TestClient(app)
- agui_payload = {
- "thread_id": "test_thread",
- "run_id": "test_run",
- "state": None,
- "messages": [{"id": "msg1", "role": "user", "content": "Hello"}],
- "chat_history": [
- {
- "agent": "TestAgent",
- "callerAgent": None,
- "timestamp": 0,
- "role": "user",
- "content": "Hello",
- }
- ],
- "tools": [],
- "context": [],
- "forwardedProps": None,
- "additional_instructions": "Be very brief",
- }
- with client.stream("POST", "/test_agency/get_response_stream", json=agui_payload) as response:
- assert response.status_code == 200
- list(response.iter_lines())
- assert captured_params["additional_instructions"] == "Be very brief"
- def test_additional_instructions_none_handling(monkeypatch, agency_factory):
- """Test that None additional_instructions are handled properly."""
- captured_params = {}
- async def fake_get_response(self, message, additional_instructions=None, **kwargs):
- captured_params["additional_instructions"] = additional_instructions
- return _make_fake_run_result(agent=self, message=message, final_output="Test response")
- monkeypatch.setattr(Agent, "get_response", fake_get_response)
- app = run_fastapi(agencies={"test_agency": agency_factory}, return_app=True, app_token_env="")
- client = TestClient(app)
- # Test without additional_instructions field
- response = client.post("/test_agency/get_response", json={"message": "Hello"})
- assert response.status_code == 200
- assert captured_params["additional_instructions"] is None
- @pytest.mark.asyncio
- async def test_additional_instructions_real_integration(agency_factory):
- """Test with a real agency instance (without mocking) to ensure end-to-end functionality."""
- agent = Agent(
- name="TestAgent",
- instructions="You are a test agent. Follow any additional instructions carefully.",
- )
- agency = Agency(agent)
- # Test that additional_instructions don't break the real call
- response = await agency.get_response(message="Say hello", additional_instructions="Keep it under 10 words")
- # Verify we get a response (even if we can't verify the LLM actually followed the instructions)
- assert response.final_output is not None
- assert isinstance(response.final_output, str)
- assert len(response.final_output) > 0
- if __name__ == "__main__":
- # Allow direct execution for debugging
- pytest.main([__file__, "-v"])
|