test_thread_manager.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. import pickle
  2. import pytest
  3. from agents.models.fake_id import FAKE_RESPONSES_ID
  4. from agency_swarm.utils.thread import ThreadManager
  5. def test_thread_manager_initialization():
  6. """Tests that ThreadManager initializes with an empty message store."""
  7. manager = ThreadManager()
  8. assert len(manager._store.messages) == 0
  9. assert manager._load_threads_callback is None
  10. assert manager._save_threads_callback is None
  11. @pytest.mark.parametrize(
  12. "method,messages",
  13. [
  14. (
  15. "add_message",
  16. [
  17. {
  18. "role": "user",
  19. "content": "Hello",
  20. "agent": "Agent1",
  21. "callerAgent": None,
  22. "timestamp": 1234567890000,
  23. }
  24. ],
  25. ),
  26. (
  27. "add_messages",
  28. [
  29. {
  30. "role": "user",
  31. "content": "Hello",
  32. "agent": "Agent1",
  33. "callerAgent": None,
  34. "timestamp": 1234567890000,
  35. },
  36. {
  37. "role": "assistant",
  38. "content": "Hi there",
  39. "agent": "Agent1",
  40. "callerAgent": None,
  41. "timestamp": 1234567891000,
  42. },
  43. ],
  44. ),
  45. ],
  46. )
  47. def test_add_messages(method: str, messages: list[dict]):
  48. """Tests adding messages through both single and batch methods."""
  49. manager = ThreadManager()
  50. target = messages[0] if method == "add_message" else messages
  51. getattr(manager, method)(target)
  52. assert len(manager._store.messages) == len(messages)
  53. assert manager._store.messages == messages
  54. def test_duplicate_message_id_is_preserved():
  55. """Ensure ThreadManager leaves duplicate message ids intact."""
  56. manager = ThreadManager()
  57. initial = {
  58. "type": "message",
  59. "id": "msg-1",
  60. "role": "assistant",
  61. "content": None,
  62. "timestamp": 1,
  63. }
  64. updated = {
  65. "type": "message",
  66. "id": "msg-1",
  67. "role": "assistant",
  68. "tool_calls": [
  69. {
  70. "id": "call-1",
  71. "type": "function",
  72. "function": {"name": "do_work", "arguments": "{}"},
  73. }
  74. ],
  75. "timestamp": 2,
  76. }
  77. manager.add_message(initial)
  78. manager.add_message(updated)
  79. assert manager._store.messages == [initial, updated]
  80. def test_function_call_output_duplicates_are_preserved():
  81. """Ensure ThreadManager does not dedupe function_call_output entries."""
  82. manager = ThreadManager()
  83. first_output = {
  84. "type": "function_call_output",
  85. "call_id": "call-1",
  86. "output": "intermediate",
  87. "timestamp": 1,
  88. }
  89. final_output = {
  90. "type": "function_call_output",
  91. "call_id": "call-1",
  92. "output": "final",
  93. "timestamp": 2,
  94. }
  95. manager.add_message(first_output)
  96. manager.add_message(final_output)
  97. assert manager._store.messages == [first_output, final_output]
  98. def test_function_call_output_unique_ids_are_preserved():
  99. """Ensure distinct message ids for the same call id remain appended."""
  100. manager = ThreadManager()
  101. first_output = {
  102. "type": "function_call_output",
  103. "id": "msg-1",
  104. "call_id": "call-1",
  105. "output": "placeholder",
  106. "timestamp": 1,
  107. }
  108. final_output = {
  109. "type": "function_call_output",
  110. "id": "msg-2",
  111. "call_id": "call-1",
  112. "output": "final",
  113. "timestamp": 2,
  114. }
  115. manager.add_message(first_output)
  116. manager.add_message(final_output)
  117. assert manager._store.messages == [first_output, final_output]
  118. def test_placeholder_messages_are_not_deduped():
  119. manager = ThreadManager()
  120. first = {
  121. "type": "message",
  122. "id": FAKE_RESPONSES_ID,
  123. "role": "assistant",
  124. "content": "initial",
  125. "timestamp": 1,
  126. }
  127. second = {
  128. "type": "message",
  129. "id": FAKE_RESPONSES_ID,
  130. "role": "assistant",
  131. "content": "follow-up",
  132. "timestamp": 2,
  133. }
  134. manager.add_message(first)
  135. manager.add_message(second)
  136. assert manager._store.messages == [first, second]
  137. def test_placeholder_tool_messages_preserve_prior_calls():
  138. manager = ThreadManager()
  139. first_call = {
  140. "type": "function_call",
  141. "id": FAKE_RESPONSES_ID,
  142. "call_id": "call-1",
  143. "role": "assistant",
  144. "timestamp": 1,
  145. "tool_calls": [
  146. {
  147. "id": "call-1",
  148. "type": "function",
  149. "function": {"name": "get_user_id", "arguments": "{}"},
  150. }
  151. ],
  152. }
  153. first_output = {
  154. "type": "function_call_output",
  155. "id": FAKE_RESPONSES_ID,
  156. "call_id": "call-1",
  157. "output": "User id is 1245725189",
  158. "timestamp": 2,
  159. }
  160. manager.add_message(first_call)
  161. manager.add_message(first_output)
  162. second_call = {
  163. "type": "function_call",
  164. "id": FAKE_RESPONSES_ID,
  165. "call_id": "call-2",
  166. "role": "assistant",
  167. "timestamp": 3,
  168. "tool_calls": [
  169. {
  170. "id": "call-2",
  171. "type": "function",
  172. "function": {"name": "get_user_id", "arguments": "{}"},
  173. }
  174. ],
  175. }
  176. second_output = {
  177. "type": "function_call_output",
  178. "id": FAKE_RESPONSES_ID,
  179. "call_id": "call-2",
  180. "output": "Done",
  181. "timestamp": 4,
  182. }
  183. manager.add_message(second_call)
  184. manager.add_message(second_output)
  185. calls = [msg for msg in manager._store.messages if msg.get("type") == "function_call"]
  186. outputs = [msg for msg in manager._store.messages if msg.get("type") == "function_call_output"]
  187. assert {msg["call_id"] for msg in calls} == {"call-1", "call-2"}
  188. assert {msg["call_id"] for msg in outputs} == {"call-1", "call-2"}
  189. def test_user_thread_shared_across_agents():
  190. """Tests that all entry-point agents share the same user thread."""
  191. manager = ThreadManager()
  192. messages = [
  193. {"role": "user", "content": "Hello Agent1", "agent": "Agent1", "callerAgent": None, "timestamp": 1234567890000},
  194. {"role": "assistant", "content": "Hi user", "agent": "Agent1", "callerAgent": None, "timestamp": 1234567891000},
  195. {"role": "user", "content": "Hello Agent2", "agent": "Agent2", "callerAgent": None, "timestamp": 1234567892000},
  196. {
  197. "role": "assistant",
  198. "content": "Hi from Agent2",
  199. "agent": "Agent2",
  200. "callerAgent": None,
  201. "timestamp": 1234567893000,
  202. },
  203. ]
  204. manager.add_messages(messages)
  205. # Both agents should see the same combined conversation history
  206. agent1_history = manager.get_conversation_history("Agent1", None)
  207. agent2_history = manager.get_conversation_history("Agent2", None)
  208. assert agent1_history == messages
  209. assert agent2_history == messages
  210. assert agent1_history == agent2_history
  211. def test_get_all_messages():
  212. """Tests retrieving all messages from the thread manager."""
  213. manager = ThreadManager()
  214. messages = [
  215. {"role": "user", "content": "Message 1", "agent": "Agent1", "callerAgent": None, "timestamp": 1234567890000},
  216. {
  217. "role": "assistant",
  218. "content": "Response 1",
  219. "agent": "Agent1",
  220. "callerAgent": None,
  221. "timestamp": 1234567891000,
  222. },
  223. ]
  224. manager.add_messages(messages)
  225. all_messages = manager.get_all_messages()
  226. assert all_messages == messages
  227. # Verify it returns a copy, not the original list
  228. all_messages.append({"role": "user", "content": "Extra"})
  229. assert len(manager._store.messages) == 2 # Original should be unchanged
  230. def test_save_callback_triggered_on_add(mocker):
  231. """Tests that save callback is triggered when adding messages."""
  232. mock_save = mocker.MagicMock()
  233. manager = ThreadManager(save_threads_callback=mock_save)
  234. message = {"role": "user", "content": "Test", "agent": "Agent1", "callerAgent": None, "timestamp": 1234567890000}
  235. manager.add_message(message)
  236. mock_save.assert_called_once_with([message])
  237. def test_clear_persists_empty_message_store():
  238. """Ensure `clear()` persists the empty message store via the save callback."""
  239. captured: list[list[dict[str, object]]] = []
  240. manager = ThreadManager(save_threads_callback=lambda msgs: captured.append(list(msgs)))
  241. manager.add_message({"role": "user", "content": "seed"})
  242. captured.clear()
  243. manager.clear()
  244. assert captured == [[]]
  245. assert manager.get_all_messages() == []
  246. def test_load_callback_on_init(mocker):
  247. """Tests that load callback is called during initialization."""
  248. loaded_messages = [
  249. {
  250. "role": "user",
  251. "content": "Loaded message",
  252. "agent": "Agent1",
  253. "callerAgent": None,
  254. "timestamp": 1234567890000,
  255. }
  256. ]
  257. mock_load = mocker.MagicMock(return_value=loaded_messages)
  258. manager = ThreadManager(load_threads_callback=mock_load)
  259. mock_load.assert_called_once()
  260. assert manager._store.messages == loaded_messages
  261. def test_thread_manager_pickleable():
  262. """Tests that ThreadManager can be pickled and unpickled correctly."""
  263. # Create manager without callbacks (callbacks aren't pickleable)
  264. manager = ThreadManager()
  265. messages = [
  266. {"role": "user", "content": "Test message", "agent": "Agent1", "callerAgent": None, "timestamp": 1234567890000}
  267. ]
  268. manager.add_messages(messages)
  269. # Pickle and unpickle
  270. pickled_data = pickle.dumps(manager)
  271. unpickled_manager = pickle.loads(pickled_data)
  272. # Verify the data is preserved
  273. assert isinstance(unpickled_manager, ThreadManager)
  274. assert unpickled_manager._store.messages == messages
  275. def test_replace_messages_skips_save_callback():
  276. captured: list[list[dict[str, object]]] = []
  277. manager = ThreadManager(save_threads_callback=lambda msgs: captured.append(list(msgs)))
  278. manager.add_message({"role": "user", "content": "seed"})
  279. captured.clear()
  280. manager.replace_messages([{"role": "assistant", "content": "new"}])
  281. assert captured == []
  282. assert [msg["content"] for msg in manager.get_all_messages()] == ["new"]
  283. def test_thread_manager_allows_duplicate_ids_by_design():
  284. """Verify ThreadManager leaves duplicate items untouched (SDK handles dedupe)."""
  285. manager = ThreadManager()
  286. first = {
  287. "id": "msg-1",
  288. "type": "function_call",
  289. "call_id": "call-1",
  290. "role": "assistant",
  291. "timestamp": 1,
  292. }
  293. second = {
  294. "id": "msg-1",
  295. "type": "function_call",
  296. "call_id": "call-1",
  297. "role": "assistant",
  298. "timestamp": 2,
  299. }
  300. manager.add_message(first)
  301. manager.add_message(second)
  302. assert manager._store.messages == [first, second]
  303. def test_function_call_output_with_same_id_different_call_ids_should_not_dedupe():
  304. manager = ThreadManager()
  305. first_output = {
  306. "type": "function_call_output",
  307. "id": "msg-1",
  308. "call_id": "call-1",
  309. "output": "first tool result",
  310. "timestamp": 1,
  311. }
  312. second_output = {
  313. "type": "function_call_output",
  314. "id": "msg-1",
  315. "call_id": "call-2",
  316. "output": "second tool result",
  317. "timestamp": 2,
  318. }
  319. manager.add_message(first_output)
  320. manager.add_message(second_output)
  321. outputs = [msg for msg in manager._store.messages if msg.get("type") == "function_call_output"]
  322. assert len(outputs) == 2
  323. assert {msg["call_id"] for msg in outputs} == {"call-1", "call-2"}
  324. def test_messages_with_none_type_and_same_id_should_not_dedupe():
  325. manager = ThreadManager()
  326. first_message = {
  327. "type": None,
  328. "id": "msg-1",
  329. "role": "assistant",
  330. "content": "first",
  331. "timestamp": 1,
  332. }
  333. second_message = {
  334. "type": None,
  335. "id": "msg-1",
  336. "role": "assistant",
  337. "content": "second",
  338. "timestamp": 2,
  339. }
  340. manager.add_message(first_message)
  341. manager.add_message(second_message)
  342. assert len(manager._store.messages) == 2