test_utils.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. """
  2. Tests for multimodal helper utilities.
  3. """
  4. import base64
  5. from pathlib import Path
  6. from unittest.mock import Mock
  7. import httpx
  8. import pytest
  9. from agency_swarm import ToolOutputFileContent, ToolOutputImage
  10. from agency_swarm.tools.utils import (
  11. tool_output_file_from_path,
  12. tool_output_file_from_url,
  13. tool_output_image_from_path,
  14. )
  15. def _write_png(tmp_path: Path) -> Path:
  16. """Write a 1x1 PNG to disk for image helper tests."""
  17. png_bytes = base64.b64decode(
  18. "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8Xw8AAhABgAP7R7AAAAAASUVORK5CYII="
  19. )
  20. image_path = tmp_path / "pixel.png"
  21. image_path.write_bytes(png_bytes)
  22. return image_path
  23. def test_tool_output_image_from_path_returns_data_url(tmp_path):
  24. image_path = _write_png(tmp_path)
  25. result = tool_output_image_from_path(image_path, detail="high")
  26. assert isinstance(result, ToolOutputImage)
  27. assert result.detail == "high"
  28. assert result.image_url.startswith("data:image/png;base64,")
  29. encoded = result.image_url.split(",", 1)[1]
  30. assert base64.b64decode(encoded) == image_path.read_bytes()
  31. def test_tool_output_image_from_path_rejects_unknown_type(tmp_path):
  32. image_path = tmp_path / "pixel"
  33. image_path.write_bytes(_write_png(tmp_path).read_bytes())
  34. with pytest.raises(ValueError, match="Unable to determine MIME type"):
  35. tool_output_image_from_path(image_path)
  36. def test_tool_output_file_from_path_embeds_file_data(tmp_path):
  37. file_path = tmp_path / "document.pdf"
  38. file_path.write_text("sample pdf content", encoding="utf-8")
  39. result = tool_output_file_from_path(file_path)
  40. assert isinstance(result, ToolOutputFileContent)
  41. assert result.filename == "document.pdf"
  42. assert result.file_data is not None
  43. assert result.file_data.startswith("data:application/pdf;base64,")
  44. encoded = result.file_data.split(",", 1)[1]
  45. assert base64.b64decode(encoded.encode("utf-8")).decode("utf-8") == "sample pdf content"
  46. def test_tool_output_file_from_path_rejects_non_pdf(tmp_path):
  47. file_path = tmp_path / "document.txt"
  48. file_path.write_text("not a pdf", encoding="utf-8")
  49. with pytest.raises(ValueError, match="Only PDF files are supported."):
  50. tool_output_file_from_path(file_path)
  51. def test_tool_output_file_from_url_returns_remote_reference():
  52. result = tool_output_file_from_url("https://example.com/archive.zip")
  53. assert isinstance(result, ToolOutputFileContent)
  54. assert result.file_url == "https://example.com/archive.zip"
  55. assert result.filename is None
  56. def test_tool_output_file_from_url_keeps_remote_pdf_when_content_type_is_pdf(monkeypatch):
  57. def _fake_head(url: str, *, follow_redirects: bool, timeout: float) -> httpx.Response:
  58. assert url == "https://1.1.1.1/doc.pdf"
  59. request = httpx.Request("HEAD", url)
  60. return httpx.Response(200, headers={"content-type": "application/pdf"}, request=request)
  61. monkeypatch.setattr("agency_swarm.tools.utils.httpx.head", _fake_head)
  62. result = tool_output_file_from_url("https://1.1.1.1/doc.pdf")
  63. assert result.file_url == "https://1.1.1.1/doc.pdf"
  64. assert result.file_data is None
  65. assert result.filename is None
  66. def test_tool_output_file_from_url_falls_back_to_data_url_for_pdf_served_as_octet_stream(monkeypatch):
  67. pdf_bytes = b"%PDF-1.4 test-pdf-data"
  68. def _fake_head(url: str, *, follow_redirects: bool, timeout: float) -> httpx.Response:
  69. assert url == "https://1.1.1.1/doc.pdf"
  70. request = httpx.Request("HEAD", url)
  71. return httpx.Response(200, headers={"content-type": "application/octet-stream"}, request=request)
  72. class _StreamResponse:
  73. status_code = 200
  74. headers: dict[str, str] = {}
  75. def __init__(self) -> None:
  76. self.request = httpx.Request("GET", "https://1.1.1.1/doc.pdf")
  77. def __enter__(self) -> "_StreamResponse":
  78. return self
  79. def __exit__(self, exc_type, exc, tb) -> bool:
  80. return False
  81. def raise_for_status(self) -> None:
  82. return None
  83. def iter_bytes(self):
  84. yield pdf_bytes
  85. def _fake_stream(method: str, url: str, *, follow_redirects: bool, timeout: float) -> _StreamResponse:
  86. assert method == "GET"
  87. assert url == "https://1.1.1.1/doc.pdf"
  88. return _StreamResponse()
  89. monkeypatch.setattr("agency_swarm.tools.utils.httpx.head", _fake_head)
  90. monkeypatch.setattr("agency_swarm.tools.utils.httpx.stream", _fake_stream)
  91. result = tool_output_file_from_url("https://1.1.1.1/doc.pdf")
  92. assert result.file_url is None
  93. assert result.filename == "doc.pdf"
  94. assert result.file_data is not None
  95. assert result.file_data.startswith("data:application/pdf;base64,")
  96. encoded = result.file_data.split(",", 1)[1]
  97. assert base64.b64decode(encoded) == pdf_bytes
  98. def test_tool_output_file_from_url_skips_local_fetch_for_unsafe_host(monkeypatch):
  99. head_mock = Mock()
  100. monkeypatch.setattr("agency_swarm.tools.utils.httpx.head", head_mock)
  101. result = tool_output_file_from_url("http://127.0.0.1/doc.pdf")
  102. assert result.file_url == "http://127.0.0.1/doc.pdf"
  103. assert result.file_data is None
  104. head_mock.assert_not_called()
  105. def test_tool_output_file_from_url_falls_back_to_file_url_when_pdf_exceeds_inline_limit(monkeypatch):
  106. def _fake_head(url: str, *, follow_redirects: bool, timeout: float) -> httpx.Response:
  107. request = httpx.Request("HEAD", url)
  108. return httpx.Response(200, headers={"content-type": "application/octet-stream"}, request=request)
  109. class _StreamResponse:
  110. status_code = 200
  111. headers: dict[str, str] = {}
  112. def __init__(self) -> None:
  113. self.request = httpx.Request("GET", "https://1.1.1.1/doc.pdf")
  114. def __enter__(self) -> "_StreamResponse":
  115. return self
  116. def __exit__(self, exc_type, exc, tb) -> bool:
  117. return False
  118. def raise_for_status(self) -> None:
  119. return None
  120. def iter_bytes(self):
  121. yield b"12345"
  122. yield b"67890"
  123. def _fake_stream(method: str, url: str, *, follow_redirects: bool, timeout: float) -> _StreamResponse:
  124. return _StreamResponse()
  125. monkeypatch.setattr("agency_swarm.tools.utils.httpx.head", _fake_head)
  126. monkeypatch.setattr("agency_swarm.tools.utils.httpx.stream", _fake_stream)
  127. monkeypatch.setattr("agency_swarm.tools.utils.MAX_INLINE_PDF_BYTES", 6)
  128. result = tool_output_file_from_url("https://1.1.1.1/doc.pdf")
  129. assert result.file_url == "https://1.1.1.1/doc.pdf"
  130. assert result.file_data is None
  131. def test_tool_output_file_from_url_preserves_file_url_for_invalid_port():
  132. result = tool_output_file_from_url("https://example.com:abc/doc.pdf")
  133. assert result.file_url == "https://example.com:abc/doc.pdf"
  134. assert result.file_data is None
  135. def test_tool_output_file_from_url_preserves_file_url_for_invalid_ipv6_host():
  136. result = tool_output_file_from_url("https://[::1/doc.pdf")
  137. assert result.file_url == "https://[::1/doc.pdf"
  138. assert result.file_data is None
  139. def test_tool_output_file_from_url_blocks_unsafe_redirect_targets(monkeypatch):
  140. def _fake_head(url: str, *, follow_redirects: bool, timeout: float) -> httpx.Response:
  141. request = httpx.Request("HEAD", url)
  142. return httpx.Response(200, headers={"content-type": "application/octet-stream"}, request=request)
  143. class _StreamResponse:
  144. status_code = 302
  145. def __init__(self) -> None:
  146. self.request = httpx.Request("GET", "https://1.1.1.1/doc.pdf")
  147. self.headers = {"location": "http://127.0.0.1/secret.pdf"}
  148. def __enter__(self) -> "_StreamResponse":
  149. return self
  150. def __exit__(self, exc_type, exc, tb) -> bool:
  151. return False
  152. def raise_for_status(self) -> None:
  153. return None
  154. def iter_bytes(self):
  155. if False:
  156. yield b""
  157. def _fake_stream(method: str, url: str, *, follow_redirects: bool, timeout: float) -> _StreamResponse:
  158. return _StreamResponse()
  159. monkeypatch.setattr("agency_swarm.tools.utils.httpx.head", _fake_head)
  160. monkeypatch.setattr("agency_swarm.tools.utils.httpx.stream", _fake_stream)
  161. result = tool_output_file_from_url("https://1.1.1.1/doc.pdf")
  162. assert result.file_url == "https://1.1.1.1/doc.pdf"
  163. assert result.file_data is None