from __future__ import annotations import json from datetime import datetime from types import SimpleNamespace from typing import Any, cast import pytest from agents import Agent as SDKAgent, RunContextWrapper, function_tool as sdk_function_tool from agents.tool_context import ToolContext from fastapi.responses import JSONResponse from pydantic import BaseModel from agency_swarm.integrations.fastapi_utils.tool_endpoints import make_tool_endpoint from agency_swarm.tools import BaseTool class TimestampModel(BaseModel): timestamp: datetime class DummyTypedTool: name = "TimestampTool" openai_schema = { "parameters": { "type": "object", "properties": { "timestamp": { "type": "string", "description": "ISO timestamp", } }, "required": ["timestamp"], } } def __init__(self) -> None: self.calls: list[str] = [] async def on_invoke_tool(self, context, input_json: str): self.calls.append(input_json) return "ok" def _fake_verify_token(): return "token" class _DummyRequest: def __init__(self, payload: dict[str, Any], *, explode: bool = False) -> None: self.payload = payload self.explode = explode async def json(self) -> dict[str, Any]: if self.explode: raise ValueError("bad json") return self.payload class EchoTool(BaseTool): text: str def run(self) -> str: return self.text class AsyncEchoTool(BaseTool): text: str async def run(self) -> str: return self.text.upper() class FailingTool(BaseTool): text: str def run(self) -> str: raise RuntimeError("run failed") @pytest.mark.asyncio async def test_make_tool_endpoint_serializes_non_json_types(monkeypatch): tool = DummyTypedTool() def fake_build_request_model(*_, **__): return TimestampModel monkeypatch.setattr( "agency_swarm.integrations.fastapi_utils.tool_endpoints.build_request_model", fake_build_request_model, ) handler = make_tool_endpoint(tool, verify_token=_fake_verify_token, context=None) request_data = TimestampModel(timestamp="2024-05-01T09:30:00Z") response = await handler(request_data=request_data, token="ignored") assert response == {"response": "ok"} assert tool.calls, "on_invoke_tool should receive serialized payload" payload = json.loads(tool.calls[0]) assert payload == {"timestamp": "2024-05-01T09:30:00Z"} @pytest.mark.asyncio async def test_make_tool_endpoint_generic_handler_for_sync_callable() -> None: def echo(value: str) -> str: return f"echo:{value}" handler = make_tool_endpoint(echo, verify_token=_fake_verify_token, context=None) response = await handler(request=_DummyRequest({"value": "ok"}), token="ignored") assert response == {"response": "echo:ok"} @pytest.mark.asyncio async def test_make_tool_endpoint_generic_handler_for_async_callable() -> None: async def echo_async(value: str) -> str: return f"async:{value}" handler = make_tool_endpoint(echo_async, verify_token=_fake_verify_token, context=None) response = await handler(request=_DummyRequest({"value": "ok"}), token="ignored") assert response == {"response": "async:ok"} @pytest.mark.asyncio async def test_make_tool_endpoint_generic_handler_returns_json_error() -> None: def failing_tool(value: str) -> str: # noqa: ARG001 raise RuntimeError("tool failed") handler = make_tool_endpoint(failing_tool, verify_token=_fake_verify_token, context=None) response = await handler(request=_DummyRequest({"value": "ok"}), token="ignored") assert isinstance(response, JSONResponse) assert response.status_code == 500 assert b"tool failed" in response.body @pytest.mark.asyncio async def test_make_tool_endpoint_generic_handler_handles_invalid_request_json() -> None: handler = make_tool_endpoint(lambda value: value, verify_token=_fake_verify_token, context=None) response = await handler(request=_DummyRequest({"value": "ok"}, explode=True), token="ignored") assert isinstance(response, JSONResponse) assert response.status_code == 500 assert b"bad json" in response.body @pytest.mark.asyncio async def test_make_tool_endpoint_for_base_tool_sync_and_async_runs() -> None: sync_handler = make_tool_endpoint(EchoTool, verify_token=_fake_verify_token, context=None) async_handler = make_tool_endpoint(AsyncEchoTool, verify_token=_fake_verify_token, context=None) sync_response = await sync_handler(request_data=EchoTool(text="hi"), token="ignored") async_response = await async_handler(request_data=AsyncEchoTool(text="hi"), token="ignored") assert sync_response == {"response": "hi"} assert async_response == {"response": "HI"} @pytest.mark.asyncio async def test_make_tool_endpoint_for_base_tool_returns_json_error() -> None: handler = make_tool_endpoint(FailingTool, verify_token=_fake_verify_token, context=None) response = await handler(request_data=FailingTool(text="hi"), token="ignored") assert isinstance(response, JSONResponse) assert response.status_code == 500 assert b"run failed" in response.body class ParamsModel(BaseModel): value: str class AgentToolParamsModel(BaseModel): input: str class ParamsTool: name = "ParamsTool" params_json_schema = {"type": "object", "properties": {"value": {"type": "string"}}, "required": ["value"]} strict_json_schema = True def __init__(self) -> None: self.calls: list[str] = [] async def __call__(self, value: str) -> str: self.calls.append(value) return value.upper() @pytest.mark.asyncio async def test_make_tool_endpoint_supports_params_json_schema(monkeypatch: pytest.MonkeyPatch) -> None: tool = ParamsTool() monkeypatch.setattr( "agency_swarm.integrations.fastapi_utils.tool_endpoints.build_request_model", lambda *_args, **_kwargs: ParamsModel, ) handler = make_tool_endpoint(tool, verify_token=_fake_verify_token, context=None) response = await handler(request_data=ParamsModel(value="ok"), token="ignored") assert response == {"response": "OK"} assert tool.calls == ["ok"] @pytest.mark.asyncio async def test_make_tool_endpoint_invokes_sdk_function_tool_with_manual_tool_context( monkeypatch: pytest.MonkeyPatch, ) -> None: seen_contexts: list[ToolContext[None]] = [] @sdk_function_tool async def endpoint_sdk_tool(ctx: RunContextWrapper[None], value: str) -> str: tool_context = cast(ToolContext[None], ctx) seen_contexts.append(tool_context) return f"{tool_context.tool_name}:{tool_context.tool_arguments}:{value}" monkeypatch.setattr( "agency_swarm.integrations.fastapi_utils.tool_endpoints.build_request_model", lambda *_args, **_kwargs: ParamsModel, ) handler = make_tool_endpoint(endpoint_sdk_tool, verify_token=_fake_verify_token, context=None) response = await handler(request_data=ParamsModel(value="ok"), token="ignored") assert response == {"response": 'endpoint_sdk_tool:{"value":"ok"}:ok'} assert len(seen_contexts) == 1 assert seen_contexts[0].tool_name == "endpoint_sdk_tool" assert seen_contexts[0].tool_arguments == '{"value":"ok"}' @pytest.mark.asyncio async def test_make_tool_endpoint_invokes_sdk_agent_tool_with_manual_tool_context( monkeypatch: pytest.MonkeyPatch, ) -> None: seen_contexts: list[ToolContext[None]] = [] async def fake_run(**kwargs: Any) -> SimpleNamespace: context = kwargs["context"] assert isinstance(context, ToolContext) seen_contexts.append(context) return SimpleNamespace(final_output="nested ok", new_items=[], interruptions=[]) monkeypatch.setattr("agents.Runner.run", fake_run) monkeypatch.setattr( "agency_swarm.integrations.fastapi_utils.tool_endpoints.build_request_model", lambda *_args, **_kwargs: AgentToolParamsModel, ) nested_agent = SDKAgent(name="nested", instructions="Return the input.") nested_tool = nested_agent.as_tool(tool_name="nested_tool", tool_description="Nested tool") handler = make_tool_endpoint(nested_tool, verify_token=_fake_verify_token, context=None) response = await handler(request_data=AgentToolParamsModel(input="ok"), token="ignored") assert response == {"response": "nested ok"} assert len(seen_contexts) == 1 assert seen_contexts[0].tool_name == "nested_tool" assert seen_contexts[0].tool_arguments == '{"input":"ok"}' assert seen_contexts[0].tool_call_id == "agency_swarm_manual_nested_tool"