custom_persistence.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. """
  2. Agency Persistence Example
  3. This example demonstrates how to persist thread data between
  4. different sessions using callback functions.
  5. """
  6. import asyncio
  7. import json
  8. import logging
  9. import os
  10. import shutil
  11. import sys
  12. import tempfile
  13. from pathlib import Path
  14. from typing import Any
  15. from agency_swarm import ModelSettings
  16. logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s")
  17. script_dir = Path(__file__).parent
  18. project_root = script_dir.parent.parent
  19. if str(project_root) not in sys.path:
  20. sys.path.insert(0, str(project_root / "src"))
  21. from agency_swarm import Agency, Agent # noqa: E402
  22. PERSISTENCE_DIR = Path(tempfile.mkdtemp(prefix="thread_persistence_"))
  23. def save_threads(messages: list[dict[str, Any]]):
  24. """
  25. Save all messages to a file.
  26. Args:
  27. messages: Flat list of all messages with agent/callerAgent metadata.
  28. Each message contains:
  29. - agent: The recipient agent name
  30. - callerAgent: The sender agent name (None for user)
  31. - timestamp: Message timestamp in milliseconds
  32. - Plus all standard OpenAI message fields
  33. Messages from all conversations are stored in a single flat list.
  34. Note: In production, you would typically use a closure to capture chat_id
  35. or use a database with user/session context for saving messages.
  36. """
  37. file_path = PERSISTENCE_DIR / "thread_data.json"
  38. with open(file_path, "w") as f:
  39. json.dump(messages, f, indent=2)
  40. def load_threads(chat_id: str) -> list[dict[str, Any]]:
  41. """
  42. Load all messages from file for a specific chat session.
  43. Args:
  44. chat_id: The chat session identifier to load messages for.
  45. Returns:
  46. Flat list of all messages with agent/callerAgent metadata.
  47. Returns empty list if no data exists.
  48. Note: This demonstrates the correct callback signature where the load_threads
  49. function accepts a chat_id parameter, which is passed via lambda closure.
  50. """
  51. # In this demo, we use a simple file for simplicity, but in production
  52. # you would typically use the chat_id to load session-specific data from a database
  53. file_path = PERSISTENCE_DIR / "thread_data.json"
  54. print(f"Loading messages for chat_id: {chat_id}")
  55. if not file_path.exists():
  56. print("No existing message data file found - starting with empty messages")
  57. return []
  58. with open(file_path) as f:
  59. messages: list[dict[str, Any]] = json.load(f)
  60. return messages
  61. # Initialize all agents and agencies at the top
  62. assistant_agent = Agent(
  63. name="AssistantAgent",
  64. instructions="You are a helpful assistant. Answer questions and help users with their tasks.",
  65. tools=[],
  66. model_settings=ModelSettings(temperature=0.0), # Deterministic responses
  67. )
  68. # Define chat_id for demonstration - in production, this would come from your session management
  69. chat_id = "demo_session"
  70. # --- Create Agency Instance (v1.x Pattern) ---
  71. agency = Agency(
  72. assistant_agent, # AssistantAgent is the entry point (positional argument)
  73. shared_instructions="Be helpful and concise in your responses.",
  74. load_threads_callback=lambda: load_threads(chat_id),
  75. save_threads_callback=lambda messages: save_threads(messages),
  76. )
  77. # Don't create the second agency here - we'll create it after the first run
  78. TEST_INFO = "blue and lucky number is 77"
  79. async def run_persistent_conversation():
  80. """
  81. Demonstrates thread isolation and persistence in Agency Swarm v1.x.
  82. Key concepts demonstrated:
  83. 1. Thread isolation: Each communication flow gets its own thread
  84. 2. Thread identifiers: Follow "sender->recipient" format
  85. 3. Persistence: Complete thread state is saved and restored
  86. 4. Correct callback signatures: load() -> all_threads, save(all_threads) -> None
  87. """
  88. user_message_1 = f"Hello. Please remember that my favorite color is {TEST_INFO}. I'll ask you about it later."
  89. print(f"\n--- Turn 1: --- \nSending message to assistant: {user_message_1}")
  90. response1 = await agency.get_response(message=user_message_1)
  91. print(f"Response from AssistantAgent: {response1.final_output}")
  92. await asyncio.sleep(1)
  93. # Simulate application restart by creating a new agency
  94. print("\n--- Simulating Application Restart ---")
  95. print("Creating new agency instance that will share the same thread...")
  96. # Create a second agent instance for the reloaded agency (to avoid agent reuse)
  97. assistant_agent_reloaded = Agent(
  98. name="AssistantAgent",
  99. instructions="You are a helpful assistant. Answer questions and help users with their tasks.",
  100. tools=[],
  101. model_settings=ModelSettings(temperature=0.0), # Deterministic responses
  102. )
  103. agency_reloaded = Agency(
  104. assistant_agent_reloaded, # Use NEW agent instance to prevent reuse error
  105. shared_instructions="Be helpful and concise in your responses.",
  106. load_threads_callback=lambda: load_threads(chat_id),
  107. save_threads_callback=lambda messages: save_threads(messages),
  108. )
  109. user_message_2 = "What was my favorite color and lucky number I told you earlier?"
  110. print(f"\n--- Turn 2: --- \nSending message to assistant: {user_message_2}")
  111. response2 = await agency_reloaded.get_response(message=user_message_2)
  112. print(f"Response from Reloaded AssistantAgent: {response2.final_output}")
  113. # Test result
  114. if response2.final_output and "blue" in response2.final_output.lower() and "77" in response2.final_output.lower():
  115. print(f"\n✅ SUCCESS: AssistantAgent remembered the information ('{TEST_INFO}')!")
  116. print("Demo completed successfully.")
  117. else:
  118. print(f"\n❌ FAILURE: AssistantAgent did NOT remember the information ('{TEST_INFO}').")
  119. print(f"Agent's response: {response2.final_output}")
  120. # Cleanup
  121. if PERSISTENCE_DIR.exists():
  122. shutil.rmtree(PERSISTENCE_DIR)
  123. print(f"\nTemporary persistence directory {PERSISTENCE_DIR} cleaned up.")
  124. if __name__ == "__main__":
  125. print("\n=== Agency Swarm v1.x Thread Isolation & Persistence Demo ===")
  126. if os.name == "nt":
  127. asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
  128. asyncio.run(run_persistent_conversation())