| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606 |
- """
- Deterministic streaming order test with two agents and custom tools.
- """
- import logging
- import os
- from typing import Any
- import pytest
- from agents import ModelSettings, function_tool
- from agents.models.fake_id import FAKE_RESPONSES_ID
- from openai.types.shared import Reasoning
- from agency_swarm import Agency, Agent
- logger = logging.getLogger(__name__)
- def _assert_sanitized_history(messages: list[dict[str, Any]]) -> None:
- """Validate persisted conversation order matches sanitized tool semantics."""
- seen_ids: set[str] = set()
- for index, message in enumerate(messages):
- msg_type = message.get("type")
- msg_id = message.get("id")
- if isinstance(msg_id, str) and msg_id and msg_id != FAKE_RESPONSES_ID:
- assert msg_id not in seen_ids, f"Duplicate message id detected: {msg_id}"
- seen_ids.add(msg_id)
- if msg_type != "function_call":
- continue
- call_id = message.get("call_id")
- assert isinstance(call_id, str) and call_id, f"Missing call_id for function_call at index {index}"
- output_index = None
- for candidate in range(index + 1, len(messages)):
- if (
- messages[candidate].get("type") == "function_call_output"
- and messages[candidate].get("call_id") == call_id
- ):
- output_index = candidate
- break
- assert output_index is not None, f"No function_call_output found for call_id {call_id}"
- between = messages[index + 1 : output_index]
- assistants = [m for m in between if m.get("role") == "assistant"]
- assert not assistants, (
- f"Intermediate assistant message found between function_call and output for call_id {call_id}: {assistants}"
- )
- def _strip_optional_initial_message_output(
- flow: list[tuple[str, str, str | None]],
- agent_name: str,
- ) -> list[tuple[str, str, str | None]]:
- """Allow optional initial agent message_output_item after first tool_call."""
- if len(flow) >= 2 and flow[1] == ("message_output_item", agent_name, None):
- return [flow[0], *flow[2:]]
- return flow
- def _normalize_optional_agent_message_outputs(
- flow: list[tuple[str, str, str | None]],
- agent_name: str,
- ) -> list[tuple[str, str, str | None]]:
- """Allow optional top-level assistant message items while preserving tool order."""
- normalized = _strip_optional_initial_message_output(flow, agent_name)
- final_message = ("message_output_item", agent_name, None)
- if len(normalized) >= 2 and normalized[-2:] == [final_message, final_message]:
- return normalized[:-1]
- return normalized
- # Additional tools for complex scenarios
- @function_tool
- def process_data(data: str) -> str:
- return f"PROCESSED:{data}"
- @function_tool
- def validate_result(result: str) -> str:
- return f"VALID:{result}"
- @function_tool
- def combine_results(results: str) -> str:
- return f"COMBINED:{results}"
- # Hardcoded expected flow (normalized stream type, agent, tool_name)
- #
- # Starting with openai-agents 0.2.10, tool calls are emitted as soon as the
- # model finalizes the tool call item (via ResponseOutputItemDoneEvent), so the
- # semantic `tool_call_item` arrives before the agent's own message output.
- # Preserve the deterministic order we now observe so that the tests confirm the
- # integration keeps step with SDK streaming semantics.
- EXPECTED_FLOW_DEFAULT: list[tuple[str, str, str | None]] = [
- ("tool_call_item", "MainAgent", "get_market_data"),
- ("tool_call_output_item", "MainAgent", None),
- ("tool_call_item", "MainAgent", "send_message"),
- ("tool_call_item", "SubAgent", "analyze_risk"),
- ("tool_call_output_item", "SubAgent", None),
- ("message_output_item", "SubAgent", None),
- ("tool_call_output_item", "MainAgent", None),
- ("message_output_item", "MainAgent", None),
- ]
- ANTHROPIC_MODEL_NAME = "anthropic/claude-sonnet-4-20250514"
- EXPECTED_FLOW_ANTHROPIC: list[tuple[str, str, str | None]] = [
- ("tool_call_item", "MainAgent", "get_market_data"),
- ("message_output_item", "MainAgent", None),
- ("tool_call_output_item", "MainAgent", None),
- ("tool_call_item", "MainAgent", "send_message"),
- ("tool_call_item", "SubAgent", "analyze_risk"),
- ("message_output_item", "SubAgent", None),
- ("tool_call_output_item", "SubAgent", None),
- ("message_output_item", "SubAgent", None),
- ("message_output_item", "MainAgent", None),
- ("tool_call_output_item", "MainAgent", None),
- ("message_output_item", "MainAgent", None),
- ]
- @function_tool
- def get_market_data(symbol: str) -> str:
- return "AAPL:PRICE=150"
- @function_tool
- def analyze_risk(data: str) -> str:
- return "RISK=LOW"
- @pytest.mark.asyncio
- @pytest.mark.parametrize(
- ("use_anthropic", "expected_flow"),
- [
- (False, EXPECTED_FLOW_DEFAULT),
- pytest.param(
- True,
- EXPECTED_FLOW_ANTHROPIC,
- marks=pytest.mark.skipif(
- not os.getenv("ANTHROPIC_API_KEY"),
- reason="ANTHROPIC_API_KEY required for Anthropic test",
- ),
- ),
- ],
- )
- async def test_full_streaming_flow_hardcoded_sequence(
- use_anthropic: bool, expected_flow: list[tuple[str, str, str | None]]
- ) -> None:
- """Proves canonical streaming order for Main→Sub agent with tool calls is deterministic."""
- if use_anthropic:
- pytest.importorskip("litellm", reason="litellm package is required for Anthropic test")
- import litellm
- from agents.extensions.models.litellm_model import LitellmModel
- litellm.modify_params = True
- main_model = LitellmModel(model=ANTHROPIC_MODEL_NAME)
- helper_model = LitellmModel(model=ANTHROPIC_MODEL_NAME)
- main_model_settings = None
- helper_model_settings = None
- main_instructions = (
- "First send a standalone 'ACK' message before any tool calls. "
- "Then call get_market_data('AAPL'). "
- "Then use the send_message tool to ask SubAgent to analyze the data and reply. "
- "Finally, respond to the user with a brief conclusion."
- )
- user_message = "Start."
- else:
- main_model = "gpt-5.4-mini"
- helper_model = "gpt-5.4-mini"
- main_model_settings = ModelSettings(
- reasoning=Reasoning(effort="low"),
- tool_choice="get_market_data",
- parallel_tool_calls=False,
- )
- helper_model_settings = ModelSettings(
- reasoning=Reasoning(effort="low"),
- tool_choice="analyze_risk",
- parallel_tool_calls=False,
- )
- main_instructions = (
- "Complete the workflow in this exact order. "
- "First call get_market_data with symbol 'AAPL'. Do not send assistant text before this tool call. "
- "After get_market_data returns, call send_message to ask SubAgent to analyze the returned market data. "
- "Do not send assistant text between get_market_data and send_message. "
- "After SubAgent replies, send one brief final conclusion to the user."
- )
- user_message = (
- "Run the streaming-order proof now: call get_market_data for AAPL, then send_message to SubAgent "
- "with the returned market data, then provide the final conclusion after SubAgent responds."
- )
- main = Agent(
- name="MainAgent",
- description="Coordinator",
- instructions=main_instructions,
- model=main_model,
- model_settings=main_model_settings,
- tools=[get_market_data],
- )
- helper = Agent(
- name="SubAgent",
- description="Risk analyzer",
- instructions=(
- "When prompted by MainAgent, first call analyze_risk on the provided data. "
- "After analyze_risk returns, reply succinctly."
- ),
- model=helper_model,
- model_settings=helper_model_settings,
- tools=[analyze_risk],
- )
- agency = Agency(
- main,
- communication_flows=[main > helper],
- shared_instructions="",
- )
- before = len(agency.thread_manager.get_all_messages())
- # Collect stream as (type, agent, tool_name)
- stream_items: list[tuple[str, str, str | None]] = []
- async for event in agency.get_response_stream(message=user_message):
- if hasattr(event, "item") and event.item is not None:
- item = event.item
- evt_type = getattr(item, "type", None)
- if evt_type == "reasoning_item":
- continue
- agent_name = getattr(event, "agent", None)
- tool_name = None
- if evt_type == "tool_call_item":
- raw = getattr(item, "raw_item", None)
- tool_name = getattr(raw, "name", None)
- if isinstance(evt_type, str) and isinstance(agent_name, str):
- stream_items.append((evt_type, agent_name, tool_name))
- all_messages = agency.thread_manager.get_all_messages()
- new_messages = all_messages[before:]
- # Map saved messages to same triple format
- comparable: list[dict[str, Any]] = []
- for m in new_messages:
- t = m.get("type")
- role = m.get("role")
- if t in {"function_call", "function_call_output"} or role == "assistant":
- comparable.append(m)
- expected_without_main_message = _strip_optional_initial_message_output(expected_flow, "MainAgent")
- assert stream_items in (expected_flow, expected_without_main_message), (
- "Stream flow mismatch:\n"
- f" got={stream_items}\n"
- f" exp={expected_flow}\n"
- f" exp_without_initial_message={expected_without_main_message}"
- )
- _assert_sanitized_history(comparable)
- _assert_tool_call_recorded(new_messages, "MainAgent", "get_market_data", context="default streaming workflow")
- _assert_tool_call_recorded(new_messages, "MainAgent", "send_message", context="default streaming workflow")
- _assert_tool_call_recorded(new_messages, "SubAgent", "analyze_risk", context="default streaming workflow")
- # Expected flow for multiple sequential sub-agent calls
- EXPECTED_FLOW_MULTIPLE_CALLS: list[tuple[str, str, str | None]] = [
- # Agent calls tool immediately without ACK message
- ("tool_call_item", "Coordinator", "get_market_data"), # First data fetch
- ("tool_call_output_item", "Coordinator", None),
- # First sub-agent call - SDK emits send_message immediately
- ("tool_call_item", "Coordinator", "send_message"), # SDK emits send_message immediately
- ("tool_call_item", "Worker", "process_data"), # Worker processes
- ("tool_call_output_item", "Worker", None),
- ("message_output_item", "Worker", None), # Worker responds
- ("tool_call_output_item", "Coordinator", None), # send_message completes
- # Second sub-agent call - SDK emits send_message immediately
- ("tool_call_item", "Coordinator", "send_message"), # SDK emits send_message immediately
- ("tool_call_item", "Worker", "validate_result"), # Worker validates
- ("tool_call_output_item", "Worker", None),
- ("message_output_item", "Worker", None), # Worker responds again
- ("tool_call_output_item", "Coordinator", None), # send_message completes
- ("message_output_item", "Coordinator", None), # Final response
- ]
- @pytest.mark.asyncio
- async def test_multiple_sequential_subagent_calls() -> None:
- """Proves repeated send_message to same sub-agent streams in strict canonical order."""
- coordinator = Agent(
- name="Coordinator",
- description="Main coordinator",
- instructions=(
- "First say 'ACK'. Then call get_market_data('TEST'). "
- "Then use send_message to ask Worker to process the data. "
- "After Worker responds, use send_message again to ask Worker to validate the result. "
- "Finally, respond with 'DONE'."
- ),
- model_settings=ModelSettings(temperature=0.0),
- tools=[get_market_data],
- )
- worker = Agent(
- name="Worker",
- description="Data processor",
- instructions=(
- "When asked to process: use process_data tool and respond 'Processed'. "
- "When asked to validate: use validate_result tool and respond 'Validated'."
- ),
- model_settings=ModelSettings(temperature=0.0),
- tools=[process_data, validate_result],
- )
- agency = Agency(
- coordinator,
- communication_flows=[coordinator > worker],
- shared_instructions="",
- )
- before = len(agency.thread_manager.get_all_messages())
- # Collect stream events
- stream_items: list[tuple[str, str, str | None]] = []
- async for event in agency.get_response_stream(message="Execute multiple tasks."):
- if hasattr(event, "item") and event.item is not None:
- item = event.item
- evt_type = getattr(item, "type", None)
- if evt_type == "reasoning_item":
- continue
- agent_name = getattr(event, "agent", None)
- tool_name = None
- if evt_type == "tool_call_item":
- raw = getattr(item, "raw_item", None)
- tool_name = getattr(raw, "name", None)
- if isinstance(evt_type, str) and isinstance(agent_name, str):
- stream_items.append((evt_type, agent_name, tool_name))
- # Verify stream matches expected (allow optional initial message_output from reasoning models)
- normalized = _normalize_optional_agent_message_outputs(stream_items, "Coordinator")
- assert normalized == EXPECTED_FLOW_MULTIPLE_CALLS, (
- f"Multiple calls stream mismatch:\n got={stream_items}\n exp={EXPECTED_FLOW_MULTIPLE_CALLS}"
- )
- # Verify saved messages
- all_messages = agency.thread_manager.get_all_messages()
- new_messages = all_messages[before:]
- comparable: list[dict[str, Any]] = []
- for m in new_messages:
- t = m.get("type")
- role = m.get("role")
- if t in {"function_call", "function_call_output"} or role == "assistant":
- comparable.append(m)
- _assert_sanitized_history(comparable)
- # Expected flow for nested delegation (A->B->C) based on actual execution
- EXPECTED_FLOW_NESTED: list[tuple[str, str, str | None]] = [
- ("message_output_item", "AgentA", None),
- ("tool_call_item", "AgentA", "send_message"), # A delegates to B
- ("tool_call_item", "AgentB", "send_message"), # B delegates to C
- ("tool_call_item", "AgentB", "analyze_risk"), # C's tool call attributed via B stream
- ("tool_call_output_item", "AgentB", None),
- ("message_output_item", "AgentB", None),
- ("tool_call_output_item", "AgentB", None),
- ("tool_call_item", "AgentB", "process_data"), # B processes
- ("tool_call_output_item", "AgentB", None),
- ("message_output_item", "AgentB", None),
- ("tool_call_output_item", "AgentA", None),
- ("message_output_item", "AgentA", None), # Final response
- ]
- @pytest.mark.asyncio
- async def test_nested_delegation_streaming() -> None:
- """Proves nested A→B→C delegation appears in stream and AgentA completes after sub-chain."""
- agent_a = Agent(
- name="AgentA",
- description="Top-level coordinator",
- instructions=(
- "First say 'ACK'. "
- "Then use send_message to ask AgentB to process and analyze data. "
- "Finally respond with 'Complete'."
- ),
- model="gpt-5.4-mini",
- tools=[],
- )
- agent_b = Agent(
- name="AgentB",
- description="Middle processor",
- instructions=(
- "When asked by AgentA: "
- "First use send_message to ask AgentC to analyze risk. "
- "Then use process_data tool with the response. "
- "Finally respond 'Processed'."
- ),
- model="gpt-5.4-mini",
- model_settings=ModelSettings(tool_choice="required"),
- tools=[process_data],
- )
- agent_c = Agent(
- name="AgentC",
- description="Risk analyzer",
- instructions="When asked: use analyze_risk tool and respond 'Risk analyzed'.",
- model="gpt-5.4-mini",
- model_settings=ModelSettings(tool_choice="required"),
- tools=[analyze_risk],
- )
- agency = Agency(
- agent_a,
- communication_flows=[agent_a > agent_b, agent_b > agent_c],
- shared_instructions="",
- )
- before = len(agency.thread_manager.get_all_messages())
- # Collect stream events
- stream_items: list[tuple[str, str, str | None]] = []
- async for event in agency.get_response_stream(message="Start nested delegation."):
- if hasattr(event, "item") and event.item is not None:
- item = event.item
- evt_type = getattr(item, "type", None)
- if evt_type == "reasoning_item":
- continue
- agent_name = getattr(event, "agent", None)
- tool_name = None
- if evt_type == "tool_call_item":
- raw = getattr(item, "raw_item", None)
- tool_name = getattr(raw, "name", None)
- if isinstance(evt_type, str) and isinstance(agent_name, str):
- stream_items.append((evt_type, agent_name, tool_name))
- # Verify stream contains the required sequence in order and AgentC performs analyze_risk
- required_seq = [
- ("tool_call_item", "AgentA", "send_message"),
- ("tool_call_item", "AgentB", "send_message"),
- ("tool_call_item", "AgentC", "analyze_risk"),
- ("tool_call_output_item", "AgentA", None),
- ("message_output_item", "AgentA", None),
- ]
- def is_subsequence(needles: list[tuple[str, str, str | None]], haystack: list[tuple[str, str, str | None]]) -> bool:
- i = 0
- for item in haystack:
- if i < len(needles) and item == needles[i]:
- i += 1
- return i == len(needles)
- assert is_subsequence(required_seq, stream_items), (
- f"Nested delegation stream mismatch (required subsequence not found):\n got={stream_items}\n req={required_seq}"
- )
- # Verify saved messages
- all_messages = agency.thread_manager.get_all_messages()
- new_messages = all_messages[before:]
- comparable: list[dict[str, Any]] = []
- for m in new_messages:
- t = m.get("type")
- role = m.get("role")
- if t in {"function_call", "function_call_output"} or role == "assistant":
- comparable.append(m)
- _assert_sanitized_history(comparable)
- # Verify stream contains the required sequence in order (for saved messages verification)
- required_seq = [
- ("tool_call_item", "AgentA", "send_message"),
- ("tool_call_item", "AgentB", "send_message"),
- ("tool_call_output_item", "AgentA", None),
- ("message_output_item", "AgentA", None),
- ]
- assert is_subsequence(required_seq, stream_items), (
- f"Nested delegation stream mismatch (required subsequence not found):\n got={stream_items}\n req={required_seq}"
- )
- # Helper to confirm specific tool calls were persisted for an agent
- def _assert_tool_call_recorded(
- messages: list[dict[str, Any]], agent_name: str, tool_name: str, *, context: str
- ) -> None:
- for message in messages:
- if message.get("type") != "function_call":
- continue
- if message.get("name") != tool_name:
- continue
- recorded_agent = message.get("agent") or message.get("callerAgent")
- if str(recorded_agent) == agent_name:
- return
- raise AssertionError(f"Expected {context}: agent '{agent_name}' did not record function_call '{tool_name}'")
- # Expected flow for parallel sub-agent calls (to different agents)
- # NOTE: No ACK message expected - we don't instruct the agent to emit one,
- # keeping the expected flow strict and deterministic.
- EXPECTED_FLOW_PARALLEL: list[tuple[str, str, str | None]] = [
- ("tool_call_item", "Orchestrator", "get_market_data"), # Get initial data arrives first via tool_called
- ("tool_call_output_item", "Orchestrator", None),
- ("tool_call_item", "Orchestrator", "send_message"),
- ("tool_call_item", "ProcessorA", "process_data"), # ProcessorA works
- ("tool_call_output_item", "ProcessorA", None),
- ("message_output_item", "ProcessorA", None),
- ("tool_call_output_item", "Orchestrator", None),
- ("tool_call_item", "Orchestrator", "send_message"),
- ("tool_call_item", "ProcessorB", "validate_result"), # ProcessorB works
- ("tool_call_output_item", "ProcessorB", None),
- ("message_output_item", "ProcessorB", None),
- ("tool_call_output_item", "Orchestrator", None),
- ("tool_call_item", "Orchestrator", "combine_results"),
- ("tool_call_output_item", "Orchestrator", None),
- ("message_output_item", "Orchestrator", None), # Final response
- ]
- @pytest.mark.asyncio
- async def test_parallel_subagent_calls() -> None:
- """Proves orchestrator issues two sub-agent calls and completion follows canonical order."""
- orchestrator = Agent(
- name="Orchestrator",
- description="Main orchestrator",
- instructions=(
- "Call get_market_data('DATA'). "
- "Then use send_message to ask ProcessorA to process the data. "
- "After ProcessorA responds, use send_message to ask ProcessorB to validate. "
- "Finally, use combine_results tool and respond 'All done'."
- ),
- model_settings=ModelSettings(temperature=0.0),
- tools=[get_market_data, combine_results],
- )
- processor_a = Agent(
- name="ProcessorA",
- description="Data processor",
- instructions="When asked: use process_data tool and respond 'ProcessorA complete'.",
- model_settings=ModelSettings(temperature=0.0, tool_choice="required"),
- tools=[process_data],
- )
- processor_b = Agent(
- name="ProcessorB",
- description="Result validator",
- instructions="When asked: use validate_result tool and respond 'ProcessorB complete'.",
- model_settings=ModelSettings(temperature=0.0, tool_choice="required"),
- tools=[validate_result],
- )
- agency = Agency(
- orchestrator,
- communication_flows=[orchestrator > processor_a, orchestrator > processor_b],
- shared_instructions="",
- )
- before = len(agency.thread_manager.get_all_messages())
- # Collect stream events
- stream_items: list[tuple[str, str, str | None]] = []
- async for event in agency.get_response_stream(message="Coordinate parallel work."):
- if hasattr(event, "item") and event.item is not None:
- item = event.item
- evt_type = getattr(item, "type", None)
- if evt_type == "reasoning_item":
- continue
- agent_name = getattr(event, "agent", None)
- tool_name = None
- if evt_type == "tool_call_item":
- raw = getattr(item, "raw_item", None)
- tool_name = getattr(raw, "name", None)
- if isinstance(evt_type, str) and isinstance(agent_name, str):
- stream_items.append((evt_type, agent_name, tool_name))
- # Verify stream matches expected (allow optional initial message_output from reasoning models)
- normalized = _normalize_optional_agent_message_outputs(stream_items, "Orchestrator")
- if normalized != EXPECTED_FLOW_PARALLEL:
- logger.error(
- "Parallel sub-agent stream mismatch",
- extra={
- "got": stream_items,
- "expected": EXPECTED_FLOW_PARALLEL,
- },
- )
- assert normalized == EXPECTED_FLOW_PARALLEL, (
- f"Parallel calls stream mismatch:\n got={stream_items}\n exp={EXPECTED_FLOW_PARALLEL}"
- )
- # Verify saved messages
- all_messages = agency.thread_manager.get_all_messages()
- new_messages = all_messages[before:]
- comparable: list[dict[str, Any]] = []
- for m in new_messages:
- t = m.get("type")
- role = m.get("role")
- if t in {"function_call", "function_call_output"} or role == "assistant":
- comparable.append(m)
- _assert_tool_call_recorded(new_messages, "ProcessorA", "process_data", context="parallel workflow")
- _assert_tool_call_recorded(new_messages, "ProcessorB", "validate_result", context="parallel workflow")
- _assert_sanitized_history(comparable)
|