test_agent_initialization.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. from types import SimpleNamespace
  2. from unittest.mock import MagicMock
  3. import pytest
  4. from agents import FunctionTool, ModelSettings, StopAtTools, WebSearchTool
  5. from pydantic import BaseModel, Field
  6. from agency_swarm import Agent
  7. from agency_swarm.integrations.openclaw_model import build_openclaw_responses_model
  8. class TaskOutput(BaseModel):
  9. task_name: str = Field(..., description="Name of the task")
  10. status: str = Field(..., description="Status of the task")
  11. priority: int = Field(..., description="Priority level (1-5)")
  12. class SimpleOutput(BaseModel):
  13. message: str = Field(..., description="Simple message")
  14. # --- Initialization Tests ---
  15. def test_agent_initialization_with_stop_at_tools_variants():
  16. """Agent should accept StopAtTools-typed and dict-compatible tool_use_behavior values."""
  17. cases = [
  18. StopAtTools(stop_at_tool_names=["ToolA", "ToolB"]),
  19. {"stop_at_tool_names": ["ToolC"]},
  20. ]
  21. for behavior in cases:
  22. agent = Agent(name="AgentStopAtTools", instructions="Test", tool_use_behavior=behavior)
  23. assert agent.tool_use_behavior == behavior
  24. def test_agent_initialization_core_configuration_variants():
  25. """Core initialization should preserve baseline defaults and explicit tool/model/output settings."""
  26. minimal = Agent(name="Agent1", instructions="Be helpful")
  27. assert minimal.name == "Agent1"
  28. assert minimal.instructions == "Be helpful"
  29. assert minimal.model == "gpt-5.4-mini"
  30. assert minimal.tools == []
  31. assert minimal.files_folder is None
  32. assert not hasattr(minimal, "response_validator")
  33. assert minimal.output_type is None
  34. tool = MagicMock(spec=FunctionTool)
  35. tool.name = "tool1"
  36. configured = Agent(
  37. name="ConfiguredAgent",
  38. instructions="Use tools",
  39. tools=[tool],
  40. model_settings=ModelSettings(
  41. temperature=0.3,
  42. max_tokens=16,
  43. top_p=0.5,
  44. ),
  45. output_type=SimpleOutput,
  46. )
  47. assert configured.tools == [tool]
  48. assert configured.model_settings.temperature == 0.3
  49. assert configured.model_settings.max_tokens == 16
  50. assert configured.model_settings.top_p == 0.5
  51. assert configured.output_type == SimpleOutput
  52. def test_agent_initialization_rejects_deprecated_kwargs() -> None:
  53. """Deprecated initialization kwargs should fail fast with clear errors."""
  54. cases: list[tuple[dict[str, object], str]] = [
  55. (
  56. {"temperature": 0.3, "max_prompt_tokens": 16},
  57. r"Deprecated Agent parameters are not supported",
  58. ),
  59. (
  60. {"reasoning_effort": "medium"},
  61. r"reasoning_effort",
  62. ),
  63. (
  64. {"truncation_strategy": "auto"},
  65. r"truncation_strategy",
  66. ),
  67. (
  68. {"response_format": {"type": "json_schema", "json_schema": {"name": "X", "schema": {}}}},
  69. r"response_format",
  70. ),
  71. (
  72. {"response_format": SimpleOutput},
  73. r"response_format",
  74. ),
  75. (
  76. {"max_prompt_tokens": 100, "max_completion_tokens": 150},
  77. r"max_prompt_tokens",
  78. ),
  79. (
  80. {
  81. "validation_attempts": 2,
  82. "id": "abc123",
  83. "tool_resources": {"vs": 1},
  84. "file_ids": ["f1"],
  85. "file_search": True,
  86. "refresh_from_id": "old",
  87. "send_message_tool_class": object,
  88. },
  89. r"Deprecated Agent parameters are not supported",
  90. ),
  91. ]
  92. for kwargs, message in cases:
  93. with pytest.raises(TypeError, match=message):
  94. Agent(name="DeprecatedKwargsAgent", instructions="Test", **kwargs)
  95. def test_agent_initialization_output_type_variants():
  96. """Explicit output types should be preserved while omitted output_type stays None."""
  97. assert Agent(name="TaskAgent", instructions="Task agent", output_type=TaskOutput).output_type == TaskOutput
  98. assert Agent(name="SimpleAgent", instructions="Simple agent", output_type=SimpleOutput).output_type == SimpleOutput
  99. assert Agent(name="BasicAgent", instructions="Basic agent").output_type is None
  100. def test_agent_initialization_guardrail_flag_aliases_and_failures() -> None:
  101. """Guardrail aliases should map consistently and fail fast for conflicts/legacy kwargs."""
  102. canonical_agent = Agent(
  103. name="AliasCanonicalAgent",
  104. instructions="Test",
  105. raise_input_guardrail_error=True,
  106. )
  107. assert canonical_agent.raise_input_guardrail_error is True
  108. with pytest.warns(DeprecationWarning, match=r"throw_input_guardrail_error"):
  109. deprecated_alias_agent = Agent(
  110. name="AliasDeprecatedAgent",
  111. instructions="Test",
  112. throw_input_guardrail_error=True,
  113. )
  114. assert deprecated_alias_agent.raise_input_guardrail_error is True
  115. with pytest.warns(DeprecationWarning, match=r"throw_input_guardrail_error"):
  116. matching_alias_agent = Agent(
  117. name="AliasMatchAgent",
  118. instructions="Test",
  119. raise_input_guardrail_error=False,
  120. throw_input_guardrail_error=False,
  121. )
  122. assert matching_alias_agent.raise_input_guardrail_error is False
  123. with pytest.raises(TypeError, match=r"Conflicting values for `raise_input_guardrail_error`"):
  124. Agent(
  125. name="AliasConflictAgent",
  126. instructions="Test",
  127. raise_input_guardrail_error=True,
  128. throw_input_guardrail_error=False,
  129. )
  130. with pytest.raises(TypeError) as exc_info:
  131. Agent(
  132. name="LegacyGuardrailAgent",
  133. instructions="Test",
  134. return_input_guardrail_errors=False,
  135. )
  136. error_message = str(exc_info.value)
  137. assert "return_input_guardrail_errors" in error_message
  138. assert "raise_input_guardrail_error" in error_message
  139. agent = Agent(
  140. name="AliasPropertyAgent",
  141. instructions="Test",
  142. raise_input_guardrail_error=True,
  143. )
  144. assert agent.throw_input_guardrail_error is True
  145. agent.throw_input_guardrail_error = False
  146. assert agent.raise_input_guardrail_error is False
  147. def test_agent_initialization_support_flags_override_defaults() -> None:
  148. """Capability flags should persist when a plain Agent overrides them."""
  149. agent = Agent(
  150. name="RestrictedAgent",
  151. instructions="Test",
  152. supports_outbound_communication=False,
  153. supports_framework_tool_wiring=False,
  154. )
  155. assert agent.supports_outbound_communication is False
  156. assert agent.supports_framework_tool_wiring is False
  157. def test_agent_initialization_skips_framework_tool_wiring_when_disabled(tmp_path) -> None:
  158. """Framework-managed tool folders should be ignored when tool wiring is disabled."""
  159. tools_dir = tmp_path / "tools"
  160. tools_dir.mkdir()
  161. (tools_dir / "loaded_tool.py").write_text(
  162. "from agents import function_tool\n\n@function_tool\ndef loaded_tool() -> str:\n return 'loaded'\n",
  163. encoding="utf-8",
  164. )
  165. agent = Agent(
  166. name="RestrictedAgent",
  167. instructions="Test",
  168. tools_folder=str(tools_dir),
  169. supports_framework_tool_wiring=False,
  170. )
  171. assert agent.tools == []
  172. def test_agent_initialization_keeps_explicit_files_folder_when_framework_tool_wiring_disabled(tmp_path) -> None:
  173. """Explicit files_folder support should survive even when framework-managed tool wiring is disabled."""
  174. files_dir = tmp_path / "docs_vs_abcdefghijklmnop"
  175. files_dir.mkdir()
  176. agent = Agent(
  177. name="RestrictedAgent",
  178. instructions="Test",
  179. files_folder=str(files_dir),
  180. supports_framework_tool_wiring=False,
  181. )
  182. assert agent._associated_vector_store_id == "vs_abcdefghijklmnop"
  183. assert [tool.__class__.__name__ for tool in agent.tools] == ["FileSearchTool"]
  184. def test_agent_initialization_converts_explicit_mcp_servers_when_framework_tool_wiring_disabled(
  185. monkeypatch: pytest.MonkeyPatch,
  186. ) -> None:
  187. converted: list[str] = []
  188. def _convert(agent: Agent) -> None:
  189. converted.append(agent.name)
  190. monkeypatch.setattr("agency_swarm.agent.core.convert_mcp_servers_to_tools", _convert)
  191. Agent(
  192. name="RestrictedAgent",
  193. instructions="Test",
  194. mcp_servers=[SimpleNamespace(name="demo")],
  195. supports_framework_tool_wiring=False,
  196. )
  197. assert converted == ["RestrictedAgent"]
  198. def test_agent_initialization_with_all_parameters():
  199. """Test Agent initialization with all parameters including output_type."""
  200. tool1 = MagicMock(spec=FunctionTool)
  201. tool1.name = "tool1"
  202. # TEST-ONLY SETUP: Create test directory to enable FileSearchTool auto-addition
  203. import tempfile
  204. from pathlib import Path
  205. from unittest.mock import PropertyMock, patch
  206. # Create a temporary test directory
  207. with tempfile.TemporaryDirectory(prefix="test_files_") as temp_dir_str:
  208. temp_dir = Path(temp_dir_str)
  209. test_file = temp_dir / "test.txt"
  210. test_file.write_text("test content for FileSearchTool")
  211. # Mock the OpenAI client to avoid API key requirement
  212. mock_vector_store = MagicMock()
  213. mock_vector_store.id = "test_vs_id"
  214. mock_client = MagicMock()
  215. mock_client.vector_stores.create.return_value = mock_vector_store
  216. uploaded_file = MagicMock()
  217. uploaded_file.id = "file-1234567890abcdef"
  218. uploaded_file.created_at = 1_735_689_600
  219. mock_client.files.create.return_value = uploaded_file
  220. vs_file = MagicMock()
  221. vs_file.status = "completed"
  222. mock_client.vector_stores.files.retrieve.return_value = vs_file
  223. # Prevent infinite pagination when syncing vector store files during init
  224. list_resp = MagicMock()
  225. list_resp.data = []
  226. list_resp.has_more = False
  227. list_resp.last_id = None
  228. mock_client.vector_stores.files.list.return_value = list_resp
  229. with patch.object(Agent, "client_sync", new_callable=PropertyMock) as mock_client_sync:
  230. mock_client_sync.return_value = mock_client
  231. agent = Agent(
  232. name="CompleteAgent",
  233. instructions="Complete agent with all params",
  234. model="gpt-5.4-mini",
  235. tools=[tool1],
  236. output_type=TaskOutput,
  237. files_folder=str(temp_dir), # Use temporary directory
  238. description="A complete test agent",
  239. )
  240. assert agent.name == "CompleteAgent"
  241. assert agent.instructions == "Complete agent with all params"
  242. assert agent.model == "gpt-5.4-mini"
  243. assert len(agent.tools) == 2
  244. assert agent.tools[0] == tool1
  245. assert agent.tools[1].__class__.__name__ == "FileSearchTool"
  246. # response_validator is completely removed
  247. assert not hasattr(agent, "response_validator")
  248. assert agent.output_type == TaskOutput
  249. assert str(temp_dir) in str(agent.files_folder) # Should contain the temp directory path
  250. assert agent.description == "A complete test agent"
  251. # --- Instruction File Loading Tests ---
  252. def test_agent_instruction_loading_variants(tmp_path):
  253. """Instruction inputs should support file paths while preserving plain strings."""
  254. # Create instruction file for absolute path test
  255. instruction_file = tmp_path / "agent_instructions.md"
  256. instruction_content = "You are a helpful assistant. Always be polite."
  257. instruction_file.write_text(instruction_content)
  258. # Absolute path
  259. agent = Agent(name="TestAgent", instructions=str(instruction_file), model="gpt-5.4-mini")
  260. assert agent.instructions == instruction_content
  261. # Relative path resolved from caller directory
  262. relative_agent = Agent(name="TestAgent", instructions="../data/files/instructions.md", model="gpt-5.4-mini")
  263. assert relative_agent.instructions == "Test instructions"
  264. instruction_text = "Direct instruction text, not a file path"
  265. agent = Agent(name="TestAgent", instructions=instruction_text, model="gpt-5.4-mini")
  266. assert agent.instructions == instruction_text
  267. def test_agent_initialization_model_settings_defaults_and_overrides():
  268. """Initialization should keep SDK defaults and preserve explicit settings overrides."""
  269. default_agent = Agent(name="TruncDefault", instructions="Test")
  270. assert default_agent.model_settings.truncation == "auto"
  271. explicit_agent = Agent(
  272. name="TruncDisabled",
  273. instructions="Test",
  274. model_settings=ModelSettings(truncation="disabled"),
  275. )
  276. assert explicit_agent.model_settings.truncation == "disabled"
  277. gpt5_agent = Agent(name="Gpt5", instructions="Test", model="gpt-5.4-mini")
  278. assert gpt5_agent.model_settings.reasoning is not None
  279. assert gpt5_agent.model_settings.reasoning.effort == "none"
  280. provider_prefixed_gpt5_agent = Agent(name="ProviderPrefixedGpt5", instructions="Test", model="openai/gpt-5.4-mini")
  281. assert provider_prefixed_gpt5_agent.model_settings.reasoning is None
  282. assert provider_prefixed_gpt5_agent.model_settings.verbosity is None
  283. @pytest.mark.parametrize("provider_model", ["openai/gpt-5.4-mini", "azure/gpt-5.4-mini"])
  284. def test_agent_initialization_model_objects_use_openclaw_default_settings_alias(
  285. monkeypatch: pytest.MonkeyPatch,
  286. provider_model: str,
  287. ) -> None:
  288. monkeypatch.setenv("OPENCLAW_PROVIDER_MODEL", provider_model)
  289. model = build_openclaw_responses_model(base_url="http://127.0.0.1:18789/v1", api_key="test-key")
  290. agent = Agent(name="UsageTrackedModel", instructions="Test", model=model)
  291. assert agent.model_settings.reasoning is not None
  292. assert agent.model_settings.reasoning.effort == "none"
  293. assert agent.model_settings.verbosity == "low"
  294. def test_agent_initialization_model_objects_preserve_explicit_openclaw_alias_defaults() -> None:
  295. model = build_openclaw_responses_model(
  296. model="openclaw:custom",
  297. base_url="http://127.0.0.1:18789/v1",
  298. api_key="test-key",
  299. )
  300. agent = Agent(name="UsageTrackedModel", instructions="Test", model=model)
  301. assert agent.model_settings.reasoning is None
  302. assert agent.model_settings.verbosity is None
  303. def test_agent_initialization_adapts_basetool_type():
  304. """Passing a BaseTool subclass should be adapted to a FunctionTool."""
  305. from pydantic import Field
  306. from agency_swarm.tools import BaseTool
  307. class _T(BaseTool):
  308. x: str = Field(..., description="x")
  309. def run(self):
  310. return self.x
  311. agent = Agent(name="ToolsAdapt", instructions="Test", tools=[_T])
  312. # tools should be adapted to FunctionTool instances
  313. from agents import FunctionTool
  314. assert len(agent.tools) == 1
  315. assert isinstance(agent.tools[0], FunctionTool)
  316. def test_agent_initialization_web_search_source_include_behavior() -> None:
  317. """Web-search source include should support init and add_tool behavior with merge and opt-out."""
  318. cases: list[tuple[Agent, bool, int, str | None]] = [
  319. (
  320. Agent(name="WebAgentDefault", instructions="Test", tools=[WebSearchTool()]),
  321. True,
  322. 1,
  323. None,
  324. ),
  325. (
  326. Agent(
  327. name="WebAgentNoSources",
  328. instructions="Test",
  329. tools=[WebSearchTool()],
  330. include_web_search_sources=False,
  331. ),
  332. False,
  333. 0,
  334. None,
  335. ),
  336. (
  337. Agent(
  338. name="WebAgentMergeSources",
  339. instructions="Test",
  340. tools=[WebSearchTool()],
  341. model_settings=ModelSettings(response_include=["message.output_text.logprobs"]),
  342. ),
  343. True,
  344. 1,
  345. "message.output_text.logprobs",
  346. ),
  347. (
  348. Agent(
  349. name="WebAgentDedupSources",
  350. instructions="Test",
  351. tools=[WebSearchTool()],
  352. model_settings=ModelSettings(response_include=["web_search_call.action.sources"]),
  353. ),
  354. True,
  355. 1,
  356. None,
  357. ),
  358. ]
  359. for agent, has_sources, count, extra_include in cases:
  360. includes = agent.model_settings.response_include or []
  361. assert ("web_search_call.action.sources" in includes) is has_sources
  362. assert includes.count("web_search_call.action.sources") == count
  363. if extra_include is not None:
  364. assert extra_include in includes
  365. add_tool_default = Agent(name="WebAgentAddTool", instructions="Test")
  366. assert (add_tool_default.model_settings.response_include or []) == []
  367. add_tool_default.add_tool(WebSearchTool())
  368. assert "web_search_call.action.sources" in (add_tool_default.model_settings.response_include or [])
  369. add_tool_opt_out = Agent(
  370. name="WebAgentAddToolNoSources",
  371. instructions="Test",
  372. include_web_search_sources=False,
  373. )
  374. add_tool_opt_out.add_tool(WebSearchTool())
  375. assert "web_search_call.action.sources" not in (add_tool_opt_out.model_settings.response_include or [])