tools.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. """
  2. Tools Demo (BaseTool and @function_tool)
  3. Implement the same Add operation two ways: BaseTool and @function_tool, each
  4. with field and model validators.
  5. Run with: python examples/tools.py
  6. """
  7. import asyncio
  8. import os
  9. import sys
  10. from pydantic import BaseModel, Field, field_validator, model_validator
  11. # Add src to path for standalone example execution
  12. sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
  13. from agency_swarm import Agency, Agent, BaseTool, ModelSettings, function_tool # noqa: E402 # isort: skip
  14. # --- BaseTool pattern --- #
  15. class AddTool(BaseTool):
  16. """Add two non-negative integers.
  17. :param a: First addend (>= 0, <= 100)
  18. :param b: Second addend (>= 0, <= 100)
  19. :returns: Sum as a string. The sum must be <= 100.
  20. """
  21. a: int = Field(..., ge=0, description="First addend (>= 0)")
  22. b: int = Field(..., ge=0, description="Second addend (>= 0)")
  23. @field_validator("a", "b")
  24. @classmethod
  25. def cap_each_value(cls, v: int) -> int:
  26. if v > 100:
  27. raise ValueError("each value must be <= 100")
  28. return v
  29. @model_validator(mode="after")
  30. def cap_sum(self) -> "AddTool":
  31. if self.a + self.b > 100:
  32. raise ValueError("sum must be <= 100")
  33. return self
  34. class ToolConfig:
  35. strict: bool = True
  36. def run(self) -> str:
  37. return str(self.a + self.b)
  38. # --- @function_tool pattern --- #
  39. class AddArgs(BaseModel):
  40. a: int = Field(..., ge=0, description="First addend (>= 0)")
  41. b: int = Field(..., ge=0, description="Second addend (>= 0)")
  42. @field_validator("a", "b")
  43. @classmethod
  44. def cap_each_value(cls, v: int) -> int:
  45. if v > 100:
  46. raise ValueError("each value must be <= 100")
  47. return v
  48. @model_validator(mode="after")
  49. def cap_sum(self) -> "AddArgs":
  50. if self.a + self.b > 100:
  51. raise ValueError("sum must be <= 100")
  52. return self
  53. @function_tool
  54. def add_numbers(args: AddArgs) -> str:
  55. """Add two non-negative integers.
  56. :returns: Sum as a string. The sum must be <= 100.
  57. """
  58. return str(args.a + args.b)
  59. def create_demo_agency() -> Agency:
  60. tool_user = Agent(
  61. name="ToolDemo",
  62. instructions=(
  63. "You can add integers using two tools: add_numbers (function tool) and AddTool (BaseTool). "
  64. "When asked to add numbers, use the specified tool. Respond strictly as either 'Result: <sum>' "
  65. "or 'Error: <reason>'."
  66. ),
  67. tools=[add_numbers, AddTool],
  68. model_settings=ModelSettings(temperature=0.0),
  69. )
  70. return Agency(tool_user)
  71. agency = create_demo_agency()
  72. async def run_demo() -> None:
  73. # FunctionTool: valid inputs
  74. r1 = await agency.get_response("Add 2 and 3 using add_numbers.")
  75. print(r1.final_output)
  76. # BaseTool: invalid inputs (sum exceeds 100)
  77. r2 = await agency.get_response("Add 70 and 50 using AddTool.")
  78. print(r2.final_output)
  79. if __name__ == "__main__":
  80. asyncio.run(run_demo())