| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255 |
- from __future__ import annotations
- import json
- from datetime import datetime
- from types import SimpleNamespace
- from typing import Any, cast
- import pytest
- from agents import Agent as SDKAgent, RunContextWrapper, function_tool as sdk_function_tool
- from agents.tool_context import ToolContext
- from fastapi.responses import JSONResponse
- from pydantic import BaseModel
- from agency_swarm.integrations.fastapi_utils.tool_endpoints import make_tool_endpoint
- from agency_swarm.tools import BaseTool
- class TimestampModel(BaseModel):
- timestamp: datetime
- class DummyTypedTool:
- name = "TimestampTool"
- openai_schema = {
- "parameters": {
- "type": "object",
- "properties": {
- "timestamp": {
- "type": "string",
- "description": "ISO timestamp",
- }
- },
- "required": ["timestamp"],
- }
- }
- def __init__(self) -> None:
- self.calls: list[str] = []
- async def on_invoke_tool(self, context, input_json: str):
- self.calls.append(input_json)
- return "ok"
- def _fake_verify_token():
- return "token"
- class _DummyRequest:
- def __init__(self, payload: dict[str, Any], *, explode: bool = False) -> None:
- self.payload = payload
- self.explode = explode
- async def json(self) -> dict[str, Any]:
- if self.explode:
- raise ValueError("bad json")
- return self.payload
- class EchoTool(BaseTool):
- text: str
- def run(self) -> str:
- return self.text
- class AsyncEchoTool(BaseTool):
- text: str
- async def run(self) -> str:
- return self.text.upper()
- class FailingTool(BaseTool):
- text: str
- def run(self) -> str:
- raise RuntimeError("run failed")
- @pytest.mark.asyncio
- async def test_make_tool_endpoint_serializes_non_json_types(monkeypatch):
- tool = DummyTypedTool()
- def fake_build_request_model(*_, **__):
- return TimestampModel
- monkeypatch.setattr(
- "agency_swarm.integrations.fastapi_utils.tool_endpoints.build_request_model",
- fake_build_request_model,
- )
- handler = make_tool_endpoint(tool, verify_token=_fake_verify_token, context=None)
- request_data = TimestampModel(timestamp="2024-05-01T09:30:00Z")
- response = await handler(request_data=request_data, token="ignored")
- assert response == {"response": "ok"}
- assert tool.calls, "on_invoke_tool should receive serialized payload"
- payload = json.loads(tool.calls[0])
- assert payload == {"timestamp": "2024-05-01T09:30:00Z"}
- @pytest.mark.asyncio
- async def test_make_tool_endpoint_generic_handler_for_sync_callable() -> None:
- def echo(value: str) -> str:
- return f"echo:{value}"
- handler = make_tool_endpoint(echo, verify_token=_fake_verify_token, context=None)
- response = await handler(request=_DummyRequest({"value": "ok"}), token="ignored")
- assert response == {"response": "echo:ok"}
- @pytest.mark.asyncio
- async def test_make_tool_endpoint_generic_handler_for_async_callable() -> None:
- async def echo_async(value: str) -> str:
- return f"async:{value}"
- handler = make_tool_endpoint(echo_async, verify_token=_fake_verify_token, context=None)
- response = await handler(request=_DummyRequest({"value": "ok"}), token="ignored")
- assert response == {"response": "async:ok"}
- @pytest.mark.asyncio
- async def test_make_tool_endpoint_generic_handler_returns_json_error() -> None:
- def failing_tool(value: str) -> str: # noqa: ARG001
- raise RuntimeError("tool failed")
- handler = make_tool_endpoint(failing_tool, verify_token=_fake_verify_token, context=None)
- response = await handler(request=_DummyRequest({"value": "ok"}), token="ignored")
- assert isinstance(response, JSONResponse)
- assert response.status_code == 500
- assert b"tool failed" in response.body
- @pytest.mark.asyncio
- async def test_make_tool_endpoint_generic_handler_handles_invalid_request_json() -> None:
- handler = make_tool_endpoint(lambda value: value, verify_token=_fake_verify_token, context=None)
- response = await handler(request=_DummyRequest({"value": "ok"}, explode=True), token="ignored")
- assert isinstance(response, JSONResponse)
- assert response.status_code == 500
- assert b"bad json" in response.body
- @pytest.mark.asyncio
- async def test_make_tool_endpoint_for_base_tool_sync_and_async_runs() -> None:
- sync_handler = make_tool_endpoint(EchoTool, verify_token=_fake_verify_token, context=None)
- async_handler = make_tool_endpoint(AsyncEchoTool, verify_token=_fake_verify_token, context=None)
- sync_response = await sync_handler(request_data=EchoTool(text="hi"), token="ignored")
- async_response = await async_handler(request_data=AsyncEchoTool(text="hi"), token="ignored")
- assert sync_response == {"response": "hi"}
- assert async_response == {"response": "HI"}
- @pytest.mark.asyncio
- async def test_make_tool_endpoint_for_base_tool_returns_json_error() -> None:
- handler = make_tool_endpoint(FailingTool, verify_token=_fake_verify_token, context=None)
- response = await handler(request_data=FailingTool(text="hi"), token="ignored")
- assert isinstance(response, JSONResponse)
- assert response.status_code == 500
- assert b"run failed" in response.body
- class ParamsModel(BaseModel):
- value: str
- class AgentToolParamsModel(BaseModel):
- input: str
- class ParamsTool:
- name = "ParamsTool"
- params_json_schema = {"type": "object", "properties": {"value": {"type": "string"}}, "required": ["value"]}
- strict_json_schema = True
- def __init__(self) -> None:
- self.calls: list[str] = []
- async def __call__(self, value: str) -> str:
- self.calls.append(value)
- return value.upper()
- @pytest.mark.asyncio
- async def test_make_tool_endpoint_supports_params_json_schema(monkeypatch: pytest.MonkeyPatch) -> None:
- tool = ParamsTool()
- monkeypatch.setattr(
- "agency_swarm.integrations.fastapi_utils.tool_endpoints.build_request_model",
- lambda *_args, **_kwargs: ParamsModel,
- )
- handler = make_tool_endpoint(tool, verify_token=_fake_verify_token, context=None)
- response = await handler(request_data=ParamsModel(value="ok"), token="ignored")
- assert response == {"response": "OK"}
- assert tool.calls == ["ok"]
- @pytest.mark.asyncio
- async def test_make_tool_endpoint_invokes_sdk_function_tool_with_manual_tool_context(
- monkeypatch: pytest.MonkeyPatch,
- ) -> None:
- seen_contexts: list[ToolContext[None]] = []
- @sdk_function_tool
- async def endpoint_sdk_tool(ctx: RunContextWrapper[None], value: str) -> str:
- tool_context = cast(ToolContext[None], ctx)
- seen_contexts.append(tool_context)
- return f"{tool_context.tool_name}:{tool_context.tool_arguments}:{value}"
- monkeypatch.setattr(
- "agency_swarm.integrations.fastapi_utils.tool_endpoints.build_request_model",
- lambda *_args, **_kwargs: ParamsModel,
- )
- handler = make_tool_endpoint(endpoint_sdk_tool, verify_token=_fake_verify_token, context=None)
- response = await handler(request_data=ParamsModel(value="ok"), token="ignored")
- assert response == {"response": 'endpoint_sdk_tool:{"value":"ok"}:ok'}
- assert len(seen_contexts) == 1
- assert seen_contexts[0].tool_name == "endpoint_sdk_tool"
- assert seen_contexts[0].tool_arguments == '{"value":"ok"}'
- @pytest.mark.asyncio
- async def test_make_tool_endpoint_invokes_sdk_agent_tool_with_manual_tool_context(
- monkeypatch: pytest.MonkeyPatch,
- ) -> None:
- seen_contexts: list[ToolContext[None]] = []
- async def fake_run(**kwargs: Any) -> SimpleNamespace:
- context = kwargs["context"]
- assert isinstance(context, ToolContext)
- seen_contexts.append(context)
- return SimpleNamespace(final_output="nested ok", new_items=[], interruptions=[])
- monkeypatch.setattr("agents.Runner.run", fake_run)
- monkeypatch.setattr(
- "agency_swarm.integrations.fastapi_utils.tool_endpoints.build_request_model",
- lambda *_args, **_kwargs: AgentToolParamsModel,
- )
- nested_agent = SDKAgent(name="nested", instructions="Return the input.")
- nested_tool = nested_agent.as_tool(tool_name="nested_tool", tool_description="Nested tool")
- handler = make_tool_endpoint(nested_tool, verify_token=_fake_verify_token, context=None)
- response = await handler(request_data=AgentToolParamsModel(input="ok"), token="ignored")
- assert response == {"response": "nested ok"}
- assert len(seen_contexts) == 1
- assert seen_contexts[0].tool_name == "nested_tool"
- assert seen_contexts[0].tool_arguments == '{"input":"ok"}'
- assert seen_contexts[0].tool_call_id == "agency_swarm_manual_nested_tool"
|