test_agent_handoffs.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. """
  2. Test suite for verifying the combination of handoffs and communication flows in Agency Swarm.
  3. Key Implementation Findings:
  4. ============================
  5. 1. **Communication Flows (SendMessage tools)**:
  6. - Agency creates a unified `send_message` tool with multiple recipients for each agent's communication flows
  7. - This is a single FunctionTool instance that can send messages to any registered recipient
  8. - Control returns to the calling agent after receiving a response (orchestrator pattern)
  9. 2. **Handoffs (via Handoff tool class)**:
  10. - Handoffs are configured by setting `Handoff` as the flow tool class in `communication_flows`
  11. - Communication flows determine handoff targets (sender with Handoff can hand off to recipient)
  12. - Handoffs represent unidirectional transfer of control (agent B takes over from agent A)
  13. 3. **Expected Tool Configuration**:
  14. - AgentA (orchestrator): `send_message` tool with AgentB and AgentC as recipients
  15. - AgentB (with handoffs): No tools for handoffs (SDK handles), but retains handoffs attribute
  16. - AgentC (specialist): No communication tools
  17. 4. **Combining Both Patterns**:
  18. - Communication flows and handoffs can coexist via different send message tool classes
  19. - Agency creates SendMessage tools based on communication_flows parameter
  20. - Tool class (SendMessage vs Handoff) determines behavior
  21. - Handoffs functionality is enabled through Handoff tool class
  22. """
  23. from unittest.mock import MagicMock, patch
  24. import pytest
  25. from agents import HandoffInputData, ModelSettings, RunContextWrapper
  26. from agency_swarm import Agency, Agent
  27. from agency_swarm.tools import Handoff
  28. from agency_swarm.utils.thread import ThreadManager
  29. @pytest.fixture
  30. def orchestrator_agent():
  31. """Create an orchestrator agent that can communicate with other agents."""
  32. return Agent(
  33. name="AgentA",
  34. instructions="You are an orchestrator agent. You coordinate tasks by communicating with other agents.",
  35. model_settings=ModelSettings(temperature=0.0),
  36. )
  37. @pytest.fixture
  38. def intermediate_agent():
  39. """Create an intermediate agent that has handoffs configured via Handoff tool class."""
  40. return Agent(
  41. name="AgentB",
  42. instructions=(
  43. "You are an intermediate agent. Whenever asked to speak with agent C, use the transfer_to_AgentC tool "
  44. "immediately, without any questions."
  45. ),
  46. model_settings=ModelSettings(temperature=0.0),
  47. )
  48. @pytest.fixture
  49. def specialist_agent():
  50. """Create a specialist agent that receives handoffs."""
  51. return Agent(
  52. name="AgentC",
  53. instructions="You are a specialist agent. You process tasks handed off from other agents.",
  54. model_settings=ModelSettings(temperature=0.0),
  55. )
  56. @pytest.fixture
  57. def mixed_communication_agency(orchestrator_agent, intermediate_agent, specialist_agent):
  58. """Create an agency with both communication flows and handoffs configured."""
  59. # Create agency with communication flows: AgentA can send messages to both AgentB and AgentC
  60. # AgentB can hand off to AgentC (enabled by Handoff tool class and communication flow)
  61. agency = Agency(
  62. orchestrator_agent, # Entry point
  63. communication_flows=[
  64. orchestrator_agent > intermediate_agent, # AgentA -> AgentB (regular SendMessage)
  65. orchestrator_agent > specialist_agent, # AgentA -> AgentC (regular SendMessage)
  66. (intermediate_agent > specialist_agent, Handoff), # AgentB -> AgentC (handoff)
  67. ],
  68. shared_instructions="Test agency for mixed communication patterns.",
  69. )
  70. return agency
  71. class TestHandoffsWithCommunicationFlows:
  72. """Test suite for handoffs combined with communication flows."""
  73. def test_agent_tool_configuration(self, mixed_communication_agency):
  74. """Test that agents have the correct tools based on communication flows and handoffs."""
  75. runtime_state_a = mixed_communication_agency.get_agent_runtime_state("AgentA")
  76. runtime_state_b = mixed_communication_agency.get_agent_runtime_state("AgentB")
  77. runtime_state_c = mixed_communication_agency.get_agent_runtime_state("AgentC")
  78. send_message_tools = list(runtime_state_a.send_message_tools.values())
  79. assert len(send_message_tools) == 1, "AgentA should expose exactly one runtime send_message tool"
  80. send_msg_tool = send_message_tools[0]
  81. recipient_names = [agent.name for agent in send_msg_tool.recipients.values()]
  82. assert "AgentB" in recipient_names, f"AgentB should be in send_message recipients, got: {recipient_names}"
  83. assert "AgentC" in recipient_names, f"AgentC should be in send_message recipients, got: {recipient_names}"
  84. assert runtime_state_b.handoffs, "AgentB should register handoffs at runtime"
  85. assert not runtime_state_c.send_message_tools, "AgentC should not expose send_message tools"
  86. def test_sendmessage_tool_recipients(self, mixed_communication_agency):
  87. """Test that SendMessage tool has the correct recipients."""
  88. runtime_state_a = mixed_communication_agency.get_agent_runtime_state("AgentA")
  89. sendmessage_tools = list(runtime_state_a.send_message_tools.values())
  90. assert len(sendmessage_tools) == 1, f"AgentA should have 1 send_message tool, got: {len(sendmessage_tools)}"
  91. send_msg_tool = sendmessage_tools[0]
  92. recipient_names = [agent.name for agent in send_msg_tool.recipients.values()]
  93. assert "AgentB" in recipient_names, f"AgentB should be in recipients, got: {recipient_names}"
  94. assert "AgentC" in recipient_names, f"AgentC should be in recipients, got: {recipient_names}"
  95. assert len(recipient_names) == 2, f"Should have exactly 2 recipients, got: {recipient_names}"
  96. def test_handoff_configuration_via_sendmessage_tool_class(self, mixed_communication_agency):
  97. """Test that handoffs are properly configured via flow tool class."""
  98. runtime_state_b = mixed_communication_agency.get_agent_runtime_state("AgentB")
  99. # Verify AgentB has handoff to AgentC in .handoffs attribute (not in .tools list)
  100. assert runtime_state_b.handoffs, "AgentB runtime state should contain handoffs"
  101. assert len(runtime_state_b.handoffs) == 1, f"AgentB should have 1 handoff, got: {len(runtime_state_b.handoffs)}"
  102. # Check that the handoff targets AgentC
  103. handoff = runtime_state_b.handoffs[0]
  104. assert handoff.agent_name == "AgentC", f"AgentB's handoff should target AgentC, got: {handoff.agent_name}"
  105. def test_agency_configuration_maintains_both_patterns(self, mixed_communication_agency):
  106. """Test that Agency maintains both communication flows and handoffs."""
  107. _ = mixed_communication_agency.agents["AgentA"]
  108. _ = mixed_communication_agency.agents["AgentC"]
  109. # Verify agents are properly registered
  110. assert len(mixed_communication_agency.agents) == 3
  111. assert all(agent_name in mixed_communication_agency.agents for agent_name in ["AgentA", "AgentB", "AgentC"])
  112. runtime_state_b = mixed_communication_agency.get_agent_runtime_state("AgentB")
  113. assert runtime_state_b.handoffs, "AgentB should register handoffs at runtime"
  114. def test_tool_count_expectations(self, mixed_communication_agency):
  115. """Test that each agent has the expected number and type of tools."""
  116. runtime_state_a = mixed_communication_agency.get_agent_runtime_state("AgentA")
  117. runtime_state_b = mixed_communication_agency.get_agent_runtime_state("AgentB")
  118. runtime_state_c = mixed_communication_agency.get_agent_runtime_state("AgentC")
  119. assert len(runtime_state_a.send_message_tools) == 1, "AgentA should expose 1 send_message tool"
  120. assert not runtime_state_b.send_message_tools, "AgentB should not expose send_message tools"
  121. assert not runtime_state_c.send_message_tools, "AgentC should not expose send_message tools"
  122. @pytest.mark.asyncio
  123. async def test_orchestrator_pattern_with_handoffs(self, mixed_communication_agency):
  124. """Test the orchestrator pattern where AgentA uses AgentB which then hands off to AgentC."""
  125. agent_a = mixed_communication_agency.agents["AgentA"]
  126. agent_b = mixed_communication_agency.agents["AgentB"]
  127. agent_c = mixed_communication_agency.agents["AgentC"]
  128. # Mock responses for the chain of communication
  129. mock_c_response = MagicMock()
  130. mock_c_response.final_output = "Task completed by AgentC"
  131. mock_b_response = MagicMock()
  132. mock_b_response.final_output = "Task processed by AgentB and handed off to AgentC"
  133. try:
  134. with (
  135. patch.object(agent_c, "get_response", return_value=mock_c_response),
  136. patch.object(agent_b, "get_response", return_value=mock_b_response),
  137. ):
  138. # AgentA orchestrates by sending message to AgentB
  139. result = await agent_a.get_response(
  140. message="Send this complex task to AgentB for processing and potential handoff",
  141. )
  142. assert result is not None
  143. except Exception as e:
  144. pytest.skip(f"Orchestrator pattern with handoffs not fully implemented: {e}")
  145. @pytest.mark.asyncio
  146. async def test_handoff_reminder_handles_empty_history(self, specialist_agent):
  147. """Ensure reminder injection does not crash when the thread history is empty."""
  148. handoff_tool = Handoff().create_handoff(specialist_agent)
  149. assert handoff_tool.input_filter is not None, "Expected handoff to expose an input filter"
  150. thread_manager = ThreadManager()
  151. context = type("Context", (), {"thread_manager": thread_manager})()
  152. run_context = RunContextWrapper(context=context)
  153. handoff_input = HandoffInputData(
  154. input_history=(),
  155. pre_handoff_items=(),
  156. new_items=(),
  157. run_context=run_context,
  158. )
  159. filtered_input = await handoff_tool.input_filter(handoff_input)
  160. assert filtered_input.input_history == ()
  161. assert thread_manager.get_all_messages() == []
  162. def test_communication_flow_isolation(self, mixed_communication_agency):
  163. """Test that communication flows and handoffs maintain proper isolation."""
  164. _ = mixed_communication_agency.agents["AgentA"]
  165. _ = mixed_communication_agency.agents["AgentB"]
  166. _ = mixed_communication_agency.agents["AgentC"]
  167. # AgentA should be able to communicate with both AgentB and AgentC independently
  168. # AgentB should only be able to hand off to AgentC (not send messages)
  169. # AgentC should not be able to initiate communication with others
  170. runtime_state_a = mixed_communication_agency.get_agent_runtime_state("AgentA")
  171. runtime_state_b = mixed_communication_agency.get_agent_runtime_state("AgentB")
  172. runtime_state_c = mixed_communication_agency.get_agent_runtime_state("AgentC")
  173. assert runtime_state_a.send_message_tools, "AgentA should have send_message tools"
  174. assert not runtime_state_b.send_message_tools, "AgentB should not expose send_message tools"
  175. assert not runtime_state_c.send_message_tools, "AgentC should not expose send_message tools"
  176. class TestComplexHandoffScenarios:
  177. """Test more complex scenarios with multiple handoffs and communication flows."""
  178. def test_multiple_handoff_targets(self):
  179. """Test agent with multiple handoff targets via Handoff tool class."""
  180. agent_a = Agent(name="AgentA", instructions="Orchestrator")
  181. agent_b = Agent(name="AgentB", instructions="Multi-handoff agent")
  182. agent_c = Agent(name="AgentC", instructions="Specialist 1")
  183. agent_d = Agent(name="AgentD", instructions="Specialist 2")
  184. agency = Agency(
  185. agent_a,
  186. communication_flows=[
  187. agent_a > agent_b,
  188. (agent_b > agent_c, Handoff), # AgentB can hand off to AgentC
  189. (agent_b > agent_d, Handoff), # AgentB can hand off to AgentD
  190. ],
  191. )
  192. runtime_state_b = agency.get_agent_runtime_state("AgentB")
  193. assert len(runtime_state_b.handoffs) == 2, (
  194. f"AgentB should have 2 handoffs, got: {len(runtime_state_b.handoffs)}"
  195. )
  196. # Verify the handoff targets are correct
  197. handoff_targets = [h.agent_name for h in runtime_state_b.handoffs]
  198. assert "AgentC" in handoff_targets, "AgentB should have handoff to AgentC"
  199. assert "AgentD" in handoff_targets, "AgentB should have handoff to AgentD"
  200. def test_bidirectional_communication_with_handoffs(self):
  201. """Test bidirectional communication flows combined with Handoff tool class."""
  202. agent_a = Agent(name="AgentA", instructions="Primary orchestrator")
  203. agent_b = Agent(name="AgentB", instructions="Secondary orchestrator with handoffs")
  204. agent_c = Agent(name="AgentC", instructions="Specialist")
  205. # Configure bidirectional communication between A and B, plus handoff capability from B to C
  206. agency = Agency(
  207. agent_a,
  208. communication_flows=[
  209. agent_a > agent_b, # A can send to B
  210. (agent_b > agent_a, Handoff), # B can hand off to A
  211. agent_a > agent_c, # A can send to C
  212. (agent_b > agent_c, Handoff), # B can hand off to C
  213. ],
  214. )
  215. runtime_state_a = agency.get_agent_runtime_state("AgentA")
  216. runtime_state_b = agency.get_agent_runtime_state("AgentB")
  217. assert runtime_state_a.send_message_tools, "AgentA should expose send_message tools"
  218. send_msg_tool = next(iter(runtime_state_a.send_message_tools.values()))
  219. recipient_names = [agent.name for agent in send_msg_tool.recipients.values()]
  220. assert "AgentB" in recipient_names, f"AgentB should be reachable, got: {recipient_names}"
  221. assert "AgentC" in recipient_names, f"AgentC should be reachable, got: {recipient_names}"
  222. assert len(runtime_state_b.handoffs) == 2, (
  223. f"AgentB should have 2 handoffs, got: {len(runtime_state_b.handoffs)}"
  224. )
  225. handoff_targets = [h.agent_name for h in runtime_state_b.handoffs]
  226. assert "AgentA" in handoff_targets, f"AgentB should have handoff to AgentA, got: {handoff_targets}"
  227. assert "AgentC" in handoff_targets, f"AgentB should have handoff to AgentC, got: {handoff_targets}"
  228. assert runtime_state_b.handoffs, "AgentB should register handoffs at runtime"
  229. def test_agency_flow_handoffs(self):
  230. """Test bidirectional communication flows combined with Handoff tool class."""
  231. agent_a = Agent(name="AgentA", instructions="Primary orchestrator")
  232. agent_b = Agent(
  233. name="AgentB",
  234. instructions="Secondary orchestrator with handoffs",
  235. )
  236. agent_c = Agent(name="AgentC", instructions="Specialist")
  237. # Configure bidirectional communication between A and B, plus handoff capability from B to C
  238. agency = Agency(
  239. agent_a,
  240. communication_flows=[
  241. (agent_a > agent_b), # A can send to B
  242. (agent_b > agent_a, Handoff), # B can send to A (using Handoff tool class)
  243. (agent_a > agent_c), # A can send to C
  244. (agent_b > agent_c, Handoff), # B can hand off to C (using Handoff tool class)
  245. ],
  246. )
  247. runtime_state_a = agency.get_agent_runtime_state("AgentA")
  248. runtime_state_b = agency.get_agent_runtime_state("AgentB")
  249. assert runtime_state_a.send_message_tools, "AgentA should expose send_message tools"
  250. send_msg_tool = next(iter(runtime_state_a.send_message_tools.values()))
  251. recipient_names = [agent.name for agent in send_msg_tool.recipients.values()]
  252. assert "AgentB" in recipient_names, "AgentB should be reachable from AgentA"
  253. assert "AgentC" in recipient_names, "AgentC should be reachable from AgentA"
  254. assert len(runtime_state_b.handoffs) == 2, (
  255. f"AgentB should have 2 handoffs, got: {len(runtime_state_b.handoffs)}"
  256. )
  257. handoff_targets = [h.agent_name for h in runtime_state_b.handoffs]
  258. assert "AgentA" in handoff_targets, f"AgentB should have handoff to AgentA, got: {handoff_targets}"
  259. assert "AgentC" in handoff_targets, f"AgentB should have handoff to AgentC, got: {handoff_targets}"
  260. @pytest.mark.asyncio
  261. async def test_nested_handoffs_on_follow_ups(self, mixed_communication_agency):
  262. """Test that there are no errors on follow up messages."""
  263. # First handoff
  264. async for _ in mixed_communication_agency.get_response_stream("Ask Agent B to use transfer_to_AgentC tool."):
  265. pass
  266. # Verify handoff occurred
  267. messages = mixed_communication_agency.thread_manager.get_all_messages()
  268. tool_names = [msg.get("name") for msg in messages if msg.get("type") == "function_call"]
  269. assert "transfer_to_AgentC" in tool_names, "Should have used transfer_to_AgentC tool"
  270. # Second handoff (follow-up)
  271. async for _ in mixed_communication_agency.get_response_stream(
  272. "Ask Agent B to use transfer_to_AgentC tool again."
  273. ):
  274. pass
  275. # Verify no errors in tool outputs
  276. messages = mixed_communication_agency.thread_manager.get_all_messages()
  277. tool_outputs = [msg.get("output", "") for msg in messages if msg.get("type") == "function_call_output"]
  278. for output in tool_outputs:
  279. assert "error" not in output.lower(), f"Found error in tool output: {output}"
  280. def test_handoff_reminders(self):
  281. """Test bidirectional communication flows combined with Handoff tool class."""
  282. class NoReminder(Handoff):
  283. add_reminder = False
  284. agent_a = Agent(
  285. name="AgentA", instructions="Primary orchestrator", model_settings=ModelSettings(temperature=0.0)
  286. )
  287. agent_b = Agent(
  288. name="AgentB",
  289. instructions="Secondary orchestrator with handoffs",
  290. model_settings=ModelSettings(temperature=0.0),
  291. )
  292. agent_c = Agent(
  293. name="AgentC",
  294. instructions="Specialist",
  295. model_settings=ModelSettings(temperature=0.0),
  296. handoff_reminder="Custom reminder",
  297. )
  298. # Configure bidirectional communication between A and B, plus handoff capability from B to C
  299. agency = Agency(
  300. agent_a,
  301. agent_b,
  302. agent_c,
  303. communication_flows=[
  304. (agent_a > agent_b, Handoff), # A can send to B
  305. (agent_b > agent_c, Handoff), # A can send to C
  306. (agent_c > agent_a, NoReminder), # No-reminder handoff
  307. ],
  308. )
  309. # Check default handoff
  310. agency.get_response_sync("Transfer to AgentB agent", recipient_agent=agent_a)
  311. system_message = agency.thread_manager.get_all_messages()[1]
  312. assert system_message["role"] == "system", (
  313. f"Incorrect role, got: {system_message}, expected reminder system message"
  314. )
  315. assert system_message["content"] == "Transfer completed. You are AgentB. Please continue the task.", (
  316. f"Incorrect content, got: {system_message}, expected reminder system message"
  317. )
  318. agency.thread_manager.clear()
  319. # Check custom reminder
  320. agency.get_response_sync("Transfer to AgentC agent", recipient_agent=agent_b)
  321. system_message = agency.thread_manager.get_all_messages()[1]
  322. assert system_message["role"] == "system", (
  323. f"Incorrect role, got: {system_message}, expected reminder system message"
  324. )
  325. assert system_message["content"] == "Custom reminder", (
  326. f"Incorrect content, got: {system_message}, expected 'Custom reminder'"
  327. )
  328. agency.thread_manager.clear()
  329. # Check no reminder handoff
  330. agency.get_response_sync("Transfer to AgentA agent", recipient_agent=agent_c)
  331. chat_history = agency.thread_manager.get_all_messages()
  332. for message in chat_history:
  333. if "role" in message:
  334. assert message["role"] != "system", f"Incorrect role, got: {message}, expected no system messages"