test_litellm_models.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. """
  2. Integration tests for litellm-patched agents.
  3. These tests verify that agents are able to process
  4. user and agent-to-agent messages without errors.
  5. """
  6. import importlib
  7. import os
  8. import pytest
  9. from agents import ModelSettings
  10. from agency_swarm import Agency, Agent
  11. from agency_swarm.tools.send_message import Handoff, SendMessage
  12. pytest.importorskip("litellm")
  13. LitellmModel = importlib.import_module("agents.extensions.models.litellm_model").LitellmModel
  14. @pytest.fixture
  15. def coordinator_agent():
  16. return Agent(
  17. name="Coordinator",
  18. instructions=(
  19. "For any user question about the user, call `transfer_to_DataAgent`. Always use the handoff tool to answer."
  20. ),
  21. model_settings=ModelSettings(tool_choice="required"),
  22. model=LitellmModel(model="openai/gpt-5.4-mini", api_key=os.getenv("OPENAI_API_KEY")),
  23. )
  24. @pytest.fixture
  25. def worker_agent():
  26. return Agent(
  27. name="Worker",
  28. instructions=(
  29. "For any user question about the user, use the `send_message` tool to ask DataAgent. "
  30. "Always use the tool to answer."
  31. ),
  32. model_settings=ModelSettings(tool_choice="required"),
  33. model=LitellmModel(model="openai/gpt-5.4-mini", api_key=os.getenv("OPENAI_API_KEY")),
  34. )
  35. @pytest.fixture
  36. def data_agent():
  37. return Agent(
  38. name="DataAgent",
  39. instructions="User name is John Doe. User age is 30. Answer with just the facts.",
  40. description="Has information about the user.",
  41. model=LitellmModel(model="openai/gpt-5.4-mini", api_key=os.getenv("OPENAI_API_KEY")),
  42. )
  43. @pytest.fixture
  44. def coordinator_worker_agency(coordinator_agent, worker_agent, data_agent) -> Agency:
  45. """Agency with coordinator->worker communication flow."""
  46. return Agency(
  47. coordinator_agent,
  48. worker_agent,
  49. communication_flows=[
  50. (coordinator_agent > data_agent, Handoff),
  51. (worker_agent > data_agent, SendMessage),
  52. ],
  53. shared_instructions="Test agency for agent-to-agent persistence verification.",
  54. )
  55. class TestLitellmModels:
  56. """Test suite for agent-to-agent conversation persistence."""
  57. @pytest.mark.asyncio
  58. async def test_agent_to_agent_messages(self, coordinator_worker_agency: Agency, worker_agent: Agent):
  59. """
  60. Verify handoff communication works with litellm-patched agents.
  61. Coordinator uses transfer_to_DataAgent and Worker uses send_message.
  62. """
  63. worker_response = await coordinator_worker_agency.get_response(
  64. message="What is my name and age?",
  65. recipient_agent="Worker",
  66. )
  67. processed_response = str(worker_response.final_output).lower()
  68. assert "john" in processed_response and "doe" in processed_response, "Response should contain the user's name"
  69. assert "30" in processed_response or "thirty" in processed_response, "Response should contain the user's age"
  70. coordinator_response = await coordinator_worker_agency.get_response(
  71. message="What is my name and age?",
  72. recipient_agent="Coordinator",
  73. )
  74. processed_response = str(coordinator_response.final_output).lower()
  75. assert "john" in processed_response and "doe" in processed_response, "Response should contain the user's name"
  76. assert "30" in processed_response or "thirty" in processed_response, "Response should contain the user's age"
  77. # Verify conversation history was created for both paths
  78. handoff_messages = coordinator_worker_agency.thread_manager.get_conversation_history("DataAgent", None)
  79. send_message_messages = coordinator_worker_agency.thread_manager.get_conversation_history(
  80. "DataAgent",
  81. "Worker",
  82. )
  83. handoff_data_agent_messages = [msg for msg in handoff_messages if msg.get("agent") == "DataAgent"]
  84. send_message_data_agent_messages = [msg for msg in send_message_messages if msg.get("agent") == "DataAgent"]
  85. assert len(handoff_data_agent_messages) > 0, "DataAgent should have messages after handoff"
  86. assert len(send_message_data_agent_messages) > 0, "DataAgent should have messages after send_message"
  87. # Verify tool calls were created by both agents
  88. all_messages = coordinator_worker_agency.thread_manager.get_all_messages()
  89. function_calls = [msg for msg in all_messages if msg.get("type") == "function_call"]
  90. worker_send_messages = [
  91. msg for msg in function_calls if msg.get("agent") == "Worker" and msg.get("name") == "send_message"
  92. ]
  93. coordinator_handoffs = [
  94. msg for msg in function_calls if msg.get("agent") == "Coordinator" and "transfer_to_" in msg.get("name", "")
  95. ]
  96. assert len(worker_send_messages) > 0, "Worker should have at least one send_message call"
  97. assert len(coordinator_handoffs) > 0, "Coordinator should have at least one handoff"