test_tool_endpoints.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. from __future__ import annotations
  2. import json
  3. from datetime import datetime
  4. from types import SimpleNamespace
  5. from typing import Any, cast
  6. import pytest
  7. from agents import Agent as SDKAgent, RunContextWrapper, function_tool as sdk_function_tool
  8. from agents.tool_context import ToolContext
  9. from fastapi.responses import JSONResponse
  10. from pydantic import BaseModel
  11. from agency_swarm.integrations.fastapi_utils.tool_endpoints import make_tool_endpoint
  12. from agency_swarm.tools import BaseTool
  13. class TimestampModel(BaseModel):
  14. timestamp: datetime
  15. class DummyTypedTool:
  16. name = "TimestampTool"
  17. openai_schema = {
  18. "parameters": {
  19. "type": "object",
  20. "properties": {
  21. "timestamp": {
  22. "type": "string",
  23. "description": "ISO timestamp",
  24. }
  25. },
  26. "required": ["timestamp"],
  27. }
  28. }
  29. def __init__(self) -> None:
  30. self.calls: list[str] = []
  31. async def on_invoke_tool(self, context, input_json: str):
  32. self.calls.append(input_json)
  33. return "ok"
  34. def _fake_verify_token():
  35. return "token"
  36. class _DummyRequest:
  37. def __init__(self, payload: dict[str, Any], *, explode: bool = False) -> None:
  38. self.payload = payload
  39. self.explode = explode
  40. async def json(self) -> dict[str, Any]:
  41. if self.explode:
  42. raise ValueError("bad json")
  43. return self.payload
  44. class EchoTool(BaseTool):
  45. text: str
  46. def run(self) -> str:
  47. return self.text
  48. class AsyncEchoTool(BaseTool):
  49. text: str
  50. async def run(self) -> str:
  51. return self.text.upper()
  52. class FailingTool(BaseTool):
  53. text: str
  54. def run(self) -> str:
  55. raise RuntimeError("run failed")
  56. @pytest.mark.asyncio
  57. async def test_make_tool_endpoint_serializes_non_json_types(monkeypatch):
  58. tool = DummyTypedTool()
  59. def fake_build_request_model(*_, **__):
  60. return TimestampModel
  61. monkeypatch.setattr(
  62. "agency_swarm.integrations.fastapi_utils.tool_endpoints.build_request_model",
  63. fake_build_request_model,
  64. )
  65. handler = make_tool_endpoint(tool, verify_token=_fake_verify_token, context=None)
  66. request_data = TimestampModel(timestamp="2024-05-01T09:30:00Z")
  67. response = await handler(request_data=request_data, token="ignored")
  68. assert response == {"response": "ok"}
  69. assert tool.calls, "on_invoke_tool should receive serialized payload"
  70. payload = json.loads(tool.calls[0])
  71. assert payload == {"timestamp": "2024-05-01T09:30:00Z"}
  72. @pytest.mark.asyncio
  73. async def test_make_tool_endpoint_generic_handler_for_sync_callable() -> None:
  74. def echo(value: str) -> str:
  75. return f"echo:{value}"
  76. handler = make_tool_endpoint(echo, verify_token=_fake_verify_token, context=None)
  77. response = await handler(request=_DummyRequest({"value": "ok"}), token="ignored")
  78. assert response == {"response": "echo:ok"}
  79. @pytest.mark.asyncio
  80. async def test_make_tool_endpoint_generic_handler_for_async_callable() -> None:
  81. async def echo_async(value: str) -> str:
  82. return f"async:{value}"
  83. handler = make_tool_endpoint(echo_async, verify_token=_fake_verify_token, context=None)
  84. response = await handler(request=_DummyRequest({"value": "ok"}), token="ignored")
  85. assert response == {"response": "async:ok"}
  86. @pytest.mark.asyncio
  87. async def test_make_tool_endpoint_generic_handler_returns_json_error() -> None:
  88. def failing_tool(value: str) -> str: # noqa: ARG001
  89. raise RuntimeError("tool failed")
  90. handler = make_tool_endpoint(failing_tool, verify_token=_fake_verify_token, context=None)
  91. response = await handler(request=_DummyRequest({"value": "ok"}), token="ignored")
  92. assert isinstance(response, JSONResponse)
  93. assert response.status_code == 500
  94. assert b"tool failed" in response.body
  95. @pytest.mark.asyncio
  96. async def test_make_tool_endpoint_generic_handler_handles_invalid_request_json() -> None:
  97. handler = make_tool_endpoint(lambda value: value, verify_token=_fake_verify_token, context=None)
  98. response = await handler(request=_DummyRequest({"value": "ok"}, explode=True), token="ignored")
  99. assert isinstance(response, JSONResponse)
  100. assert response.status_code == 500
  101. assert b"bad json" in response.body
  102. @pytest.mark.asyncio
  103. async def test_make_tool_endpoint_for_base_tool_sync_and_async_runs() -> None:
  104. sync_handler = make_tool_endpoint(EchoTool, verify_token=_fake_verify_token, context=None)
  105. async_handler = make_tool_endpoint(AsyncEchoTool, verify_token=_fake_verify_token, context=None)
  106. sync_response = await sync_handler(request_data=EchoTool(text="hi"), token="ignored")
  107. async_response = await async_handler(request_data=AsyncEchoTool(text="hi"), token="ignored")
  108. assert sync_response == {"response": "hi"}
  109. assert async_response == {"response": "HI"}
  110. @pytest.mark.asyncio
  111. async def test_make_tool_endpoint_for_base_tool_returns_json_error() -> None:
  112. handler = make_tool_endpoint(FailingTool, verify_token=_fake_verify_token, context=None)
  113. response = await handler(request_data=FailingTool(text="hi"), token="ignored")
  114. assert isinstance(response, JSONResponse)
  115. assert response.status_code == 500
  116. assert b"run failed" in response.body
  117. class ParamsModel(BaseModel):
  118. value: str
  119. class AgentToolParamsModel(BaseModel):
  120. input: str
  121. class ParamsTool:
  122. name = "ParamsTool"
  123. params_json_schema = {"type": "object", "properties": {"value": {"type": "string"}}, "required": ["value"]}
  124. strict_json_schema = True
  125. def __init__(self) -> None:
  126. self.calls: list[str] = []
  127. async def __call__(self, value: str) -> str:
  128. self.calls.append(value)
  129. return value.upper()
  130. @pytest.mark.asyncio
  131. async def test_make_tool_endpoint_supports_params_json_schema(monkeypatch: pytest.MonkeyPatch) -> None:
  132. tool = ParamsTool()
  133. monkeypatch.setattr(
  134. "agency_swarm.integrations.fastapi_utils.tool_endpoints.build_request_model",
  135. lambda *_args, **_kwargs: ParamsModel,
  136. )
  137. handler = make_tool_endpoint(tool, verify_token=_fake_verify_token, context=None)
  138. response = await handler(request_data=ParamsModel(value="ok"), token="ignored")
  139. assert response == {"response": "OK"}
  140. assert tool.calls == ["ok"]
  141. @pytest.mark.asyncio
  142. async def test_make_tool_endpoint_invokes_sdk_function_tool_with_manual_tool_context(
  143. monkeypatch: pytest.MonkeyPatch,
  144. ) -> None:
  145. seen_contexts: list[ToolContext[None]] = []
  146. @sdk_function_tool
  147. async def endpoint_sdk_tool(ctx: RunContextWrapper[None], value: str) -> str:
  148. tool_context = cast(ToolContext[None], ctx)
  149. seen_contexts.append(tool_context)
  150. return f"{tool_context.tool_name}:{tool_context.tool_arguments}:{value}"
  151. monkeypatch.setattr(
  152. "agency_swarm.integrations.fastapi_utils.tool_endpoints.build_request_model",
  153. lambda *_args, **_kwargs: ParamsModel,
  154. )
  155. handler = make_tool_endpoint(endpoint_sdk_tool, verify_token=_fake_verify_token, context=None)
  156. response = await handler(request_data=ParamsModel(value="ok"), token="ignored")
  157. assert response == {"response": 'endpoint_sdk_tool:{"value":"ok"}:ok'}
  158. assert len(seen_contexts) == 1
  159. assert seen_contexts[0].tool_name == "endpoint_sdk_tool"
  160. assert seen_contexts[0].tool_arguments == '{"value":"ok"}'
  161. @pytest.mark.asyncio
  162. async def test_make_tool_endpoint_invokes_sdk_agent_tool_with_manual_tool_context(
  163. monkeypatch: pytest.MonkeyPatch,
  164. ) -> None:
  165. seen_contexts: list[ToolContext[None]] = []
  166. async def fake_run(**kwargs: Any) -> SimpleNamespace:
  167. context = kwargs["context"]
  168. assert isinstance(context, ToolContext)
  169. seen_contexts.append(context)
  170. return SimpleNamespace(final_output="nested ok", new_items=[], interruptions=[])
  171. monkeypatch.setattr("agents.Runner.run", fake_run)
  172. monkeypatch.setattr(
  173. "agency_swarm.integrations.fastapi_utils.tool_endpoints.build_request_model",
  174. lambda *_args, **_kwargs: AgentToolParamsModel,
  175. )
  176. nested_agent = SDKAgent(name="nested", instructions="Return the input.")
  177. nested_tool = nested_agent.as_tool(tool_name="nested_tool", tool_description="Nested tool")
  178. handler = make_tool_endpoint(nested_tool, verify_token=_fake_verify_token, context=None)
  179. response = await handler(request_data=AgentToolParamsModel(input="ok"), token="ignored")
  180. assert response == {"response": "nested ok"}
  181. assert len(seen_contexts) == 1
  182. assert seen_contexts[0].tool_name == "nested_tool"
  183. assert seen_contexts[0].tool_arguments == '{"input":"ok"}'
  184. assert seen_contexts[0].tool_call_id == "agency_swarm_manual_nested_tool"