test_mcp_integration.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import asyncio
  2. import logging
  3. import os
  4. import pytest
  5. from agents import ModelSettings
  6. from agents.mcp.server import MCPServerStdio
  7. from dotenv import load_dotenv
  8. from agency_swarm import Agency, Agent
  9. from agency_swarm.tools.mcp_manager import LoopAffineAsyncProxy, PersistentMCPServerManager
  10. load_dotenv(override=True)
  11. logger = logging.getLogger(__name__)
  12. def _stdio_server_path() -> str:
  13. # Use the test stdio server script bundled in tests/data/scripts
  14. this_dir = os.path.dirname(os.path.abspath(__file__))
  15. return os.path.abspath(os.path.join(this_dir, "..", "..", "data", "scripts", "stdio_server.py"))
  16. def _agency_factory() -> Agency:
  17. stdio_server = MCPServerStdio(
  18. name="Test_STDIO_Server",
  19. params={
  20. "command": "python",
  21. "args": [_stdio_server_path()],
  22. },
  23. client_session_timeout_seconds=15,
  24. )
  25. agent = Agent(
  26. name="MCP StdIO Agent",
  27. model_settings=ModelSettings(temperature=0),
  28. mcp_servers=[stdio_server],
  29. )
  30. return Agency(
  31. agent,
  32. name="mcp_stdio_agency",
  33. user_context={"session_id": "mcp_stdio_session"},
  34. shared_instructions="Test MCP StdIO Integration",
  35. )
  36. @pytest.mark.asyncio
  37. async def test_mcp_stdio_get_response(caplog):
  38. agency = _agency_factory()
  39. with caplog.at_level(logging.ERROR):
  40. res = await agency.get_response("What tools do you have?")
  41. assert "greet" in res.final_output.lower() and "add" in res.final_output.lower()
  42. # ensure no MCP cleanup error logs were emitted
  43. err_msgs = [rec.getMessage() for rec in caplog.records]
  44. assert not any(
  45. ("Attempted to exit cancel scope in a different task than it was entered in" in msg)
  46. or ("Error cleaning up server:" in msg)
  47. for msg in err_msgs
  48. ), f"Found MCP cleanup error logs: {err_msgs}"
  49. @pytest.mark.asyncio
  50. async def test_mcp_proxy_enters_async_context_when_session_reset():
  51. manager = PersistentMCPServerManager()
  52. server = MCPServerStdio(
  53. name="Test_STDIO_Server_Context",
  54. params={
  55. "command": "python",
  56. "args": [_stdio_server_path()],
  57. },
  58. client_session_timeout_seconds=15,
  59. )
  60. await manager.ensure_connected(server)
  61. proxy = LoopAffineAsyncProxy(server, manager)
  62. server.session = None
  63. try:
  64. async with proxy as acquired:
  65. assert acquired is server
  66. tools = await proxy.list_tools()
  67. names = [tool.name for tool in tools]
  68. assert "greet" in names
  69. assert server.session is not None
  70. finally:
  71. await manager.shutdown()
  72. @pytest.mark.asyncio
  73. async def test_mcp_stdio_get_response_stream(caplog):
  74. agency = _agency_factory()
  75. saw_any_event = False
  76. saw_error = False
  77. async def _consume_stream():
  78. nonlocal saw_any_event, saw_error
  79. async for ev in agency.get_response_stream("What tools do you have?"):
  80. saw_any_event = True
  81. if isinstance(ev, dict) and ev.get("type") == "error":
  82. saw_error = True
  83. with caplog.at_level(logging.ERROR):
  84. try:
  85. await asyncio.wait_for(_consume_stream(), timeout=30)
  86. except asyncio.TimeoutError: # noqa: UP041
  87. pytest.fail("Streaming timed out; possible hang in MCP streaming handling")
  88. assert saw_any_event, "Expected at least one streaming event"
  89. assert not saw_error, "Received error event during MCP streaming"
  90. # ensure no MCP cleanup error logs were emitted
  91. err_msgs = [rec.getMessage() for rec in caplog.records]
  92. assert not any(
  93. ("Attempted to exit cancel scope in a different task than it was entered in" in msg)
  94. or ("Error cleaning up server:" in msg)
  95. for msg in err_msgs
  96. ), f"Found MCP cleanup error logs: {err_msgs}"