test_guardrail_validation.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. from unittest.mock import AsyncMock, MagicMock, patch
  2. import pytest
  3. from agency_swarm import (
  4. GuardrailFunctionOutput,
  5. InputGuardrailTripwireTriggered,
  6. OutputGuardrailTripwireTriggered,
  7. ThreadManager,
  8. )
  9. from agency_swarm.agent.core import AgencyContext
  10. from agency_swarm.agent.execution_streaming import prune_guardrail_messages
  11. @pytest.mark.asyncio
  12. @patch("agents.Runner.run", new_callable=AsyncMock)
  13. async def test_output_guardrail_auto_retry(mock_runner_run, minimal_agent, mock_thread_manager):
  14. class _Out:
  15. output_info = "fix it"
  16. class _OutputGuardrailResult:
  17. output = _Out()
  18. guardrail = object()
  19. mock_runner_run.side_effect = [
  20. OutputGuardrailTripwireTriggered(_OutputGuardrailResult()),
  21. MagicMock(new_items=[], final_output="ok"),
  22. ]
  23. result = await minimal_agent.get_response("Task: Demo")
  24. assert result.final_output == "ok"
  25. assert mock_runner_run.call_count == 2
  26. second_input = mock_runner_run.call_args_list[1].kwargs["input"]
  27. assert second_input[-1]["content"] == "fix it"
  28. @pytest.mark.asyncio
  29. async def test_input_guardrail_no_retry_streaming(monkeypatch, minimal_agent):
  30. agent = minimal_agent
  31. # Ensure multiple attempts available to prove no retry happens
  32. agent.validation_attempts = 2
  33. agent.raise_input_guardrail_error = True
  34. ctx = AgencyContext(agency_instance=None, thread_manager=ThreadManager(), subagents={})
  35. calls = {"n": 0}
  36. def fake_run_streamed(**kwargs):
  37. calls["n"] += 1
  38. class _InRes:
  39. output = GuardrailFunctionOutput(
  40. output_info="Prefix your request with 'Task:'",
  41. tripwire_triggered=True,
  42. )
  43. guardrail = object()
  44. raise InputGuardrailTripwireTriggered(_InRes())
  45. monkeypatch.setattr(
  46. "agency_swarm.agent.execution_helpers.Runner.run_streamed",
  47. staticmethod(fake_run_streamed),
  48. )
  49. received: list[object] = []
  50. stream = agent.get_response_stream(message="Hello", agency_context=ctx)
  51. with pytest.raises(InputGuardrailTripwireTriggered):
  52. async for ev in stream:
  53. received.append(ev)
  54. # Should surface an error event and not retry
  55. assert any(isinstance(ev, dict) and ev.get("type") == "error" for ev in received)
  56. err = next(ev for ev in received if isinstance(ev, dict) and ev.get("type") == "error")
  57. assert "Task:" in err.get("content", "")
  58. assert calls["n"] == 1
  59. # Validate persisted guidance is marked as input_guardrail_error in streaming mode
  60. msgs = ctx.thread_manager.get_all_messages()
  61. sys_msgs = [m for m in msgs if m.get("role") == "system"]
  62. assert sys_msgs and sys_msgs[-1].get("message_origin") == "input_guardrail_error"
  63. @pytest.mark.asyncio
  64. @patch("agents.Runner.run", new_callable=AsyncMock)
  65. async def test_input_guardrail_returns_error_non_stream(mock_runner_run, minimal_agent, mock_thread_manager):
  66. agent = minimal_agent
  67. agent.raise_input_guardrail_error = False
  68. class _InRes:
  69. output = GuardrailFunctionOutput(
  70. output_info="Prefix your request with 'Task:'",
  71. tripwire_triggered=True,
  72. )
  73. guardrail = object()
  74. mock_runner_run.side_effect = InputGuardrailTripwireTriggered(_InRes())
  75. ctx = AgencyContext(agency_instance=None, thread_manager=mock_thread_manager, subagents={})
  76. res = await agent.get_response(message="Hello", agency_context=ctx)
  77. assert res.final_output == "Prefix your request with 'Task:'"
  78. msgs = ctx.thread_manager.get_all_messages()
  79. roles_contents = [(m.get("role"), m.get("content")) for m in msgs]
  80. assert ("assistant", "Prefix your request with 'Task:'") in roles_contents
  81. assistant_msgs = [m for m in msgs if m.get("role") == "assistant"]
  82. system_msgs = [m for m in msgs if m.get("role") == "system"]
  83. assert assistant_msgs and assistant_msgs[-1].get("message_origin") == "input_guardrail_message"
  84. assert not system_msgs
  85. @pytest.mark.asyncio
  86. @patch("agents.Runner.run", new_callable=AsyncMock)
  87. async def test_input_guardrail_error_no_assistant_messages(mock_runner_run, minimal_agent, mock_thread_manager):
  88. """When raise_input_guardrail_error=True, no assistant messages should persist."""
  89. agent = minimal_agent
  90. agent.raise_input_guardrail_error = True
  91. class _InRes:
  92. output = GuardrailFunctionOutput(
  93. output_info="Prefix your request with 'Task:'",
  94. tripwire_triggered=True,
  95. )
  96. guardrail = object()
  97. mock_runner_run.side_effect = InputGuardrailTripwireTriggered(_InRes())
  98. ctx = AgencyContext(agency_instance=None, thread_manager=mock_thread_manager, subagents={})
  99. with pytest.raises(InputGuardrailTripwireTriggered):
  100. await agent.get_response(message="Hello", agency_context=ctx)
  101. msgs = ctx.thread_manager.get_all_messages()
  102. # Should have exactly 2 messages: user input + system guardrail error
  103. assert len(msgs) == 2, f"Expected 2 messages (user + guardrail), got {len(msgs)}: {msgs}"
  104. # First message: user input
  105. assert msgs[0].get("role") == "user"
  106. assert msgs[0].get("content") == "Hello"
  107. # Second message: system guardrail error (not message)
  108. assert msgs[1].get("role") == "system"
  109. assert "Prefix your request with 'Task:'" in msgs[1].get("content", "")
  110. assert msgs[1].get("message_origin") == "input_guardrail_error"
  111. # Critical: NO assistant messages should be present
  112. assistant_msgs = [m for m in msgs if m.get("role") == "assistant"]
  113. assert len(assistant_msgs) == 0, f"Expected no assistant messages, but found {len(assistant_msgs)}"
  114. def test_prune_guardrail_messages_drops_subagent_history():
  115. """Input guardrail guidance must remove downstream agent chatter from history."""
  116. run_trace_id = "trace_123"
  117. messages = [
  118. {
  119. "role": "user",
  120. "content": "What is your support email address?",
  121. "agent": "CustomerSupportAgent",
  122. "callerAgent": None,
  123. "agent_run_id": "agent_run_parent",
  124. "run_trace_id": run_trace_id,
  125. },
  126. {
  127. "role": "user",
  128. "content": "Please provide the support email address.",
  129. "agent": "DatabaseAgent",
  130. "callerAgent": "CustomerSupportAgent",
  131. "agent_run_id": "agent_run_database",
  132. "run_trace_id": run_trace_id,
  133. },
  134. {
  135. "role": "user",
  136. "content": "Please provide the support email address.",
  137. "agent": "EmailAgent",
  138. "callerAgent": "DatabaseAgent",
  139. "agent_run_id": "agent_run_email",
  140. "run_trace_id": run_trace_id,
  141. },
  142. {
  143. "role": "system",
  144. "content": "Please, prefix your request with 'Support:'.",
  145. "message_origin": "input_guardrail_error",
  146. "agent": "EmailAgent",
  147. "callerAgent": "DatabaseAgent",
  148. "agent_run_id": "agent_run_email",
  149. "run_trace_id": run_trace_id,
  150. },
  151. {
  152. "role": "system",
  153. "content": "When chatting with this agent, provide your name (Alice).",
  154. "message_origin": "input_guardrail_error",
  155. "agent": "DatabaseAgent",
  156. "callerAgent": "CustomerSupportAgent",
  157. "agent_run_id": "agent_run_database",
  158. "run_trace_id": run_trace_id,
  159. },
  160. {
  161. "role": "assistant",
  162. "content": "Please, prefix your request with 'Support:' describing what you need.",
  163. "message_origin": "input_guardrail_message",
  164. "agent": "CustomerSupportAgent",
  165. "callerAgent": None,
  166. "agent_run_id": "agent_run_parent",
  167. "run_trace_id": run_trace_id,
  168. },
  169. ]
  170. cleaned = prune_guardrail_messages(
  171. messages,
  172. initial_saved_count=0,
  173. run_trace_id=run_trace_id,
  174. collapse_to_root=True,
  175. )
  176. assert cleaned == [messages[0], messages[-1]], cleaned
  177. assert all(msg.get("agent") == "CustomerSupportAgent" for msg in cleaned)
  178. assert [msg.get("role") for msg in cleaned] == ["user", "assistant"]
  179. @pytest.mark.asyncio
  180. async def test_input_guardrail_streaming_strict_prunes_and_raises(monkeypatch, minimal_agent):
  181. agent = minimal_agent
  182. agent.raise_input_guardrail_error = True
  183. ctx = AgencyContext(agency_instance=None, thread_manager=ThreadManager(), subagents={})
  184. pruned_history = [{"role": "assistant", "content": "kept"}]
  185. observed_calls: list[dict[str, object]] = []
  186. def fake_prune(messages, *, initial_saved_count, run_trace_id, collapse_to_root):
  187. observed_calls.append(
  188. {
  189. "initial_saved_count": initial_saved_count,
  190. "run_trace_id": run_trace_id,
  191. "collapse_to_root": collapse_to_root,
  192. "message_count": len(messages),
  193. }
  194. )
  195. return list(pruned_history)
  196. monkeypatch.setattr("agency_swarm.agent.execution_streaming.prune_guardrail_messages", fake_prune)
  197. class _GuardrailResult:
  198. def __init__(self):
  199. self.output = GuardrailFunctionOutput(
  200. output_info="Please, prefix your request with 'Support:' describing what you need.",
  201. tripwire_triggered=True,
  202. )
  203. self.guardrail = object()
  204. class _FakeRunResult:
  205. def __init__(self):
  206. self.guardrail_result = _GuardrailResult()
  207. self.input_guardrail_results = [self.guardrail_result]
  208. self.new_items = []
  209. self.raw_responses = []
  210. self.final_output = ""
  211. async def stream_events(self):
  212. raise InputGuardrailTripwireTriggered(self.guardrail_result)
  213. yield # pragma: no cover
  214. def cancel(self):
  215. return None
  216. monkeypatch.setattr("agents.Runner.run_streamed", staticmethod(lambda **_: _FakeRunResult()))
  217. stream = agent.get_response_stream(message="Hello", agency_context=ctx)
  218. with pytest.raises(InputGuardrailTripwireTriggered):
  219. async for _ in stream:
  220. pass
  221. assert observed_calls, "prune_guardrail_messages was not invoked"
  222. assert observed_calls[-1]["collapse_to_root"] is True
  223. assert ctx.thread_manager.get_all_messages() == pruned_history
  224. @pytest.mark.asyncio
  225. async def test_input_guardrail_streaming_friendly_prunes_and_streams(monkeypatch, minimal_agent):
  226. agent = minimal_agent
  227. agent.raise_input_guardrail_error = False
  228. ctx = AgencyContext(agency_instance=None, thread_manager=ThreadManager(), subagents={})
  229. pruned_history = [{"role": "assistant", "content": "friendly"}]
  230. prune_invocations = 0
  231. def fake_prune(messages, *, initial_saved_count, run_trace_id, collapse_to_root):
  232. nonlocal prune_invocations
  233. prune_invocations += 1
  234. return list(pruned_history)
  235. monkeypatch.setattr("agency_swarm.agent.execution_streaming.prune_guardrail_messages", fake_prune)
  236. class _GuardrailResult:
  237. def __init__(self):
  238. self.output = GuardrailFunctionOutput(
  239. output_info="Only support questions are allowed.",
  240. tripwire_triggered=True,
  241. )
  242. self.guardrail = object()
  243. class _FakeRunResult:
  244. def __init__(self):
  245. self.guardrail_result = _GuardrailResult()
  246. self.input_guardrail_results = [self.guardrail_result]
  247. self.new_items = []
  248. self.raw_responses = []
  249. self.final_output = ""
  250. async def stream_events(self):
  251. raise InputGuardrailTripwireTriggered(self.guardrail_result)
  252. yield # pragma: no cover
  253. def cancel(self):
  254. return None
  255. monkeypatch.setattr("agents.Runner.run_streamed", staticmethod(lambda **_: _FakeRunResult()))
  256. stream = agent.get_response_stream(message="Need help", agency_context=ctx)
  257. events: list[object] = []
  258. async for ev in stream:
  259. events.append(ev)
  260. assert prune_invocations == 1
  261. assert ctx.thread_manager.get_all_messages() == pruned_history
  262. assert any(getattr(ev, "name", None) == "message_output_created" for ev in events)