_response_test_helpers.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from typing import Any
  2. from agents import ModelSettings, RunConfig, RunHooks, RunResult, TResponseInputItem
  3. from agency_swarm import Agent
  4. from agency_swarm.agent.context_types import AgencyContext
  5. from tests.deterministic_model import DeterministicModel
  6. def _make_agent(name: str, response_text: str = "Test response") -> Agent:
  7. return Agent(
  8. name=name,
  9. instructions="You are a test agent.",
  10. model=DeterministicModel(default_response=response_text),
  11. model_settings=ModelSettings(temperature=0.0),
  12. )
  13. class CapturingAgent(Agent):
  14. def __init__(self, name: str, response_text: str = "Test response") -> None:
  15. super().__init__(
  16. name=name,
  17. instructions="You are a test agent.",
  18. model=DeterministicModel(default_response=response_text),
  19. model_settings=ModelSettings(temperature=0.0),
  20. )
  21. self.last_context_override: dict[str, Any] | None = None
  22. self.last_hooks_override: RunHooks | None = None
  23. self.last_agency_context: AgencyContext | None = None
  24. self.last_message: str | list[TResponseInputItem] | None = None
  25. async def get_response(
  26. self,
  27. message: str | list[TResponseInputItem],
  28. sender_name: str | None = None,
  29. context_override: dict[str, Any] | None = None,
  30. hooks_override: RunHooks | None = None,
  31. run_config_override: RunConfig | None = None,
  32. file_ids: list[str] | None = None,
  33. additional_instructions: str | None = None,
  34. agency_context: AgencyContext | None = None,
  35. **kwargs: Any,
  36. ) -> RunResult:
  37. self.last_message = message
  38. self.last_context_override = context_override
  39. self.last_hooks_override = hooks_override
  40. self.last_agency_context = agency_context
  41. return await super().get_response(
  42. message=message,
  43. sender_name=sender_name,
  44. context_override=context_override,
  45. hooks_override=hooks_override,
  46. run_config_override=run_config_override,
  47. file_ids=file_ids,
  48. additional_instructions=additional_instructions,
  49. agency_context=agency_context,
  50. **kwargs,
  51. )
  52. def get_response_stream( # type: ignore[override]
  53. self,
  54. message: str | list[dict[str, Any]],
  55. sender_name: str | None = None,
  56. context_override: dict[str, Any] | None = None,
  57. hooks_override: RunHooks | None = None,
  58. run_config_override: RunConfig | None = None,
  59. file_ids: list[str] | None = None,
  60. additional_instructions: str | None = None,
  61. agency_context: AgencyContext | None = None,
  62. parent_run_id: str | None = None,
  63. **kwargs: Any,
  64. ):
  65. self.last_message = message
  66. self.last_context_override = context_override
  67. self.last_hooks_override = hooks_override
  68. self.last_agency_context = agency_context
  69. return super().get_response_stream(
  70. message=message,
  71. sender_name=sender_name,
  72. context_override=context_override,
  73. hooks_override=hooks_override,
  74. run_config_override=run_config_override,
  75. file_ids=file_ids,
  76. additional_instructions=additional_instructions,
  77. agency_context=agency_context,
  78. parent_run_id=parent_run_id,
  79. **kwargs,
  80. )