test_send_message_extra_params.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import json
  2. import pytest
  3. from agents import RunContextWrapper
  4. from pydantic import BaseModel, Field
  5. from agency_swarm import Agency, Agent, ModelSettings
  6. from agency_swarm.context import MasterContext
  7. from agency_swarm.tools.send_message import SendMessage
  8. from agency_swarm.utils.thread import ThreadManager
  9. class ExtraParams(BaseModel):
  10. key_moments: str = Field(description="Important context")
  11. decisions: str = Field(description="Decisions made")
  12. class SendMessageWithContext(SendMessage):
  13. extra_params_model = ExtraParams
  14. class NestedSendMessage(SendMessage):
  15. class ExtraParams(BaseModel):
  16. summary: str = Field(description="Short summary")
  17. @pytest.mark.asyncio
  18. async def test_schema_includes_extra_params_for_explicit_and_nested_models():
  19. a = Agent(name="A", instructions="", model_settings=ModelSettings(temperature=0.0))
  20. b = Agent(name="B", instructions="", model_settings=ModelSettings(temperature=0.0))
  21. explicit_agency = Agency(a, communication_flows=[(a > b, SendMessageWithContext)])
  22. explicit_tool = next(iter(explicit_agency.get_agent_runtime_state("A").send_message_tools.values()))
  23. explicit_props = explicit_tool.params_json_schema.get("properties", {})
  24. explicit_required = explicit_tool.params_json_schema.get("required", [])
  25. assert "key_moments" in explicit_props and explicit_props["key_moments"]["type"] == "string"
  26. assert "decisions" in explicit_props and explicit_props["decisions"]["type"] == "string"
  27. assert "key_moments" in explicit_required and "decisions" in explicit_required
  28. nested_agency = Agency(a, communication_flows=[(a > b, NestedSendMessage)])
  29. nested_tool = next(iter(nested_agency.get_agent_runtime_state("A").send_message_tools.values()))
  30. nested_props = nested_tool.params_json_schema.get("properties", {})
  31. nested_required = nested_tool.params_json_schema.get("required", [])
  32. assert "summary" in nested_props and nested_props["summary"]["type"] == "string"
  33. assert "summary" in nested_required
  34. @pytest.mark.asyncio
  35. async def test_validation_of_extra_params_errors():
  36. a = Agent(
  37. name="A",
  38. instructions="Use send_message to talk to B and include fields.",
  39. model_settings=ModelSettings(temperature=0.0),
  40. )
  41. b = Agent(name="B", instructions="Reply with OK", model_settings=ModelSettings(temperature=0.0))
  42. agency = Agency(a, communication_flows=[(a > b, SendMessageWithContext)])
  43. runtime_state = agency.get_agent_runtime_state("A")
  44. send_tool = next(iter(runtime_state.send_message_tools.values()))
  45. args = {
  46. "recipient_agent": "B",
  47. "message": "hi",
  48. "additional_instructions": "",
  49. }
  50. wrapper = RunContextWrapper(context=MasterContext(thread_manager=ThreadManager(), agents={}))
  51. out = await send_tool.on_invoke_tool(wrapper, json.dumps(args))
  52. assert isinstance(out, str) and out.startswith("Error: Invalid extra parameters")