| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422 |
- import pickle
- import pytest
- from agents.models.fake_id import FAKE_RESPONSES_ID
- from agency_swarm.utils.thread import ThreadManager
- def test_thread_manager_initialization():
- """Tests that ThreadManager initializes with an empty message store."""
- manager = ThreadManager()
- assert len(manager._store.messages) == 0
- assert manager._load_threads_callback is None
- assert manager._save_threads_callback is None
- @pytest.mark.parametrize(
- "method,messages",
- [
- (
- "add_message",
- [
- {
- "role": "user",
- "content": "Hello",
- "agent": "Agent1",
- "callerAgent": None,
- "timestamp": 1234567890000,
- }
- ],
- ),
- (
- "add_messages",
- [
- {
- "role": "user",
- "content": "Hello",
- "agent": "Agent1",
- "callerAgent": None,
- "timestamp": 1234567890000,
- },
- {
- "role": "assistant",
- "content": "Hi there",
- "agent": "Agent1",
- "callerAgent": None,
- "timestamp": 1234567891000,
- },
- ],
- ),
- ],
- )
- def test_add_messages(method: str, messages: list[dict]):
- """Tests adding messages through both single and batch methods."""
- manager = ThreadManager()
- target = messages[0] if method == "add_message" else messages
- getattr(manager, method)(target)
- assert len(manager._store.messages) == len(messages)
- assert manager._store.messages == messages
- def test_duplicate_message_id_is_preserved():
- """Ensure ThreadManager leaves duplicate message ids intact."""
- manager = ThreadManager()
- initial = {
- "type": "message",
- "id": "msg-1",
- "role": "assistant",
- "content": None,
- "timestamp": 1,
- }
- updated = {
- "type": "message",
- "id": "msg-1",
- "role": "assistant",
- "tool_calls": [
- {
- "id": "call-1",
- "type": "function",
- "function": {"name": "do_work", "arguments": "{}"},
- }
- ],
- "timestamp": 2,
- }
- manager.add_message(initial)
- manager.add_message(updated)
- assert manager._store.messages == [initial, updated]
- def test_function_call_output_duplicates_are_preserved():
- """Ensure ThreadManager does not dedupe function_call_output entries."""
- manager = ThreadManager()
- first_output = {
- "type": "function_call_output",
- "call_id": "call-1",
- "output": "intermediate",
- "timestamp": 1,
- }
- final_output = {
- "type": "function_call_output",
- "call_id": "call-1",
- "output": "final",
- "timestamp": 2,
- }
- manager.add_message(first_output)
- manager.add_message(final_output)
- assert manager._store.messages == [first_output, final_output]
- def test_function_call_output_unique_ids_are_preserved():
- """Ensure distinct message ids for the same call id remain appended."""
- manager = ThreadManager()
- first_output = {
- "type": "function_call_output",
- "id": "msg-1",
- "call_id": "call-1",
- "output": "placeholder",
- "timestamp": 1,
- }
- final_output = {
- "type": "function_call_output",
- "id": "msg-2",
- "call_id": "call-1",
- "output": "final",
- "timestamp": 2,
- }
- manager.add_message(first_output)
- manager.add_message(final_output)
- assert manager._store.messages == [first_output, final_output]
- def test_placeholder_messages_are_not_deduped():
- manager = ThreadManager()
- first = {
- "type": "message",
- "id": FAKE_RESPONSES_ID,
- "role": "assistant",
- "content": "initial",
- "timestamp": 1,
- }
- second = {
- "type": "message",
- "id": FAKE_RESPONSES_ID,
- "role": "assistant",
- "content": "follow-up",
- "timestamp": 2,
- }
- manager.add_message(first)
- manager.add_message(second)
- assert manager._store.messages == [first, second]
- def test_placeholder_tool_messages_preserve_prior_calls():
- manager = ThreadManager()
- first_call = {
- "type": "function_call",
- "id": FAKE_RESPONSES_ID,
- "call_id": "call-1",
- "role": "assistant",
- "timestamp": 1,
- "tool_calls": [
- {
- "id": "call-1",
- "type": "function",
- "function": {"name": "get_user_id", "arguments": "{}"},
- }
- ],
- }
- first_output = {
- "type": "function_call_output",
- "id": FAKE_RESPONSES_ID,
- "call_id": "call-1",
- "output": "User id is 1245725189",
- "timestamp": 2,
- }
- manager.add_message(first_call)
- manager.add_message(first_output)
- second_call = {
- "type": "function_call",
- "id": FAKE_RESPONSES_ID,
- "call_id": "call-2",
- "role": "assistant",
- "timestamp": 3,
- "tool_calls": [
- {
- "id": "call-2",
- "type": "function",
- "function": {"name": "get_user_id", "arguments": "{}"},
- }
- ],
- }
- second_output = {
- "type": "function_call_output",
- "id": FAKE_RESPONSES_ID,
- "call_id": "call-2",
- "output": "Done",
- "timestamp": 4,
- }
- manager.add_message(second_call)
- manager.add_message(second_output)
- calls = [msg for msg in manager._store.messages if msg.get("type") == "function_call"]
- outputs = [msg for msg in manager._store.messages if msg.get("type") == "function_call_output"]
- assert {msg["call_id"] for msg in calls} == {"call-1", "call-2"}
- assert {msg["call_id"] for msg in outputs} == {"call-1", "call-2"}
- def test_user_thread_shared_across_agents():
- """Tests that all entry-point agents share the same user thread."""
- manager = ThreadManager()
- messages = [
- {"role": "user", "content": "Hello Agent1", "agent": "Agent1", "callerAgent": None, "timestamp": 1234567890000},
- {"role": "assistant", "content": "Hi user", "agent": "Agent1", "callerAgent": None, "timestamp": 1234567891000},
- {"role": "user", "content": "Hello Agent2", "agent": "Agent2", "callerAgent": None, "timestamp": 1234567892000},
- {
- "role": "assistant",
- "content": "Hi from Agent2",
- "agent": "Agent2",
- "callerAgent": None,
- "timestamp": 1234567893000,
- },
- ]
- manager.add_messages(messages)
- # Both agents should see the same combined conversation history
- agent1_history = manager.get_conversation_history("Agent1", None)
- agent2_history = manager.get_conversation_history("Agent2", None)
- assert agent1_history == messages
- assert agent2_history == messages
- assert agent1_history == agent2_history
- def test_get_all_messages():
- """Tests retrieving all messages from the thread manager."""
- manager = ThreadManager()
- messages = [
- {"role": "user", "content": "Message 1", "agent": "Agent1", "callerAgent": None, "timestamp": 1234567890000},
- {
- "role": "assistant",
- "content": "Response 1",
- "agent": "Agent1",
- "callerAgent": None,
- "timestamp": 1234567891000,
- },
- ]
- manager.add_messages(messages)
- all_messages = manager.get_all_messages()
- assert all_messages == messages
- # Verify it returns a copy, not the original list
- all_messages.append({"role": "user", "content": "Extra"})
- assert len(manager._store.messages) == 2 # Original should be unchanged
- def test_save_callback_triggered_on_add(mocker):
- """Tests that save callback is triggered when adding messages."""
- mock_save = mocker.MagicMock()
- manager = ThreadManager(save_threads_callback=mock_save)
- message = {"role": "user", "content": "Test", "agent": "Agent1", "callerAgent": None, "timestamp": 1234567890000}
- manager.add_message(message)
- mock_save.assert_called_once_with([message])
- def test_clear_persists_empty_message_store():
- """Ensure `clear()` persists the empty message store via the save callback."""
- captured: list[list[dict[str, object]]] = []
- manager = ThreadManager(save_threads_callback=lambda msgs: captured.append(list(msgs)))
- manager.add_message({"role": "user", "content": "seed"})
- captured.clear()
- manager.clear()
- assert captured == [[]]
- assert manager.get_all_messages() == []
- def test_load_callback_on_init(mocker):
- """Tests that load callback is called during initialization."""
- loaded_messages = [
- {
- "role": "user",
- "content": "Loaded message",
- "agent": "Agent1",
- "callerAgent": None,
- "timestamp": 1234567890000,
- }
- ]
- mock_load = mocker.MagicMock(return_value=loaded_messages)
- manager = ThreadManager(load_threads_callback=mock_load)
- mock_load.assert_called_once()
- assert manager._store.messages == loaded_messages
- def test_thread_manager_pickleable():
- """Tests that ThreadManager can be pickled and unpickled correctly."""
- # Create manager without callbacks (callbacks aren't pickleable)
- manager = ThreadManager()
- messages = [
- {"role": "user", "content": "Test message", "agent": "Agent1", "callerAgent": None, "timestamp": 1234567890000}
- ]
- manager.add_messages(messages)
- # Pickle and unpickle
- pickled_data = pickle.dumps(manager)
- unpickled_manager = pickle.loads(pickled_data)
- # Verify the data is preserved
- assert isinstance(unpickled_manager, ThreadManager)
- assert unpickled_manager._store.messages == messages
- def test_replace_messages_skips_save_callback():
- captured: list[list[dict[str, object]]] = []
- manager = ThreadManager(save_threads_callback=lambda msgs: captured.append(list(msgs)))
- manager.add_message({"role": "user", "content": "seed"})
- captured.clear()
- manager.replace_messages([{"role": "assistant", "content": "new"}])
- assert captured == []
- assert [msg["content"] for msg in manager.get_all_messages()] == ["new"]
- def test_thread_manager_allows_duplicate_ids_by_design():
- """Verify ThreadManager leaves duplicate items untouched (SDK handles dedupe)."""
- manager = ThreadManager()
- first = {
- "id": "msg-1",
- "type": "function_call",
- "call_id": "call-1",
- "role": "assistant",
- "timestamp": 1,
- }
- second = {
- "id": "msg-1",
- "type": "function_call",
- "call_id": "call-1",
- "role": "assistant",
- "timestamp": 2,
- }
- manager.add_message(first)
- manager.add_message(second)
- assert manager._store.messages == [first, second]
- def test_function_call_output_with_same_id_different_call_ids_should_not_dedupe():
- manager = ThreadManager()
- first_output = {
- "type": "function_call_output",
- "id": "msg-1",
- "call_id": "call-1",
- "output": "first tool result",
- "timestamp": 1,
- }
- second_output = {
- "type": "function_call_output",
- "id": "msg-1",
- "call_id": "call-2",
- "output": "second tool result",
- "timestamp": 2,
- }
- manager.add_message(first_output)
- manager.add_message(second_output)
- outputs = [msg for msg in manager._store.messages if msg.get("type") == "function_call_output"]
- assert len(outputs) == 2
- assert {msg["call_id"] for msg in outputs} == {"call-1", "call-2"}
- def test_messages_with_none_type_and_same_id_should_not_dedupe():
- manager = ThreadManager()
- first_message = {
- "type": None,
- "id": "msg-1",
- "role": "assistant",
- "content": "first",
- "timestamp": 1,
- }
- second_message = {
- "type": None,
- "id": "msg-1",
- "role": "assistant",
- "content": "second",
- "timestamp": 2,
- }
- manager.add_message(first_message)
- manager.add_message(second_message)
- assert len(manager._store.messages) == 2
|