test_agency_responses_recipient_reminders.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. import pytest
  2. from agency_swarm import Agency
  3. from tests.test_agency_modules._response_test_helpers import CapturingAgent
  4. @pytest.mark.asyncio
  5. async def test_agency_get_response_adds_recipient_switch_reminder_after_handoff() -> None:
  6. """Adds recipient_reminder when previous user call used handoff and target agent changed."""
  7. agent_a = CapturingAgent("AgentA")
  8. agent_b = CapturingAgent("AgentB")
  9. agency = Agency(agent_a, agent_b)
  10. agency.thread_manager.add_messages(
  11. [
  12. {"role": "user", "content": "previous", "agent": "AgentB", "callerAgent": None},
  13. {
  14. "role": "system",
  15. "content": "Transfer completed. You are AgentB. Please continue the task.",
  16. "agent": "AgentB",
  17. "callerAgent": None,
  18. "message_origin": "handoff_reminder",
  19. },
  20. {"role": "assistant", "content": "done", "agent": "AgentB", "callerAgent": None},
  21. ]
  22. )
  23. await agency.get_response("new request", recipient_agent="AgentA")
  24. assert isinstance(agent_a.last_message, list)
  25. assert agent_a.last_message[0]["message_origin"] == "recipient_reminder"
  26. assert (
  27. agent_a.last_message[0]["content"]
  28. == 'User has switched recipient agent. You are "AgentA". Please continue the task.'
  29. )
  30. assert agent_a.last_message[1] == {"role": "user", "content": "new request"}
  31. @pytest.mark.asyncio
  32. async def test_agency_get_response_skips_recipient_switch_reminder_without_switch_or_handoff() -> None:
  33. """Does not add recipient_reminder unless both handoff-use and recipient-switch are true."""
  34. agent_a = CapturingAgent("AgentA")
  35. agent_b = CapturingAgent("AgentB")
  36. agency = Agency(agent_a, agent_b)
  37. agency.thread_manager.add_messages(
  38. [
  39. {"role": "user", "content": "previous", "agent": "AgentA", "callerAgent": None},
  40. {"role": "assistant", "content": "done", "agent": "AgentA", "callerAgent": None},
  41. ]
  42. )
  43. await agency.get_response("new request", recipient_agent="AgentB")
  44. assert isinstance(agent_b.last_message, str)
  45. agency.thread_manager.clear()
  46. agency.thread_manager.add_messages(
  47. [
  48. {"role": "user", "content": "previous", "agent": "AgentA", "callerAgent": None},
  49. {
  50. "role": "system",
  51. "content": "Transfer completed. You are AgentA. Please continue the task.",
  52. "agent": "AgentA",
  53. "callerAgent": None,
  54. "message_origin": "handoff_reminder",
  55. },
  56. {"role": "assistant", "content": "done", "agent": "AgentA", "callerAgent": None},
  57. ]
  58. )
  59. await agency.get_response("same target", recipient_agent="AgentA")
  60. assert isinstance(agent_a.last_message, str)
  61. @pytest.mark.asyncio
  62. async def test_agency_get_response_adds_reminder_after_repeated_manual_switches() -> None:
  63. """Refreshes the active control reminder on each manual recipient switch."""
  64. agent_a = CapturingAgent("AgentA")
  65. agent_b = CapturingAgent("AgentB")
  66. agent_c = CapturingAgent("AgentC")
  67. agency = Agency(agent_a, agent_b, agent_c)
  68. agency.thread_manager.add_messages(
  69. [
  70. {"role": "user", "content": "previous", "agent": "AgentB", "callerAgent": None},
  71. {
  72. "role": "system",
  73. "content": "Transfer completed. You are AgentB. Please continue the task.",
  74. "agent": "AgentB",
  75. "callerAgent": None,
  76. "message_origin": "handoff_reminder",
  77. },
  78. {"role": "assistant", "content": "done", "agent": "AgentB", "callerAgent": None},
  79. ]
  80. )
  81. await agency.get_response("switch to AgentA", recipient_agent="AgentA")
  82. assert isinstance(agent_a.last_message, list)
  83. assert agent_a.last_message[0]["message_origin"] == "recipient_reminder"
  84. await agency.get_response("switch to AgentC", recipient_agent="AgentC")
  85. assert isinstance(agent_c.last_message, list)
  86. assert agent_c.last_message[0]["message_origin"] == "recipient_reminder"
  87. assert (
  88. agent_c.last_message[0]["content"]
  89. == 'User has switched recipient agent. You are "AgentC". Please continue the task.'
  90. )
  91. assert agent_c.last_message[-1] == {"role": "user", "content": "switch to AgentC"}
  92. @pytest.mark.asyncio
  93. async def test_agency_get_response_adds_reminder_after_structured_switch_turn() -> None:
  94. """Structured user inputs should keep reminder chaining on later switches."""
  95. agent_a = CapturingAgent("AgentA")
  96. agent_b = CapturingAgent("AgentB")
  97. agent_c = CapturingAgent("AgentC")
  98. agency = Agency(agent_a, agent_b, agent_c)
  99. agency.thread_manager.add_messages(
  100. [
  101. {"role": "user", "content": "previous", "agent": "AgentB", "callerAgent": None},
  102. {
  103. "role": "system",
  104. "content": "Transfer completed. You are AgentB. Please continue the task.",
  105. "agent": "AgentB",
  106. "callerAgent": None,
  107. "message_origin": "handoff_reminder",
  108. },
  109. {"role": "assistant", "content": "done", "agent": "AgentB", "callerAgent": None},
  110. ]
  111. )
  112. await agency.get_response(
  113. [
  114. {"role": "user", "content": "switch to AgentA"},
  115. {"role": "user", "content": "keep the same structured payload"},
  116. ],
  117. recipient_agent="AgentA",
  118. )
  119. await agency.get_response("switch to AgentC", recipient_agent="AgentC")
  120. assert isinstance(agent_c.last_message, list)
  121. assert agent_c.last_message[0]["message_origin"] == "recipient_reminder"
  122. assert agent_c.last_message[-1] == {"role": "user", "content": "switch to AgentC"}
  123. @pytest.mark.asyncio
  124. async def test_agency_get_response_adds_reminder_after_split_run_handoff_turn() -> None:
  125. """Split top-level run ids should still preserve handoff reminder chaining."""
  126. agent_a = CapturingAgent("AgentA")
  127. agent_b = CapturingAgent("AgentB")
  128. agent_c = CapturingAgent("AgentC")
  129. agency = Agency(agent_a, agent_b, agent_c)
  130. agency.thread_manager.add_messages(
  131. [
  132. {
  133. "role": "user",
  134. "content": "previous",
  135. "agent": "AgentA",
  136. "callerAgent": None,
  137. "agent_run_id": "top-run",
  138. },
  139. {
  140. "role": "system",
  141. "content": "Transfer completed. You are AgentB. Please continue the task.",
  142. "agent": "AgentA",
  143. "callerAgent": None,
  144. "agent_run_id": "top-run",
  145. "message_origin": "handoff_reminder",
  146. },
  147. {
  148. "role": "assistant",
  149. "content": "done",
  150. "agent": "AgentB",
  151. "callerAgent": None,
  152. "agent_run_id": "handoff-run",
  153. "parent_run_id": "top-run",
  154. },
  155. ]
  156. )
  157. await agency.get_response("switch to AgentC", recipient_agent="AgentC")
  158. assert isinstance(agent_c.last_message, list)
  159. assert agent_c.last_message[0]["message_origin"] == "recipient_reminder"
  160. assert agent_c.last_message[-1] == {"role": "user", "content": "switch to AgentC"}
  161. @pytest.mark.asyncio
  162. async def test_agency_get_response_adds_reminder_after_interrupted_handoff_turn() -> None:
  163. """Manual switches should still refresh control when the prior handoff never answered."""
  164. agent_a = CapturingAgent("AgentA")
  165. agent_b = CapturingAgent("AgentB")
  166. agency = Agency(agent_a, agent_b)
  167. agency.thread_manager.add_messages(
  168. [
  169. {"role": "user", "content": "previous", "agent": "AgentA", "callerAgent": None},
  170. {
  171. "role": "system",
  172. "content": "Transfer completed. You are AgentB. Please continue the task.",
  173. "agent": "AgentA",
  174. "callerAgent": None,
  175. "message_origin": "handoff_reminder",
  176. },
  177. ]
  178. )
  179. await agency.get_response("switch to AgentB", recipient_agent="AgentB")
  180. assert isinstance(agent_b.last_message, list)
  181. assert agent_b.last_message[0]["message_origin"] == "recipient_reminder"
  182. assert agent_b.last_message[-1] == {"role": "user", "content": "switch to AgentB"}
  183. @pytest.mark.asyncio
  184. async def test_agency_get_response_adds_reminder_when_interrupted_handoff_switches_back() -> None:
  185. """Interrupted handoffs should refresh control even when switching back to the original agent."""
  186. agent_a = CapturingAgent("AgentA")
  187. agent_b = CapturingAgent("AgentB")
  188. agency = Agency(agent_a, agent_b)
  189. agency.thread_manager.add_messages(
  190. [
  191. {"role": "user", "content": "previous", "agent": "AgentA", "callerAgent": None},
  192. {
  193. "role": "system",
  194. "content": "Transfer completed. You are AgentB. Please continue the task.",
  195. "agent": "AgentA",
  196. "callerAgent": None,
  197. "message_origin": "handoff_reminder",
  198. },
  199. ]
  200. )
  201. await agency.get_response("switch back to AgentA", recipient_agent="AgentA")
  202. assert isinstance(agent_a.last_message, list)
  203. assert agent_a.last_message[0]["message_origin"] == "recipient_reminder"
  204. assert agent_a.last_message[-1] == {"role": "user", "content": "switch back to AgentA"}
  205. @pytest.mark.asyncio
  206. async def test_agency_get_response_accepts_legacy_top_level_reminder_in_split_run_turn() -> None:
  207. """Split-run histories should still honor top-level reminders that predate run ids."""
  208. agent_a = CapturingAgent("AgentA")
  209. agent_b = CapturingAgent("AgentB")
  210. agency = Agency(agent_a, agent_b)
  211. agency.thread_manager.add_messages(
  212. [
  213. {
  214. "role": "user",
  215. "content": "top-level request",
  216. "agent": "AgentA",
  217. "callerAgent": None,
  218. "agent_run_id": "top-run",
  219. },
  220. {
  221. "role": "system",
  222. "content": "Transfer completed. You are AgentB. Please continue the task.",
  223. "agent": "AgentB",
  224. "callerAgent": None,
  225. "message_origin": "handoff_reminder",
  226. },
  227. {
  228. "role": "assistant",
  229. "content": "done",
  230. "agent": "AgentB",
  231. "callerAgent": None,
  232. "agent_run_id": "handoff-run",
  233. "parent_run_id": "top-run",
  234. },
  235. ]
  236. )
  237. await agency.get_response("switch to AgentA", recipient_agent="AgentA")
  238. assert isinstance(agent_a.last_message, list)
  239. assert agent_a.last_message[0]["message_origin"] == "recipient_reminder"
  240. assert agent_a.last_message[-1] == {"role": "user", "content": "switch to AgentA"}
  241. @pytest.mark.asyncio
  242. async def test_agency_get_response_stream_adds_recipient_switch_reminder_after_handoff() -> None:
  243. """Streaming path should prepend recipient_reminder under the same conditions."""
  244. agent_a = CapturingAgent("AgentA")
  245. agent_b = CapturingAgent("AgentB")
  246. agency = Agency(agent_a, agent_b)
  247. agency.thread_manager.add_messages(
  248. [
  249. {"role": "user", "content": "previous", "agent": "AgentB", "callerAgent": None},
  250. {
  251. "role": "system",
  252. "content": "Transfer completed. You are AgentB. Please continue the task.",
  253. "agent": "AgentB",
  254. "callerAgent": None,
  255. "message_origin": "handoff_reminder",
  256. },
  257. {"role": "assistant", "content": "done", "agent": "AgentB", "callerAgent": None},
  258. ]
  259. )
  260. stream = agency.get_response_stream("new request", recipient_agent="AgentA")
  261. async for _event in stream:
  262. pass
  263. assert isinstance(agent_a.last_message, list)
  264. assert agent_a.last_message[0]["message_origin"] == "recipient_reminder"
  265. @pytest.mark.asyncio
  266. async def test_agency_get_response_stream_keeps_empty_input_guard_when_reminder_would_apply() -> None:
  267. """Streaming empty-input validation should still win over reminder injection."""
  268. agent_a = CapturingAgent("AgentA")
  269. agent_b = CapturingAgent("AgentB")
  270. agency = Agency(agent_a, agent_b)
  271. agency.thread_manager.add_messages(
  272. [
  273. {"role": "user", "content": "previous", "agent": "AgentB", "callerAgent": None},
  274. {
  275. "role": "system",
  276. "content": "Transfer completed. You are AgentB. Please continue the task.",
  277. "agent": "AgentB",
  278. "callerAgent": "AgentA",
  279. "message_origin": "handoff_reminder",
  280. },
  281. {"role": "assistant", "content": "done", "agent": "AgentB", "callerAgent": None},
  282. ]
  283. )
  284. events = []
  285. stream = agency.get_response_stream(" ", recipient_agent="AgentA")
  286. async for event in stream:
  287. events.append(event)
  288. assert stream.final_result is None
  289. assert events == [{"type": "error", "content": "message cannot be empty"}]
  290. assert agent_a.last_message == " "
  291. @pytest.mark.asyncio
  292. async def test_agency_get_response_ignores_descendant_handoff_reminders_from_other_runs() -> None:
  293. """Child-run handoff reminders should not trigger user-thread recipient reminders."""
  294. agent_a = CapturingAgent("AgentA")
  295. agent_b = CapturingAgent("AgentB")
  296. agency = Agency(agent_a, agent_b)
  297. agency.thread_manager.add_messages(
  298. [
  299. {
  300. "role": "user",
  301. "content": "top-level request",
  302. "agent": "AgentA",
  303. "callerAgent": None,
  304. "agent_run_id": "top-run",
  305. },
  306. {
  307. "role": "assistant",
  308. "content": "top-level response",
  309. "agent": "AgentA",
  310. "callerAgent": None,
  311. "agent_run_id": "top-run",
  312. },
  313. {
  314. "role": "system",
  315. "content": "Transfer completed. You are Specialist. Please continue the task.",
  316. "agent": "Specialist",
  317. "callerAgent": "AgentA",
  318. "agent_run_id": "child-run",
  319. "message_origin": "handoff_reminder",
  320. },
  321. {
  322. "role": "assistant",
  323. "content": "child response",
  324. "agent": "Specialist",
  325. "callerAgent": "AgentA",
  326. "agent_run_id": "child-run",
  327. },
  328. ]
  329. )
  330. await agency.get_response("new top-level switch", recipient_agent="AgentB")
  331. assert agent_b.last_message == "new top-level switch"
  332. @pytest.mark.asyncio
  333. async def test_agency_get_response_ignores_legacy_child_handoff_reminders_without_run_ids() -> None:
  334. """Legacy histories without run ids should only trust user-thread reminders."""
  335. agent_a = CapturingAgent("AgentA")
  336. agent_b = CapturingAgent("AgentB")
  337. agency = Agency(agent_a, agent_b)
  338. agency.thread_manager.add_messages(
  339. [
  340. {"role": "user", "content": "top-level request", "agent": "AgentA", "callerAgent": None},
  341. {"role": "assistant", "content": "top-level response", "agent": "AgentA", "callerAgent": None},
  342. {
  343. "role": "system",
  344. "content": "Transfer completed. You are Specialist. Please continue the task.",
  345. "agent": "Specialist",
  346. "callerAgent": "AgentA",
  347. "message_origin": "handoff_reminder",
  348. },
  349. {"role": "assistant", "content": "child response", "agent": "Specialist", "callerAgent": "AgentA"},
  350. ]
  351. )
  352. await agency.get_response("new legacy switch", recipient_agent="AgentB")
  353. assert agent_b.last_message == "new legacy switch"