test_multi_agency_support.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. """
  2. Tests for multi-agency agent support.
  3. This module tests the ability for a single agent to be registered
  4. in multiple agencies without context leakage between them.
  5. """
  6. import asyncio
  7. from typing import Literal
  8. import pytest
  9. from agents import RunResult
  10. from agents.items import ToolCallOutputItem
  11. from pydantic import Field
  12. from agency_swarm import Agency, Agent
  13. from agency_swarm.tools import BaseTool
  14. class SharedStateTool(BaseTool):
  15. """Tool that uses shared state to test context isolation."""
  16. value: str = Field(..., description="Value to store or retrieve")
  17. action: Literal["set", "get"] = Field(..., description="Either 'set' or 'get'")
  18. def run(self):
  19. """Execute the tool action."""
  20. print(f"Shared state: {self.context}")
  21. if self.action == "set":
  22. self.context.set("test_value", self.value)
  23. print(f"Set test_value to: {self.value}")
  24. return f"Set test_value to: {self.value}"
  25. elif self.action == "get":
  26. stored_value = self.context.get("test_value", "NOT_SET")
  27. print(f"Current test_value: {stored_value}")
  28. return f"Current test_value: {stored_value}"
  29. def _tool_outputs(response: RunResult) -> list[str]:
  30. return [str(item.output) for item in response.new_items if isinstance(item, ToolCallOutputItem)]
  31. def _assert_tool_output_contains(response: RunResult, expected: str) -> None:
  32. tool_outputs = _tool_outputs(response)
  33. assert any(expected in output for output in tool_outputs), f"Expected {expected!r} in tool outputs: {tool_outputs}"
  34. def _assert_tool_output_excludes(response: RunResult, unexpected: str) -> None:
  35. tool_outputs = _tool_outputs(response)
  36. assert all(unexpected not in output for output in tool_outputs), (
  37. f"Did not expect {unexpected!r} in tool outputs: {tool_outputs}"
  38. )
  39. @pytest.fixture
  40. def shared_agent():
  41. """Create an agent that will be shared between multiple agencies."""
  42. return Agent(
  43. name="SharedAgent",
  44. instructions="You are a shared agent that can set and get values using the SharedStateTool.",
  45. tools=[SharedStateTool],
  46. )
  47. @pytest.fixture
  48. def agency1(shared_agent):
  49. """Create the first agency."""
  50. assistant1 = Agent(name="Assistant1", instructions="You are Assistant1 in Agency1")
  51. agency = Agency(
  52. shared_agent,
  53. assistant1,
  54. communication_flows=[shared_agent > assistant1],
  55. name="Agency1",
  56. user_context={"agency_name": "Agency1", "test_data": "agency1_data"},
  57. )
  58. return agency
  59. @pytest.fixture
  60. def agency2(shared_agent):
  61. """Create the second agency using the same shared agent."""
  62. assistant2 = Agent(name="Assistant2", instructions="You are Assistant2 in Agency2")
  63. agency = Agency(
  64. shared_agent,
  65. assistant2,
  66. communication_flows=[shared_agent > assistant2],
  67. name="Agency2",
  68. user_context={"agency_name": "Agency2", "test_data": "agency2_data"},
  69. )
  70. return agency
  71. class TestMultiAgencySupport:
  72. """Test cases for multi-agency agent support."""
  73. @pytest.mark.asyncio
  74. async def test_agent_can_be_registered_in_multiple_agencies(self, shared_agent, agency1, agency2):
  75. """Test that a single agent can be registered in multiple agencies."""
  76. # Verify the agent is registered in both agencies
  77. assert shared_agent.name in agency1.agents
  78. assert shared_agent.name in agency2.agents
  79. assert id(agency1.agents["SharedAgent"]) == id(shared_agent)
  80. assert id(agency2.agents["SharedAgent"]) == id(shared_agent)
  81. # Verify each agency has its own context for the shared agent
  82. context1 = agency1.get_agent_context("SharedAgent")
  83. context2 = agency2.get_agent_context("SharedAgent")
  84. # Verify contexts are different and isolated
  85. assert context1.agency_instance is agency1
  86. assert context2.agency_instance is agency2
  87. assert context1.thread_manager is not context2.thread_manager
  88. @pytest.mark.asyncio
  89. async def test_thread_manager_isolation(self, shared_agent, agency1, agency2):
  90. """Test that thread managers are isolated between agencies."""
  91. # Get responses from both agencies
  92. await agency1.get_response("Set test_value to 'agency1_value' using the SharedStateTool")
  93. await agency2.get_response("Set test_value to 'agency2_value' using the SharedStateTool")
  94. # Get agency contexts for the shared agent
  95. context1 = agency1.get_agent_context("SharedAgent")
  96. context2 = agency2.get_agent_context("SharedAgent")
  97. assert context1.thread_manager is not context2.thread_manager
  98. # Verify thread isolation by checking message counts
  99. messages1 = context1.thread_manager.get_all_messages()
  100. messages2 = context2.thread_manager.get_all_messages()
  101. # Each agency should have its own conversation history
  102. agency1_messages = [m for m in messages1 if m.get("agent") == "SharedAgent"]
  103. agency2_messages = [m for m in messages2 if m.get("agent") == "SharedAgent"]
  104. assert len(agency1_messages) > 0
  105. assert len(agency2_messages) > 0
  106. # Messages should be different between agencies
  107. assert messages1 != messages2
  108. @pytest.mark.asyncio
  109. async def test_context_isolation_between_agencies(self, shared_agent, agency1, agency2):
  110. """Test that MasterContext is isolated between agencies."""
  111. # Set values in agency1
  112. await agency1.get_response("Use SharedStateTool to set test_value to 'agency1_secret'")
  113. # Get value in agency2 - should not see agency1's value
  114. response2 = await agency2.get_response("Use SharedStateTool to get the current test_value")
  115. # The value should be isolated - agency2 shouldn't see agency1's value
  116. _assert_tool_output_excludes(response2, "agency1_secret")
  117. # Set different value in agency2
  118. await agency2.get_response("Use SharedStateTool to set test_value to 'agency2_secret'")
  119. # Verify agency1 still has its own value
  120. response1 = await agency1.get_response("Use SharedStateTool to get the current test_value")
  121. _assert_tool_output_contains(response1, "agency1_secret")
  122. _assert_tool_output_excludes(response1, "agency2_secret")
  123. @pytest.mark.asyncio
  124. async def test_subagent_registration_isolation(self, shared_agent, agency1, agency2):
  125. """Test that subagent registration is isolated between agencies."""
  126. # Get agency contexts for the shared agent
  127. context1 = agency1.get_agent_context("SharedAgent")
  128. context2 = agency2.get_agent_context("SharedAgent")
  129. # Each agency should have different subagents
  130. subagents1 = context1.subagents
  131. subagents2 = context2.subagents
  132. # Verify subagent isolation
  133. assert "Assistant1" in subagents1
  134. assert "Assistant1" not in subagents2
  135. assert "Assistant2" in subagents2
  136. assert "Assistant2" not in subagents1
  137. @pytest.mark.asyncio
  138. async def test_user_context_isolation(self, shared_agent, agency1, agency2):
  139. """Test that user context is isolated between agencies."""
  140. # Verify each agency has its own user context
  141. assert agency1.user_context["agency_name"] == "Agency1"
  142. assert agency1.user_context["test_data"] == "agency1_data"
  143. assert agency2.user_context["agency_name"] == "Agency2"
  144. assert agency2.user_context["test_data"] == "agency2_data"
  145. # User contexts should be different
  146. assert agency1.user_context != agency2.user_context
  147. @pytest.mark.asyncio
  148. async def test_concurrent_agency_operations(self, shared_agent, agency1, agency2):
  149. """Test concurrent operations on the same agent from different agencies (now safe with stateless design)."""
  150. # Run concurrent operations - this should be safe with stateless context passing
  151. import asyncio
  152. task1 = asyncio.create_task(agency1.get_response("Use SharedStateTool to set test_value to 'concurrent1'"))
  153. task2 = asyncio.create_task(agency2.get_response("Use SharedStateTool to set test_value to 'concurrent2'"))
  154. # Wait for both to complete
  155. response1, response2 = await asyncio.gather(task1, task2)
  156. # Both should complete successfully with no race conditions
  157. assert response1.final_output is not None
  158. assert response2.final_output is not None
  159. # Each context should have its own value without relying on live-model phrasing.
  160. assert agency1.user_context["test_value"] == "concurrent1"
  161. assert agency2.user_context["test_value"] == "concurrent2"
  162. @pytest.mark.asyncio
  163. async def test_streaming_context_isolation(self, shared_agent, agency1, agency2):
  164. """Test that streaming responses maintain context isolation."""
  165. # Test streaming from agency1
  166. events1 = []
  167. async for event in agency1.get_response_stream("Use SharedStateTool to set test_value to 'stream1'"):
  168. events1.append(event)
  169. # Test streaming from agency2
  170. events2 = []
  171. async for event in agency2.get_response_stream("Use SharedStateTool to set test_value to 'stream2'"):
  172. events2.append(event)
  173. # Both streams should complete
  174. assert len(events1) > 0
  175. assert len(events2) > 0
  176. # Verify context isolation after streaming
  177. response1 = await agency1.get_response("Use SharedStateTool to get the current test_value")
  178. response2 = await agency2.get_response("Use SharedStateTool to get the current test_value")
  179. # Should have different values
  180. outputs1 = _tool_outputs(response1)
  181. outputs2 = _tool_outputs(response2)
  182. assert any("stream1" in output for output in outputs1), f"Expected stream1 in tool outputs: {outputs1}"
  183. assert any("stream2" in output for output in outputs2), f"Expected stream2 in tool outputs: {outputs2}"
  184. assert outputs1 != outputs2
  185. class TestStatelessContextPassing:
  186. """Test cases for stateless context passing functionality."""
  187. @pytest.mark.asyncio
  188. async def test_context_isolation_during_concurrent_execution(self, shared_agent, agency1, agency2):
  189. """Test that contexts remain isolated during concurrent execution."""
  190. # Execute in both agencies concurrently - this should work without race conditions
  191. # because contexts are passed statlessly
  192. task1 = asyncio.create_task(agency1.get_response("Hello from Agency1"))
  193. task2 = asyncio.create_task(agency2.get_response("Hello from Agency2"))
  194. # Both should complete successfully without interference
  195. response1, response2 = await asyncio.gather(task1, task2)
  196. assert response1.final_output is not None
  197. assert response2.final_output is not None
  198. def test_context_factory_validation(self, shared_agent, agency1, agency2):
  199. """Test that context factory pattern works correctly."""
  200. # Each agency should have its own context for the shared agent
  201. context1 = agency1.get_agent_context("SharedAgent")
  202. context2 = agency2.get_agent_context("SharedAgent")
  203. # Contexts should be different instances
  204. assert context1 is not context2
  205. assert context1.agency_instance is agency1
  206. assert context2.agency_instance is agency2
  207. # Invalid agent name should raise error
  208. with pytest.raises(ValueError, match="No context found for agent"):
  209. agency1.get_agent_context("NonexistentAgent")