test_tool_concurrency_integration.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. """Integration tests for one_call_at_a_time tool concurrency."""
  2. import pytest
  3. from pydantic import BaseModel, Field
  4. from agency_swarm import Agent, BaseTool
  5. class TestToolConcurrencyEndToEnd:
  6. """End-to-end integration test for one_call_at_a_time functionality."""
  7. class ToolExecutionReport(BaseModel):
  8. """Output type that captures any tool execution errors."""
  9. sequential_tool_result: str = Field(description="Result from the sequential tool")
  10. parallel_tool_result: str = Field(description="Result from the parallel tool")
  11. errors_encountered: list[str] = Field(
  12. default_factory=list,
  13. description="List of any errors or concurrency violations that occurred "
  14. "containing exact error messages that you've received",
  15. )
  16. @pytest.mark.asyncio
  17. async def test_agent_enforces_tool_concurrency(self):
  18. """Test that agent properly enforces one_call_at_a_time using structured output."""
  19. class SequentialTool(BaseTool):
  20. """A tool that must run sequentially and takes time."""
  21. duration: float = Field(description="How long to process in seconds")
  22. class ToolConfig:
  23. one_call_at_a_time = True
  24. strict = False
  25. def run(self):
  26. import time
  27. time.sleep(self.duration)
  28. return f"SequentialTool completed processing for {self.duration} seconds"
  29. class ParallelTool(BaseTool):
  30. """A tool that can run in parallel."""
  31. message: str = Field(description="Message to process")
  32. class ToolConfig:
  33. strict = False
  34. def run(self):
  35. return f"ParallelTool processed: {self.message}"
  36. # Create agent with structured output for response validation
  37. agent = Agent(
  38. name="ConcurrencyTestAgent",
  39. instructions="""You are a test agent with two tools: SequentialTool and ParallelTool.
  40. When asked to use both tools simultaneously:
  41. 1. Try to call SequentialTool with duration=2
  42. 2. Try to call ParallelTool with message="test_parallel"
  43. 3. Report the results and any errors that occur
  44. IMPORTANT: Always call both tools in a single response, not sequentially.
  45. If you encounter tool concurrency violations, include them in the errors_encountered list.""",
  46. tools=[SequentialTool, ParallelTool],
  47. output_type=self.ToolExecutionReport,
  48. model="gpt-5.4-mini",
  49. )
  50. # Ask agent to use both tools simultaneously
  51. response = await agent.get_response(
  52. "Please use both SequentialTool and ParallelTool at the same time. "
  53. "Call SequentialTool with duration 1 and ParallelTool with message 'test_parallel'. "
  54. "Report any concurrency violations or errors in the structured output."
  55. )
  56. # Verify the structured output
  57. output = response.final_output
  58. assert isinstance(output, self.ToolExecutionReport)
  59. # Check tool outputs directly for concurrency violations; avoid relying on summary wording
  60. tool_outputs = [str(item.output) for item in response.new_items if hasattr(item, "output")]
  61. concurrency_errors = [out for out in tool_outputs if "concurrency violation" in out.lower()]
  62. assert len(concurrency_errors) > 0, f"Expected concurrency violation, but got tool outputs: {tool_outputs}"
  63. # At least one tool should have completed successfully
  64. success_markers = (
  65. f"{SequentialTool.__name__} completed processing".lower(),
  66. f"{ParallelTool.__name__} processed".lower(),
  67. )
  68. successful_results = [
  69. output
  70. for output in tool_outputs
  71. if any(marker in output.lower() for marker in success_markers) and "error" not in output.lower()
  72. ]
  73. assert len(successful_results) > 0, "At least one tool should have completed successfully"
  74. class TestFunctionToolConcurrency:
  75. """Test concurrency with @function_tool decorated tools."""
  76. def test_function_tool_tools_folder_integration(self, tmp_path):
  77. """Test that function tools from tools_folder get proper concurrency handling."""
  78. # Create a tools folder with function tools
  79. tools_dir = tmp_path / "tools"
  80. tools_dir.mkdir()
  81. # Create function tool file
  82. tool_file = tools_dir / "concurrency_tool.py"
  83. tool_file.write_text("""
  84. from agents import function_tool
  85. import time
  86. @function_tool
  87. def sequential_tool(duration: float) -> str:
  88. '''A tool that must run sequentially.'''
  89. time.sleep(duration)
  90. return f"Sequential tool completed after {duration}s"
  91. # Set one_call_at_a_time attribute
  92. sequential_tool.one_call_at_a_time = True
  93. @function_tool
  94. def parallel_tool(message: str) -> str:
  95. '''A tool that can run in parallel.'''
  96. return f"Parallel tool: {message}"
  97. """)
  98. # Create agent with tools_folder
  99. agent = Agent(
  100. name="ToolsFolderAgent",
  101. instructions="Test tools folder integration.",
  102. tools_folder=str(tools_dir),
  103. model="gpt-5.4-mini",
  104. )
  105. # Should have loaded both tools
  106. tool_names = [tool.name for tool in agent.tools]
  107. assert "sequential_tool" in tool_names
  108. assert "parallel_tool" in tool_names
  109. # Find the tools
  110. sequential_tool = next(t for t in agent.tools if t.name == "sequential_tool")
  111. parallel_tool = next(t for t in agent.tools if t.name == "parallel_tool")
  112. # Sequential tool should have one_call_at_a_time
  113. assert getattr(sequential_tool, "one_call_at_a_time", False) is True
  114. assert getattr(sequential_tool, "_one_call_guard_installed", False) is True
  115. # Parallel tool should not
  116. assert getattr(parallel_tool, "one_call_at_a_time", False) is False
  117. assert getattr(parallel_tool, "_one_call_guard_installed", False) is True