test_output_guardrail_retries.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. from unittest.mock import AsyncMock, MagicMock, patch
  2. import pytest
  3. from agents import RunErrorDetails
  4. from agency_swarm import Agent, GuardrailFunctionOutput, OutputGuardrailTripwireTriggered, ThreadManager
  5. from agency_swarm.agent.core import AgencyContext
  6. def _make_tripwire(
  7. agent_output: str,
  8. guidance: str,
  9. *,
  10. include_run_data: bool = True,
  11. ) -> OutputGuardrailTripwireTriggered:
  12. class _GuardrailObj:
  13. pass
  14. class _MockRunItem:
  15. def __init__(self, role: str, content: str):
  16. self.role = role
  17. self.content = content
  18. def to_input_item(self):
  19. return {"role": self.role, "content": self.content}
  20. guardrail_result = type(
  21. "_OutputGuardrailResult",
  22. (),
  23. {
  24. "agent_output": agent_output,
  25. "output": GuardrailFunctionOutput(output_info=guidance, tripwire_triggered=True),
  26. "guardrail": _GuardrailObj(),
  27. },
  28. )()
  29. # Create the exception with the guardrail_result
  30. exception = OutputGuardrailTripwireTriggered(guardrail_result)
  31. # Set the run_data on the exception - needed by _extract_guardrail_texts
  32. if include_run_data:
  33. exception.run_data = RunErrorDetails(
  34. input=[],
  35. new_items=[_MockRunItem("assistant", agent_output)],
  36. raw_responses=[],
  37. last_agent=None,
  38. context_wrapper=None,
  39. input_guardrail_results=[],
  40. output_guardrail_results=[],
  41. )
  42. return exception
  43. @pytest.mark.asyncio
  44. @patch("agency_swarm.agent.execution_helpers.Runner.run", new_callable=AsyncMock)
  45. async def test_output_guardrail_retries_update_history(mock_runner_run):
  46. agent = Agent(name="RetryAgent", instructions="Test", validation_attempts=1)
  47. # Prepare minimal agency context to capture messages
  48. ctx = AgencyContext(agency_instance=None, thread_manager=ThreadManager(), subagents={})
  49. # First attempt trips, second returns a minimal RunResult-like object
  50. mock_runner_run.side_effect = [
  51. _make_tripwire(agent_output="BAD OUTPUT", guidance="ERROR: fix format"),
  52. MagicMock(new_items=[], final_output="GOOD"),
  53. ]
  54. # Execute
  55. res = await agent.get_response(message="What is openai?", agency_context=ctx)
  56. assert getattr(res, "final_output", None) == "GOOD"
  57. # Validate conversation history contains initial user, appended assistant, appended user guidance
  58. all_msgs = ctx.thread_manager.get_all_messages()
  59. # Extract role and content for clarity
  60. trio = [(m.get("role"), m.get("content")) for m in all_msgs]
  61. # Expect at least 3 messages; find the last three
  62. assert ("user", "What is openai?") in trio
  63. assert ("assistant", "BAD OUTPUT") in trio
  64. assert ("system", "ERROR: fix format") in trio
  65. # The guidance system message should be classified as an output guardrail error
  66. sys_msgs = [m for m in all_msgs if m.get("role") == "system"]
  67. assert sys_msgs and sys_msgs[-1].get("message_origin") == "output_guardrail_error"
  68. @pytest.mark.asyncio
  69. @patch("agency_swarm.agent.execution_helpers.Runner.run", new_callable=AsyncMock)
  70. async def test_output_guardrail_retries_without_run_data(mock_runner_run):
  71. agent = Agent(name="RetryAgentNoRunData", instructions="Test", validation_attempts=1)
  72. ctx = AgencyContext(agency_instance=None, thread_manager=ThreadManager(), subagents={})
  73. mock_runner_run.side_effect = [
  74. _make_tripwire(agent_output="MALFORMED", guidance="Provide JSON", include_run_data=False),
  75. MagicMock(new_items=[], final_output="RECOVERED"),
  76. ]
  77. result = await agent.get_response(message="Fix this", agency_context=ctx)
  78. assert getattr(result, "final_output", None) == "RECOVERED"
  79. history = ctx.thread_manager.get_all_messages()
  80. contents = [(m.get("role"), m.get("content")) for m in history]
  81. assert ("assistant", "MALFORMED") in contents
  82. assert ("system", "Provide JSON") in contents
  83. class _DummyStream:
  84. def __init__(self, events):
  85. self._events = events
  86. async def stream_events(self):
  87. for ev in self._events:
  88. yield ev
  89. def cancel(self):
  90. pass
  91. class _SimpleEvent:
  92. def __init__(self, t: str):
  93. self.type = t
  94. @pytest.mark.asyncio
  95. @patch("agency_swarm.agent.execution_helpers.Runner.run_streamed")
  96. async def test_output_guardrail_retries_streaming(mock_run_streamed):
  97. agent = Agent(name="RetryStreamAgent", instructions="Test", validation_attempts=1)
  98. ctx = AgencyContext(agency_instance=None, thread_manager=ThreadManager(), subagents={})
  99. # First call raises; second returns a dummy stream with one event
  100. mock_run_streamed.side_effect = [
  101. _make_tripwire(agent_output="STREAM BAD", guidance="ERROR: needs header"),
  102. _DummyStream([_SimpleEvent("run_item_stream_event")]),
  103. ]
  104. # Collect streamed events
  105. received = []
  106. async for ev in agent.get_response_stream(message="Hello", agency_context=ctx):
  107. received.append(ev)
  108. assert received, "expected events from second attempt"
  109. # The guidance user message should be in history
  110. msgs = ctx.thread_manager.get_all_messages()
  111. roles_contents = [(m.get("role"), m.get("content")) for m in msgs]
  112. assert ("system", "ERROR: needs header") in roles_contents