test_override_policy.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. """Unit tests for FastAPI request override policy helpers."""
  2. from __future__ import annotations
  3. from pathlib import Path
  4. from types import SimpleNamespace
  5. from typing import cast
  6. from openai import AsyncOpenAI
  7. from agency_swarm.integrations.fastapi_utils.override_policy import (
  8. RequestOverridePolicy,
  9. get_allowed_dirs_for_metadata,
  10. )
  11. from agency_swarm.integrations.fastapi_utils.request_models import ClientConfig
  12. def test_request_override_policy_flags() -> None:
  13. policy = RequestOverridePolicy(ClientConfig(default_headers={"x-request-id": "req-1"}))
  14. assert policy.has_client_overrides is True
  15. assert policy.has_openai_overrides is True
  16. model_only = RequestOverridePolicy(ClientConfig(model="gpt-4o-mini"))
  17. assert model_only.has_client_overrides is True
  18. assert model_only.has_openai_overrides is False
  19. litellm_cfg = cast(
  20. ClientConfig,
  21. SimpleNamespace(
  22. base_url=None,
  23. api_key=None,
  24. default_headers=None,
  25. litellm_keys={"anthropic": "sk-ant"},
  26. ),
  27. )
  28. litellm_only = RequestOverridePolicy(litellm_cfg)
  29. assert litellm_only.has_client_overrides is True
  30. assert litellm_only.has_openai_overrides is False
  31. empty = RequestOverridePolicy(None)
  32. assert empty.has_client_overrides is False
  33. assert empty.has_openai_overrides is False
  34. def test_get_allowed_dirs_for_metadata_returns_absolute_paths(tmp_path) -> None:
  35. allowed = tmp_path / "uploads"
  36. allowed.mkdir(parents=True, exist_ok=True)
  37. file_entry = tmp_path / "not-a-dir.txt"
  38. file_entry.write_text("x", encoding="utf-8")
  39. missing_entry = tmp_path / "missing"
  40. tilde_entry = Path("~") / "custom"
  41. visible = get_allowed_dirs_for_metadata(
  42. [
  43. str(allowed),
  44. str(file_entry),
  45. str(missing_entry),
  46. tilde_entry,
  47. ]
  48. )
  49. assert visible == [
  50. str(allowed.expanduser().resolve()),
  51. str(file_entry.expanduser().resolve()),
  52. str(missing_entry.expanduser().resolve()),
  53. str(tilde_entry.expanduser().resolve()),
  54. ]
  55. assert all(Path(p).is_absolute() for p in visible)
  56. def test_build_file_upload_client_uses_selected_agent_client() -> None:
  57. model = SimpleNamespace(
  58. openai_client=AsyncOpenAI(
  59. api_key="sk-agent",
  60. base_url="https://api.agent.test/v1",
  61. default_headers={"x-agency-id": "agency-1"},
  62. )
  63. )
  64. agent = SimpleNamespace(model=model)
  65. agency = SimpleNamespace(
  66. agents={"Recipient": agent},
  67. entry_points=[SimpleNamespace(name="Recipient")],
  68. )
  69. policy = RequestOverridePolicy(ClientConfig(default_headers={"x-request-id": "req-1"}))
  70. client = policy.build_file_upload_client(agency, recipient_agent="Recipient")
  71. assert client is not None
  72. assert client.api_key == "sk-agent"
  73. headers = dict(client.default_headers or {})
  74. assert headers["x-agency-id"] == "agency-1"
  75. assert headers["x-request-id"] == "req-1"
  76. def test_build_file_upload_client_headers_only_without_baseline_returns_none(monkeypatch) -> None:
  77. agency = SimpleNamespace(
  78. agents={"A": SimpleNamespace(model="gpt-4o-mini")},
  79. entry_points=[SimpleNamespace(name="A")],
  80. )
  81. policy = RequestOverridePolicy(ClientConfig(default_headers={"x-request-id": "req-1"}))
  82. monkeypatch.setattr(
  83. "agency_swarm.integrations.fastapi_utils.override_policy.get_default_openai_client", lambda: None
  84. )
  85. client = policy.build_file_upload_client(agency, recipient_agent="A")
  86. assert client is None