test_communication.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import json
  2. import uuid
  3. import pytest
  4. from agents import ModelSettings, RunResult
  5. from agency_swarm import Agency, Agent
  6. @pytest.fixture
  7. def planner_agent_instance():
  8. return Agent(
  9. name="Planner",
  10. description="Plans the work.",
  11. instructions=(
  12. "You are a Planner. You will receive a task. Determine the steps. "
  13. "Delegate the execution step to the Worker agent using the send_message tool. "
  14. "Ensure your message to the Worker clearly includes the full and exact task description you received. "
  15. "After receiving the final result, relay it verbatim to the user including all task identifiers."
  16. ),
  17. model_settings=ModelSettings(temperature=0.0),
  18. )
  19. @pytest.fixture
  20. def worker_agent_instance():
  21. return Agent(
  22. name="Worker",
  23. description="Does the work.",
  24. instructions=(
  25. "You are a Worker. You will receive execution instructions from the Planner including a task description. "
  26. "Perform the task (simulate by creating a result string like 'Work done for: [task description]'). "
  27. "Send the result string to the Reporter agent using the send_message tool. "
  28. "Ensure your message clearly references the specific task description you were given by the Planner."
  29. ),
  30. model_settings=ModelSettings(temperature=0.0),
  31. )
  32. @pytest.fixture
  33. def reporter_agent_instance():
  34. return Agent(
  35. name="Reporter",
  36. description="Reports the results.",
  37. instructions=(
  38. "You are a Reporter. You will receive results from the Worker, "
  39. "which should reference a specific task description. "
  40. "Format this into a final report string. "
  41. "Ensure your final report clearly identifies the specific task "
  42. "description that was processed along with the results."
  43. ),
  44. model_settings=ModelSettings(temperature=0.0),
  45. )
  46. @pytest.fixture
  47. def multi_agent_agency(planner_agent_instance, worker_agent_instance, reporter_agent_instance):
  48. agency = Agency(
  49. planner_agent_instance,
  50. communication_flows=[
  51. planner_agent_instance > worker_agent_instance,
  52. worker_agent_instance > reporter_agent_instance,
  53. ],
  54. shared_instructions="This is a test agency.",
  55. )
  56. return agency
  57. @pytest.mark.asyncio
  58. async def test_multi_agent_communication_flow(multi_agent_agency: Agency):
  59. """Proves end-to-end Planner→Worker→Reporter pipeline yields a final output with task context."""
  60. initial_task = f"Process test data batch {uuid.uuid4()}."
  61. print(f"\n--- Starting Integration Test --- TASK: {initial_task}")
  62. final_result: RunResult = await multi_agent_agency.get_response(message=initial_task)
  63. print(f"--- Integration Test Complete --- FINAL OUTPUT:\n{final_result.final_output}")
  64. assert final_result.final_output is not None
  65. assert isinstance(final_result.final_output, str)
  66. assert len(final_result.final_output) > 0
  67. task_id_part = initial_task.split(" ")[-1].split(".")[0]
  68. assert task_id_part in final_result.final_output
  69. print("--- Assertions Passed ---")
  70. @pytest.mark.asyncio
  71. async def test_context_preservation_in_agent_communication(multi_agent_agency: Agency):
  72. """Proves agent-to-agent thread isolation with correct caller/agent identifiers in flat storage."""
  73. initial_task = "Simple task for testing context preservation."
  74. print(f"\n--- Testing Context Preservation --- TASK: {initial_task}")
  75. # Execute the communication flow
  76. await multi_agent_agency.get_response(message=initial_task)
  77. # Direct verification - check actual messages in flat storage
  78. thread_manager = multi_agent_agency.thread_manager
  79. all_messages = thread_manager.get_all_messages()
  80. # Extract unique conversation pairs from messages
  81. conversation_pairs = set()
  82. for msg in all_messages:
  83. agent = msg.get("agent", "")
  84. caller = msg.get("callerAgent")
  85. if agent:
  86. # Convert None to "user" for display
  87. caller_name = "user" if caller is None else caller
  88. conversation_pairs.add(f"{caller_name}->{agent}")
  89. actual_conversations = list(conversation_pairs)
  90. print(f"--- Actual conversations created: {actual_conversations}")
  91. # Verify that we have agent-to-agent communication
  92. agent_to_agent_convs = [conv for conv in actual_conversations if "->" in conv and not conv.startswith("user->")]
  93. assert len(agent_to_agent_convs) > 0, (
  94. f"No agent-to-agent conversations found. Conversations: {actual_conversations}"
  95. )
  96. # Verify expected communication patterns exist
  97. expected_agent_patterns = ["Planner->Worker", "Worker->Reporter"]
  98. for pattern in expected_agent_patterns:
  99. if pattern in actual_conversations:
  100. print(f"✓ Found expected conversation pattern: {pattern}")
  101. # Verify the pattern follows structured format
  102. assert "->" in pattern, f"Conversation should be structured: {pattern}"
  103. # Verify sender and recipient are correctly formatted
  104. sender, recipient = pattern.split("->")
  105. assert sender in ["Planner", "Worker", "Reporter"], f"Invalid sender: {sender}"
  106. assert recipient in ["Planner", "Worker", "Reporter"], f"Invalid recipient: {recipient}"
  107. # Verify that user conversations also exist
  108. user_convs = [conv for conv in actual_conversations if conv.startswith("user->")]
  109. assert len(user_convs) > 0, f"No user conversations found. Conversations: {actual_conversations}"
  110. print("✓ Verified all conversations use proper identifiers")
  111. print("✓ Message isolation verified through flat storage")
  112. print("--- Context preservation test passed ---")
  113. @pytest.mark.asyncio
  114. async def test_non_blocking_parallel_agent_interactions(
  115. planner_agent_instance, worker_agent_instance, reporter_agent_instance
  116. ):
  117. """Proves Planner can initiate two distinct inter-agent sends without blocking; both complete."""
  118. # Create agency where Planner can talk to both Worker and Reporter directly
  119. agency = Agency(
  120. planner_agent_instance,
  121. communication_flows=[
  122. planner_agent_instance > worker_agent_instance,
  123. planner_agent_instance > reporter_agent_instance,
  124. ],
  125. shared_instructions="",
  126. )
  127. before_count = len(agency.thread_manager.get_all_messages())
  128. result: RunResult = await agency.get_response(
  129. message=(
  130. "Say hello to both agents at the same time in parallel. "
  131. "In THIS SAME assistant turn, EMIT EXACTLY TWO send_message TOOL CALLS BACK-TO-BACK: first to Worker, "
  132. "then to Reporter. DO NOT produce any assistant text in this turn. DO NOT wait for any tool result "
  133. "between these two calls. Each message must include the exact task description you received."
  134. )
  135. )
  136. assert result is not None and isinstance(result.final_output, str)
  137. all_messages = agency.thread_manager.get_all_messages()
  138. new_messages = all_messages[before_count:]
  139. call_indices = []
  140. output_indices = [] # Planner-only
  141. send_message_like_call_indices = []
  142. called_recipients: list[str] = []
  143. for idx, msg in enumerate(new_messages):
  144. msg_type = msg.get("type")
  145. if msg_type == "function_call":
  146. call_indices.append(idx)
  147. try:
  148. args = json.loads(msg.get("arguments", "{}"))
  149. except Exception:
  150. args = {}
  151. if isinstance(args, dict):
  152. if {"message", "additional_instructions"}.issubset(args.keys()):
  153. send_message_like_call_indices.append(idx)
  154. # Track recipient if present
  155. recipient = args.get("recipient_agent")
  156. if isinstance(recipient, str) and recipient:
  157. called_recipients.append(recipient)
  158. elif msg_type == "function_call_output":
  159. # Only consider outputs from the Planner (ignore sub-agent outputs)
  160. if msg.get("agent") == "Planner":
  161. output_indices.append(idx)
  162. # Ensure we see at least two inter-agent calls
  163. assert len(send_message_like_call_indices) >= 2, (
  164. f"Expected at least two inter-agent function_call items; found indices {send_message_like_call_indices}."
  165. )
  166. # Ensure calls target two different recipients (order-agnostic)
  167. assert len(set(called_recipients)) >= 2, (
  168. f"Expected calls to two distinct recipients; got recipients {called_recipients}"
  169. )
  170. # Ensure Planner produced at least two outputs (both calls completed)
  171. assert len(output_indices) >= 2, (
  172. f"Expected at least two Planner function_call_output items; got indices {output_indices}"
  173. )
  174. # both calls must occur before the second Planner output
  175. second_output_idx = sorted(output_indices)[1]
  176. send_message_like_call_indices.sort()
  177. assert send_message_like_call_indices[1] < second_output_idx, (
  178. f"Both inter-agent calls must occur before the second Planner output; "
  179. f"calls={send_message_like_call_indices}, second_output={second_output_idx}"
  180. )