test_agent_flow_integration.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. """
  2. Unit tests for AgentFlow integration with Agency class.
  3. Tests the parsing and handling of AgentFlow objects in communication_flows.
  4. """
  5. import pytest
  6. from agency_swarm import Agency, Agent
  7. from agency_swarm.tools.send_message import Handoff, SendMessage, SendMessageHandoff
  8. class CustomSendMessage(SendMessage):
  9. """Custom send message tool for testing."""
  10. pass
  11. # --- Agency Integration Tests ---
  12. def test_agency_with_mixed_communication_flows():
  13. """Test Agency with mixed communication flow formats."""
  14. agent1 = Agent(name="Agent1", instructions="Test agent 1", model="gpt-5.4-mini")
  15. agent2 = Agent(name="Agent2", instructions="Test agent 2", model="gpt-5.4-mini")
  16. agent3 = Agent(name="Agent3", instructions="Test agent 3", model="gpt-5.4-mini")
  17. agent4 = Agent(name="Agent4", instructions="Test agent 4", model="gpt-5.4-mini")
  18. agency = Agency(
  19. agent1,
  20. communication_flows=[
  21. (agent1 > agent2 > agent3, CustomSendMessage), # Chain with tool
  22. (agent1, agent4), # Basic pair
  23. (agent2, agent4, Handoff), # Pair with tool
  24. ],
  25. )
  26. assert len(agency.agents) == 4
  27. # Check tool mappings
  28. tool_mappings = agency._communication_tool_classes
  29. assert tool_mappings[("Agent1", "Agent2")] == CustomSendMessage
  30. assert tool_mappings[("Agent2", "Agent3")] == CustomSendMessage
  31. assert tool_mappings[("Agent2", "Agent4")] == Handoff
  32. def test_agency_with_mixed_communication_flows_reverse():
  33. """Test Agency with reverse communication flow."""
  34. agent1 = Agent(name="Agent1", instructions="Test agent 1", model="gpt-5.4-mini")
  35. agent2 = Agent(name="Agent2", instructions="Test agent 2", model="gpt-5.4-mini")
  36. agent3 = Agent(name="Agent3", instructions="Test agent 3", model="gpt-5.4-mini")
  37. agency = Agency(
  38. agent1,
  39. communication_flows=[
  40. (agent3 < agent2 < agent1, CustomSendMessage), # Chain with tool
  41. ],
  42. )
  43. assert len(agency.agents) == 3
  44. # Check tool mappings
  45. tool_mappings = agency._communication_tool_classes
  46. assert tool_mappings[("Agent1", "Agent2")] == CustomSendMessage
  47. assert tool_mappings[("Agent2", "Agent3")] == CustomSendMessage
  48. def test_duplicate_flow_detection_with_chains():
  49. """Test that duplicate flows are detected with AgentFlow chains."""
  50. agent1 = Agent(name="Agent1", instructions="Test agent 1", model="gpt-5.4-mini")
  51. agent2 = Agent(name="Agent2", instructions="Test agent 2", model="gpt-5.4-mini")
  52. agent3 = Agent(name="Agent3", instructions="Test agent 3", model="gpt-5.4-mini")
  53. with pytest.raises(ValueError, match="Duplicate communication flow detected"):
  54. Agency(
  55. agent1,
  56. communication_flows=[
  57. (agent1 > agent2 > agent3, CustomSendMessage), # Creates agent1->agent2
  58. (agent1, agent2), # Duplicate agent1->agent2
  59. ],
  60. )
  61. def test_agent_flow_with_handoff_tool():
  62. """Test that Handoff works with AgentFlow."""
  63. agent1 = Agent(name="Agent1", instructions="Test agent 1", model="gpt-5.4-mini")
  64. agent2 = Agent(name="Agent2", instructions="Test agent 2", model="gpt-5.4-mini")
  65. agent3 = Agent(name="Agent3", instructions="Test agent 3", model="gpt-5.4-mini")
  66. # This should work without errors
  67. agency = Agency(
  68. agent1,
  69. communication_flows=[
  70. (agent1 > agent2 > agent3, Handoff),
  71. ],
  72. )
  73. assert len(agency.agents) == 3
  74. runtime_state1 = agency.get_agent_runtime_state("Agent1")
  75. runtime_state2 = agency.get_agent_runtime_state("Agent2")
  76. handoff_names_1 = [handoff.tool_name for handoff in runtime_state1.handoffs]
  77. handoff_names_2 = [handoff.tool_name for handoff in runtime_state2.handoffs]
  78. assert "transfer_to_Agent2" in handoff_names_1
  79. assert "transfer_to_Agent3" in handoff_names_2
  80. assert not agency.get_agent_runtime_state("Agent3").handoffs
  81. def test_send_message_handoff_name_is_deprecated() -> None:
  82. agent1 = Agent(name="Agent1", instructions="Test agent 1", model="gpt-5.4-mini")
  83. agent2 = Agent(name="Agent2", instructions="Test agent 2", model="gpt-5.4-mini")
  84. with pytest.deprecated_call(match=r"SendMessageHandoff is deprecated; use Handoff instead\."):
  85. Agency(
  86. agent1,
  87. communication_flows=[
  88. (agent1 > agent2, SendMessageHandoff),
  89. ],
  90. )