| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286 |
- import asyncio
- import time
- from pathlib import Path
- import pytest
- from agents.exceptions import AgentsException
- from agency_swarm import Agency, Agent, function_tool
- from agency_swarm.agent.conversation_starters_cache import extract_final_output_text, load_cached_starter
- from tests.deterministic_model import DeterministicModel
- async def _wait_for_cache_files(cache_dir: Path, expected: int) -> list[Path]:
- deadline = time.monotonic() + 30.0
- while time.monotonic() < deadline:
- cache_files = list(cache_dir.glob("*.json"))
- if len(cache_files) >= expected:
- return cache_files
- await asyncio.sleep(0.1)
- return list(cache_dir.glob("*.json"))
- @pytest.mark.asyncio
- async def test_conversation_starter_cache_reuse_without_llm(tmp_path, monkeypatch):
- monkeypatch.setenv("AGENCY_SWARM_CHATS_DIR", str(tmp_path))
- starter = "What is the weather in London?"
- model = DeterministicModel(default_response="Cached starter answer.")
- agent = Agent(
- name="StarterAgent",
- instructions="You are helpful.",
- model=model,
- conversation_starters=[starter],
- cache_conversation_starters=True,
- )
- agency = Agency(agent)
- assert agency.thread_manager.get_all_messages() == []
- cache_dir = Path(tmp_path) / "starter_cache"
- cache_files = await _wait_for_cache_files(cache_dir, 1)
- assert len(cache_files) == 1
- cached = load_cached_starter(
- agent.name,
- starter,
- expected_fingerprint=agent._conversation_starters_fingerprint,
- )
- assert cached is not None
- expected_text = extract_final_output_text(cached.items)
- assert expected_text
- monkeypatch.setenv("OPENAI_API_KEY", "sk-invalid")
- result = await agency.get_response(starter)
- assert result.final_output == expected_text
- assert len(agency.thread_manager.get_all_messages()) >= 2
- model._default_response = "Live follow-up answer."
- follow_up = await agency.get_response(starter)
- assert follow_up.final_output == "Live follow-up answer."
- agent_cached = Agent(
- name="StarterAgent",
- instructions="You are helpful.",
- model="gpt-5.4-mini",
- conversation_starters=[starter],
- cache_conversation_starters=True,
- )
- Agency(agent_cached)
- @pytest.mark.asyncio
- async def test_quick_reply_cache_reuse_without_model_call(tmp_path, monkeypatch):
- monkeypatch.setenv("AGENCY_SWARM_CHATS_DIR", str(tmp_path))
- quick_reply = "hi"
- model = DeterministicModel(default_response="Hello there.")
- agent = Agent(
- name="QuickReplyAgent",
- instructions="You are helpful.",
- model=model,
- quick_replies=[quick_reply],
- cache_conversation_starters=True,
- )
- agency = Agency(agent)
- cache_dir = Path(tmp_path) / "starter_cache"
- cache_files = await _wait_for_cache_files(cache_dir, 1)
- assert len(cache_files) == 1
- cached = load_cached_starter(
- agent.name,
- quick_reply,
- expected_fingerprint=agent._conversation_starters_fingerprint,
- )
- assert cached is not None
- expected_text = extract_final_output_text(cached.items)
- assert expected_text
- async def _fail_get_response(*_args, **_kwargs):
- raise RuntimeError("model should not be called for cached quick reply")
- monkeypatch.setattr(model, "get_response", _fail_get_response)
- result = await agency.get_response(quick_reply)
- assert result.final_output == expected_text
- with pytest.raises(AgentsException, match="Runner execution failed for agent QuickReplyAgent"):
- await agency.get_response(quick_reply)
- @function_tool
- def get_weather(location: str) -> str:
- return f"The weather in {location} is sunny, 22°C with light winds."
- @pytest.mark.asyncio
- async def test_conversation_starter_cache_reuse_stream_without_llm(tmp_path, monkeypatch):
- monkeypatch.setenv("AGENCY_SWARM_CHATS_DIR", str(tmp_path))
- starter = "What is the weather in London?"
- model = DeterministicModel(default_response="Cached starter answer.")
- agent = Agent(
- name="StarterAgent",
- instructions="You are helpful.",
- model=model,
- conversation_starters=[starter],
- cache_conversation_starters=True,
- )
- agency = Agency(agent)
- cache_dir = Path(tmp_path) / "starter_cache"
- cache_files = await _wait_for_cache_files(cache_dir, 1)
- assert len(cache_files) == 1
- cached = load_cached_starter(
- agent.name,
- starter,
- expected_fingerprint=agent._conversation_starters_fingerprint,
- )
- assert cached is not None
- expected_text = extract_final_output_text(cached.items)
- assert expected_text
- monkeypatch.setenv("OPENAI_API_KEY", "sk-invalid")
- stream = agency.get_response_stream(starter)
- async for _event in stream:
- pass
- final_result = stream.final_result
- assert final_result is not None
- assert final_result.final_output == expected_text
- @pytest.mark.asyncio
- async def test_conversation_starter_cache_skips_with_context_override(tmp_path, monkeypatch):
- monkeypatch.setenv("AGENCY_SWARM_CHATS_DIR", str(tmp_path))
- starter = "What is the weather in London?"
- model = DeterministicModel(default_response="Cached starter answer.")
- agent = Agent(
- name="StarterAgent",
- instructions="You are helpful.",
- model=model,
- conversation_starters=[starter],
- cache_conversation_starters=True,
- )
- agency = Agency(agent)
- cache_dir = Path(tmp_path) / "starter_cache"
- cache_files = await _wait_for_cache_files(cache_dir, 1)
- assert len(cache_files) == 1
- cached = load_cached_starter(
- agent.name,
- starter,
- expected_fingerprint=agent._conversation_starters_fingerprint,
- )
- assert cached is not None
- expected_text = extract_final_output_text(cached.items)
- assert expected_text == "Cached starter answer."
- model._default_response = "Context override answer."
- result = await agency.get_response(starter, context_override={"user_id": "abc"})
- assert result.final_output == "Context override answer."
- @pytest.mark.asyncio
- async def test_conversation_starter_cache_skips_stream_with_context_override(tmp_path, monkeypatch):
- monkeypatch.setenv("AGENCY_SWARM_CHATS_DIR", str(tmp_path))
- starter = "What is the weather in London?"
- model = DeterministicModel(default_response="Cached starter answer.")
- agent = Agent(
- name="StarterAgent",
- instructions="You are helpful.",
- model=model,
- conversation_starters=[starter],
- cache_conversation_starters=True,
- )
- agency = Agency(agent)
- cache_dir = Path(tmp_path) / "starter_cache"
- cache_files = await _wait_for_cache_files(cache_dir, 1)
- assert len(cache_files) == 1
- cached = load_cached_starter(
- agent.name,
- starter,
- expected_fingerprint=agent._conversation_starters_fingerprint,
- )
- assert cached is not None
- expected_text = extract_final_output_text(cached.items)
- assert expected_text == "Cached starter answer."
- model._default_response = "Context override answer."
- stream = agency.get_response_stream(starter, context_override={"user_id": "abc"})
- async for _event in stream:
- pass
- assert stream.final_output == "Context override answer."
- @pytest.mark.asyncio
- async def test_conversation_starter_cache_skips_on_shared_instructions_change(tmp_path, monkeypatch):
- monkeypatch.setenv("AGENCY_SWARM_CHATS_DIR", str(tmp_path))
- starter = "What is the weather in London?"
- agent = Agent(
- name="StarterAgent",
- instructions="You are helpful.",
- model="gpt-5.4-mini",
- conversation_starters=[starter],
- cache_conversation_starters=True,
- )
- agency = Agency(agent, shared_instructions="Respond with ALPHA.")
- cache_dir = Path(tmp_path) / "starter_cache"
- cache_files = await _wait_for_cache_files(cache_dir, 1)
- assert len(cache_files) == 1
- monkeypatch.setenv("OPENAI_API_KEY", "sk-invalid")
- agency.shared_instructions = "Respond with BRAVO."
- with pytest.raises(AgentsException):
- await agency.get_response(starter)
- @pytest.mark.asyncio
- async def test_conversation_starter_cache_populates_for_agency_tools(tmp_path, monkeypatch):
- monkeypatch.setenv("AGENCY_SWARM_CHATS_DIR", str(tmp_path))
- starters = ["What is the weather in London?"]
- ceo = Agent(
- name="CEO",
- instructions="Always use send_message to ask the Worker for weather.",
- model="gpt-5.4-mini",
- conversation_starters=starters,
- cache_conversation_starters=True,
- )
- worker = Agent(
- name="Worker",
- instructions="Provide weather using get_weather.",
- tools=[get_weather],
- model="gpt-5.4-mini",
- )
- Agency(ceo, communication_flows=[(ceo > worker)], name="TerminalDemoAgency")
- cache_dir = Path(tmp_path) / "starter_cache"
- cache_files = sorted(await _wait_for_cache_files(cache_dir, len(starters)))
- assert len(cache_files) == len(starters)
- cached = load_cached_starter(
- ceo.name,
- starters[0],
- expected_fingerprint=ceo._conversation_starters_fingerprint,
- )
- assert cached is not None
- items = cached.items
- tool_call_index = next(
- idx
- for idx, item in enumerate(items)
- if isinstance(item, dict) and item.get("type") == "function_call" and item.get("agent") == ceo.name
- )
- worker_message_index = next(
- idx
- for idx, item in enumerate(items)
- if isinstance(item, dict) and item.get("type") == "message" and item.get("agent") == worker.name
- )
- assert tool_call_index < worker_message_index
|