| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264 |
- import warnings
- from typing import Any
- import pytest
- from agents import RunHooks
- from agency_swarm import Agency
- from agency_swarm.utils.thread import ThreadManager
- from tests.test_agency_modules._response_test_helpers import CapturingAgent, _make_agent
- # --- Fixtures ---
- @pytest.fixture
- def mock_agent():
- """Provides an Agent instance for testing."""
- return CapturingAgent("MockAgent")
- @pytest.fixture
- def mock_agent2():
- """Provides a second Agent instance for testing."""
- return _make_agent("MockAgent2")
- # --- Agency Response Method Tests ---
- @pytest.mark.asyncio
- async def test_agency_get_response_basic(mock_agent):
- """Test basic Agency.get_response functionality."""
- agency = Agency(mock_agent)
- result = await agency.get_response("Test message", "MockAgent")
- assert result.final_output == "Test response"
- @pytest.mark.asyncio
- async def test_agency_get_response_sync_inside_running_event_loop(mock_agent):
- """Ensure Agency.get_response_sync works when called from a running event loop."""
- agency = Agency(mock_agent)
- with warnings.catch_warnings():
- warnings.simplefilter("error", RuntimeWarning)
- result = agency.get_response_sync("Test message", "MockAgent")
- assert result.final_output == "Test response"
- @pytest.mark.asyncio
- async def test_agency_get_response_with_hooks(mock_agent):
- """Test Agency.get_response with hooks."""
- saved_messages: list[list[dict[str, Any]]] = []
- def mock_load_cb():
- return []
- def mock_save_cb(messages):
- saved_messages.append(messages)
- agency = Agency(mock_agent, load_threads_callback=mock_load_cb, save_threads_callback=mock_save_cb)
- hooks_override = RunHooks()
- result = await agency.get_response("Test message", "MockAgent", hooks_override=hooks_override)
- assert result.final_output == "Test response"
- assert saved_messages
- assert mock_agent.last_hooks_override is hooks_override
- @pytest.mark.asyncio
- async def test_agency_get_response_preserves_positional_hooks_override(mock_agent):
- """Adding agency_context_override must not break legacy positional hooks calls."""
- agency = Agency(mock_agent)
- hooks_override = RunHooks()
- result = await agency.get_response("Test message", "MockAgent", None, hooks_override)
- assert result.final_output == "Test response"
- assert mock_agent.last_hooks_override is hooks_override
- @pytest.mark.asyncio
- async def test_agency_get_response_sync_preserves_positional_hooks_override(mock_agent):
- """The sync entrypoint should keep the old positional argument order."""
- agency = Agency(mock_agent)
- hooks_override = RunHooks()
- with warnings.catch_warnings():
- warnings.simplefilter("error", RuntimeWarning)
- result = agency.get_response_sync("Test message", "MockAgent", None, hooks_override)
- assert result.final_output == "Test response"
- assert mock_agent.last_hooks_override is hooks_override
- @pytest.mark.asyncio
- async def test_agency_get_response_invalid_recipient_warning(mock_agent):
- """Test Agency.get_response with invalid recipient agent name."""
- agency = Agency(mock_agent)
- with pytest.raises(ValueError, match="Agent with name 'InvalidAgent' not found"):
- await agency.get_response("Test message", "InvalidAgent")
- @pytest.mark.asyncio
- async def test_agency_get_response_stream_basic(mock_agent):
- """Test basic Agency.get_response_stream functionality."""
- agency = Agency(mock_agent)
- events = []
- stream = agency.get_response_stream("Test message", "MockAgent")
- async for event in stream:
- events.append(event)
- assert stream.final_result is not None
- assert stream.final_result.final_output == "Test response"
- assert isinstance(events, list)
- @pytest.mark.asyncio
- async def test_agency_get_response_stream_with_hooks(mock_agent):
- """Test Agency.get_response_stream with hooks."""
- saved_messages: list[list[dict[str, Any]]] = []
- def mock_load_cb():
- return []
- def mock_save_cb(messages):
- saved_messages.append(messages)
- agency = Agency(mock_agent, load_threads_callback=mock_load_cb, save_threads_callback=mock_save_cb)
- hooks_override = RunHooks()
- events = []
- stream = agency.get_response_stream("Test message", "MockAgent", hooks_override=hooks_override)
- async for event in stream:
- events.append(event)
- assert stream.final_result is not None
- assert stream.final_result.final_output == "Test response"
- assert saved_messages
- @pytest.mark.asyncio
- async def test_agency_get_response_stream_preserves_positional_hooks_override(mock_agent):
- """The streaming entrypoint should keep the old positional argument order."""
- agency = Agency(mock_agent)
- hooks_override = RunHooks()
- stream = agency.get_response_stream("Test message", "MockAgent", None, hooks_override)
- async for _event in stream:
- pass
- assert stream.final_result is not None
- assert stream.final_result.final_output == "Test response"
- assert mock_agent.last_hooks_override is hooks_override
- @pytest.mark.asyncio
- async def test_agency_get_response_stream_does_not_mutate_context_override(mock_agent):
- """Ensure streaming runs leave the caller-provided context untouched."""
- capturing_agent = CapturingAgent("CaptureAgent")
- agency = Agency(capturing_agent)
- context_override = {"test_key": "test_value"}
- events = []
- stream = agency.get_response_stream("Test message", "CaptureAgent", context_override=context_override)
- async for event in stream:
- events.append(event)
- # Streaming still works while the user's dict stays clean
- assert stream.final_result is not None
- assert context_override == {"test_key": "test_value"}
- assert "streaming_context" not in context_override
- assert capturing_agent.last_context_override is not None
- assert capturing_agent.last_context_override is not context_override
- assert "streaming_context" in capturing_agent.last_context_override
- assert isinstance(events, list)
- @pytest.mark.asyncio
- async def test_agency_agent_to_agent_communication(mock_agent, mock_agent2):
- """Test agent-to-agent communication through Agency."""
- agency = Agency(mock_agent, communication_flows=[(mock_agent, mock_agent2)])
- result = await agency.get_response("Test message", "MockAgent")
- assert result.final_output == "Test response"
- @pytest.mark.asyncio
- async def test_agency_get_response_uses_agency_context_override_thread_manager(mock_agent):
- """Agency entrypoints should allow per-run thread manager isolation."""
- agency = Agency(mock_agent)
- isolated_thread_manager = ThreadManager()
- isolated_context = agency.get_agent_context("MockAgent", thread_manager_override=isolated_thread_manager)
- result = await agency.get_response(
- "Test message",
- "MockAgent",
- agency_context_override=isolated_context,
- )
- assert result.final_output == "Test response"
- assert mock_agent.last_agency_context is isolated_context
- assert isolated_thread_manager.get_all_messages()
- assert agency.thread_manager.get_all_messages() == []
- @pytest.mark.asyncio
- async def test_agency_get_response_stream_uses_agency_context_override_thread_manager(mock_agent):
- """Streaming entrypoints should respect a run-scoped agency context override."""
- agency = Agency(mock_agent)
- isolated_thread_manager = ThreadManager()
- isolated_context = agency.get_agent_context("MockAgent", thread_manager_override=isolated_thread_manager)
- stream = agency.get_response_stream(
- "Test message",
- "MockAgent",
- agency_context_override=isolated_context,
- )
- async for _event in stream:
- pass
- assert stream.final_result is not None
- assert stream.final_result.final_output == "Test response"
- assert mock_agent.last_agency_context is isolated_context
- assert isolated_thread_manager.get_all_messages()
- assert agency.thread_manager.get_all_messages() == []
- @pytest.mark.asyncio
- async def test_agent_communication_context_hooks_propagation(mock_agent, mock_agent2):
- """Test that context and hooks are properly propagated in agent communication."""
- saved_messages: list[list[dict[str, Any]]] = []
- def mock_load_cb():
- return []
- def mock_save_cb(messages):
- saved_messages.append(messages)
- agency = Agency(
- mock_agent,
- communication_flows=[(mock_agent, mock_agent2)],
- load_threads_callback=mock_load_cb,
- save_threads_callback=mock_save_cb,
- )
- context_override = {"test_key": "test_value"}
- hooks_override = RunHooks()
- result = await agency.get_response(
- "Test message", "MockAgent", context_override=context_override, hooks_override=hooks_override
- )
- assert result.final_output == "Test response"
- assert saved_messages
- assert mock_agent.last_context_override is context_override
- assert mock_agent.last_hooks_override is hooks_override
|