test_thread_isolation_persistence.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. """
  2. Thread Isolation Persistence Tests
  3. Tests that thread isolation is maintained across persistence operations
  4. using direct structural verification.
  5. """
  6. import json
  7. import uuid
  8. from pathlib import Path
  9. from typing import Any
  10. import pytest
  11. from agents import ModelSettings
  12. from agency_swarm import Agency, Agent
  13. @pytest.fixture
  14. def ceo_agent_instance():
  15. return Agent(
  16. name="CEO",
  17. description="Chief Executive Officer",
  18. instructions="You are the CEO. Remember information and delegate tasks.",
  19. model_settings=ModelSettings(temperature=0.0),
  20. )
  21. @pytest.fixture
  22. def developer_agent_instance():
  23. return Agent(
  24. name="Developer",
  25. description="Software Developer",
  26. instructions="You are a Developer. Remember technical details.",
  27. model_settings=ModelSettings(temperature=0.0),
  28. )
  29. @pytest.fixture(scope="function")
  30. def temp_persistence_dir(tmp_path):
  31. """Temporary directory for persistence testing."""
  32. yield tmp_path
  33. def file_save_callback(messages: list[dict[str, Any]], base_dir: Path):
  34. """Save flat message list to JSON file."""
  35. file_path = base_dir / "messages.json"
  36. with open(file_path, "w") as f:
  37. json.dump(messages, f, indent=2)
  38. # Also save individual conversation files for backward compatibility
  39. conversations = {}
  40. for msg in messages:
  41. agent = msg.get("agent", "")
  42. caller = msg.get("callerAgent", "user")
  43. thread_id = f"{caller}->{agent}"
  44. if thread_id not in conversations:
  45. conversations[thread_id] = []
  46. conversations[thread_id].append(msg)
  47. for thread_id, msgs in conversations.items():
  48. sanitized_thread_id = thread_id.replace("->", "_to_")
  49. file_path = base_dir / f"{sanitized_thread_id}.json"
  50. with open(file_path, "w") as f:
  51. json.dump({"items": msgs, "metadata": {}}, f, indent=2)
  52. def file_load_callback_all_messages(base_dir: Path) -> list[dict[str, Any]]:
  53. """Load flat message list from JSON file."""
  54. file_path = base_dir / "messages.json"
  55. if file_path.exists():
  56. try:
  57. with open(file_path) as f:
  58. messages = json.load(f)
  59. if isinstance(messages, list):
  60. return messages
  61. except Exception:
  62. pass
  63. # Fall back to loading from individual thread files (migration)
  64. messages = []
  65. for file_path in base_dir.glob("*.json"):
  66. if file_path.name == "messages.json":
  67. continue
  68. try:
  69. with open(file_path) as f:
  70. thread_dict = json.load(f)
  71. if isinstance(thread_dict.get("items"), list):
  72. messages.extend(thread_dict["items"])
  73. except Exception:
  74. continue
  75. return messages
  76. @pytest.fixture
  77. def file_persistence_callbacks(temp_persistence_dir):
  78. """Fixture to provide configured file callbacks."""
  79. def save_cb(messages):
  80. return file_save_callback(messages, temp_persistence_dir)
  81. def load_cb():
  82. return file_load_callback_all_messages(temp_persistence_dir)
  83. return load_cb, save_cb
  84. @pytest.mark.asyncio
  85. async def test_thread_persistence_shared_structural(
  86. file_persistence_callbacks, ceo_agent_instance, developer_agent_instance
  87. ):
  88. """Test that shared user thread is persisted and restored correctly."""
  89. load_cb, save_cb = file_persistence_callbacks
  90. test_id = uuid.uuid4().hex[:8]
  91. print(f"\n--- Thread Persistence Isolation Test {test_id} ---")
  92. # Create agency with persistence
  93. agency = Agency(
  94. ceo_agent_instance,
  95. communication_flows=[ceo_agent_instance > developer_agent_instance],
  96. shared_instructions="Persistence isolation test agency",
  97. load_threads_callback=load_cb,
  98. save_threads_callback=save_cb,
  99. )
  100. # Test data - use unique identifiers for precise verification
  101. ceo_info = f"CEOPROJECT{uuid.uuid4().hex[:8]}"
  102. dev_info = f"DEVPROJECT{uuid.uuid4().hex[:8]}"
  103. # Step 1: Create messages with unique information
  104. await agency.get_response(message=f"CEO project: {ceo_info}", recipient_agent="CEO")
  105. await agency.get_response(message=f"Developer project: {dev_info}", recipient_agent="Developer")
  106. # Step 2: Verify shared user thread before persistence
  107. thread_manager = agency.thread_manager
  108. ceo_messages = thread_manager.get_conversation_history("CEO", None)
  109. dev_messages = thread_manager.get_conversation_history("Developer", None)
  110. assert ceo_messages == dev_messages, "User thread should be shared before persistence"
  111. thread_content = str(ceo_messages).lower()
  112. assert ceo_info.lower() in thread_content, "User thread missing CEO info"
  113. assert dev_info.lower() in thread_content, "User thread missing Developer info"
  114. # Step 3: Verify saved data contains the full shared conversation
  115. all_saved_messages = load_cb()
  116. saved_content = str(all_saved_messages).lower()
  117. assert ceo_info.lower() in saved_content, "Saved data missing CEO info"
  118. assert dev_info.lower() in saved_content, "Saved data missing Developer info"
  119. # Step 4: Verify loaded messages match saved messages
  120. all_loaded_messages = load_cb()
  121. assert all_loaded_messages == all_saved_messages, "Loaded messages should match saved messages"
  122. print("✓ Shared user thread preserved in memory and persistence")
  123. @pytest.mark.asyncio
  124. async def test_persistence_thread_file_separation(
  125. file_persistence_callbacks, ceo_agent_instance, developer_agent_instance
  126. ):
  127. """
  128. Test that different threads are saved as separate files.
  129. Verifies file-level isolation of thread persistence.
  130. """
  131. load_cb, save_cb = file_persistence_callbacks
  132. print("\n--- Persistence File Separation Test ---")
  133. agency = Agency(
  134. ceo_agent_instance,
  135. communication_flows=[ceo_agent_instance > developer_agent_instance],
  136. shared_instructions="File separation test agency",
  137. load_threads_callback=load_cb,
  138. save_threads_callback=save_cb,
  139. )
  140. # Create threads
  141. await agency.get_response(message="CEO message", recipient_agent="CEO")
  142. await agency.get_response(message="Developer message", recipient_agent="Developer")
  143. # Verify messages exist
  144. all_messages = load_cb()
  145. ceo_messages = [msg for msg in all_messages if msg.get("agent") == "CEO" and msg.get("callerAgent") is None]
  146. dev_messages = [msg for msg in all_messages if msg.get("agent") == "Developer" and msg.get("callerAgent") is None]
  147. assert len(ceo_messages) > 0, "CEO messages should exist"
  148. assert len(dev_messages) > 0, "Developer messages should exist"
  149. # Verify content separation
  150. ceo_file_content = str(ceo_messages).lower()
  151. dev_file_content = str(dev_messages).lower()
  152. assert "ceo message" in ceo_file_content, "CEO file missing CEO content"
  153. assert "developer message" not in ceo_file_content, "CEO file contaminated with Developer content"
  154. assert "developer message" in dev_file_content, "Developer file missing Developer content"
  155. assert "ceo message" not in dev_file_content, "Developer file contaminated with CEO content"
  156. print("✓ Each conversation properly tracked")
  157. print("✓ Message-level content isolation verified")