test_tool_factory_openapi.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import json
  2. from copy import deepcopy
  3. import pytest
  4. from agents import FunctionTool
  5. from agents.exceptions import ModelBehaviorError
  6. from pydantic import BaseModel
  7. from agency_swarm.tools.base_tool import BaseTool
  8. from agency_swarm.tools.tool_factory import ToolFactory
  9. class TestFromOpenapiSchema:
  10. def test_converts_simple_openapi_schema(self):
  11. schema = {
  12. "openapi": "3.1.0",
  13. "servers": [{"url": "https://api.test.com"}],
  14. "paths": {
  15. "/tickets": {
  16. "post": {
  17. "operationId": "create_ticket",
  18. "description": "Create a support ticket",
  19. "parameters": [
  20. {
  21. "name": "priority",
  22. "in": "query",
  23. "schema": {"type": "string"},
  24. "required": False,
  25. }
  26. ],
  27. }
  28. }
  29. },
  30. }
  31. tools = ToolFactory.from_openapi_schema(schema, strict=False)
  32. assert len(tools) == 1
  33. tool = tools[0]
  34. assert isinstance(tool, FunctionTool)
  35. assert tool.name == "create_ticket"
  36. assert "support ticket" in tool.description.lower()
  37. @pytest.mark.asyncio
  38. async def test_validation_errors_raise_model_behavior_error(self):
  39. schema = {
  40. "openapi": "3.1.0",
  41. "servers": [{"url": "https://api.test.com"}],
  42. "paths": {
  43. "/tickets": {
  44. "post": {
  45. "operationId": "create_ticket",
  46. "requestBody": {
  47. "required": True,
  48. "content": {
  49. "application/json": {
  50. "schema": {
  51. "type": "object",
  52. "properties": {"message": {"type": "string"}},
  53. "required": ["message"],
  54. }
  55. }
  56. },
  57. },
  58. }
  59. }
  60. },
  61. }
  62. tools = ToolFactory.from_openapi_schema(schema, strict=False)
  63. tool = tools[0]
  64. with pytest.raises(ModelBehaviorError, match="Invalid JSON input in request body"):
  65. await tool.on_invoke_tool(None, json.dumps({"requestBody": {}}))
  66. class TestGetOpenapiSchema:
  67. def test_generates_schema_for_base_tools(self):
  68. class TestTool(BaseTool):
  69. input_field: str
  70. def run(self):
  71. return self.input_field
  72. result_json = ToolFactory.get_openapi_schema([TestTool], "https://api.test.com")
  73. result = json.loads(result_json)
  74. assert result["info"]["title"] == "Agent Tools"
  75. assert "/tool/TestTool" in result["paths"]
  76. post_schema = result["paths"]["/tool/TestTool"]["post"]
  77. assert post_schema["operationId"] == "TestTool"
  78. assert post_schema["requestBody"]["required"] is True
  79. assert post_schema["responses"]["200"]["description"] == "Tool executed successfully"
  80. assert (
  81. post_schema["responses"]["422"]["content"]["application/json"]["schema"]["$ref"]
  82. == "#/components/schemas/HTTPValidationError"
  83. )
  84. assert post_schema["security"] == [{"HTTPBearer": []}]
  85. assert set(result["components"]["schemas"].keys()) == {"HTTPValidationError", "TestTool", "ValidationError"}
  86. def test_generates_schema_for_function_tool(self):
  87. async def dummy_tool(ctx, input_json: str):
  88. return "ok"
  89. function_tool = FunctionTool(
  90. name="dummy_tool",
  91. description="Dummy tool",
  92. params_json_schema={"type": "object", "properties": {}},
  93. on_invoke_tool=dummy_tool,
  94. )
  95. result_json = ToolFactory.get_openapi_schema([function_tool], "https://api.test.com")
  96. result = json.loads(result_json)
  97. assert "/tool/dummy_tool" in result["paths"]
  98. post_schema = result["paths"]["/tool/dummy_tool"]["post"]
  99. assert post_schema["operationId"] == "dummy_tool"
  100. assert post_schema["security"] == [{"HTTPBearer": []}]
  101. assert "requestBody" not in post_schema
  102. assert "422" not in post_schema["responses"]
  103. def test_get_openapi_schema_preserves_function_tool_defs(self):
  104. class Address(BaseModel):
  105. street: str
  106. zip_code: int
  107. class AddressListTool(BaseTool):
  108. addresses: list[Address]
  109. def run(self):
  110. return ",".join(addr.street for addr in self.addresses)
  111. function_tool = ToolFactory.adapt_base_tool(AddressListTool)
  112. original_schema = deepcopy(function_tool.params_json_schema)
  113. ToolFactory.get_openapi_schema([function_tool], "https://api.test.com")
  114. assert function_tool.params_json_schema == original_schema
  115. def test_union_function_tool_omits_request_body(self):
  116. class Contact(BaseModel):
  117. identifier: str | int
  118. class ContainerTool(BaseTool):
  119. contact: Contact
  120. def run(self):
  121. return str(self.contact.identifier)
  122. function_tool = ToolFactory.adapt_base_tool(ContainerTool)
  123. result = json.loads(ToolFactory.get_openapi_schema([function_tool], "https://api.test.com"))
  124. post_schema = result["paths"]["/tool/ContainerTool"]["post"]
  125. assert "requestBody" not in post_schema
  126. assert "422" not in post_schema["responses"]
  127. assert result["components"]["schemas"] == {}