test_litellm_placeholder_ids_integration.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. """
  2. Integration test verifying LiteLLM placeholder IDs are normalized before persistence.
  3. Requires live Anthropic access; skipped automatically when ANTHROPIC_API_KEY is not configured.
  4. """
  5. import importlib
  6. import os
  7. import pytest
  8. from agents.models.fake_id import FAKE_RESPONSES_ID
  9. from agency_swarm import Agency, Agent, ModelSettings, function_tool
  10. from agency_swarm.tools.send_message import Handoff
  11. litellm = pytest.importorskip("litellm")
  12. LitellmModel = importlib.import_module("agents.extensions.models.litellm_model").LitellmModel
  13. pytestmark = pytest.mark.skipif(
  14. not os.getenv("ANTHROPIC_API_KEY"),
  15. reason="ANTHROPIC_API_KEY is required for LiteLLM integration test.",
  16. )
  17. def _build_agency() -> Agency:
  18. @function_tool
  19. def get_user_id(args: str) -> str:
  20. return "User id is 1245725189"
  21. coordinator_agent = Agent(
  22. name="Coordinator",
  23. instructions=(
  24. "You are a coordinator agent. Your job is to receive tasks and delegate them either via "
  25. "When you receive a task, use the `send_message` tool and select 'Worker' as the recipient "
  26. "to ask the Worker agent to perform the task. Always include the full "
  27. "task details in your message. "
  28. "When delegating, only relay the exact task text and never include unrelated user information."
  29. ),
  30. model_settings=ModelSettings(temperature=0.0),
  31. model=LitellmModel(model="anthropic/claude-sonnet-4-20250514"),
  32. tools=[get_user_id],
  33. )
  34. worker_agent = Agent(
  35. name="Worker",
  36. instructions="You perform tasks.",
  37. model_settings=ModelSettings(temperature=0.0),
  38. model=LitellmModel(model="anthropic/claude-sonnet-4-20250514"),
  39. )
  40. data_agent = Agent(
  41. name="DataAgent",
  42. instructions="You are a DataAgent that provides information about the user. \
  43. User name is John Doe. User age is 30.",
  44. description="Has information about the user.",
  45. model_settings=ModelSettings(temperature=0.0),
  46. model=LitellmModel(model="anthropic/claude-sonnet-4-20250514"),
  47. )
  48. return Agency(
  49. coordinator_agent,
  50. worker_agent,
  51. communication_flows=[
  52. (coordinator_agent > data_agent, Handoff),
  53. (worker_agent > data_agent, Handoff),
  54. ],
  55. shared_instructions="Test agency for agent-to-agent persistence verification.",
  56. )
  57. def test_litellm_placeholder_ids_are_not_persisted() -> None:
  58. litellm.modify_params = True
  59. agency = _build_agency()
  60. agency.get_response_sync(message="Say hi to data agent")
  61. agency.get_response_sync(message="Hello")
  62. messages = agency.thread_manager.get_all_messages()
  63. placeholder_items = [msg for msg in messages if msg.get("id") == FAKE_RESPONSES_ID]
  64. assert len(messages) >= 6, "Expected multiple conversation items after two turns"
  65. assert not placeholder_items, "Placeholder IDs should not be persisted after normalization"
  66. function_calls = [msg for msg in messages if msg.get("type") == "function_call"]
  67. function_outputs = [msg for msg in messages if msg.get("type") == "function_call_output"]
  68. function_call_ids = [msg.get("call_id") for msg in function_calls]
  69. output_call_ids = [msg.get("call_id") for msg in function_outputs]
  70. assert len(function_call_ids) >= 1, "Expected at least one function call in persisted history"
  71. assert all(isinstance(call_id, str) and call_id for call_id in function_call_ids)
  72. assert len(function_call_ids) == len(set(function_call_ids))
  73. assert set(function_call_ids) <= set(output_call_ids), "Each function call should have a matching output"
  74. for i, msg in enumerate(messages):
  75. if msg.get("type") != "function_call":
  76. continue
  77. call_id = msg.get("call_id")
  78. output_idx = None
  79. for j in range(i + 1, len(messages)):
  80. if messages[j].get("type") == "function_call_output" and messages[j].get("call_id") == call_id:
  81. output_idx = j
  82. break
  83. assert output_idx is not None, f"Missing function_call_output for call_id={call_id}"
  84. between = messages[i + 1 : output_idx]
  85. assistant_between = [item for item in between if item.get("role") == "assistant"]
  86. assert not assistant_between, "Tool call/results should remain consecutive without assistant inserts"