test_execution_streaming.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. from agency_swarm.agent.execution_streaming import prune_guardrail_messages
  2. def _build_message(
  3. *,
  4. role: str | None,
  5. message_origin: str | None = None,
  6. parent_run_id: str | None = None,
  7. agent_run_id: str | None = None,
  8. run_trace_id: str = "trace_guardrail",
  9. caller_agent: str | None = None,
  10. agent: str | None = None,
  11. extra: dict | None = None,
  12. ) -> dict:
  13. msg = {
  14. "role": role,
  15. "message_origin": message_origin,
  16. "parent_run_id": parent_run_id,
  17. "agent_run_id": agent_run_id,
  18. "run_trace_id": run_trace_id,
  19. "callerAgent": caller_agent,
  20. "agent": agent,
  21. "type": "message",
  22. }
  23. if extra:
  24. msg.update(extra)
  25. return msg
  26. def test_prune_guardrail_messages_parent_run_only_keeps_user_and_guidance() -> None:
  27. """
  28. Tree:
  29. CustomerSupportAgent (guardrail trips here)
  30. └── DatabaseAgent
  31. └── EmailAgent
  32. Guardrail fires before any delegation completes, so the history must collapse to the
  33. real user + the guardrail guidance from the root agent.
  34. """
  35. preserved_user = _build_message(role="user", parent_run_id=None, agent_run_id="agent_run_parent")
  36. guardrail_message = _build_message(
  37. role="assistant",
  38. message_origin="input_guardrail_message",
  39. parent_run_id=None,
  40. agent_run_id="agent_run_parent",
  41. )
  42. to_remove_assistant = _build_message(role="assistant", parent_run_id=None, agent_run_id="agent_run_parent")
  43. unrelated_other_trace = _build_message(role="assistant", run_trace_id="trace_other")
  44. all_messages = [
  45. preserved_user,
  46. guardrail_message,
  47. to_remove_assistant,
  48. unrelated_other_trace,
  49. ]
  50. pruned = prune_guardrail_messages(
  51. all_messages,
  52. initial_saved_count=1,
  53. run_trace_id="trace_guardrail",
  54. )
  55. assert pruned == [preserved_user, guardrail_message, unrelated_other_trace]
  56. def test_prune_guardrail_messages_child_run_keeps_trigger_input_and_guidance() -> None:
  57. """
  58. Tree:
  59. CustomerSupportAgent
  60. └── DatabaseAgent (guardrail fires here)
  61. The DatabaseAgent's user prompt plus its guidance must remain so the parent knows what
  62. to fix, but any generated assistant outputs/function calls are trimmed.
  63. """
  64. preserved_user = _build_message(role="user", parent_run_id=None, agent_run_id="agent_run_parent")
  65. forwarded_input = _build_message(
  66. role="user",
  67. parent_run_id="call_child",
  68. agent_run_id="agent_run_child",
  69. caller_agent="ParentAgent",
  70. )
  71. guardrail_message = _build_message(
  72. role="assistant",
  73. message_origin="input_guardrail_message",
  74. parent_run_id="call_child",
  75. agent_run_id="agent_run_child",
  76. )
  77. function_call = _build_message(
  78. role=None,
  79. parent_run_id="call_child",
  80. agent_run_id="agent_run_child",
  81. caller_agent="ParentAgent",
  82. extra={"type": "function_call"},
  83. )
  84. all_messages = [
  85. preserved_user,
  86. forwarded_input,
  87. guardrail_message,
  88. function_call,
  89. ]
  90. pruned = prune_guardrail_messages(
  91. all_messages,
  92. initial_saved_count=1,
  93. run_trace_id="trace_guardrail",
  94. )
  95. assert pruned == [preserved_user, forwarded_input, guardrail_message]
  96. def test_prune_guardrail_messages_preserves_other_traces() -> None:
  97. """
  98. Tree:
  99. CustomerSupportAgent (affected trace)
  100. A parallel trace stays untouched even when the guardrail triggers for the trace under
  101. inspection.
  102. """
  103. preserved_user = _build_message(role="user", parent_run_id=None, agent_run_id="agent_run_parent")
  104. guardrail_message = _build_message(
  105. role="assistant",
  106. message_origin="input_guardrail_message",
  107. parent_run_id=None,
  108. agent_run_id="agent_run_parent",
  109. )
  110. concurrent_message = _build_message(
  111. role="assistant",
  112. parent_run_id=None,
  113. agent_run_id="agent_run_other",
  114. run_trace_id="trace_other",
  115. )
  116. all_messages = [
  117. preserved_user,
  118. guardrail_message,
  119. concurrent_message,
  120. ]
  121. pruned = prune_guardrail_messages(
  122. all_messages,
  123. initial_saved_count=1,
  124. run_trace_id="trace_guardrail",
  125. )
  126. assert pruned == [preserved_user, guardrail_message, concurrent_message]
  127. def test_prune_guardrail_messages_drops_no_op_trace_descendants() -> None:
  128. """
  129. Tree:
  130. ParentAgent
  131. └── HelperAgent (run_trace_id=no-op)
  132. HelperAgent belongs to a no-op trace branch, so it is removed even if it emitted
  133. assistant outputs before the guardrail fired elsewhere.
  134. """
  135. preserved_user = _build_message(role="user", parent_run_id=None, agent_run_id="agent_run_parent")
  136. guardrail_message = _build_message(
  137. role="assistant",
  138. message_origin="input_guardrail_message",
  139. parent_run_id=None,
  140. agent_run_id="agent_run_parent",
  141. agent="ParentAgent",
  142. )
  143. helper_assistant = _build_message(
  144. role="assistant",
  145. parent_run_id="call_send_message",
  146. agent_run_id="agent_run_helper",
  147. run_trace_id="no-op",
  148. caller_agent="ParentAgent",
  149. agent="HelperAgent",
  150. )
  151. all_messages = [
  152. preserved_user,
  153. guardrail_message,
  154. helper_assistant,
  155. ]
  156. pruned = prune_guardrail_messages(
  157. all_messages,
  158. initial_saved_count=1,
  159. run_trace_id="trace_guardrail",
  160. )
  161. assert pruned == [preserved_user, guardrail_message]
  162. def test_prune_guardrail_messages_drops_nested_agent_user_and_errors() -> None:
  163. """
  164. Tree:
  165. CustomerSupportAgent
  166. └── DatabaseAgent
  167. └── EmailAgent (guardrail fires here, but parent guidance also added)
  168. Inter-agent *user* messages remain so the next retry has full context, while any
  169. assistant/system responses are scoped to the guardrail location.
  170. """
  171. real_user = _build_message(
  172. role="user",
  173. parent_run_id=None,
  174. agent_run_id="agent_run_parent",
  175. caller_agent=None,
  176. agent="CustomerSupportAgent",
  177. )
  178. db_user = _build_message(
  179. role="user",
  180. parent_run_id="call_db",
  181. agent_run_id="agent_run_db",
  182. caller_agent="CustomerSupportAgent",
  183. agent="DatabaseAgent",
  184. )
  185. email_user = _build_message(
  186. role="user",
  187. parent_run_id="call_email",
  188. agent_run_id="agent_run_email",
  189. caller_agent="DatabaseAgent",
  190. agent="EmailAgent",
  191. )
  192. email_guardrail_guidance = _build_message(
  193. role="assistant",
  194. message_origin="input_guardrail_message",
  195. parent_run_id="call_email",
  196. agent_run_id="agent_run_email",
  197. caller_agent="DatabaseAgent",
  198. agent="EmailAgent",
  199. )
  200. db_guardrail_guidance = _build_message(
  201. role="assistant",
  202. message_origin="input_guardrail_message",
  203. parent_run_id="call_db",
  204. agent_run_id="agent_run_db",
  205. caller_agent="CustomerSupportAgent",
  206. agent="DatabaseAgent",
  207. )
  208. top_guardrail_message = _build_message(
  209. role="assistant",
  210. message_origin="input_guardrail_message",
  211. parent_run_id=None,
  212. agent_run_id="agent_run_parent",
  213. caller_agent=None,
  214. agent="CustomerSupportAgent",
  215. )
  216. pruned = prune_guardrail_messages(
  217. [
  218. real_user,
  219. db_user,
  220. email_user,
  221. email_guardrail_guidance,
  222. db_guardrail_guidance,
  223. top_guardrail_message,
  224. ],
  225. initial_saved_count=1,
  226. run_trace_id="trace_guardrail",
  227. )
  228. assert pruned == [
  229. real_user,
  230. db_user,
  231. email_user,
  232. email_guardrail_guidance,
  233. db_guardrail_guidance,
  234. top_guardrail_message,
  235. ]
  236. def test_prune_guardrail_messages_preserves_parent_guidance_after_child_guardrail() -> None:
  237. """
  238. Tree:
  239. CustomerSupportAgent
  240. └── DatabaseAgent (guardrail fires here)
  241. Even when the child trips the guardrail, the parent agent receives its own guidance
  242. message that must remain in history so the user can see what to fix.
  243. """
  244. real_user = _build_message(
  245. role="user",
  246. parent_run_id=None,
  247. agent_run_id="agent_run_parent",
  248. caller_agent=None,
  249. agent="CustomerSupportAgent",
  250. )
  251. parent_prompt = _build_message(
  252. role="user",
  253. parent_run_id="call_db",
  254. agent_run_id="agent_run_db",
  255. caller_agent="CustomerSupportAgent",
  256. agent="DatabaseAgent",
  257. )
  258. db_guardrail_guidance = _build_message(
  259. role="assistant",
  260. message_origin="input_guardrail_message",
  261. parent_run_id="call_db",
  262. agent_run_id="agent_run_db",
  263. caller_agent="CustomerSupportAgent",
  264. agent="DatabaseAgent",
  265. )
  266. parent_guidance = _build_message(
  267. role="assistant",
  268. message_origin="input_guardrail_message",
  269. parent_run_id=None,
  270. agent_run_id="agent_run_parent",
  271. caller_agent=None,
  272. agent="CustomerSupportAgent",
  273. )
  274. pruned = prune_guardrail_messages(
  275. [real_user, parent_prompt, db_guardrail_guidance, parent_guidance],
  276. initial_saved_count=1,
  277. run_trace_id="trace_guardrail",
  278. )
  279. assert pruned == [real_user, parent_prompt, db_guardrail_guidance, parent_guidance]
  280. def test_prune_guardrail_messages_keeps_child_guardrail_guidance_for_parent() -> None:
  281. """
  282. Tree:
  283. CustomerSupportAgent
  284. └── DatabaseAgent (guardrail fires here)
  285. Parent agent still needs to see the child guardrail guidance (callerAgent is set),
  286. otherwise replays would lack actionable detail.
  287. """
  288. real_user = _build_message(
  289. role="user",
  290. parent_run_id=None,
  291. agent_run_id="agent_run_parent",
  292. caller_agent=None,
  293. agent="CustomerSupportAgent",
  294. )
  295. db_prompt = _build_message(
  296. role="user",
  297. parent_run_id="call_db",
  298. agent_run_id="agent_run_db",
  299. caller_agent="CustomerSupportAgent",
  300. agent="DatabaseAgent",
  301. )
  302. db_guardrail_message = _build_message(
  303. role="assistant",
  304. message_origin="input_guardrail_message",
  305. parent_run_id="call_db",
  306. agent_run_id="agent_run_db",
  307. caller_agent="CustomerSupportAgent",
  308. agent="DatabaseAgent",
  309. )
  310. pruned = prune_guardrail_messages(
  311. [real_user, db_prompt, db_guardrail_message],
  312. initial_saved_count=1,
  313. run_trace_id="trace_guardrail",
  314. )
  315. assert pruned == [real_user, db_prompt, db_guardrail_message]
  316. def test_prune_guardrail_messages_drops_descendants_after_guardrail() -> None:
  317. """
  318. Tree:
  319. CustomerSupportAgent
  320. └── DatabaseAgent
  321. └── EmailAgent (guardrail fires here)
  322. └── HelperAgent (spawned after guardrail)
  323. Any further delegations triggered after the guardrail trips are removed to keep the
  324. tree consistent with the halted execution.
  325. """
  326. real_user = _build_message(
  327. role="user",
  328. parent_run_id=None,
  329. agent_run_id="agent_run_parent",
  330. caller_agent=None,
  331. agent="CustomerSupportAgent",
  332. )
  333. db_user = _build_message(
  334. role="user",
  335. parent_run_id="call_db",
  336. agent_run_id="agent_run_db",
  337. caller_agent="CustomerSupportAgent",
  338. agent="DatabaseAgent",
  339. )
  340. email_guardrail_guidance = _build_message(
  341. role="assistant",
  342. message_origin="input_guardrail_message",
  343. parent_run_id="call_email",
  344. agent_run_id="agent_run_email",
  345. caller_agent="DatabaseAgent",
  346. agent="EmailAgent",
  347. )
  348. descendant_after_guardrail = _build_message(
  349. role="assistant",
  350. parent_run_id="call_followup",
  351. agent_run_id="agent_run_followup",
  352. caller_agent="EmailAgent",
  353. agent="HelperAgent",
  354. )
  355. parent_guidance = _build_message(
  356. role="assistant",
  357. message_origin="input_guardrail_message",
  358. parent_run_id=None,
  359. agent_run_id="agent_run_parent",
  360. caller_agent=None,
  361. agent="CustomerSupportAgent",
  362. )
  363. pruned = prune_guardrail_messages(
  364. [real_user, db_user, email_guardrail_guidance, descendant_after_guardrail, parent_guidance],
  365. initial_saved_count=1,
  366. run_trace_id="trace_guardrail",
  367. )
  368. assert pruned == [real_user, email_guardrail_guidance, parent_guidance]