test_conversation_starters_cache.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. import asyncio
  2. import time
  3. from pathlib import Path
  4. import pytest
  5. from agents.exceptions import AgentsException
  6. from agency_swarm import Agency, Agent, function_tool
  7. from agency_swarm.agent.conversation_starters_cache import extract_final_output_text, load_cached_starter
  8. from tests.deterministic_model import DeterministicModel
  9. async def _wait_for_cache_files(cache_dir: Path, expected: int) -> list[Path]:
  10. deadline = time.monotonic() + 30.0
  11. while time.monotonic() < deadline:
  12. cache_files = list(cache_dir.glob("*.json"))
  13. if len(cache_files) >= expected:
  14. return cache_files
  15. await asyncio.sleep(0.1)
  16. return list(cache_dir.glob("*.json"))
  17. @pytest.mark.asyncio
  18. async def test_conversation_starter_cache_reuse_without_llm(tmp_path, monkeypatch):
  19. monkeypatch.setenv("AGENCY_SWARM_CHATS_DIR", str(tmp_path))
  20. starter = "What is the weather in London?"
  21. model = DeterministicModel(default_response="Cached starter answer.")
  22. agent = Agent(
  23. name="StarterAgent",
  24. instructions="You are helpful.",
  25. model=model,
  26. conversation_starters=[starter],
  27. cache_conversation_starters=True,
  28. )
  29. agency = Agency(agent)
  30. assert agency.thread_manager.get_all_messages() == []
  31. cache_dir = Path(tmp_path) / "starter_cache"
  32. cache_files = await _wait_for_cache_files(cache_dir, 1)
  33. assert len(cache_files) == 1
  34. cached = load_cached_starter(
  35. agent.name,
  36. starter,
  37. expected_fingerprint=agent._conversation_starters_fingerprint,
  38. )
  39. assert cached is not None
  40. expected_text = extract_final_output_text(cached.items)
  41. assert expected_text
  42. monkeypatch.setenv("OPENAI_API_KEY", "sk-invalid")
  43. result = await agency.get_response(starter)
  44. assert result.final_output == expected_text
  45. assert len(agency.thread_manager.get_all_messages()) >= 2
  46. model._default_response = "Live follow-up answer."
  47. follow_up = await agency.get_response(starter)
  48. assert follow_up.final_output == "Live follow-up answer."
  49. agent_cached = Agent(
  50. name="StarterAgent",
  51. instructions="You are helpful.",
  52. model="gpt-5.4-mini",
  53. conversation_starters=[starter],
  54. cache_conversation_starters=True,
  55. )
  56. Agency(agent_cached)
  57. @pytest.mark.asyncio
  58. async def test_quick_reply_cache_reuse_without_model_call(tmp_path, monkeypatch):
  59. monkeypatch.setenv("AGENCY_SWARM_CHATS_DIR", str(tmp_path))
  60. quick_reply = "hi"
  61. model = DeterministicModel(default_response="Hello there.")
  62. agent = Agent(
  63. name="QuickReplyAgent",
  64. instructions="You are helpful.",
  65. model=model,
  66. quick_replies=[quick_reply],
  67. cache_conversation_starters=True,
  68. )
  69. agency = Agency(agent)
  70. cache_dir = Path(tmp_path) / "starter_cache"
  71. cache_files = await _wait_for_cache_files(cache_dir, 1)
  72. assert len(cache_files) == 1
  73. cached = load_cached_starter(
  74. agent.name,
  75. quick_reply,
  76. expected_fingerprint=agent._conversation_starters_fingerprint,
  77. )
  78. assert cached is not None
  79. expected_text = extract_final_output_text(cached.items)
  80. assert expected_text
  81. async def _fail_get_response(*_args, **_kwargs):
  82. raise RuntimeError("model should not be called for cached quick reply")
  83. monkeypatch.setattr(model, "get_response", _fail_get_response)
  84. result = await agency.get_response(quick_reply)
  85. assert result.final_output == expected_text
  86. with pytest.raises(AgentsException, match="Runner execution failed for agent QuickReplyAgent"):
  87. await agency.get_response(quick_reply)
  88. @function_tool
  89. def get_weather(location: str) -> str:
  90. return f"The weather in {location} is sunny, 22°C with light winds."
  91. @pytest.mark.asyncio
  92. async def test_conversation_starter_cache_reuse_stream_without_llm(tmp_path, monkeypatch):
  93. monkeypatch.setenv("AGENCY_SWARM_CHATS_DIR", str(tmp_path))
  94. starter = "What is the weather in London?"
  95. model = DeterministicModel(default_response="Cached starter answer.")
  96. agent = Agent(
  97. name="StarterAgent",
  98. instructions="You are helpful.",
  99. model=model,
  100. conversation_starters=[starter],
  101. cache_conversation_starters=True,
  102. )
  103. agency = Agency(agent)
  104. cache_dir = Path(tmp_path) / "starter_cache"
  105. cache_files = await _wait_for_cache_files(cache_dir, 1)
  106. assert len(cache_files) == 1
  107. cached = load_cached_starter(
  108. agent.name,
  109. starter,
  110. expected_fingerprint=agent._conversation_starters_fingerprint,
  111. )
  112. assert cached is not None
  113. expected_text = extract_final_output_text(cached.items)
  114. assert expected_text
  115. monkeypatch.setenv("OPENAI_API_KEY", "sk-invalid")
  116. stream = agency.get_response_stream(starter)
  117. async for _event in stream:
  118. pass
  119. final_result = stream.final_result
  120. assert final_result is not None
  121. assert final_result.final_output == expected_text
  122. @pytest.mark.asyncio
  123. async def test_conversation_starter_cache_skips_with_context_override(tmp_path, monkeypatch):
  124. monkeypatch.setenv("AGENCY_SWARM_CHATS_DIR", str(tmp_path))
  125. starter = "What is the weather in London?"
  126. model = DeterministicModel(default_response="Cached starter answer.")
  127. agent = Agent(
  128. name="StarterAgent",
  129. instructions="You are helpful.",
  130. model=model,
  131. conversation_starters=[starter],
  132. cache_conversation_starters=True,
  133. )
  134. agency = Agency(agent)
  135. cache_dir = Path(tmp_path) / "starter_cache"
  136. cache_files = await _wait_for_cache_files(cache_dir, 1)
  137. assert len(cache_files) == 1
  138. cached = load_cached_starter(
  139. agent.name,
  140. starter,
  141. expected_fingerprint=agent._conversation_starters_fingerprint,
  142. )
  143. assert cached is not None
  144. expected_text = extract_final_output_text(cached.items)
  145. assert expected_text == "Cached starter answer."
  146. model._default_response = "Context override answer."
  147. result = await agency.get_response(starter, context_override={"user_id": "abc"})
  148. assert result.final_output == "Context override answer."
  149. @pytest.mark.asyncio
  150. async def test_conversation_starter_cache_skips_stream_with_context_override(tmp_path, monkeypatch):
  151. monkeypatch.setenv("AGENCY_SWARM_CHATS_DIR", str(tmp_path))
  152. starter = "What is the weather in London?"
  153. model = DeterministicModel(default_response="Cached starter answer.")
  154. agent = Agent(
  155. name="StarterAgent",
  156. instructions="You are helpful.",
  157. model=model,
  158. conversation_starters=[starter],
  159. cache_conversation_starters=True,
  160. )
  161. agency = Agency(agent)
  162. cache_dir = Path(tmp_path) / "starter_cache"
  163. cache_files = await _wait_for_cache_files(cache_dir, 1)
  164. assert len(cache_files) == 1
  165. cached = load_cached_starter(
  166. agent.name,
  167. starter,
  168. expected_fingerprint=agent._conversation_starters_fingerprint,
  169. )
  170. assert cached is not None
  171. expected_text = extract_final_output_text(cached.items)
  172. assert expected_text == "Cached starter answer."
  173. model._default_response = "Context override answer."
  174. stream = agency.get_response_stream(starter, context_override={"user_id": "abc"})
  175. async for _event in stream:
  176. pass
  177. assert stream.final_output == "Context override answer."
  178. @pytest.mark.asyncio
  179. async def test_conversation_starter_cache_skips_on_shared_instructions_change(tmp_path, monkeypatch):
  180. monkeypatch.setenv("AGENCY_SWARM_CHATS_DIR", str(tmp_path))
  181. starter = "What is the weather in London?"
  182. agent = Agent(
  183. name="StarterAgent",
  184. instructions="You are helpful.",
  185. model="gpt-5.4-mini",
  186. conversation_starters=[starter],
  187. cache_conversation_starters=True,
  188. )
  189. agency = Agency(agent, shared_instructions="Respond with ALPHA.")
  190. cache_dir = Path(tmp_path) / "starter_cache"
  191. cache_files = await _wait_for_cache_files(cache_dir, 1)
  192. assert len(cache_files) == 1
  193. monkeypatch.setenv("OPENAI_API_KEY", "sk-invalid")
  194. agency.shared_instructions = "Respond with BRAVO."
  195. with pytest.raises(AgentsException):
  196. await agency.get_response(starter)
  197. @pytest.mark.asyncio
  198. async def test_conversation_starter_cache_populates_for_agency_tools(tmp_path, monkeypatch):
  199. monkeypatch.setenv("AGENCY_SWARM_CHATS_DIR", str(tmp_path))
  200. starters = ["What is the weather in London?"]
  201. ceo = Agent(
  202. name="CEO",
  203. instructions="Always use send_message to ask the Worker for weather.",
  204. model="gpt-5.4-mini",
  205. conversation_starters=starters,
  206. cache_conversation_starters=True,
  207. )
  208. worker = Agent(
  209. name="Worker",
  210. instructions="Provide weather using get_weather.",
  211. tools=[get_weather],
  212. model="gpt-5.4-mini",
  213. )
  214. Agency(ceo, communication_flows=[(ceo > worker)], name="TerminalDemoAgency")
  215. cache_dir = Path(tmp_path) / "starter_cache"
  216. cache_files = sorted(await _wait_for_cache_files(cache_dir, len(starters)))
  217. assert len(cache_files) == len(starters)
  218. cached = load_cached_starter(
  219. ceo.name,
  220. starters[0],
  221. expected_fingerprint=ceo._conversation_starters_fingerprint,
  222. )
  223. assert cached is not None
  224. items = cached.items
  225. tool_call_index = next(
  226. idx
  227. for idx, item in enumerate(items)
  228. if isinstance(item, dict) and item.get("type") == "function_call" and item.get("agent") == ceo.name
  229. )
  230. worker_message_index = next(
  231. idx
  232. for idx, item in enumerate(items)
  233. if isinstance(item, dict) and item.get("type") == "message" and item.get("agent") == worker.name
  234. )
  235. assert tool_call_index < worker_message_index