test_persistence.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. import json
  2. from pathlib import Path
  3. from typing import Any
  4. import pytest
  5. from agency_swarm import Agency, Agent
  6. # --- File Persistence Setup ---
  7. @pytest.fixture(scope="function")
  8. def temp_persistence_dir(tmp_path):
  9. print(f"\nTEMP DIR: Created at {tmp_path}")
  10. yield tmp_path
  11. def file_save_callback(messages: list[dict[str, Any]], base_dir: Path):
  12. """Save flat message list to a JSON file."""
  13. print(f"\nFILE SAVE: Saving {len(messages)} messages to {base_dir}")
  14. try:
  15. file_path = base_dir / "messages.json"
  16. with open(file_path, "w") as f:
  17. json.dump(messages, f, indent=2)
  18. print(f"FILE SAVE: Successfully saved {file_path}")
  19. except Exception as e:
  20. print(f"FILE SAVE ERROR: Failed to save messages: {e}")
  21. import traceback
  22. traceback.print_exc()
  23. def file_load_callback(base_dir: Path) -> list[dict[str, Any]] | None:
  24. """Load flat message list from a JSON file."""
  25. file_path = base_dir / "messages.json"
  26. print(f"\nFILE LOAD: Attempting to load messages from {file_path}")
  27. if not file_path.exists():
  28. print("FILE LOAD: File not found.")
  29. return None
  30. try:
  31. with open(file_path) as f:
  32. messages = json.load(f)
  33. # Basic validation of loaded structure - should be a list
  34. if not isinstance(messages, list):
  35. print(f"FILE LOAD ERROR: Loaded data should be a list, got {type(messages)}.")
  36. return None
  37. print(f"FILE LOAD: Successfully loaded {len(messages)} messages")
  38. return messages
  39. except Exception as e:
  40. print(f"FILE LOAD ERROR: Failed to load messages: {e}")
  41. # Log traceback for detailed debugging
  42. import traceback
  43. traceback.print_exc()
  44. return None
  45. def file_save_callback_error(messages: list[dict[str, Any]], base_dir: Path):
  46. """Mock file save callback that raises an error."""
  47. if not messages:
  48. print("FILE SAVE ERROR (Intentional Fail): Messages list is empty.")
  49. raise ValueError("Cannot simulate save error for empty messages")
  50. file_path = base_dir / "messages_test.json"
  51. print(f"\nFILE SAVE ERROR: Intentionally failing at {file_path}")
  52. raise OSError("Simulated save error")
  53. def file_load_callback_error(base_dir: Path) -> list[dict[str, Any]] | None:
  54. """Mock file load callback that raises an error."""
  55. file_path = base_dir / "messages_test.json"
  56. print(f"\nFILE LOAD ERROR: Intentionally failing at {file_path}")
  57. raise OSError("Simulated load error")
  58. # --- Test Agent ---
  59. @pytest.fixture
  60. def persistence_agent():
  61. return Agent(
  62. name="PersistenceTester",
  63. instructions="Remember the secret code word I tell you. In the next turn, repeat the code word.",
  64. )
  65. def test_persistence_callbacks_must_be_callable(persistence_agent: Agent) -> None:
  66. with pytest.raises(TypeError, match="must be callable"):
  67. Agency(
  68. persistence_agent,
  69. load_threads_callback="invalid-load-callback",
  70. save_threads_callback=lambda _messages: None,
  71. )
  72. with pytest.raises(TypeError, match="must be callable"):
  73. Agency(
  74. persistence_agent,
  75. load_threads_callback=lambda: [],
  76. save_threads_callback="invalid-save-callback",
  77. )
  78. @pytest.fixture
  79. def file_persistence_callbacks(temp_persistence_dir):
  80. """Fixture to provide configured file callbacks that follow the correct interface."""
  81. def load_messages_for_chat(chat_id: str) -> list[dict[str, Any]]:
  82. """Load flat message list for a specific chat_id."""
  83. print(f"\nLOADING MESSAGES for chat_id: {chat_id}")
  84. file_path = temp_persistence_dir / f"messages_{chat_id}.json"
  85. if not file_path.exists():
  86. print("LOAD: No messages file found, returning empty list")
  87. return []
  88. try:
  89. with open(file_path) as f:
  90. messages = json.load(f)
  91. print(f"LOADED: {len(messages)} messages for chat_id {chat_id}")
  92. return messages if isinstance(messages, list) else []
  93. except Exception as e:
  94. print(f"ERROR loading {file_path}: {e}")
  95. return []
  96. def save_messages_for_chat(messages: list[dict[str, Any]], chat_id: str):
  97. """Save flat message list for a specific chat_id."""
  98. print(f"\nSAVING MESSAGES for chat_id: {chat_id} ({len(messages)} messages)")
  99. try:
  100. file_path = temp_persistence_dir / f"messages_{chat_id}.json"
  101. with open(file_path, "w") as f:
  102. json.dump(messages, f, indent=2)
  103. print(f"SAVED: {len(messages)} messages for chat {chat_id} to {file_path}")
  104. except Exception as e:
  105. print(f"SAVE ERROR for chat_id {chat_id}: {e}")
  106. import traceback
  107. traceback.print_exc()
  108. # Return the actual functions that take chat_id
  109. return load_messages_for_chat, save_messages_for_chat
  110. # --- Test Cases ---
  111. @pytest.mark.asyncio
  112. async def test_persistence_callbacks_called(temp_persistence_dir, persistence_agent, file_persistence_callbacks):
  113. """
  114. Test that save and load callbacks are invoked correctly with proper closure pattern.
  115. """
  116. chat_id = "test_chat_123"
  117. message1 = "First message for callback test."
  118. message2 = "Second message for callback test."
  119. # Expected file for flat message storage
  120. messages_file = temp_persistence_dir / f"messages_{chat_id}.json"
  121. # Get the actual callback functions
  122. load_messages_for_chat, save_messages_for_chat = file_persistence_callbacks
  123. # Define callbacks using closure pattern from deployment docs
  124. def load_messages():
  125. return load_messages_for_chat(chat_id)
  126. def save_messages(messages):
  127. save_messages_for_chat(messages, chat_id)
  128. # Initialize Agency with closure-based callbacks (NO parameters in lambda)
  129. agency = Agency(
  130. persistence_agent,
  131. load_threads_callback=lambda: load_messages(),
  132. save_threads_callback=lambda messages: save_messages(messages),
  133. )
  134. # Turn 1
  135. print(f"\n--- Callback Test Turn 1 --- MSG: {message1}")
  136. assert not messages_file.exists(), f"File {messages_file} should not exist before first run."
  137. await agency.get_response(message=message1)
  138. # Verify save succeeded by checking file existence
  139. assert messages_file.exists(), f"File {messages_file} should exist after first run."
  140. # Turn 2 - new agency instance should load previous history
  141. print(f"\n--- Callback Test Turn 2 --- MSG: {message2}")
  142. persistence_agent2 = Agent(
  143. name="PersistenceTester",
  144. instructions="Remember the secret code word I tell you. In the next turn, repeat the code word.",
  145. )
  146. # Same closure pattern for second agency
  147. agency2 = Agency(
  148. persistence_agent2,
  149. load_threads_callback=lambda: load_messages(),
  150. save_threads_callback=lambda messages: save_messages(messages),
  151. )
  152. await agency2.get_response(message=message2)
  153. # Verify file still exists and has more content
  154. assert messages_file.exists(), f"File {messages_file} should still exist after second run."
  155. with open(messages_file) as f:
  156. final_data = json.load(f)
  157. # Should have at least 2 user messages (turn 1 and turn 2)
  158. user_messages = [item for item in final_data if item.get("role") == "user"]
  159. assert len(user_messages) >= 2, f"Should have at least 2 user messages, got {len(user_messages)}"
  160. @pytest.mark.asyncio
  161. async def test_persistence_load_all_messages(temp_persistence_dir, file_persistence_callbacks):
  162. """
  163. Test that load callback returns all messages for a chat_id correctly.
  164. """
  165. chat_id = "load_messages_test_789"
  166. # Create test agents
  167. ceo = Agent(name="CEO", instructions="You are the CEO.")
  168. dev = Agent(name="Developer", instructions="You are the Developer.")
  169. # Get callback functions
  170. load_messages_for_chat, save_messages_for_chat = file_persistence_callbacks
  171. # Define callbacks using closure pattern
  172. def load_messages():
  173. return load_messages_for_chat(chat_id)
  174. def save_messages(messages):
  175. save_messages_for_chat(messages, chat_id)
  176. # Create agency with communication flow
  177. agency = Agency(
  178. ceo,
  179. communication_flows=[ceo > dev],
  180. load_threads_callback=lambda: load_messages(),
  181. save_threads_callback=lambda messages: save_messages(messages),
  182. )
  183. # Create messages with different agents
  184. await agency.get_response("CEO: Plan the project", recipient_agent="CEO")
  185. await agency.get_response("Developer: Code the project", recipient_agent="Developer")
  186. # Now test that load_messages returns ALL messages
  187. all_loaded_messages = load_messages()
  188. assert isinstance(all_loaded_messages, list), "Load callback should return a list"
  189. assert len(all_loaded_messages) >= 4, (
  190. f"Should load at least 4 messages (2 user + 2 assistant), got {len(all_loaded_messages)}"
  191. )
  192. # Check that we have messages from both agents
  193. ceo_messages = [msg for msg in all_loaded_messages if msg.get("agent") == "CEO"]
  194. dev_messages = [msg for msg in all_loaded_messages if msg.get("agent") == "Developer"]
  195. assert len(ceo_messages) > 0, "Should have messages for CEO agent"
  196. assert len(dev_messages) > 0, "Should have messages for Developer agent"
  197. # Verify each message has proper structure
  198. for msg in all_loaded_messages:
  199. assert isinstance(msg, dict), "Each message should be dict"
  200. assert "agent" in msg, "Message missing 'agent'"
  201. assert "timestamp" in msg, "Message missing 'timestamp'"
  202. print(
  203. f"✓ Successfully loaded {len(all_loaded_messages)} messages with agents: "
  204. f"{ {msg.get('agent') for msg in all_loaded_messages} }"
  205. )
  206. @pytest.mark.asyncio
  207. async def test_persistence_error_handling(temp_persistence_dir, persistence_agent, file_persistence_callbacks):
  208. """
  209. Test graceful error handling when persistence callbacks fail.
  210. """
  211. def load_with_error():
  212. """Load callback that raises an error."""
  213. raise OSError("Simulated load error")
  214. def save_with_error(messages):
  215. """Save callback that raises an error."""
  216. raise OSError("Simulated save error")
  217. # Test load error handling - should handle gracefully and continue
  218. agency_load_error = Agency(
  219. persistence_agent,
  220. load_threads_callback=lambda: load_with_error(),
  221. save_threads_callback=lambda messages: [], # No-op save
  222. )
  223. # Should handle load error gracefully and continue (not raise error)
  224. result = await agency_load_error.get_response("Test message despite load error")
  225. assert result is not None, "Should continue working despite load error"
  226. # Test save error handling - create separate agent instance
  227. persistence_agent2 = Agent(
  228. name="PersistenceTester",
  229. instructions="Remember the secret code word I tell you. In the next turn, repeat the code word.",
  230. )
  231. agency_save_error = Agency(
  232. persistence_agent2,
  233. load_threads_callback=lambda: [], # Return empty messages list
  234. save_threads_callback=lambda messages: save_with_error(messages),
  235. )
  236. # Should complete successfully despite save error
  237. result = await agency_save_error.get_response("Test message despite save error")
  238. assert result is not None, "Should continue working despite save error"
  239. print("✓ Error handling verified - system continues gracefully despite persistence errors")
  240. @pytest.mark.asyncio
  241. async def test_no_persistence_no_callbacks(persistence_agent, temp_persistence_dir):
  242. """
  243. Test that history is NOT persisted between Agency instances if no callbacks are provided.
  244. """
  245. message1 = "First message, should be forgotten."
  246. message2 = "Second message, load should not happen."
  247. # Agency Instance 1 - Turn 1 (No callbacks)
  248. print("\n--- No Persistence Test - Instance 1 - Turn 1 --- Creating Agency 1")
  249. agency1 = Agency(persistence_agent, load_threads_callback=None, save_threads_callback=None)
  250. print(f"--- No Persistence Test - Instance 1 - Turn 1 --- MSG: {message1}")
  251. await agency1.get_response(message=message1)
  252. # Check that no file was created (as no save callback was provided)
  253. assert len(list(temp_persistence_dir.glob("*.json"))) == 0, "No persistence files should exist"
  254. print("--- No Persistence Test - Verified no file created after Turn 1 ---")
  255. # Agency Instance 2 - Turn 2 (No callbacks)
  256. print("\n--- No Persistence Test - Instance 2 - Turn 2 --- Creating Agency 2")
  257. persistence_agent2 = Agent(
  258. name="PersistenceTester",
  259. instructions="Remember the secret code word I tell you. In the next turn, repeat the code word.",
  260. )
  261. agency2 = Agency(persistence_agent2, load_threads_callback=None, save_threads_callback=None)
  262. print(f"--- No Persistence Test - Instance 2 - Turn 2 --- MSG: {message2}")
  263. await agency2.get_response(message=message2)
  264. # Verify the messages in agency2 only contain message2, not message1
  265. messages_in_agency2 = agency2.thread_manager._store.messages
  266. assert messages_in_agency2 is not None
  267. found_message1 = any(
  268. item.get("role") == "user" and message1 in item.get("content", "") for item in messages_in_agency2
  269. )
  270. found_message2 = any(
  271. item.get("role") == "user" and message2 in item.get("content", "") for item in messages_in_agency2
  272. )
  273. assert not found_message1, f"Message '{message1}' (from instance 1) was unexpectedly found in instance 2."
  274. assert found_message2, f"Message '{message2}' not found in instance 2 messages."
  275. print("--- No Persistence Test - Verified message history in instance 2 ---")