test_mcp_manager.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. from __future__ import annotations
  2. import asyncio
  3. import logging
  4. from types import SimpleNamespace
  5. from typing import Any
  6. from unittest.mock import patch
  7. import pytest
  8. import agency_swarm.tools.mcp_manager as mcp_manager
  9. from agency_swarm.tools.mcp_manager import LoopAffineAsyncProxy, PersistentMCPServerManager
  10. class _DummyServer:
  11. def __init__(self, name: str = "dummy") -> None:
  12. self.name = name
  13. self.session = None
  14. self.connect_calls = 0
  15. async def connect(self) -> None:
  16. self.connect_calls += 1
  17. self.session = object()
  18. async def cleanup(self) -> None:
  19. self.session = None
  20. class _SlowCleanupServer(_DummyServer):
  21. def __init__(self, delay: float) -> None:
  22. super().__init__()
  23. self.delay = delay
  24. async def cleanup(self) -> None:
  25. await asyncio.sleep(self.delay)
  26. await super().cleanup()
  27. class _FailingCleanupServer(_DummyServer):
  28. async def cleanup(self) -> None:
  29. raise RuntimeError("cleanup failed")
  30. class _AsyncContextServer(_DummyServer):
  31. def __init__(self) -> None:
  32. super().__init__()
  33. self.context_entered = 0
  34. self.context_exited = 0
  35. async def __aenter__(self) -> _AsyncContextServer:
  36. self.context_entered += 1
  37. return self
  38. async def __aexit__(self, exc_type, exc, tb) -> bool: # noqa: ANN001
  39. self.context_exited += 1
  40. return False
  41. class _SyncContextServer(_DummyServer):
  42. def __init__(self) -> None:
  43. super().__init__()
  44. self.context_entered = 0
  45. self.context_exited = 0
  46. self.value = 42
  47. def __aenter__(self) -> _SyncContextServer:
  48. self.context_entered += 1
  49. return self
  50. def __aexit__(self, exc_type, exc, tb) -> bool: # noqa: ANN001
  51. self.context_exited += 1
  52. return False
  53. async def ping(self, payload: str) -> str:
  54. return f"pong:{payload}"
  55. @pytest.mark.asyncio
  56. async def test_ensure_connected_reuses_driver_for_proxy() -> None:
  57. manager = PersistentMCPServerManager()
  58. server = _DummyServer()
  59. await manager.ensure_connected(server)
  60. proxy = LoopAffineAsyncProxy(server, manager)
  61. await manager.ensure_connected(proxy)
  62. try:
  63. assert len(manager._drivers) == 1
  64. finally:
  65. await manager.shutdown()
  66. @pytest.mark.asyncio
  67. async def test_reconnect_replaces_driver_and_resets_session() -> None:
  68. manager = PersistentMCPServerManager()
  69. server = _DummyServer()
  70. await manager.ensure_connected(server)
  71. old_state = manager._drivers[server]
  72. try:
  73. proxy = LoopAffineAsyncProxy(server, manager)
  74. await manager.reconnect(proxy)
  75. assert server.session is not None
  76. assert manager._drivers[server] is not old_state
  77. finally:
  78. await manager.shutdown()
  79. @pytest.mark.asyncio
  80. async def test_shutdown_handles_cleanup_timeout() -> None:
  81. manager = PersistentMCPServerManager()
  82. server = _SlowCleanupServer(delay=0.1)
  83. manager._timeouts["cleanup"] = 0.01
  84. await manager.ensure_connected(server)
  85. try:
  86. await manager.shutdown()
  87. except TimeoutError as exc: # pragma: no cover - current behavior under test
  88. pytest.fail(f"shutdown should not propagate TimeoutError: {exc}")
  89. assert manager._drivers == {}
  90. @pytest.mark.asyncio
  91. async def test_shutdown_logs_cleanup_exception(caplog: pytest.LogCaptureFixture) -> None:
  92. manager = PersistentMCPServerManager()
  93. server = _FailingCleanupServer()
  94. await manager.ensure_connected(server)
  95. with caplog.at_level(logging.WARNING):
  96. await manager.shutdown()
  97. assert "Error during MCP server 'dummy' shutdown" in caplog.text
  98. @pytest.mark.asyncio
  99. async def test_proxy_supports_async_context_manager() -> None:
  100. manager = PersistentMCPServerManager()
  101. server = _AsyncContextServer()
  102. await manager.ensure_connected(server)
  103. proxy = LoopAffineAsyncProxy(server, manager)
  104. try:
  105. async with proxy as acquired:
  106. assert acquired is server
  107. assert server.context_entered == 1
  108. assert server.context_exited == 0
  109. finally:
  110. await manager.shutdown()
  111. assert server.context_entered == 1
  112. assert server.context_exited == 1
  113. @pytest.mark.asyncio
  114. async def test_proxy_rejects_missing_context_methods() -> None:
  115. manager = PersistentMCPServerManager()
  116. proxy = LoopAffineAsyncProxy(_DummyServer(), manager)
  117. with pytest.raises(TypeError, match="does not support asynchronous context management"):
  118. await proxy.__aenter__()
  119. with pytest.raises(TypeError, match="does not support asynchronous context management"):
  120. await proxy.__aexit__(None, None, None)
  121. @pytest.mark.asyncio
  122. async def test_proxy_supports_sync_context_and_method_proxying() -> None:
  123. manager = PersistentMCPServerManager()
  124. server = _SyncContextServer()
  125. await manager.ensure_connected(server)
  126. proxy = LoopAffineAsyncProxy(server, manager)
  127. try:
  128. acquired = await proxy.__aenter__()
  129. assert acquired is server
  130. assert server.context_entered == 1
  131. assert await proxy.ping("ok") == "pong:ok"
  132. assert proxy.value == 42
  133. await proxy.__aexit__(None, None, None)
  134. assert server.context_exited == 1
  135. finally:
  136. await manager.shutdown()
  137. def test_register_get_all_and_mark_atexit() -> None:
  138. manager = PersistentMCPServerManager()
  139. named = _DummyServer(name="persisted")
  140. duplicate = _DummyServer(name="persisted")
  141. unnamed = _DummyServer(name="")
  142. assert manager.register(named) is named
  143. assert manager.register(duplicate) is named
  144. assert manager.register(unnamed) is unnamed
  145. assert manager.get("persisted") is named
  146. assert manager.get("missing") is None
  147. assert manager.all() == [named]
  148. assert manager.mark_atexit_registered() is True
  149. assert manager.mark_atexit_registered() is False
  150. def test_shutdown_sync_noop_when_lock_already_held() -> None:
  151. manager = PersistentMCPServerManager()
  152. assert manager._sync_shutdown_lock.acquire(blocking=False) is True
  153. try:
  154. manager.shutdown_sync()
  155. finally:
  156. manager._sync_shutdown_lock.release()
  157. def test_shutdown_sync_logs_non_loop_runtime_error(
  158. monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
  159. ) -> None:
  160. manager = PersistentMCPServerManager()
  161. def fake_run(coro: Any) -> None:
  162. coro.close()
  163. raise RuntimeError("other runtime")
  164. monkeypatch.setattr(mcp_manager.asyncio, "run", fake_run)
  165. with caplog.at_level(logging.WARNING):
  166. manager.shutdown_sync()
  167. assert "Error during persistent MCP manager shutdown: other runtime" in caplog.text
  168. def test_shutdown_sync_schedules_shutdown_when_loop_running(monkeypatch: pytest.MonkeyPatch) -> None:
  169. manager = PersistentMCPServerManager()
  170. scheduled: list[Any] = []
  171. def fake_run(coro: Any) -> None:
  172. coro.close()
  173. raise RuntimeError("asyncio.run() cannot be called from a running event loop")
  174. class _FakeLoop:
  175. def create_task(self, coro: Any) -> None:
  176. scheduled.append(coro)
  177. coro.close()
  178. monkeypatch.setattr(mcp_manager.asyncio, "run", fake_run)
  179. monkeypatch.setattr(mcp_manager.asyncio, "get_running_loop", lambda: _FakeLoop())
  180. manager.shutdown_sync()
  181. assert len(scheduled) == 1
  182. def test_shutdown_sync_logs_when_loop_lookup_fails(
  183. monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
  184. ) -> None:
  185. manager = PersistentMCPServerManager()
  186. def fake_run(coro: Any) -> None:
  187. coro.close()
  188. raise RuntimeError("asyncio.run() cannot be called from a running event loop")
  189. def fake_get_running_loop() -> Any:
  190. raise RuntimeError("no running loop")
  191. monkeypatch.setattr(mcp_manager.asyncio, "run", fake_run)
  192. monkeypatch.setattr(mcp_manager.asyncio, "get_running_loop", fake_get_running_loop)
  193. with caplog.at_level(logging.WARNING):
  194. manager.shutdown_sync()
  195. assert "Error during persistent MCP manager shutdown: no running loop" in caplog.text
  196. def test_shutdown_sync_logs_unexpected_exception(
  197. monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
  198. ) -> None:
  199. manager = PersistentMCPServerManager()
  200. def fake_run(coro: Any) -> None:
  201. coro.close()
  202. raise ValueError("boom")
  203. monkeypatch.setattr(mcp_manager.asyncio, "run", fake_run)
  204. with caplog.at_level(logging.WARNING):
  205. manager.shutdown_sync()
  206. assert "Error during persistent MCP manager shutdown: boom" in caplog.text
  207. @pytest.mark.asyncio
  208. async def test_attach_persistent_mcp_servers_handles_invalid_shapes(monkeypatch: pytest.MonkeyPatch) -> None:
  209. fake_manager = SimpleNamespace(
  210. get=lambda _name: None,
  211. register=lambda server: server,
  212. ensure_connected=lambda _server: asyncio.sleep(0),
  213. )
  214. monkeypatch.setattr(mcp_manager, "default_mcp_manager", fake_manager)
  215. await mcp_manager.attach_persistent_mcp_servers(SimpleNamespace(agents=None))
  216. await mcp_manager.attach_persistent_mcp_servers(
  217. SimpleNamespace(agents={"agent": SimpleNamespace(mcp_servers="not-a-list")})
  218. )
  219. @pytest.mark.asyncio
  220. async def test_attach_persistent_mcp_servers_rejects_missing_name(monkeypatch: pytest.MonkeyPatch) -> None:
  221. fake_manager = SimpleNamespace(
  222. get=lambda _name: None,
  223. register=lambda server: server,
  224. ensure_connected=lambda _server: asyncio.sleep(0),
  225. )
  226. monkeypatch.setattr(mcp_manager, "default_mcp_manager", fake_manager)
  227. agency = SimpleNamespace(agents={"agent": SimpleNamespace(mcp_servers=[SimpleNamespace(name="")])})
  228. with pytest.raises(ValueError, match="has no name provided"):
  229. await mcp_manager.attach_persistent_mcp_servers(agency)
  230. @pytest.mark.asyncio
  231. async def test_attach_persistent_mcp_servers_registers_and_connects(monkeypatch: pytest.MonkeyPatch) -> None:
  232. connected: list[Any] = []
  233. store: dict[str, Any] = {}
  234. class _FakeManager:
  235. def get(self, name: str) -> Any | None:
  236. return store.get(name)
  237. def register(self, server: Any) -> Any:
  238. store[server.name] = server
  239. return server
  240. async def ensure_connected(self, server: Any) -> None:
  241. connected.append(server)
  242. monkeypatch.setattr(mcp_manager, "default_mcp_manager", _FakeManager())
  243. server_a = _DummyServer(name="a")
  244. server_b = _DummyServer(name="b")
  245. agent = SimpleNamespace(mcp_servers=[server_a, server_b])
  246. agency = SimpleNamespace(agents={"agent": agent})
  247. await mcp_manager.attach_persistent_mcp_servers(agency)
  248. assert all(isinstance(server, LoopAffineAsyncProxy) for server in agent.mcp_servers)
  249. assert len(connected) == 2
  250. def test_register_and_connect_agent_servers_validates_inputs(monkeypatch: pytest.MonkeyPatch) -> None:
  251. ensure_driver_calls: list[Any] = []
  252. class _FakeManager:
  253. def get(self, _name: str) -> None:
  254. return None
  255. def register(self, server: Any) -> Any:
  256. return server
  257. def _ensure_driver(self, server: Any) -> None:
  258. ensure_driver_calls.append(server)
  259. monkeypatch.setattr(mcp_manager, "default_mcp_manager", _FakeManager())
  260. mcp_manager.register_and_connect_agent_servers(SimpleNamespace(mcp_servers=None))
  261. assert ensure_driver_calls == []
  262. with pytest.raises(ValueError, match="duplicate name"):
  263. mcp_manager.register_and_connect_agent_servers(
  264. SimpleNamespace(mcp_servers=[_DummyServer(name="same"), _DummyServer(name="same")])
  265. )
  266. with pytest.raises(ValueError, match="has no name provided"):
  267. mcp_manager.register_and_connect_agent_servers(SimpleNamespace(mcp_servers=[SimpleNamespace(name="")]))
  268. def test_register_and_connect_agent_servers_reuses_persistent_instances(monkeypatch: pytest.MonkeyPatch) -> None:
  269. existing = _DummyServer(name="existing")
  270. registered: list[Any] = []
  271. ensured: list[Any] = []
  272. class _FakeManager:
  273. def get(self, name: str) -> Any | None:
  274. if name == "existing":
  275. return existing
  276. return None
  277. def register(self, server: Any) -> Any:
  278. registered.append(server)
  279. return server
  280. def _ensure_driver(self, server: Any) -> None:
  281. ensured.append(server)
  282. monkeypatch.setattr(mcp_manager, "default_mcp_manager", _FakeManager())
  283. agent = SimpleNamespace(mcp_servers=[_DummyServer(name="existing"), _DummyServer(name="new")])
  284. mcp_manager.register_and_connect_agent_servers(agent)
  285. assert all(isinstance(server, LoopAffineAsyncProxy) for server in agent.mcp_servers)
  286. assert len(registered) == 1
  287. assert registered[0].name == "new"
  288. assert ensured == [existing, registered[0]]
  289. def test_convert_mcp_servers_to_tools(monkeypatch: pytest.MonkeyPatch) -> None:
  290. added_tools: list[str] = []
  291. agent = SimpleNamespace(
  292. mcp_servers=["server"],
  293. mcp_config={"convert_schemas_to_strict": True},
  294. add_tool=lambda tool: added_tools.append(tool),
  295. )
  296. with patch("agency_swarm.tools.tool_factory.ToolFactory.from_mcp", return_value=["a", "b"]) as mock_from_mcp:
  297. mcp_manager.convert_mcp_servers_to_tools(agent)
  298. assert mock_from_mcp.call_count == 1
  299. assert mock_from_mcp.call_args.kwargs == {
  300. "convert_schemas_to_strict": True,
  301. "context": None,
  302. "agent": agent,
  303. }
  304. assert added_tools == ["a", "b"]
  305. assert agent.mcp_servers == []