test_agency_context_sharing.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. """
  2. Integration test for agency context sharing between agents.
  3. This test verifies that agents can share data through the agency context,
  4. ensuring that changes made by one agent are visible to other agents.
  5. """
  6. import pytest
  7. from agents import ModelSettings, RunContextWrapper, function_tool
  8. from agency_swarm import Agency, Agent, MasterContext
  9. from tests.deterministic_model import DeterministicModel
  10. @function_tool
  11. async def store_data(ctx: RunContextWrapper[MasterContext], key: str, value: str) -> str:
  12. """Store data in the shared context."""
  13. context: MasterContext = ctx.context
  14. context.set(key, value)
  15. return f"Stored {key}={value}"
  16. @function_tool
  17. async def get_data(ctx: RunContextWrapper[MasterContext], key: str) -> str:
  18. """Get data from the shared context."""
  19. context: MasterContext = ctx.context
  20. value = context.get(key)
  21. return f"Value for {key}: {value}"
  22. @pytest.mark.asyncio
  23. async def test_context_sharing_between_agents():
  24. """Test that data stored by one agent is accessible to another agent."""
  25. # Create agents
  26. agent1 = Agent(
  27. name="Agent1",
  28. instructions="You store data in the context.",
  29. tools=[store_data],
  30. model=DeterministicModel(),
  31. model_settings=ModelSettings(tool_choice="required"),
  32. tool_use_behavior="stop_on_first_tool",
  33. )
  34. agent2 = Agent(
  35. name="Agent2",
  36. instructions="You retrieve and store data in the context.",
  37. tools=[get_data, store_data],
  38. model=DeterministicModel(),
  39. model_settings=ModelSettings(tool_choice="required"),
  40. tool_use_behavior="stop_on_first_tool",
  41. )
  42. # Create agency with both agents as entry points
  43. agency = Agency(
  44. agent1,
  45. agent2,
  46. communication_flows=[agent1 > agent2],
  47. user_context={"initial": "test"},
  48. )
  49. # Agent1 stores data
  50. response1 = await agency.get_response(
  51. "Store shared_key with value shared_value",
  52. recipient_agent=agent1,
  53. )
  54. tool_outputs_1 = [item.output for item in response1.new_items if hasattr(item, "output")]
  55. assert any("Stored shared_key=shared_value" in str(output) for output in tool_outputs_1)
  56. # Verify data is in agency context
  57. assert agency.user_context.get("shared_key") == "shared_value"
  58. assert agency.user_context.get("initial") == "test" # Original value preserved
  59. # Directly ask Agent2 to retrieve the data
  60. response2 = await agency.get_response(
  61. "Get the value for shared_key",
  62. recipient_agent=agent2,
  63. )
  64. tool_outputs_2 = [item.output for item in response2.new_items if hasattr(item, "output")]
  65. assert any("Value for shared_key: shared_value" in str(output) for output in tool_outputs_2)
  66. # Agent2 can also store data that's visible to the agency
  67. await agency.get_response(
  68. "Store agent2_key with value agent2_value",
  69. recipient_agent=agent2,
  70. )
  71. # Verify Agent2's data is in agency context
  72. assert agency.user_context.get("agent2_key") == "agent2_value"
  73. assert agency.user_context.get("shared_key") == "shared_value" # Previous data preserved
  74. # Retrieve Agent2's data directly from Agent2
  75. response4 = await agency.get_response(
  76. "Get the value for agent2_key",
  77. recipient_agent=agent2,
  78. )
  79. tool_outputs_4 = [item.output for item in response4.new_items if hasattr(item, "output")]
  80. assert any("Value for agent2_key: agent2_value" in str(output) for output in tool_outputs_4)