test_mcp_config_strict_mode.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. """Test that mcp_config convert_schemas_to_strict is respected."""
  2. from unittest.mock import AsyncMock, patch
  3. import pytest
  4. from agents import FunctionTool
  5. class _DummyServer:
  6. def __init__(self) -> None:
  7. self.name = "test_server"
  8. @pytest.mark.asyncio
  9. @pytest.mark.parametrize(
  10. ("mcp_config", "expected_strict"),
  11. [
  12. pytest.param({"convert_schemas_to_strict": False}, False, id="explicit-false"),
  13. pytest.param({"convert_schemas_to_strict": True}, True, id="explicit-true"),
  14. pytest.param(None, False, id="default-missing"),
  15. pytest.param({}, False, id="default-empty"),
  16. ],
  17. )
  18. @patch("agents.mcp.util.MCPUtil.get_function_tools")
  19. @patch("agency_swarm.tools.mcp_converter.default_mcp_manager")
  20. async def test_mcp_config_convert_schemas_to_strict_is_propagated(
  21. mock_manager,
  22. mock_get_function_tools: AsyncMock,
  23. mcp_config: dict[str, bool] | None,
  24. expected_strict: bool,
  25. ) -> None:
  26. from agency_swarm import Agent
  27. test_tool = FunctionTool(
  28. name="test_tool",
  29. description="Test tool",
  30. params_json_schema={"type": "object", "properties": {}},
  31. on_invoke_tool=AsyncMock(return_value="test"),
  32. strict_json_schema=expected_strict,
  33. )
  34. observed_convert_values: list[bool] = []
  35. async def capture_convert_schemas_to_strict(server, strict, context, agent):
  36. observed_convert_values.append(strict)
  37. return [test_tool]
  38. mock_get_function_tools.side_effect = capture_convert_schemas_to_strict
  39. server = _DummyServer()
  40. mock_manager.get.return_value = None
  41. mock_manager.register.side_effect = lambda srv: srv
  42. mock_manager._ensure_driver.return_value = None
  43. agent_kwargs = {
  44. "name": "TestAgent",
  45. "mcp_servers": [server],
  46. }
  47. if mcp_config is not None:
  48. agent_kwargs["mcp_config"] = mcp_config
  49. Agent(**agent_kwargs)
  50. assert len(observed_convert_values) == 1
  51. observed_value = observed_convert_values[0]
  52. assert type(observed_value) is bool
  53. assert observed_value is expected_strict