| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- import json
- from copy import deepcopy
- import pytest
- from agents import FunctionTool
- from agents.exceptions import ModelBehaviorError
- from pydantic import BaseModel
- from agency_swarm.tools.base_tool import BaseTool
- from agency_swarm.tools.tool_factory import ToolFactory
- class TestFromOpenapiSchema:
- def test_converts_simple_openapi_schema(self):
- schema = {
- "openapi": "3.1.0",
- "servers": [{"url": "https://api.test.com"}],
- "paths": {
- "/tickets": {
- "post": {
- "operationId": "create_ticket",
- "description": "Create a support ticket",
- "parameters": [
- {
- "name": "priority",
- "in": "query",
- "schema": {"type": "string"},
- "required": False,
- }
- ],
- }
- }
- },
- }
- tools = ToolFactory.from_openapi_schema(schema, strict=False)
- assert len(tools) == 1
- tool = tools[0]
- assert isinstance(tool, FunctionTool)
- assert tool.name == "create_ticket"
- assert "support ticket" in tool.description.lower()
- @pytest.mark.asyncio
- async def test_validation_errors_raise_model_behavior_error(self):
- schema = {
- "openapi": "3.1.0",
- "servers": [{"url": "https://api.test.com"}],
- "paths": {
- "/tickets": {
- "post": {
- "operationId": "create_ticket",
- "requestBody": {
- "required": True,
- "content": {
- "application/json": {
- "schema": {
- "type": "object",
- "properties": {"message": {"type": "string"}},
- "required": ["message"],
- }
- }
- },
- },
- }
- }
- },
- }
- tools = ToolFactory.from_openapi_schema(schema, strict=False)
- tool = tools[0]
- with pytest.raises(ModelBehaviorError, match="Invalid JSON input in request body"):
- await tool.on_invoke_tool(None, json.dumps({"requestBody": {}}))
- class TestGetOpenapiSchema:
- def test_generates_schema_for_base_tools(self):
- class TestTool(BaseTool):
- input_field: str
- def run(self):
- return self.input_field
- result_json = ToolFactory.get_openapi_schema([TestTool], "https://api.test.com")
- result = json.loads(result_json)
- assert result["info"]["title"] == "Agent Tools"
- assert "/tool/TestTool" in result["paths"]
- post_schema = result["paths"]["/tool/TestTool"]["post"]
- assert post_schema["operationId"] == "TestTool"
- assert post_schema["requestBody"]["required"] is True
- assert post_schema["responses"]["200"]["description"] == "Tool executed successfully"
- assert (
- post_schema["responses"]["422"]["content"]["application/json"]["schema"]["$ref"]
- == "#/components/schemas/HTTPValidationError"
- )
- assert post_schema["security"] == [{"HTTPBearer": []}]
- assert set(result["components"]["schemas"].keys()) == {"HTTPValidationError", "TestTool", "ValidationError"}
- def test_generates_schema_for_function_tool(self):
- async def dummy_tool(ctx, input_json: str):
- return "ok"
- function_tool = FunctionTool(
- name="dummy_tool",
- description="Dummy tool",
- params_json_schema={"type": "object", "properties": {}},
- on_invoke_tool=dummy_tool,
- )
- result_json = ToolFactory.get_openapi_schema([function_tool], "https://api.test.com")
- result = json.loads(result_json)
- assert "/tool/dummy_tool" in result["paths"]
- post_schema = result["paths"]["/tool/dummy_tool"]["post"]
- assert post_schema["operationId"] == "dummy_tool"
- assert post_schema["security"] == [{"HTTPBearer": []}]
- assert "requestBody" not in post_schema
- assert "422" not in post_schema["responses"]
- def test_get_openapi_schema_preserves_function_tool_defs(self):
- class Address(BaseModel):
- street: str
- zip_code: int
- class AddressListTool(BaseTool):
- addresses: list[Address]
- def run(self):
- return ",".join(addr.street for addr in self.addresses)
- function_tool = ToolFactory.adapt_base_tool(AddressListTool)
- original_schema = deepcopy(function_tool.params_json_schema)
- ToolFactory.get_openapi_schema([function_tool], "https://api.test.com")
- assert function_tool.params_json_schema == original_schema
- def test_union_function_tool_omits_request_body(self):
- class Contact(BaseModel):
- identifier: str | int
- class ContainerTool(BaseTool):
- contact: Contact
- def run(self):
- return str(self.contact.identifier)
- function_tool = ToolFactory.adapt_base_tool(ContainerTool)
- result = json.loads(ToolFactory.get_openapi_schema([function_tool], "https://api.test.com"))
- post_schema = result["paths"]["/tool/ContainerTool"]["post"]
- assert "requestBody" not in post_schema
- assert "422" not in post_schema["responses"]
- assert result["components"]["schemas"] == {}
|