test_tool_system.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. import json
  2. from pathlib import Path
  3. from types import SimpleNamespace
  4. from unittest.mock import AsyncMock, MagicMock, patch
  5. import pytest
  6. from agents import RunContextWrapper
  7. from agents.tool import _get_function_tool_invoke_context
  8. from agents.tool_context import ToolContext
  9. from pydantic import Field
  10. from agency_swarm import Agent, BaseTool, GuardrailFunctionOutput, InputGuardrailTripwireTriggered
  11. from agency_swarm.context import MasterContext
  12. from agency_swarm.tools.send_message import SendMessage
  13. from agency_swarm.utils.thread import ThreadManager
  14. # --- Fixtures ---
  15. class _FakeStream:
  16. def __init__(self, events: list[SimpleNamespace], final_result: object | None = None) -> None:
  17. self._events = events
  18. self._final_result = final_result
  19. def __aiter__(self):
  20. async def _events():
  21. for event in self._events:
  22. yield event
  23. return _events()
  24. @property
  25. def final_result(self):
  26. return self._final_result
  27. class _FakeAgent:
  28. def __init__(self, name: str, stream_events: list[SimpleNamespace] | None = None) -> None:
  29. self.name = name
  30. self.model = "gpt-5.4-mini"
  31. self.raise_input_guardrail_error = False
  32. self._stream_events = stream_events or []
  33. def get_response_stream(self, **_kwargs):
  34. return _FakeStream(self._stream_events)
  35. async def get_response(self, **_kwargs):
  36. return "ok"
  37. class _ErroringStreamAgent(_FakeAgent):
  38. def __init__(self, name: str, error_text: str) -> None:
  39. super().__init__(name)
  40. self._error_text = error_text
  41. def get_response_stream(self, **_kwargs):
  42. raise RuntimeError(self._error_text)
  43. class _GuardrailTripwireAgent(_FakeAgent):
  44. def __init__(self, name: str, exc: InputGuardrailTripwireTriggered) -> None:
  45. super().__init__(name)
  46. self._exc = exc
  47. async def get_response(self, **_kwargs):
  48. raise self._exc
  49. @pytest.fixture
  50. def mock_sender_agent():
  51. return _FakeAgent("SenderAgent")
  52. @pytest.fixture
  53. def mock_recipient_agent():
  54. event = SimpleNamespace(
  55. item=SimpleNamespace(
  56. type="message_output_item",
  57. raw_item=SimpleNamespace(content=[SimpleNamespace(text="Response from recipient")]),
  58. )
  59. )
  60. return _FakeAgent("RecipientAgent", stream_events=[event])
  61. @pytest.fixture
  62. def mock_master_context():
  63. return MasterContext(
  64. thread_manager=ThreadManager(),
  65. agents={},
  66. user_context={"user_key": "user_value"},
  67. )
  68. @pytest.fixture
  69. def mock_run_context_wrapper(mock_master_context):
  70. return RunContextWrapper(context=mock_master_context)
  71. @pytest.fixture
  72. def mock_context(mock_sender_agent, mock_recipient_agent):
  73. thread_manager = MagicMock(spec=ThreadManager)
  74. thread_manager.get_thread = MagicMock(return_value=MagicMock())
  75. thread_manager.add_items_and_save = AsyncMock()
  76. return MasterContext(
  77. thread_manager=thread_manager,
  78. agents={"SenderAgent": mock_sender_agent, "RecipientAgent": mock_recipient_agent},
  79. user_context={"user_key": "user_val"},
  80. agent_runtime_state={},
  81. shared_instructions=None,
  82. _is_streaming=True,
  83. )
  84. @pytest.fixture
  85. def mock_wrapper(mock_context, mock_sender_agent):
  86. wrapper = RunContextWrapper(context=mock_context)
  87. wrapper.hooks = MagicMock()
  88. wrapper.agent = mock_sender_agent
  89. return wrapper
  90. @pytest.fixture
  91. def specific_send_message_tool(mock_sender_agent, mock_recipient_agent):
  92. # Create an instance of SendMessage for testing its on_invoke_tool method directly
  93. return SendMessage(
  94. sender_agent=mock_sender_agent,
  95. recipients={mock_recipient_agent.name.lower(): mock_recipient_agent},
  96. )
  97. @pytest.fixture
  98. def base_tool():
  99. # Create a class that inherits from BaseTool
  100. class TestTool(BaseTool):
  101. input: str = Field(description="The input to the tool")
  102. class ToolConfig:
  103. strict = True
  104. def run(self):
  105. print(f"Running TestTool with input: {self.input}")
  106. return self.input
  107. return TestTool
  108. # --- Test Cases ---
  109. def test_send_message_advertises_tool_context_to_agents_sdk(
  110. specific_send_message_tool,
  111. mock_master_context,
  112. ) -> None:
  113. """The SDK should pass SendMessage the real ToolContext with call metadata."""
  114. tool_context = ToolContext(
  115. context=mock_master_context,
  116. tool_name="send_message",
  117. tool_call_id="call_send_message",
  118. tool_arguments="{}",
  119. )
  120. selected_context = _get_function_tool_invoke_context(specific_send_message_tool, tool_context)
  121. assert selected_context is tool_context
  122. @pytest.mark.asyncio
  123. async def test_send_message_success(specific_send_message_tool, mock_wrapper, mock_recipient_agent, mock_context):
  124. message_content = "Test message"
  125. args_dict = {
  126. "recipient_agent": mock_recipient_agent.name, # Add the recipient_agent field
  127. "message": message_content,
  128. "additional_instructions": "Additional instructions for test.",
  129. }
  130. args_json_string = json.dumps(args_dict)
  131. result = await specific_send_message_tool.on_invoke_tool(
  132. wrapper=mock_wrapper, arguments_json_string=args_json_string
  133. )
  134. assert result == "Response from recipient"
  135. # The test now properly uses get_response_stream which is what SendMessage actually calls
  136. @pytest.mark.asyncio
  137. async def test_send_message_invalid_json(specific_send_message_tool, mock_wrapper):
  138. args_json_string = "{invalid json string"
  139. expected_error_message = (
  140. f"Error: Invalid arguments format for tool {specific_send_message_tool.name}. Expected a valid JSON string."
  141. )
  142. with patch("agency_swarm.tools.send_message.logger") as mock_module_logger:
  143. result = await specific_send_message_tool.on_invoke_tool(
  144. wrapper=mock_wrapper, arguments_json_string=args_json_string
  145. )
  146. assert result == expected_error_message
  147. mock_module_logger.error.assert_called_once()
  148. @pytest.mark.asyncio
  149. async def test_send_message_missing_required_param(specific_send_message_tool, mock_wrapper):
  150. # Test missing 'message'
  151. args_dict_missing_message = {
  152. "recipient_agent": "RecipientAgent",
  153. # "message" is missing
  154. }
  155. args_json_missing_message = json.dumps(args_dict_missing_message)
  156. expected_error_missing_message = (
  157. f"Error: Missing required parameter 'message' for tool {specific_send_message_tool.name}."
  158. )
  159. with patch("agency_swarm.tools.send_message.logger") as mock_module_logger:
  160. result = await specific_send_message_tool.on_invoke_tool(
  161. wrapper=mock_wrapper, arguments_json_string=args_json_missing_message
  162. )
  163. assert result == expected_error_missing_message
  164. mock_module_logger.error.assert_called_once_with(
  165. f"Tool '{specific_send_message_tool.name}' invoked without 'message' parameter."
  166. )
  167. mock_module_logger.reset_mock()
  168. @pytest.mark.asyncio
  169. async def test_send_message_target_agent_error(mock_wrapper):
  170. error_text = "Target agent failed"
  171. tool = SendMessage(
  172. sender_agent=mock_wrapper.agent,
  173. recipients={"recipientagent": _ErroringStreamAgent("RecipientAgent", error_text)},
  174. )
  175. message_content = "Test message"
  176. args_dict = {
  177. "recipient_agent": "RecipientAgent",
  178. "message": message_content,
  179. "additional_instructions": "",
  180. }
  181. args_json_string = json.dumps(args_dict)
  182. expected_error_message = f"Error: Failed to get response from agent 'RecipientAgent'. Reason: {error_text}"
  183. with patch("agency_swarm.tools.send_message.logger") as mock_module_logger:
  184. result = await tool.on_invoke_tool(wrapper=mock_wrapper, arguments_json_string=args_json_string)
  185. assert result == expected_error_message
  186. mock_module_logger.error.assert_called_once()
  187. @pytest.mark.asyncio
  188. async def test_send_message_input_guardrail_returns_error(mock_sender_agent, mock_wrapper):
  189. class _InRes:
  190. output = GuardrailFunctionOutput(
  191. output_info="Prefix your request with 'Task:'",
  192. tripwire_triggered=True,
  193. )
  194. guardrail = object()
  195. recipient = _GuardrailTripwireAgent("RecipientAgent", InputGuardrailTripwireTriggered(_InRes()))
  196. mock_wrapper.context.agents = {"SenderAgent": mock_sender_agent, "RecipientAgent": recipient}
  197. mock_wrapper.context._is_streaming = False
  198. tool = SendMessage(sender_agent=mock_sender_agent, recipients={recipient.name.lower(): recipient})
  199. args = {
  200. "recipient_agent": recipient.name,
  201. "message": "Hello",
  202. "additional_instructions": "",
  203. }
  204. result = await tool.on_invoke_tool(wrapper=mock_wrapper, arguments_json_string=json.dumps(args))
  205. assert result == "Prefix your request with 'Task:'"
  206. @pytest.mark.asyncio
  207. async def test_base_tool(base_tool):
  208. """
  209. Test that BaseTool can be used via the on_invoke_tool method of the adapted FunctionTool.
  210. """
  211. from agency_swarm.tools import ToolFactory
  212. function_tool = ToolFactory.adapt_base_tool(base_tool)
  213. input_json = '{"input": "hello"}'
  214. result = await function_tool.on_invoke_tool(None, input_json)
  215. assert result == "hello"
  216. @pytest.mark.asyncio
  217. async def test_schema_conversion():
  218. agent = Agent(name="test", instructions="test", schemas_folder="tests/data/schemas")
  219. tool_names = [tool.name for tool in agent.tools]
  220. assert "createPaste" in tool_names
  221. def test_tools_folder_autoload():
  222. tools_path = Path("tests/data/tools").resolve()
  223. agent = Agent(name="test", instructions="test", tools_folder=str(tools_path))
  224. tool_names = [tool.name for tool in agent.tools]
  225. assert "ExampleTool1" in tool_names
  226. assert "sample_tool" in tool_names
  227. def test_relative_tools_folder_is_class_local():
  228. agent = Agent(name="test", instructions="test", tools_folder="../data/tools")
  229. tool_names = [tool.name for tool in agent.tools]
  230. assert "ExampleTool1" in tool_names and "sample_tool" in tool_names
  231. def test_tools_folder_edge_cases(tmp_path):
  232. """Test tools_folder handles edge cases correctly."""
  233. tools_dir = tmp_path / "tools"
  234. tools_dir.mkdir()
  235. # Create files that should be ignored
  236. (tools_dir / "_private_tool.py").write_text("# Should be ignored")
  237. (tools_dir / "readme.txt").write_text("Not a Python file")
  238. (tools_dir / "invalid_tool.py").write_text("invalid python syntax !")
  239. # Create valid tool
  240. (tools_dir / "valid_tool.py").write_text("""
  241. from agents import function_tool
  242. @function_tool
  243. def valid_tool() -> str:
  244. return "works"
  245. """)
  246. agent = Agent(name="test", instructions="test", tools_folder=str(tools_dir))
  247. tool_names = [tool.name for tool in agent.tools]
  248. # Only valid_tool should be loaded
  249. assert "valid_tool" in tool_names
  250. assert "_private_tool" not in tool_names
  251. assert len(tool_names) == 1
  252. @pytest.mark.asyncio
  253. async def test_tools_folder_supports_relative_imports(tmp_path):
  254. """Tools that use relative imports should load correctly from tools_folder."""
  255. tools_dir = tmp_path / "tools"
  256. tools_dir.mkdir()
  257. # Helper module imported relatively by the tool
  258. (tools_dir / "helpers.py").write_text("def greet(name: str) -> str:\n return f'hello {name}'\n")
  259. # Tool that relies on relative import
  260. (tools_dir / "RelativeTool.py").write_text(
  261. "from pydantic import Field\n"
  262. "from agency_swarm.tools import BaseTool\n"
  263. "from .helpers import greet\n\n"
  264. "class RelativeTool(BaseTool):\n"
  265. " name: str = Field(description='Name to greet')\n\n"
  266. " def run(self):\n"
  267. " return greet(self.name)\n"
  268. )
  269. agent = Agent(name="test", instructions="test", tools_folder=str(tools_dir))
  270. tool = next(t for t in agent.tools if t.name == "RelativeTool")
  271. result = await tool.on_invoke_tool(None, json.dumps({"name": "Ada"}))
  272. assert result == "hello Ada"
  273. @pytest.mark.parametrize("folder", [None, "/nonexistent/path"])
  274. def test_tools_folder_missing(folder: str | None):
  275. """Agent should handle missing or invalid tools_folder gracefully."""
  276. agent = Agent(name="test", instructions="test", tools_folder=folder)
  277. assert agent.tools == []
  278. @pytest.mark.asyncio
  279. async def test_shared_state_property(mock_run_context_wrapper):
  280. class TestTool(BaseTool):
  281. def run(self):
  282. return "ok"
  283. tool = TestTool()
  284. tool._context = mock_run_context_wrapper
  285. with pytest.raises(AttributeError, match=r"_shared_state"):
  286. _ = tool._shared_state
  287. assert tool.context is mock_run_context_wrapper.context
  288. # --- one_call_at_a_time Tests ---
  289. def test_base_tool_one_call_at_a_time_config():
  290. """Test that BaseTool ToolConfig supports one_call_at_a_time parameter."""
  291. class OneCallTool(BaseTool):
  292. input: str = Field(description="Tool input")
  293. class ToolConfig:
  294. one_call_at_a_time = True
  295. def run(self):
  296. return f"processed: {self.input}"
  297. class NormalTool(BaseTool):
  298. input: str = Field(description="Tool input")
  299. def run(self):
  300. return f"processed: {self.input}"
  301. # Test that the config attribute exists and has correct values
  302. assert hasattr(OneCallTool.ToolConfig, "one_call_at_a_time")
  303. assert OneCallTool.ToolConfig.one_call_at_a_time is True
  304. # Normal tool should default to False
  305. assert (
  306. not hasattr(NormalTool.ToolConfig, "one_call_at_a_time")
  307. or getattr(NormalTool.ToolConfig, "one_call_at_a_time", False) is False
  308. )
  309. @pytest.mark.asyncio
  310. async def test_base_tool_one_call_propagation():
  311. """Test that one_call_at_a_time is propagated from BaseTool to FunctionTool."""
  312. from agency_swarm.tools import ToolFactory
  313. class OneCallTool(BaseTool):
  314. input: str = Field(description="Tool input")
  315. class ToolConfig:
  316. one_call_at_a_time = True
  317. strict = False
  318. def run(self):
  319. return f"sequential: {self.input}"
  320. # Adapt to FunctionTool
  321. function_tool = ToolFactory.adapt_base_tool(OneCallTool)
  322. # Check that the attribute was propagated
  323. assert hasattr(function_tool, "one_call_at_a_time")
  324. assert function_tool.one_call_at_a_time is True
  325. @pytest.mark.asyncio
  326. async def test_base_tool_normal_tool_no_one_call():
  327. """Test that normal tools don't have one_call_at_a_time set."""
  328. from agency_swarm.tools import ToolFactory
  329. class NormalTool(BaseTool):
  330. input: str = Field(description="Tool input")
  331. def run(self):
  332. return f"normal: {self.input}"
  333. # Adapt to FunctionTool
  334. function_tool = ToolFactory.adapt_base_tool(NormalTool)
  335. # Check that one_call_at_a_time is False or not set
  336. one_call_value = getattr(function_tool, "one_call_at_a_time", False)
  337. assert one_call_value is False
  338. def test_agent_has_concurrency_manager():
  339. """Test that Agent instances have a tool concurrency manager."""
  340. agent = Agent(name="test", instructions="test")
  341. assert hasattr(agent, "tool_concurrency_manager")
  342. assert agent.tool_concurrency_manager is not None
  343. # Test that it's the right type
  344. from agency_swarm.tools.concurrency import ToolConcurrencyManager
  345. assert isinstance(agent.tool_concurrency_manager, ToolConcurrencyManager)
  346. def test_agent_concurrency_manager_independence():
  347. """Test that different agents have independent concurrency managers."""
  348. agent1 = Agent(name="agent1", instructions="test")
  349. agent2 = Agent(name="agent2", instructions="test")
  350. # Should be different instances
  351. assert agent1.tool_concurrency_manager is not agent2.tool_concurrency_manager
  352. # Test independence
  353. agent1.tool_concurrency_manager.acquire_lock("tool1")
  354. busy1, owner1 = agent1.tool_concurrency_manager.is_lock_active()
  355. busy2, owner2 = agent2.tool_concurrency_manager.is_lock_active()
  356. assert busy1 is True and owner1 == "tool1"
  357. assert busy2 is False and owner2 is None
  358. # TODO: Add tests for response validation aspects
  359. # TODO: Add tests for context/hooks propagation (more complex, might need integration tests)
  360. # TODO: Add parameterized tests for various message inputs (empty, long, special chars)
  361. # TODO: Add tests for specific schema validation failures (if FunctionTool provides hooks)