test_fastapi_dry_run.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. """Integration tests for DRY_RUN behavior in FastAPI integration."""
  2. import pytest
  3. pytest.importorskip("fastapi.testclient")
  4. from fastapi.testclient import TestClient
  5. from agency_swarm import Agency, Agent, function_tool, run_fastapi
  6. from agency_swarm.agent.file_manager import AgentFileManager
  7. @pytest.fixture
  8. def agency_factory_with_tool():
  9. """Provide an agency factory that defines a simple FunctionTool on the agent."""
  10. @function_tool
  11. def greet(name: str) -> str:
  12. """Greet a person by name."""
  13. return f"Hello, {name}"
  14. def create_agency(load_threads_callback=None, save_threads_callback=None):
  15. agent = Agent(
  16. name="TestAgent",
  17. instructions="Base",
  18. # Use a normal OpenAI model name here; this test only verifies endpoint
  19. # registration under DRY_RUN and does not invoke the model.
  20. model="gpt-4o-mini",
  21. tools=[greet],
  22. )
  23. return Agency(
  24. agent,
  25. load_threads_callback=load_threads_callback,
  26. save_threads_callback=save_threads_callback,
  27. )
  28. return create_agency
  29. def test_dry_run_metadata_includes_tools(monkeypatch, agency_factory_with_tool):
  30. """When DRY_RUN=1, endpoints are registered (not 404) without side effects."""
  31. # Enable DRY_RUN for the app lifecycle
  32. monkeypatch.setenv("DRY_RUN", "1")
  33. @function_tool
  34. def add_one(x: int) -> int:
  35. """Add one."""
  36. return x + 1
  37. app = run_fastapi(
  38. agencies={"test_agency": agency_factory_with_tool},
  39. tools=[add_one],
  40. return_app=True,
  41. app_token_env="", # disable auth for test
  42. enable_agui=False,
  43. )
  44. client = TestClient(app)
  45. # Metadata endpoint should exist and include tools
  46. res = client.get("/test_agency/get_metadata")
  47. assert res.status_code == 200
  48. data = res.json()
  49. # Verify at least one tool is present in the agent node's data
  50. nodes = data.get("nodes", [])
  51. assert isinstance(nodes, list) and nodes, "Expected nodes in metadata"
  52. # Find the agent node and inspect its tools list
  53. agent_nodes = [n for n in nodes if n.get("id") == "TestAgent"]
  54. assert agent_nodes, "Agent node 'TestAgent' should be present"
  55. agent_node = agent_nodes[0]
  56. tools_list = agent_node.get("data", {}).get("tools", [])
  57. assert isinstance(tools_list, list) and len(tools_list) >= 1, "Expected tools listed for agent in DRY_RUN"
  58. # get_response should be registered under DRY_RUN: 422 means validation ran (route exists), not 404.
  59. res_resp = client.post("/test_agency/get_response", json={})
  60. assert res_resp.status_code == 422
  61. # tool endpoints should be available in DRY_RUN
  62. tool_res = client.post("/tool/add_one", json={"x": 1})
  63. assert tool_res.status_code == 200
  64. assert tool_res.json()["response"] == 2
  65. def test_fastapi_setup_and_metadata_force_dry_run_for_files_folder(monkeypatch, tmp_path):
  66. files = tmp_path / "files"
  67. files.mkdir()
  68. (files / "report.pdf").write_text("report", encoding="utf-8")
  69. (files / "chart.png").write_bytes(b"png")
  70. def record(self):
  71. raise AssertionError("parse_files_folder_for_vs_id should not run during FastAPI setup or metadata")
  72. monkeypatch.delenv("DRY_RUN", raising=False)
  73. monkeypatch.setattr(AgentFileManager, "parse_files_folder_for_vs_id", record)
  74. def create_agency(load_threads_callback=None, save_threads_callback=None):
  75. agent = Agent(name="FileAgent", instructions="Test", files_folder=str(files))
  76. return Agency(
  77. agent,
  78. load_threads_callback=load_threads_callback,
  79. save_threads_callback=save_threads_callback,
  80. )
  81. app = run_fastapi(agencies={"test_agency": create_agency}, return_app=True, app_token_env="")
  82. client = TestClient(app)
  83. res = client.get("/test_agency/get_metadata")
  84. assert res.status_code == 200
  85. payload = res.json()
  86. node = next(n for n in payload["nodes"] if n["id"] == "FileAgent")
  87. assert {"file_search", "code_interpreter"} <= set(node["data"]["capabilities"])