test_function_tool_compat.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. from dataclasses import replace
  2. from types import SimpleNamespace
  3. from typing import Any, cast
  4. import pytest
  5. from agents import (
  6. Agent as SDKAgent,
  7. FunctionTool,
  8. RunContextWrapper,
  9. function_tool as sdk_function_tool,
  10. tool_namespace,
  11. )
  12. from agents.tool_context import ToolContext
  13. from agency_swarm import Agent, function_tool
  14. @pytest.mark.asyncio
  15. async def test_sdk_function_tool_accepts_manual_run_context_wrapper() -> None:
  16. """SDK @function_tool instances should accept direct RunContextWrapper calls."""
  17. seen_contexts: list[ToolContext[dict[str, str]]] = []
  18. @sdk_function_tool
  19. async def sdk_context_tool(ctx: RunContextWrapper[dict[str, str]], value: str) -> str:
  20. tool_context = cast(ToolContext[dict[str, str]], ctx)
  21. seen_contexts.append(tool_context)
  22. return f"{tool_context.context['label']}:{value}:{tool_context.tool_name}:{tool_context.tool_arguments}"
  23. wrapper = RunContextWrapper(context={"label": "agency"})
  24. agent = Agent(name="test", instructions="test", tools=[sdk_context_tool])
  25. tool = agent.tools[0]
  26. assert isinstance(tool, FunctionTool)
  27. on_invoke_tool = cast(Any, tool.on_invoke_tool)
  28. result = await on_invoke_tool(wrapper, '{"value":"ping"}')
  29. assert result == 'agency:ping:sdk_context_tool:{"value":"ping"}'
  30. assert len(seen_contexts) == 1
  31. assert isinstance(seen_contexts[0], ToolContext)
  32. assert seen_contexts[0].context is wrapper.context
  33. @pytest.mark.asyncio
  34. async def test_agency_function_tool_accepts_keyword_input_and_manual_context() -> None:
  35. """agency_swarm.function_tool instances should accept direct positional and keyword input calls."""
  36. seen_contexts: list[ToolContext[dict[str, str] | None]] = []
  37. @function_tool
  38. async def agency_context_tool(ctx: RunContextWrapper[dict[str, str] | None], x: int) -> str:
  39. tool_context = cast(ToolContext[dict[str, str] | None], ctx)
  40. seen_contexts.append(tool_context)
  41. context_label = "none" if tool_context.context is None else tool_context.context["label"]
  42. return f"{context_label}:{x}:{tool_context.tool_name}:{tool_context.tool_arguments}"
  43. on_invoke_tool = cast(Any, agency_context_tool.on_invoke_tool)
  44. keyword_result = await on_invoke_tool(ctx=None, input='{"x":1}')
  45. positional_result = await on_invoke_tool({"label": "manual"}, '{"x":2}')
  46. assert keyword_result == 'none:1:agency_context_tool:{"x":1}'
  47. assert positional_result == 'manual:2:agency_context_tool:{"x":2}'
  48. assert len(seen_contexts) == 2
  49. assert all(isinstance(seen_context, ToolContext) for seen_context in seen_contexts)
  50. assert [seen_context.context for seen_context in seen_contexts] == [
  51. None,
  52. {"label": "manual"},
  53. ]
  54. assert [seen_context.tool_arguments for seen_context in seen_contexts] == [
  55. '{"x":1}',
  56. '{"x":2}',
  57. ]
  58. @pytest.mark.asyncio
  59. async def test_agency_function_tool_rebuilds_forwarded_tool_context_for_new_tool() -> None:
  60. """Forwarded ToolContext should describe the callee, not the caller."""
  61. seen_contexts: list[ToolContext[dict[str, str]]] = []
  62. @function_tool
  63. async def callee_tool(ctx: RunContextWrapper[dict[str, str]], value: str) -> str:
  64. tool_context = cast(ToolContext[dict[str, str]], ctx)
  65. seen_contexts.append(tool_context)
  66. return f"{tool_context.context['label']}:{value}:{tool_context.tool_name}:{tool_context.tool_arguments}"
  67. caller_context = ToolContext(
  68. context={"label": "agency"},
  69. tool_name="caller_tool",
  70. tool_call_id="caller_call",
  71. tool_arguments='{"value":"old"}',
  72. )
  73. on_invoke_tool = cast(Any, callee_tool.on_invoke_tool)
  74. result = await on_invoke_tool(caller_context, '{"value":"new"}')
  75. assert result == 'agency:new:callee_tool:{"value":"new"}'
  76. assert len(seen_contexts) == 1
  77. assert seen_contexts[0] is not caller_context
  78. assert seen_contexts[0].context is caller_context.context
  79. assert seen_contexts[0].tool_name == "callee_tool"
  80. assert seen_contexts[0].tool_arguments == '{"value":"new"}'
  81. assert seen_contexts[0].tool_call_id == "agency_swarm_manual_callee_tool"
  82. @pytest.mark.asyncio
  83. async def test_copied_agency_function_tool_rebinds_manual_context_to_copy_name() -> None:
  84. """Copied FunctionTool instances should describe the copied tool."""
  85. seen_contexts: list[ToolContext[dict[str, str]]] = []
  86. @function_tool
  87. async def original_tool(ctx: RunContextWrapper[dict[str, str]], value: str) -> str:
  88. tool_context = cast(ToolContext[dict[str, str]], ctx)
  89. seen_contexts.append(tool_context)
  90. return f"{tool_context.context['label']}:{value}:{tool_context.tool_name}:{tool_context.tool_call_id}"
  91. copied_tool = replace(original_tool, name="copied_tool")
  92. on_invoke_tool = cast(Any, copied_tool.on_invoke_tool)
  93. result = await on_invoke_tool(RunContextWrapper(context={"label": "agency"}), '{"value":"new"}')
  94. assert result == "agency:new:copied_tool:agency_swarm_manual_copied_tool"
  95. assert len(seen_contexts) == 1
  96. assert seen_contexts[0].tool_name == "copied_tool"
  97. assert seen_contexts[0].tool_call_id == "agency_swarm_manual_copied_tool"
  98. @pytest.mark.asyncio
  99. async def test_copied_agency_function_tool_rebinds_failure_handler_to_copy() -> None:
  100. """Copied FunctionTool failure handlers should resolve against the copied tool."""
  101. def failure_message(ctx: ToolContext[None], error: Exception) -> str:
  102. return f"{ctx.tool_name}:{error}"
  103. @function_tool(failure_error_function=failure_message)
  104. async def original_tool(ctx: RunContextWrapper[None], value: str) -> str: # noqa: ARG001
  105. raise RuntimeError("boom")
  106. copied_tool = replace(original_tool, name="copied_tool")
  107. on_invoke_tool = cast(Any, copied_tool.on_invoke_tool)
  108. result = await on_invoke_tool(RunContextWrapper(context=None), '{"value":"new"}')
  109. assert result == "copied_tool:boom"
  110. def test_agency_function_tool_preserves_deferred_namespace_metadata() -> None:
  111. """Wrapped FunctionTool instances should keep SDK loading metadata."""
  112. @function_tool(defer_loading=True)
  113. def deferred_tool(value: str) -> str:
  114. return value
  115. namespaced_tool = tool_namespace(name="demo_namespace", description="Demo namespace", tools=[deferred_tool])[0]
  116. agent = Agent(name="test", instructions="test", tools=[namespaced_tool])
  117. tool = agent.tools[0]
  118. assert isinstance(tool, FunctionTool)
  119. assert tool.defer_loading is True
  120. assert tool._tool_namespace == "demo_namespace"
  121. assert tool._tool_namespace_description == "Demo namespace"
  122. @pytest.mark.asyncio
  123. async def test_deferred_agency_function_tool_keeps_sdk_tool_context_metadata() -> None:
  124. """Deferred top-level tools can use the SDK synthetic namespace for real calls."""
  125. seen_contexts: list[ToolContext[dict[str, str]]] = []
  126. @function_tool(defer_loading=True)
  127. async def deferred_tool(ctx: RunContextWrapper[dict[str, str]], value: str) -> str:
  128. tool_context = cast(ToolContext[dict[str, str]], ctx)
  129. seen_contexts.append(tool_context)
  130. return f"{tool_context.tool_call_id}:{value}"
  131. sdk_context = ToolContext(
  132. context={"label": "agency"},
  133. tool_name="deferred_tool",
  134. tool_namespace="deferred_tool",
  135. tool_call_id="call_real",
  136. tool_arguments='{"value":"new"}',
  137. )
  138. on_invoke_tool = cast(Any, deferred_tool.on_invoke_tool)
  139. result = await on_invoke_tool(sdk_context, '{"value":"new"}')
  140. assert result == "call_real:new"
  141. assert len(seen_contexts) == 1
  142. assert seen_contexts[0] is sdk_context
  143. @pytest.mark.asyncio
  144. async def test_namespaced_agency_function_tool_manual_context_keeps_namespace() -> None:
  145. """Manual ToolContext should preserve SDK namespace identity."""
  146. seen_contexts: list[ToolContext[dict[str, str]]] = []
  147. @function_tool
  148. async def namespaced_tool(ctx: RunContextWrapper[dict[str, str]], value: str) -> str:
  149. tool_context = cast(ToolContext[dict[str, str]], ctx)
  150. seen_contexts.append(tool_context)
  151. return f"{tool_context.qualified_tool_name}:{tool_context.context['label']}:{value}"
  152. namespaced = tool_namespace(name="demo_namespace", description="Demo namespace", tools=[namespaced_tool])[0]
  153. agent = Agent(name="test", instructions="test", tools=[namespaced])
  154. tool = agent.tools[0]
  155. assert isinstance(tool, FunctionTool)
  156. on_invoke_tool = cast(Any, tool.on_invoke_tool)
  157. result = await on_invoke_tool(RunContextWrapper(context={"label": "agency"}), '{"value":"new"}')
  158. assert result == "demo_namespace.namespaced_tool:agency:new"
  159. assert len(seen_contexts) == 1
  160. assert seen_contexts[0].tool_name == "namespaced_tool"
  161. assert seen_contexts[0].tool_namespace == "demo_namespace"
  162. assert seen_contexts[0].qualified_tool_name == "demo_namespace.namespaced_tool"
  163. def test_sdk_agent_tool_is_not_rewrapped_as_decorator_function_tool() -> None:
  164. """Agent.as_tool() should keep the SDK's own invoker."""
  165. nested_agent = SDKAgent(name="nested", instructions="Return the input.")
  166. nested_tool = nested_agent.as_tool(tool_name="nested_tool", tool_description="Nested tool")
  167. agent = Agent(name="test", instructions="test", tools=[nested_tool])
  168. tool = agent.tools[0]
  169. assert isinstance(tool, FunctionTool)
  170. assert tool._is_agent_tool is True
  171. assert getattr(tool, "_agency_swarm_manual_tool_context_compat", False) is False
  172. assert not hasattr(tool, "_agency_original_on_invoke_tool")
  173. @pytest.mark.asyncio
  174. async def test_manual_function_tool_receives_original_context_unchanged() -> None:
  175. """Manual FunctionTool callbacks should keep the exact caller-provided context."""
  176. seen_contexts: list[object] = []
  177. async def invoke(ctx: object, input_json: str) -> str:
  178. seen_contexts.append(ctx)
  179. return input_json
  180. tool = FunctionTool(
  181. name="manual_context_tool",
  182. description="manual context tool",
  183. params_json_schema={"type": "object", "properties": {}},
  184. on_invoke_tool=invoke,
  185. strict_json_schema=False,
  186. )
  187. original_context = SimpleNamespace(marker="original")
  188. agent = Agent(name="test", instructions="test", tools=[tool])
  189. agent_tool = agent.tools[0]
  190. assert isinstance(agent_tool, FunctionTool)
  191. on_invoke_tool = cast(Any, agent_tool.on_invoke_tool)
  192. result = await on_invoke_tool(original_context, '{"value":"pong"}')
  193. assert result == '{"value":"pong"}'
  194. assert len(seen_contexts) == 1
  195. assert seen_contexts[0] is original_context
  196. assert not isinstance(seen_contexts[0], ToolContext)