test_agency_initialization.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. from collections.abc import AsyncIterator
  2. import pytest
  3. from agents import ModelSettings, Tool
  4. from agents.agent_output import AgentOutputSchemaBase
  5. from agents.handoffs import Handoff as RawSDKHandoff
  6. from agents.items import ModelResponse, TResponseInputItem, TResponseStreamEvent
  7. from agents.models.interface import Model, ModelTracing
  8. from openai.types.responses.response_prompt_param import ResponsePromptParam
  9. import agency_swarm
  10. from agency_swarm import Agency, Agent, Handoff, SDKHandoff
  11. from agency_swarm.agent.conversation_starters_cache import load_cached_starter
  12. from agency_swarm.tools import Handoff as ToolHandoff
  13. from agency_swarm.tools.send_message import SendMessage
  14. from tests.deterministic_model import DeterministicModel
  15. # --- Fixtures ---
  16. def _make_agent(name: str) -> Agent:
  17. return Agent(
  18. name=name,
  19. instructions="You are a test agent.",
  20. model=DeterministicModel(),
  21. model_settings=ModelSettings(temperature=0.0),
  22. )
  23. class _FailingModel(Model):
  24. def __init__(self, model: str = "test-failing") -> None:
  25. self.model = model
  26. async def get_response(
  27. self,
  28. system_instructions: str | None,
  29. input: str | list[TResponseInputItem],
  30. model_settings: ModelSettings,
  31. tools: list[Tool],
  32. output_schema: AgentOutputSchemaBase | None,
  33. handoffs: list[RawSDKHandoff],
  34. tracing: ModelTracing,
  35. *,
  36. previous_response_id: str | None,
  37. conversation_id: str | None,
  38. prompt: ResponsePromptParam | None,
  39. ) -> ModelResponse:
  40. raise RuntimeError("Warmup failure")
  41. def stream_response(
  42. self,
  43. system_instructions: str | None,
  44. input: str | list[TResponseInputItem],
  45. model_settings: ModelSettings,
  46. tools: list[Tool],
  47. output_schema: AgentOutputSchemaBase | None,
  48. handoffs: list[RawSDKHandoff],
  49. tracing: ModelTracing,
  50. *,
  51. previous_response_id: str | None,
  52. conversation_id: str | None,
  53. prompt: ResponsePromptParam | None,
  54. ) -> AsyncIterator[TResponseStreamEvent]:
  55. async def _stream() -> AsyncIterator[TResponseStreamEvent]:
  56. if False:
  57. yield {} # pragma: no cover
  58. return
  59. return _stream()
  60. @pytest.fixture
  61. def mock_agent():
  62. """Provides an Agent instance for testing."""
  63. return _make_agent("MockAgent")
  64. @pytest.fixture
  65. def mock_agent2():
  66. """Provides a second Agent instance for testing."""
  67. return _make_agent("MockAgent2")
  68. # --- Agency Initialization Tests ---
  69. def test_agency_minimal_initialization(mock_agent):
  70. """Test Agency initialization with minimal parameters."""
  71. agency = Agency(mock_agent)
  72. assert agency.agents == {"MockAgent": mock_agent}
  73. assert agency.shared_instructions is None or agency.shared_instructions == ""
  74. assert agency.persistence_hooks is None
  75. def test_agency_initialization_with_flows(mock_agent, mock_agent2):
  76. """Test Agency initialization with communication flows."""
  77. agency = Agency(mock_agent, communication_flows=[(mock_agent, mock_agent2)])
  78. assert agency.agents == {"MockAgent": mock_agent, "MockAgent2": mock_agent2}
  79. # Check that agents are properly registered
  80. assert len(agency.agents) == 2
  81. def test_agency_initialization_shared_instructions(mock_agent):
  82. """Test Agency initialization with shared instructions."""
  83. instructions_content = "These are shared instructions for all agents."
  84. agency = Agency(mock_agent, shared_instructions=instructions_content)
  85. assert agency.shared_instructions == instructions_content
  86. def test_agency_initialization_persistence_hooks(mock_agent):
  87. """Test Agency initialization with persistence hooks."""
  88. saved_messages = []
  89. def mock_load_cb():
  90. return []
  91. def mock_save_cb(messages):
  92. saved_messages.append(messages)
  93. agency = Agency(mock_agent, load_threads_callback=mock_load_cb, save_threads_callback=mock_save_cb)
  94. assert agency.persistence_hooks is not None
  95. # The callbacks are passed to ThreadManager and PersistenceHooks, not stored directly
  96. assert saved_messages == []
  97. def test_agency_duplicate_agent_names_forbidden():
  98. """Test that Agency raises ValueError when trying to register two agents with
  99. the same name but different instances."""
  100. # Create two different mock agents with the same name
  101. agent1 = _make_agent("DuplicateName")
  102. agent2 = _make_agent("DuplicateName")
  103. # Verify they are different instances
  104. assert id(agent1) != id(agent2)
  105. # Attempting to create an Agency with two agents having the same name should raise ValueError
  106. with pytest.raises(ValueError, match=r"Duplicate agent name 'DuplicateName' with different instances found"):
  107. Agency(agent1, agent2)
  108. # --- Shared Instruction File Loading Tests ---
  109. def test_agency_shared_instructions_file_loading(tmp_path):
  110. """Test that agency can load shared instructions from a file."""
  111. # Create shared instruction file
  112. shared_file = tmp_path / "shared_instructions.txt"
  113. shared_content = "All agents must follow these shared guidelines."
  114. shared_file.write_text(shared_content)
  115. # Create test agent
  116. agent = Agent(name="TestAgent", instructions="You are a test agent.", model="gpt-5.4-mini")
  117. # Create agency with shared instruction file
  118. agency = Agency(
  119. agent, # Entry point agent as positional argument
  120. shared_instructions=str(shared_file),
  121. )
  122. assert agency.shared_instructions == shared_content
  123. def test_agency_shared_instructions_string():
  124. """Test that agency accepts instruction strings that aren't files."""
  125. shared_text = "These are shared instructions as text"
  126. agent = Agent(name="TestAgent", instructions="Test agent instructions", model="gpt-5.4-mini")
  127. agency = Agency(
  128. agent, # Entry point agent as positional argument
  129. shared_instructions=shared_text,
  130. )
  131. # Should keep the text as-is since it's not a file
  132. assert agency.shared_instructions == shared_text
  133. def test_agency_shared_instructions_none():
  134. """Test agency with no shared instructions."""
  135. agent = Agent(name="TestAgent", instructions="Test agent", model="gpt-5.4-mini")
  136. agency = Agency(
  137. agent, # Entry point agent as positional argument
  138. shared_instructions=None,
  139. )
  140. assert agency.shared_instructions == ""
  141. def test_agency_rejects_global_model(mock_agent):
  142. """Global model parameter is not supported."""
  143. with pytest.raises(TypeError, match=r"unexpected keyword argument 'model'"):
  144. Agency(mock_agent, model="gpt-4o")
  145. class _CustomSendMessage(SendMessage):
  146. pass
  147. def test_agency_send_message_tool_class_does_not_mutate_agent(mock_agent):
  148. """Agency-level SendMessage fallback should not mutate Agent state."""
  149. sentinel = object()
  150. mock_agent.send_message_tool_class = sentinel
  151. Agency(mock_agent, send_message_tool_class=_CustomSendMessage)
  152. assert mock_agent.send_message_tool_class is sentinel
  153. def test_agency_warmup_failure_does_not_abort_initialization(tmp_path, monkeypatch) -> None:
  154. """Warmup failures should be best-effort during sync init."""
  155. monkeypatch.setenv("AGENCY_SWARM_CHATS_DIR", str(tmp_path))
  156. agent = Agent(
  157. name="WarmupFailAgent",
  158. instructions="You are a test agent.",
  159. model=_FailingModel(),
  160. conversation_starters=["Hello"],
  161. cache_conversation_starters=True,
  162. )
  163. Agency(agent)
  164. def test_agency_warmup_supports_quick_replies_without_starter_cache_flag(tmp_path, monkeypatch) -> None:
  165. monkeypatch.setenv("AGENCY_SWARM_CHATS_DIR", str(tmp_path))
  166. quick_reply = "hi"
  167. agent = Agent(
  168. name="QuickReplyWarmupAgent",
  169. instructions="You are a test agent.",
  170. model=DeterministicModel(default_response="hello"),
  171. quick_replies=[quick_reply],
  172. cache_conversation_starters=False,
  173. )
  174. Agency(agent)
  175. cached = load_cached_starter(
  176. agent.name,
  177. quick_reply,
  178. expected_fingerprint=agent._conversation_starters_fingerprint,
  179. )
  180. assert cached is not None
  181. def test_package_handoff_export_uses_framework_handoff(mock_agent, mock_agent2) -> None:
  182. """Top-level Handoff should configure Agency Swarm flow handoffs."""
  183. assert Handoff is ToolHandoff
  184. assert SDKHandoff is RawSDKHandoff
  185. assert not hasattr(agency_swarm, "AgentsHandoff")
  186. agency = Agency(mock_agent, communication_flows=[(mock_agent, mock_agent2, Handoff)])
  187. runtime_state = agency._agent_runtime_state[mock_agent.name]
  188. assert len(runtime_state.handoffs) == 1
  189. assert runtime_state.send_message_tools == {}