test_litellm_anthropic_nonstreaming.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. """
  2. Non-streaming version of Anthropic message ordering test.
  3. Verifies correct message ordering in non-streaming mode.
  4. """
  5. import importlib
  6. import os
  7. import pytest
  8. from agents import ModelSettings
  9. from agents.models.fake_id import FAKE_RESPONSES_ID
  10. from agency_swarm import Agency, Agent, function_tool
  11. from agency_swarm.tools.send_message import Handoff
  12. litellm = pytest.importorskip("litellm")
  13. LitellmModel = importlib.import_module("agents.extensions.models.litellm_model").LitellmModel
  14. pytestmark = pytest.mark.skipif(
  15. not os.getenv("ANTHROPIC_API_KEY"),
  16. reason="ANTHROPIC_API_KEY required for Anthropic streaming test.",
  17. )
  18. @function_tool
  19. def get_user_id(args: str) -> str:
  20. """Returns user ID for testing."""
  21. return "User id is 1245725189"
  22. def _assert_valid_tool_call_pairs(messages: list[dict[str, object]]) -> None:
  23. function_calls = [msg for msg in messages if msg.get("type") == "function_call"]
  24. function_outputs = [msg for msg in messages if msg.get("type") == "function_call_output"]
  25. call_ids = [msg.get("call_id") for msg in function_calls]
  26. assert all(isinstance(call_id, str) and call_id for call_id in call_ids)
  27. assert len(call_ids) == len(set(call_ids))
  28. output_call_ids = [msg.get("call_id") for msg in function_outputs]
  29. assert all(isinstance(call_id, str) and call_id for call_id in output_call_ids)
  30. assert set(call_ids) <= set(output_call_ids)
  31. placeholder_items = [msg for msg in messages if msg.get("id") == FAKE_RESPONSES_ID]
  32. assert not placeholder_items, "Placeholder IDs should not persist in Anthropic/LiteLLM history"
  33. @pytest.fixture(scope="function")
  34. def litellm_anthropic_agency():
  35. coordinator = Agent(
  36. name="Coordinator",
  37. instructions="You are a coordinator agent.",
  38. model_settings=ModelSettings(temperature=0.0),
  39. model=LitellmModel(model="anthropic/claude-sonnet-4-20250514"),
  40. tools=[get_user_id],
  41. )
  42. worker = Agent(
  43. name="Worker",
  44. instructions="You perform tasks.",
  45. model_settings=ModelSettings(temperature=0.0),
  46. model=LitellmModel(model="anthropic/claude-sonnet-4-20250514"),
  47. )
  48. return Agency(
  49. coordinator,
  50. worker,
  51. communication_flows=[(coordinator > worker, Handoff)],
  52. shared_instructions="Test agency",
  53. )
  54. class TestLitellmAnthropicNonStreamingMessageOrdering:
  55. """Verify no intermediate assistant messages persist during tool execution (non-streaming mode)."""
  56. @pytest.mark.asyncio
  57. async def test_tool_usage_no_intermediate_messages(self, litellm_anthropic_agency: Agency):
  58. """Verify tool usage preserves correct message sequence in non-streaming mode."""
  59. litellm.modify_params = True
  60. # First turn with tool usage
  61. await litellm_anthropic_agency.get_response(message="get my id")
  62. # Verify message structure
  63. messages = litellm_anthropic_agency.thread_manager.get_all_messages()
  64. _assert_valid_tool_call_pairs(messages)
  65. # Find all function_call and function_call_output pairs
  66. for i, msg in enumerate(messages):
  67. if msg.get("type") == "function_call":
  68. # Find corresponding function_call_output
  69. call_id = msg.get("call_id")
  70. output_idx = None
  71. for j in range(i + 1, len(messages)):
  72. if messages[j].get("type") == "function_call_output" and messages[j].get("call_id") == call_id:
  73. output_idx = j
  74. break
  75. assert output_idx is not None, f"No function_call_output found for call_id {call_id}"
  76. # Check messages between function_call and function_call_output
  77. between = messages[i + 1 : output_idx]
  78. assistant_msgs = [m for m in between if m.get("role") == "assistant"]
  79. assert not assistant_msgs, (
  80. f"Found {len(assistant_msgs)} intermediate assistant message(s) "
  81. f"between function_call and function_call_output. This violates "
  82. f"Anthropic's requirement for consecutive tool_use/tool_result pairs."
  83. )
  84. # Second turn should succeed
  85. await litellm_anthropic_agency.get_response(message="hi")
  86. @pytest.mark.asyncio
  87. async def test_handoff_no_intermediate_messages(self, litellm_anthropic_agency: Agency):
  88. """Verify handoff preserves correct message sequence in non-streaming mode."""
  89. litellm.modify_params = True
  90. # First turn with handoff
  91. await litellm_anthropic_agency.get_response(message="transfer to worker", recipient_agent="Coordinator")
  92. # Verify no intermediate assistant messages between tool calls and outputs
  93. messages = litellm_anthropic_agency.thread_manager.get_all_messages()
  94. _assert_valid_tool_call_pairs(messages)
  95. for i, msg in enumerate(messages):
  96. if msg.get("type") == "function_call":
  97. call_id = msg.get("call_id")
  98. output_idx = None
  99. for j in range(i + 1, len(messages)):
  100. if messages[j].get("type") == "function_call_output" and messages[j].get("call_id") == call_id:
  101. output_idx = j
  102. break
  103. if output_idx is not None:
  104. between = messages[i + 1 : output_idx]
  105. assistant_msgs = [m for m in between if m.get("role") == "assistant"]
  106. assert not assistant_msgs, (
  107. "Found intermediate assistant message(s) during handoff "
  108. "that would violate Anthropic API requirements."
  109. )
  110. # Second turn should succeed
  111. await litellm_anthropic_agency.get_response(message="hi")