test_context_persistence.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. """
  2. Integration test for context persistence across agent calls.
  3. This test verifies that modifications to user_context are preserved
  4. between different agent invocations within the same agency.
  5. """
  6. import pytest
  7. from pydantic import Field
  8. from agency_swarm import Agency, Agent
  9. from agency_swarm.tools import BaseTool
  10. class StoreValueTool(BaseTool):
  11. """Store a value in the shared context."""
  12. key: str = Field(..., description="Key to store")
  13. value: str = Field(..., description="Value to store")
  14. def run(self):
  15. if self.context:
  16. self.context.set(self.key, self.value)
  17. return f"Stored {self.key}={self.value}"
  18. return "No context available"
  19. class ReadValueTool(BaseTool):
  20. """Read a value from the shared context."""
  21. key: str = Field(..., description="Key to read")
  22. def run(self):
  23. if self.context:
  24. value = self.context.get(self.key, "not_found")
  25. return f"Value for {self.key}: {value}"
  26. return "No context available"
  27. @pytest.mark.asyncio
  28. async def test_context_persistence_between_calls():
  29. """Test that context changes persist between separate agent calls."""
  30. # Create agent with both tools
  31. agent = Agent(
  32. name="ContextAgent",
  33. instructions="You store and retrieve data using the provided tools.",
  34. tools=[StoreValueTool, ReadValueTool],
  35. model="gpt-5.4-mini",
  36. )
  37. # Create agency with initial context
  38. agency = Agency(
  39. agent,
  40. user_context={"initial": "value"},
  41. )
  42. # First call: Store a value
  43. response1 = await agency.get_response("Store the value 'test_data' with key 'stored_key' using StoreValueTool")
  44. # Verify the tool was called
  45. tool_outputs = [item.output for item in response1.new_items if hasattr(item, "output")]
  46. assert any("Stored stored_key=test_data" in str(output) for output in tool_outputs)
  47. # Second call: Read the value back
  48. response2 = await agency.get_response("Read the value for key 'stored_key' using ReadValueTool")
  49. # Verify the value was persisted
  50. tool_outputs2 = [item.output for item in response2.new_items if hasattr(item, "output")]
  51. assert any("Value for stored_key: test_data" in str(output) for output in tool_outputs2)
  52. # Verify agency context was updated
  53. assert agency.user_context.get("stored_key") == "test_data"
  54. assert agency.user_context.get("initial") == "value" # Original value still there
  55. @pytest.mark.asyncio
  56. async def test_context_override_does_not_affect_agency():
  57. """Test that context_override doesn't modify the agency's user_context."""
  58. agent = Agent(
  59. name="TestAgent",
  60. instructions="You read data using ReadValueTool.",
  61. tools=[ReadValueTool],
  62. model="gpt-5.4-mini",
  63. )
  64. agency = Agency(
  65. agent,
  66. user_context={"agency_key": "agency_value"},
  67. )
  68. # Call with context override
  69. response = await agency.get_response(
  70. "Read the value for key 'override_key' using ReadValueTool", context_override={"override_key": "override_value"}
  71. )
  72. # Verify the override was used in the call
  73. tool_outputs = [item.output for item in response.new_items if hasattr(item, "output")]
  74. assert any("Value for override_key: override_value" in str(output) for output in tool_outputs)
  75. # Verify agency context was NOT modified
  76. assert "override_key" not in agency.user_context
  77. assert agency.user_context == {"agency_key": "agency_value"}