test_tools_utils.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import json
  2. from unittest.mock import AsyncMock, MagicMock, patch
  3. import pytest
  4. from agency_swarm.tools.utils import from_openapi_schema, validate_openapi_spec
  5. @pytest.fixture
  6. def base_spec():
  7. return {"servers": [{"url": "https://api.example.com"}], "paths": {}}
  8. @pytest.fixture
  9. def mock_tool_setup():
  10. with patch("agency_swarm.tools.utils.FunctionTool") as mock_func:
  11. tool = MagicMock()
  12. mock_func.return_value = tool
  13. yield mock_func, tool
  14. class TestFromOpenAPISchema:
  15. def test_basic_schema_conversion(self, base_spec, mock_tool_setup):
  16. mock_func, tool = mock_tool_setup
  17. base_spec["paths"]["/users"] = {"get": {"operationId": "getUsers", "description": "Get users"}}
  18. tools = from_openapi_schema(base_spec)
  19. assert len(tools) == 1 and tools[0] == tool
  20. args = mock_func.call_args.kwargs
  21. assert args["name"] == "getUsers" and args["description"] == "Get users"
  22. def test_string_input_with_body(self, base_spec, mock_tool_setup):
  23. mock_func, _ = mock_tool_setup
  24. base_spec["paths"]["/test"] = {
  25. "post": {
  26. "operationId": "create",
  27. "description": "Create",
  28. "requestBody": {"content": {"application/json": {"schema": {"type": "object"}}}},
  29. }
  30. }
  31. from_openapi_schema(json.dumps(base_spec))
  32. schema = mock_func.call_args.kwargs["params_json_schema"]
  33. assert "requestBody" in schema["properties"]
  34. @pytest.mark.parametrize("strict,expected", [(True, True), (False, False)])
  35. def test_strict_mode(self, base_spec, mock_tool_setup, strict, expected):
  36. mock_func, _ = mock_tool_setup
  37. base_spec["paths"]["/test"] = {"get": {"operationId": "test", "description": "Test"}}
  38. with patch("agency_swarm.tools.utils.ensure_strict_json_schema"):
  39. from_openapi_schema(base_spec, strict=strict)
  40. assert mock_func.call_args.kwargs["strict_json_schema"] == expected
  41. def test_parameter_handling(self, base_spec, mock_tool_setup):
  42. mock_func, _ = mock_tool_setup
  43. base_spec["paths"]["/test"] = {
  44. "get": {
  45. "operationId": "test",
  46. "description": "Test",
  47. "parameters": [
  48. {"name": "legacy", "type": "string", "required": False},
  49. {"name": "new", "schema": {"type": "integer"}, "required": True},
  50. ],
  51. }
  52. }
  53. from_openapi_schema(base_spec)
  54. schema = mock_func.call_args.kwargs["params_json_schema"]
  55. params = schema["properties"]["parameters"]
  56. assert "legacy" in params["properties"] and "new" in params["properties"]
  57. assert "new" in params["required"] and "legacy" not in params["required"]
  58. @pytest.mark.asyncio
  59. async def test_invoke_get_request(self, base_spec, mock_tool_setup):
  60. mock_func, _ = mock_tool_setup
  61. base_spec["paths"]["/users/{id}"] = {
  62. "get": {
  63. "operationId": "getUser",
  64. "description": "Get user",
  65. "parameters": [{"name": "id", "schema": {"type": "string"}, "required": True}],
  66. }
  67. }
  68. with patch("agency_swarm.tools.utils.httpx.AsyncClient") as mock_client_cls:
  69. client = AsyncMock()
  70. response = MagicMock()
  71. response.json.return_value = {"id": "123"}
  72. client.request.return_value = response
  73. mock_client_cls.return_value.__aenter__.return_value = client
  74. from_openapi_schema(base_spec)
  75. invoke_func = mock_func.call_args.kwargs["on_invoke_tool"]
  76. result = await invoke_func(MagicMock(), json.dumps({"parameters": {"id": "123"}}))
  77. client.request.assert_called_once_with(
  78. "GET", "https://api.example.com/users/123", params={}, json=None, headers={}
  79. )
  80. assert result == {"id": "123"}
  81. @pytest.mark.asyncio
  82. async def test_invoke_post_request(self, base_spec, mock_tool_setup):
  83. mock_func, _ = mock_tool_setup
  84. base_spec["paths"]["/users"] = {
  85. "post": {
  86. "operationId": "createUser",
  87. "description": "Create user",
  88. "requestBody": {"content": {"application/json": {"schema": {"type": "object"}}}},
  89. }
  90. }
  91. with patch("agency_swarm.tools.utils.httpx.AsyncClient") as mock_client_cls:
  92. client = AsyncMock()
  93. mock_response = MagicMock()
  94. mock_response.json.return_value = {"id": "456"}
  95. client.request.return_value = mock_response
  96. mock_client_cls.return_value.__aenter__.return_value = client
  97. from_openapi_schema(base_spec)
  98. invoke_func = mock_func.call_args.kwargs["on_invoke_tool"]
  99. await invoke_func(MagicMock(), json.dumps({"requestBody": {"name": "test"}}))
  100. client.request.assert_called_once_with(
  101. "POST", "https://api.example.com/users", params={}, json={"name": "test"}, headers={}
  102. )
  103. @pytest.mark.asyncio
  104. async def test_non_json_response(self, base_spec, mock_tool_setup):
  105. mock_func, _ = mock_tool_setup
  106. base_spec["paths"]["/text"] = {"get": {"operationId": "getText", "description": "Get text"}}
  107. with patch("agency_swarm.tools.utils.httpx.AsyncClient") as mock_client_cls:
  108. client = AsyncMock()
  109. response = MagicMock()
  110. response.json.side_effect = Exception("Not JSON")
  111. response.text = "plain text"
  112. client.request.return_value = response
  113. mock_client_cls.return_value.__aenter__.return_value = client
  114. from_openapi_schema(base_spec)
  115. invoke_func = mock_func.call_args.kwargs["on_invoke_tool"]
  116. result = await invoke_func(MagicMock(), json.dumps({"parameters": {}}))
  117. assert result == "plain text"
  118. def test_multiple_operations(self, base_spec, mock_tool_setup):
  119. mock_func, _ = mock_tool_setup
  120. base_spec["paths"] = {
  121. "/users": {
  122. "get": {"operationId": "getUsers", "description": "Get users"},
  123. "post": {"operationId": "createUser", "description": "Create user"},
  124. },
  125. "/posts": {"get": {"operationId": "getPosts", "description": "Get posts"}},
  126. }
  127. tools = from_openapi_schema(base_spec)
  128. assert len(tools) == 3 and mock_func.call_count == 3
  129. class TestValidateOpenAPISpec:
  130. @pytest.mark.parametrize(
  131. "spec,should_pass",
  132. [
  133. ({"paths": {"/users": {"get": {"operationId": "getUsers", "description": "Get users"}}}}, True),
  134. ({"info": {"title": "API"}}, False), # Missing paths
  135. ({"paths": {"/users": "invalid"}}, False), # Invalid path item
  136. ({"paths": {"/users": {"get": {"description": "Get users"}}}}, False), # Missing operationId
  137. ({"paths": {"/users": {"get": {"operationId": "getUsers"}}}}, False), # Missing description
  138. ],
  139. )
  140. def test_validation(self, spec, should_pass):
  141. if should_pass:
  142. result = validate_openapi_spec(json.dumps(spec))
  143. assert result == spec
  144. else:
  145. with pytest.raises(ValueError):
  146. validate_openapi_spec(json.dumps(spec))