test_extra_params.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import json
  2. import pytest
  3. from agents import ModelSettings, RunContextWrapper
  4. from pydantic import BaseModel, Field
  5. from agency_swarm import Agent
  6. from agency_swarm.context import MasterContext
  7. from agency_swarm.tools.send_message import SendMessage
  8. from agency_swarm.utils.thread import ThreadManager
  9. from tests.deterministic_model import DeterministicModel
  10. def _make_stub_agent(name: str, response: str = "ack") -> Agent:
  11. return Agent(
  12. name=name,
  13. instructions="stub",
  14. model=DeterministicModel(default_response=response),
  15. model_settings=ModelSettings(temperature=0.0),
  16. )
  17. class NumericContextParams(BaseModel):
  18. count: int = Field(description="Count")
  19. summary: str = Field(description="Summary")
  20. class InlineContext(BaseModel):
  21. value: int = Field(description="Value")
  22. class SendMessageWithContext(SendMessage):
  23. extra_params_model = NumericContextParams
  24. class BadExtra(BaseModel):
  25. foo: str
  26. @classmethod
  27. def model_json_schema(cls, *args, **kwargs): # type: ignore[override]
  28. raise ValueError("boom")
  29. class SendMessageBad(SendMessage):
  30. ExtraParams = BadExtra
  31. # --- New pattern: inline field declarations ---
  32. class SendMessageInline(SendMessage):
  33. """Inline extra params without nested class."""
  34. tool_name = "send_message_inline"
  35. priority: str = Field(description="Task priority level")
  36. context_summary: str = Field(description="Summary of context")
  37. def _wrapper_with_recipient(recipient: Agent) -> RunContextWrapper[MasterContext]:
  38. ctx = MasterContext(
  39. thread_manager=ThreadManager(),
  40. agents={"B": recipient},
  41. user_context={},
  42. agent_runtime_state={},
  43. shared_instructions=None,
  44. )
  45. return RunContextWrapper(context=ctx)
  46. @pytest.mark.asyncio
  47. async def test_send_message_extra_params_schema_validation_and_success() -> None:
  48. """Extra params should be merged into schema, validate input, and still allow successful sends."""
  49. sender = _make_stub_agent("Sender")
  50. recipient = _make_stub_agent("Recipient")
  51. tool = SendMessageWithContext(sender, recipients={"B": recipient})
  52. properties = tool.params_json_schema["properties"]
  53. required = tool.params_json_schema["required"]
  54. assert properties["count"]["type"] == "integer"
  55. assert properties["summary"]["type"] == "string"
  56. assert "count" in required and "summary" in required
  57. valid_args = json.dumps(
  58. {
  59. "recipient_agent": "B",
  60. "message": "msg",
  61. "additional_instructions": "",
  62. "count": 1,
  63. "summary": "ok",
  64. }
  65. )
  66. invalid_args = json.dumps(
  67. {
  68. "recipient_agent": "B",
  69. "message": "msg",
  70. "additional_instructions": "",
  71. "count": "bad",
  72. "summary": "ok",
  73. }
  74. )
  75. wrapper = _wrapper_with_recipient(recipient)
  76. assert await tool.on_invoke_tool(wrapper, valid_args) == "ack"
  77. invalid_result = await tool.on_invoke_tool(wrapper, invalid_args)
  78. assert isinstance(invalid_result, str) and invalid_result.startswith("Error: Invalid extra parameters")
  79. @pytest.mark.asyncio
  80. async def test_send_message_bad_extra_params_model_falls_back_without_validation() -> None:
  81. """Schema generation failures in ExtraParams should keep tool usable without extra field validation."""
  82. sender = _make_stub_agent("Sender")
  83. recipient = _make_stub_agent("Recipient")
  84. tool = SendMessageBad(sender, recipients={"B": recipient})
  85. assert "foo" not in tool.params_json_schema["properties"]
  86. wrapper = _wrapper_with_recipient(recipient)
  87. base_args = {
  88. "recipient_agent": "B",
  89. "message": "m",
  90. "additional_instructions": "",
  91. }
  92. result_no_extra = await tool.on_invoke_tool(wrapper, json.dumps(base_args))
  93. result_unknown_extra = await tool.on_invoke_tool(wrapper, json.dumps({**base_args, "foo": "x"}))
  94. assert isinstance(result_no_extra, str) and not result_no_extra.startswith("Error: Invalid extra parameters")
  95. assert isinstance(result_unknown_extra, str) and not result_unknown_extra.startswith(
  96. "Error: Invalid extra parameters"
  97. )
  98. def test_inline_fields_merged_into_schema() -> None:
  99. """Inline field declarations should be auto-discovered and merged into the schema."""
  100. sender = _make_stub_agent("Sender")
  101. recipient = _make_stub_agent("Recipient")
  102. tool = SendMessageInline(sender, recipients={"B": recipient})
  103. properties = tool.params_json_schema["properties"]
  104. required = tool.params_json_schema["required"]
  105. assert "priority" in properties
  106. assert properties["priority"]["type"] == "string"
  107. assert "context_summary" in properties
  108. assert properties["context_summary"]["type"] == "string"
  109. assert "priority" in required
  110. assert "context_summary" in required
  111. # Verify tool_name was applied
  112. assert tool.name == "send_message_inline"
  113. def test_tool_name_class_attribute() -> None:
  114. """tool_name class attribute should set the tool name without __init__ override."""
  115. class MySendMessage(SendMessage):
  116. tool_name = "send_message_custom"
  117. sender = _make_stub_agent("Sender")
  118. tool = MySendMessage(sender)
  119. assert tool.name == "send_message_custom"
  120. def test_tool_name_explicit_name_takes_precedence() -> None:
  121. """Explicit name= kwarg should override tool_name class attribute."""
  122. class MySendMessage(SendMessage):
  123. tool_name = "send_message_custom"
  124. sender = _make_stub_agent("Sender")
  125. tool = MySendMessage(sender, name="send_message_override")
  126. assert tool.name == "send_message_override"
  127. def test_inline_fields_not_picked_up_when_extra_params_exists() -> None:
  128. """ExtraParams nested class should take priority over inline fields."""
  129. class MixedSendMessage(SendMessage):
  130. class ExtraParams(BaseModel):
  131. from_nested: str = Field(description="From nested")
  132. # This should be ignored because ExtraParams is present
  133. inline_field: str = Field(description="Should be ignored")
  134. sender = _make_stub_agent("Sender")
  135. tool = MixedSendMessage(sender)
  136. properties = tool.params_json_schema["properties"]
  137. assert "from_nested" in properties
  138. assert "inline_field" not in properties
  139. def test_bare_annotations_are_not_treated_as_inline_fields() -> None:
  140. """Bare subclass annotations should remain internal typing hints."""
  141. class BareAnnotationSendMessage(SendMessage):
  142. internal_note: str
  143. sender = _make_stub_agent("Sender")
  144. tool = BareAnnotationSendMessage(sender)
  145. assert "internal_note" not in tool.params_json_schema["properties"]
  146. def test_inline_fields_resolve_forward_annotations() -> None:
  147. """Inline fields should resolve future-style annotations from the subclass module."""
  148. class FutureStyleAnnotation(SendMessage):
  149. context: "InlineContext" = Field(description="Structured context")
  150. sender = _make_stub_agent("Sender")
  151. tool = FutureStyleAnnotation(sender)
  152. assert "context" in tool.params_json_schema["properties"]
  153. assert tool._extra_params_model is not None
  154. assert tool._extra_params_model.model_fields["context"].annotation is InlineContext
  155. @pytest.mark.asyncio
  156. async def test_inline_fields_validation() -> None:
  157. """Inline fields should be validated like ExtraParams fields."""
  158. sender = _make_stub_agent("Sender")
  159. recipient = _make_stub_agent("Recipient")
  160. class StrictInline(SendMessage):
  161. count: int = Field(description="Must be integer")
  162. tool = StrictInline(sender, recipients={"Recipient": recipient})
  163. wrapper = _wrapper_with_recipient(recipient)
  164. invalid_args = json.dumps(
  165. {
  166. "recipient_agent": "Recipient",
  167. "message": "test",
  168. "additional_instructions": "",
  169. "count": "not_a_number",
  170. }
  171. )
  172. result = await tool.on_invoke_tool(wrapper, invalid_args)
  173. assert isinstance(result, str) and result.startswith("Error: Invalid extra parameters")