conftest.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from unittest.mock import AsyncMock, MagicMock
  2. import pytest
  3. from dotenv import load_dotenv
  4. from agency_swarm import Agency, AgencyContext, Agent
  5. from agency_swarm.utils.thread import ThreadManager
  6. load_dotenv()
  7. @pytest.fixture
  8. def mock_thread_manager():
  9. """Provides a mocked ThreadManager instance with flat message storage."""
  10. manager = MagicMock(spec=ThreadManager)
  11. messages = []
  12. def add_message_side_effect(message):
  13. """Side effect for add_message to append to message list."""
  14. messages.append(message)
  15. def add_messages_side_effect(msgs):
  16. """Side effect for add_messages to extend message list."""
  17. messages.extend(msgs)
  18. def get_conversation_history_side_effect(agent, caller_agent=None):
  19. """Side effect for get_conversation_history to filter messages."""
  20. filtered = []
  21. for msg in messages:
  22. if (msg.get("agent") == agent and msg.get("callerAgent") == caller_agent) or (
  23. msg.get("callerAgent") == agent and msg.get("agent") == caller_agent
  24. ):
  25. filtered.append(msg)
  26. return filtered
  27. def get_all_messages_side_effect():
  28. """Side effect for get_all_messages to return all messages."""
  29. return messages.copy()
  30. def replace_messages_side_effect(new_messages):
  31. """Replace stored messages with the provided collection."""
  32. messages.clear()
  33. messages.extend(new_messages)
  34. manager.add_message.side_effect = add_message_side_effect
  35. manager.add_messages.side_effect = add_messages_side_effect
  36. manager.get_conversation_history.side_effect = get_conversation_history_side_effect
  37. manager.get_all_messages.side_effect = get_all_messages_side_effect
  38. manager.replace_messages.side_effect = replace_messages_side_effect
  39. # Legacy compatibility - these should not be used but may be called
  40. manager.get_thread = MagicMock()
  41. manager.add_item_and_save = MagicMock()
  42. manager.add_items_and_save = MagicMock()
  43. return manager
  44. @pytest.fixture
  45. def mock_agency_instance(mock_thread_manager):
  46. agency = MagicMock()
  47. agency.agents = {}
  48. agency.user_context = {}
  49. agency.thread_manager = mock_thread_manager
  50. return agency
  51. @pytest.fixture
  52. def minimal_agent(mock_thread_manager, mock_agency_instance):
  53. """Provides a minimal Agent instance for basic tests."""
  54. agent = Agent(name="TestAgent", instructions="Test instructions")
  55. # Create an agency and replace its thread manager with our mock
  56. agency = Agency(agent)
  57. agency.thread_manager = mock_thread_manager
  58. # Mock the agent's context creation to always return a context with our mock thread manager
  59. def mock_create_minimal_context():
  60. return AgencyContext(
  61. agency_instance=None,
  62. thread_manager=mock_thread_manager,
  63. subagents={},
  64. load_threads_callback=None,
  65. save_threads_callback=None,
  66. shared_instructions=None,
  67. )
  68. agent._create_minimal_context = mock_create_minimal_context
  69. return agent
  70. @pytest.fixture
  71. def mock_agent():
  72. """Provides a mocked Agent instance for testing."""
  73. agent = MagicMock(spec=Agent)
  74. agent.name = "MockAgent"
  75. agent.get_response = AsyncMock()
  76. # Create a proper async generator mock for get_response_stream
  77. async def default_stream(*args, **kwargs):
  78. yield {"event": "text", "data": "Mock response"}
  79. yield {"event": "done"}
  80. agent.get_response_stream = default_stream
  81. return agent
  82. @pytest.fixture
  83. def mock_agent2():
  84. """Provides a second mocked Agent instance for testing."""
  85. agent = MagicMock(spec=Agent)
  86. agent.name = "MockAgent2"
  87. agent.get_response = AsyncMock()
  88. # Create a proper async generator mock for get_response_stream
  89. async def default_stream(*args, **kwargs):
  90. yield {"event": "text", "data": "Mock response 2"}
  91. yield {"event": "done"}
  92. agent.get_response_stream = default_stream
  93. return agent