test_from_mcp_method.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. from unittest.mock import AsyncMock, patch
  2. import pytest
  3. from agents import FunctionTool, ToolOutputImage
  4. from agents.run_context import RunContextWrapper
  5. from agency_swarm.tools.tool_factory import ToolFactory
  6. class _DummyServer:
  7. def __init__(self, name: str = "dummy_server") -> None:
  8. self.name = name
  9. self.connect_calls = 0
  10. self.cleanup_calls = 0
  11. self.session = None
  12. async def connect(self) -> None:
  13. self.connect_calls += 1
  14. self.session = object()
  15. async def cleanup(self) -> None:
  16. self.cleanup_calls += 1
  17. self.session = None
  18. @pytest.mark.asyncio
  19. @patch("agents.mcp.util.MCPUtil.get_function_tools", new_callable=AsyncMock)
  20. @patch("agency_swarm.tools.mcp_manager.default_mcp_manager")
  21. async def test_from_mcp_connects_once_and_reuses_connection(mock_manager, mock_get_function_tools: AsyncMock) -> None:
  22. server = _DummyServer()
  23. original_invoke = AsyncMock(return_value="payload")
  24. function_tool = FunctionTool(
  25. name="echo",
  26. description="test tool",
  27. params_json_schema={"type": "object", "properties": {}},
  28. on_invoke_tool=original_invoke,
  29. strict_json_schema=False,
  30. )
  31. mock_get_function_tools.return_value = [function_tool]
  32. mock_manager.register.side_effect = lambda srv: srv
  33. async def fake_ensure(srv):
  34. await srv.connect()
  35. mock_manager.ensure_connected = AsyncMock(side_effect=fake_ensure)
  36. mock_manager.get.return_value = server
  37. # Test that from_mcp returns FunctionTool instances
  38. tools = ToolFactory.from_mcp([server])
  39. assert len(tools) == 1
  40. assert server.connect_calls == 1
  41. assert server.cleanup_calls == 0
  42. # Verify the tool is wrapped with error handling but still delegates to original
  43. ctx = RunContextWrapper(context=None)
  44. result = await tools[0].on_invoke_tool(ctx, "{}")
  45. assert result == "payload"
  46. original_invoke.assert_called_once()
  47. @pytest.mark.asyncio
  48. @patch("agents.mcp.util.MCPUtil.get_function_tools", new_callable=AsyncMock)
  49. @patch("agency_swarm.tools.mcp_manager.default_mcp_manager")
  50. async def test_from_mcp_tools_are_invokable(mock_manager, mock_get_function_tools: AsyncMock) -> None:
  51. """Test that tools converted from MCP servers can be invoked correctly."""
  52. async def mock_invoke(ctx, input_json: str):
  53. return f"Echo: {input_json}"
  54. function_tool = FunctionTool(
  55. name="echo",
  56. description="test tool",
  57. params_json_schema={"type": "object", "properties": {"message": {"type": "string"}}},
  58. on_invoke_tool=mock_invoke,
  59. strict_json_schema=False,
  60. )
  61. mock_get_function_tools.return_value = [function_tool]
  62. server = _DummyServer()
  63. mock_manager.register.return_value = server
  64. mock_manager.ensure_connected = AsyncMock()
  65. mock_manager.get.return_value = server
  66. # Test that from_mcp returns FunctionTool instances
  67. tools = ToolFactory.from_mcp([server])
  68. assert len(tools) == 1
  69. tool = tools[0]
  70. # Verify tool properties
  71. assert tool.name == "echo"
  72. assert tool.description == "test tool"
  73. assert "message" in tool.params_json_schema["properties"]
  74. # Invoke the tool and verify it works
  75. ctx = RunContextWrapper(context=None)
  76. result = await tool.on_invoke_tool(ctx, '{"message": "hello"}')
  77. assert result == 'Echo: {"message": "hello"}'
  78. # Invoke again to verify tool can be called multiple times
  79. result2 = await tool.on_invoke_tool(ctx, '{"message": "world"}')
  80. assert result2 == 'Echo: {"message": "world"}'
  81. @pytest.mark.asyncio
  82. @patch("agents.mcp.util.MCPUtil.get_function_tools", new_callable=AsyncMock)
  83. @patch("agency_swarm.tools.mcp_manager.default_mcp_manager")
  84. async def test_from_mcp_function_tools_preserve_structured_outputs(
  85. mock_manager, mock_get_function_tools: AsyncMock
  86. ) -> None:
  87. """FunctionTool instances from MCP must preserve structured outputs like ToolOutputImage."""
  88. image_output = ToolOutputImage(image_url="https://example.com/sample.png")
  89. async def mock_invoke(ctx, input_json: str):
  90. return image_output
  91. function_tool = FunctionTool(
  92. name="structured",
  93. description="returns structured output",
  94. params_json_schema={"type": "object", "properties": {}},
  95. on_invoke_tool=mock_invoke,
  96. strict_json_schema=False,
  97. )
  98. mock_get_function_tools.return_value = [function_tool]
  99. server = _DummyServer()
  100. mock_manager.register.return_value = server
  101. mock_manager.ensure_connected = AsyncMock()
  102. mock_manager.get.return_value = server
  103. # Get FunctionTool instances from MCP
  104. tools = ToolFactory.from_mcp([server])
  105. assert len(tools) == 1
  106. tool = tools[0]
  107. # Verify the tool preserves structured outputs
  108. ctx = RunContextWrapper(context=None)
  109. result = await tool.on_invoke_tool(ctx, "{}")
  110. assert result is image_output
  111. @pytest.mark.asyncio
  112. @patch("agents.mcp.util.MCPUtil.get_function_tools", new_callable=AsyncMock)
  113. @patch("agency_swarm.tools.mcp_manager.default_mcp_manager")
  114. async def test_from_mcp_tools_catch_exceptions_and_return_error_strings(
  115. mock_manager, mock_get_function_tools: AsyncMock
  116. ) -> None:
  117. """MCP tools should catch exceptions and return error strings instead of propagating."""
  118. async def mock_invoke_that_raises(ctx, input_json: str):
  119. raise TimeoutError("Connection timed out after 5 seconds")
  120. function_tool = FunctionTool(
  121. name="failing_tool",
  122. description="a tool that fails",
  123. params_json_schema={"type": "object", "properties": {}},
  124. on_invoke_tool=mock_invoke_that_raises,
  125. strict_json_schema=False,
  126. )
  127. mock_get_function_tools.return_value = [function_tool]
  128. server = _DummyServer()
  129. mock_manager.register.return_value = server
  130. mock_manager.ensure_connected = AsyncMock()
  131. mock_manager.get.return_value = server
  132. # Get FunctionTool instances from MCP
  133. tools = ToolFactory.from_mcp([server])
  134. assert len(tools) == 1
  135. tool = tools[0]
  136. # Invoke the tool - should NOT raise, instead return error string
  137. ctx = RunContextWrapper(context=None)
  138. result = await tool.on_invoke_tool(ctx, "{}")
  139. # Verify error is returned as string (using SDK's default_tool_error_function format)
  140. assert isinstance(result, str)
  141. assert "error" in result.lower()
  142. assert "Connection timed out after 5 seconds" in result