test_agency_responses.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. import warnings
  2. from typing import Any
  3. import pytest
  4. from agents import RunHooks
  5. from agency_swarm import Agency
  6. from agency_swarm.utils.thread import ThreadManager
  7. from tests.test_agency_modules._response_test_helpers import CapturingAgent, _make_agent
  8. # --- Fixtures ---
  9. @pytest.fixture
  10. def mock_agent():
  11. """Provides an Agent instance for testing."""
  12. return CapturingAgent("MockAgent")
  13. @pytest.fixture
  14. def mock_agent2():
  15. """Provides a second Agent instance for testing."""
  16. return _make_agent("MockAgent2")
  17. # --- Agency Response Method Tests ---
  18. @pytest.mark.asyncio
  19. async def test_agency_get_response_basic(mock_agent):
  20. """Test basic Agency.get_response functionality."""
  21. agency = Agency(mock_agent)
  22. result = await agency.get_response("Test message", "MockAgent")
  23. assert result.final_output == "Test response"
  24. @pytest.mark.asyncio
  25. async def test_agency_get_response_sync_inside_running_event_loop(mock_agent):
  26. """Ensure Agency.get_response_sync works when called from a running event loop."""
  27. agency = Agency(mock_agent)
  28. with warnings.catch_warnings():
  29. warnings.simplefilter("error", RuntimeWarning)
  30. result = agency.get_response_sync("Test message", "MockAgent")
  31. assert result.final_output == "Test response"
  32. @pytest.mark.asyncio
  33. async def test_agency_get_response_with_hooks(mock_agent):
  34. """Test Agency.get_response with hooks."""
  35. saved_messages: list[list[dict[str, Any]]] = []
  36. def mock_load_cb():
  37. return []
  38. def mock_save_cb(messages):
  39. saved_messages.append(messages)
  40. agency = Agency(mock_agent, load_threads_callback=mock_load_cb, save_threads_callback=mock_save_cb)
  41. hooks_override = RunHooks()
  42. result = await agency.get_response("Test message", "MockAgent", hooks_override=hooks_override)
  43. assert result.final_output == "Test response"
  44. assert saved_messages
  45. assert mock_agent.last_hooks_override is hooks_override
  46. @pytest.mark.asyncio
  47. async def test_agency_get_response_preserves_positional_hooks_override(mock_agent):
  48. """Adding agency_context_override must not break legacy positional hooks calls."""
  49. agency = Agency(mock_agent)
  50. hooks_override = RunHooks()
  51. result = await agency.get_response("Test message", "MockAgent", None, hooks_override)
  52. assert result.final_output == "Test response"
  53. assert mock_agent.last_hooks_override is hooks_override
  54. @pytest.mark.asyncio
  55. async def test_agency_get_response_sync_preserves_positional_hooks_override(mock_agent):
  56. """The sync entrypoint should keep the old positional argument order."""
  57. agency = Agency(mock_agent)
  58. hooks_override = RunHooks()
  59. with warnings.catch_warnings():
  60. warnings.simplefilter("error", RuntimeWarning)
  61. result = agency.get_response_sync("Test message", "MockAgent", None, hooks_override)
  62. assert result.final_output == "Test response"
  63. assert mock_agent.last_hooks_override is hooks_override
  64. @pytest.mark.asyncio
  65. async def test_agency_get_response_invalid_recipient_warning(mock_agent):
  66. """Test Agency.get_response with invalid recipient agent name."""
  67. agency = Agency(mock_agent)
  68. with pytest.raises(ValueError, match="Agent with name 'InvalidAgent' not found"):
  69. await agency.get_response("Test message", "InvalidAgent")
  70. @pytest.mark.asyncio
  71. async def test_agency_get_response_stream_basic(mock_agent):
  72. """Test basic Agency.get_response_stream functionality."""
  73. agency = Agency(mock_agent)
  74. events = []
  75. stream = agency.get_response_stream("Test message", "MockAgent")
  76. async for event in stream:
  77. events.append(event)
  78. assert stream.final_result is not None
  79. assert stream.final_result.final_output == "Test response"
  80. assert isinstance(events, list)
  81. @pytest.mark.asyncio
  82. async def test_agency_get_response_stream_with_hooks(mock_agent):
  83. """Test Agency.get_response_stream with hooks."""
  84. saved_messages: list[list[dict[str, Any]]] = []
  85. def mock_load_cb():
  86. return []
  87. def mock_save_cb(messages):
  88. saved_messages.append(messages)
  89. agency = Agency(mock_agent, load_threads_callback=mock_load_cb, save_threads_callback=mock_save_cb)
  90. hooks_override = RunHooks()
  91. events = []
  92. stream = agency.get_response_stream("Test message", "MockAgent", hooks_override=hooks_override)
  93. async for event in stream:
  94. events.append(event)
  95. assert stream.final_result is not None
  96. assert stream.final_result.final_output == "Test response"
  97. assert saved_messages
  98. @pytest.mark.asyncio
  99. async def test_agency_get_response_stream_preserves_positional_hooks_override(mock_agent):
  100. """The streaming entrypoint should keep the old positional argument order."""
  101. agency = Agency(mock_agent)
  102. hooks_override = RunHooks()
  103. stream = agency.get_response_stream("Test message", "MockAgent", None, hooks_override)
  104. async for _event in stream:
  105. pass
  106. assert stream.final_result is not None
  107. assert stream.final_result.final_output == "Test response"
  108. assert mock_agent.last_hooks_override is hooks_override
  109. @pytest.mark.asyncio
  110. async def test_agency_get_response_stream_does_not_mutate_context_override(mock_agent):
  111. """Ensure streaming runs leave the caller-provided context untouched."""
  112. capturing_agent = CapturingAgent("CaptureAgent")
  113. agency = Agency(capturing_agent)
  114. context_override = {"test_key": "test_value"}
  115. events = []
  116. stream = agency.get_response_stream("Test message", "CaptureAgent", context_override=context_override)
  117. async for event in stream:
  118. events.append(event)
  119. # Streaming still works while the user's dict stays clean
  120. assert stream.final_result is not None
  121. assert context_override == {"test_key": "test_value"}
  122. assert "streaming_context" not in context_override
  123. assert capturing_agent.last_context_override is not None
  124. assert capturing_agent.last_context_override is not context_override
  125. assert "streaming_context" in capturing_agent.last_context_override
  126. assert isinstance(events, list)
  127. @pytest.mark.asyncio
  128. async def test_agency_agent_to_agent_communication(mock_agent, mock_agent2):
  129. """Test agent-to-agent communication through Agency."""
  130. agency = Agency(mock_agent, communication_flows=[(mock_agent, mock_agent2)])
  131. result = await agency.get_response("Test message", "MockAgent")
  132. assert result.final_output == "Test response"
  133. @pytest.mark.asyncio
  134. async def test_agency_get_response_uses_agency_context_override_thread_manager(mock_agent):
  135. """Agency entrypoints should allow per-run thread manager isolation."""
  136. agency = Agency(mock_agent)
  137. isolated_thread_manager = ThreadManager()
  138. isolated_context = agency.get_agent_context("MockAgent", thread_manager_override=isolated_thread_manager)
  139. result = await agency.get_response(
  140. "Test message",
  141. "MockAgent",
  142. agency_context_override=isolated_context,
  143. )
  144. assert result.final_output == "Test response"
  145. assert mock_agent.last_agency_context is isolated_context
  146. assert isolated_thread_manager.get_all_messages()
  147. assert agency.thread_manager.get_all_messages() == []
  148. @pytest.mark.asyncio
  149. async def test_agency_get_response_stream_uses_agency_context_override_thread_manager(mock_agent):
  150. """Streaming entrypoints should respect a run-scoped agency context override."""
  151. agency = Agency(mock_agent)
  152. isolated_thread_manager = ThreadManager()
  153. isolated_context = agency.get_agent_context("MockAgent", thread_manager_override=isolated_thread_manager)
  154. stream = agency.get_response_stream(
  155. "Test message",
  156. "MockAgent",
  157. agency_context_override=isolated_context,
  158. )
  159. async for _event in stream:
  160. pass
  161. assert stream.final_result is not None
  162. assert stream.final_result.final_output == "Test response"
  163. assert mock_agent.last_agency_context is isolated_context
  164. assert isolated_thread_manager.get_all_messages()
  165. assert agency.thread_manager.get_all_messages() == []
  166. @pytest.mark.asyncio
  167. async def test_agent_communication_context_hooks_propagation(mock_agent, mock_agent2):
  168. """Test that context and hooks are properly propagated in agent communication."""
  169. saved_messages: list[list[dict[str, Any]]] = []
  170. def mock_load_cb():
  171. return []
  172. def mock_save_cb(messages):
  173. saved_messages.append(messages)
  174. agency = Agency(
  175. mock_agent,
  176. communication_flows=[(mock_agent, mock_agent2)],
  177. load_threads_callback=mock_load_cb,
  178. save_threads_callback=mock_save_cb,
  179. )
  180. context_override = {"test_key": "test_value"}
  181. hooks_override = RunHooks()
  182. result = await agency.get_response(
  183. "Test message", "MockAgent", context_override=context_override, hooks_override=hooks_override
  184. )
  185. assert result.final_output == "Test response"
  186. assert saved_messages
  187. assert mock_agent.last_context_override is context_override
  188. assert mock_agent.last_hooks_override is hooks_override