| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- from unittest.mock import AsyncMock, MagicMock, patch
- import pytest
- from agents import RunErrorDetails
- from agency_swarm import Agent, GuardrailFunctionOutput, OutputGuardrailTripwireTriggered, ThreadManager
- from agency_swarm.agent.core import AgencyContext
- def _make_tripwire(
- agent_output: str,
- guidance: str,
- *,
- include_run_data: bool = True,
- ) -> OutputGuardrailTripwireTriggered:
- class _GuardrailObj:
- pass
- class _MockRunItem:
- def __init__(self, role: str, content: str):
- self.role = role
- self.content = content
- def to_input_item(self):
- return {"role": self.role, "content": self.content}
- guardrail_result = type(
- "_OutputGuardrailResult",
- (),
- {
- "agent_output": agent_output,
- "output": GuardrailFunctionOutput(output_info=guidance, tripwire_triggered=True),
- "guardrail": _GuardrailObj(),
- },
- )()
- # Create the exception with the guardrail_result
- exception = OutputGuardrailTripwireTriggered(guardrail_result)
- # Set the run_data on the exception - needed by _extract_guardrail_texts
- if include_run_data:
- exception.run_data = RunErrorDetails(
- input=[],
- new_items=[_MockRunItem("assistant", agent_output)],
- raw_responses=[],
- last_agent=None,
- context_wrapper=None,
- input_guardrail_results=[],
- output_guardrail_results=[],
- )
- return exception
- @pytest.mark.asyncio
- @patch("agency_swarm.agent.execution_helpers.Runner.run", new_callable=AsyncMock)
- async def test_output_guardrail_retries_update_history(mock_runner_run):
- agent = Agent(name="RetryAgent", instructions="Test", validation_attempts=1)
- # Prepare minimal agency context to capture messages
- ctx = AgencyContext(agency_instance=None, thread_manager=ThreadManager(), subagents={})
- # First attempt trips, second returns a minimal RunResult-like object
- mock_runner_run.side_effect = [
- _make_tripwire(agent_output="BAD OUTPUT", guidance="ERROR: fix format"),
- MagicMock(new_items=[], final_output="GOOD"),
- ]
- # Execute
- res = await agent.get_response(message="What is openai?", agency_context=ctx)
- assert getattr(res, "final_output", None) == "GOOD"
- # Validate conversation history contains initial user, appended assistant, appended user guidance
- all_msgs = ctx.thread_manager.get_all_messages()
- # Extract role and content for clarity
- trio = [(m.get("role"), m.get("content")) for m in all_msgs]
- # Expect at least 3 messages; find the last three
- assert ("user", "What is openai?") in trio
- assert ("assistant", "BAD OUTPUT") in trio
- assert ("system", "ERROR: fix format") in trio
- # The guidance system message should be classified as an output guardrail error
- sys_msgs = [m for m in all_msgs if m.get("role") == "system"]
- assert sys_msgs and sys_msgs[-1].get("message_origin") == "output_guardrail_error"
- @pytest.mark.asyncio
- @patch("agency_swarm.agent.execution_helpers.Runner.run", new_callable=AsyncMock)
- async def test_output_guardrail_retries_without_run_data(mock_runner_run):
- agent = Agent(name="RetryAgentNoRunData", instructions="Test", validation_attempts=1)
- ctx = AgencyContext(agency_instance=None, thread_manager=ThreadManager(), subagents={})
- mock_runner_run.side_effect = [
- _make_tripwire(agent_output="MALFORMED", guidance="Provide JSON", include_run_data=False),
- MagicMock(new_items=[], final_output="RECOVERED"),
- ]
- result = await agent.get_response(message="Fix this", agency_context=ctx)
- assert getattr(result, "final_output", None) == "RECOVERED"
- history = ctx.thread_manager.get_all_messages()
- contents = [(m.get("role"), m.get("content")) for m in history]
- assert ("assistant", "MALFORMED") in contents
- assert ("system", "Provide JSON") in contents
- class _DummyStream:
- def __init__(self, events):
- self._events = events
- async def stream_events(self):
- for ev in self._events:
- yield ev
- def cancel(self):
- pass
- class _SimpleEvent:
- def __init__(self, t: str):
- self.type = t
- @pytest.mark.asyncio
- @patch("agency_swarm.agent.execution_helpers.Runner.run_streamed")
- async def test_output_guardrail_retries_streaming(mock_run_streamed):
- agent = Agent(name="RetryStreamAgent", instructions="Test", validation_attempts=1)
- ctx = AgencyContext(agency_instance=None, thread_manager=ThreadManager(), subagents={})
- # First call raises; second returns a dummy stream with one event
- mock_run_streamed.side_effect = [
- _make_tripwire(agent_output="STREAM BAD", guidance="ERROR: needs header"),
- _DummyStream([_SimpleEvent("run_item_stream_event")]),
- ]
- # Collect streamed events
- received = []
- async for ev in agent.get_response_stream(message="Hello", agency_context=ctx):
- received.append(ev)
- assert received, "expected events from second attempt"
- # The guidance user message should be in history
- msgs = ctx.thread_manager.get_all_messages()
- roles_contents = [(m.get("role"), m.get("content")) for m in msgs]
- assert ("system", "ERROR: needs header") in roles_contents
|