| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206 |
- """Unit tests for fastapi_utils download_file behavior."""
- import asyncio
- from pathlib import Path
- from unittest.mock import AsyncMock, MagicMock
- import pytest
- @pytest.mark.asyncio
- async def test_download_file_cleans_up_tmp_on_http_error(tmp_path: Path) -> None:
- """Temp file must be deleted and no fd must be leaked when the HTTP request fails."""
- import gc
- import agency_swarm.integrations.fastapi_utils.file_handler as fh
- response_obj = MagicMock()
- response_obj.raise_for_status = MagicMock(side_effect=Exception("HTTP 500"))
- response_obj.aiter_bytes = MagicMock()
- stream_cm = MagicMock()
- stream_cm.__aenter__ = AsyncMock(return_value=response_obj)
- stream_cm.__aexit__ = AsyncMock(return_value=False)
- client_obj = MagicMock()
- client_obj.stream = MagicMock(return_value=stream_cm)
- client_cm = MagicMock()
- client_cm.__aenter__ = AsyncMock(return_value=client_obj)
- client_cm.__aexit__ = AsyncMock(return_value=False)
- original_client = fh.httpx.AsyncClient
- fh.httpx.AsyncClient = MagicMock(return_value=client_cm)
- try:
- with pytest.raises(Exception, match="HTTP 500"):
- await fh.download_file("https://example.com/file.pdf", "file.pdf", str(tmp_path))
- finally:
- fh.httpx.AsyncClient = original_client
- gc.collect()
- leftover = list(tmp_path.glob("*.tmp"))
- assert leftover == [], f"Temp file was not cleaned up: {leftover}"
- open_fds_in_dir = [fd for fd in range(3, 1024) if _fd_points_to_dir(fd, str(tmp_path))]
- assert open_fds_in_dir == [], f"Leaked file descriptors pointing to tmp_path: {open_fds_in_dir}"
- def _fd_points_to_dir(fd: int, directory: str) -> bool:
- try:
- import os
- stat = os.fstat(fd)
- path = Path(directory)
- for file_path in path.iterdir():
- try:
- file_stat = file_path.stat()
- if file_stat.st_ino == stat.st_ino and file_stat.st_dev == stat.st_dev:
- return True
- except OSError:
- pass
- except OSError:
- pass
- return False
- @pytest.mark.asyncio
- async def test_download_file_concurrent_same_base_name(tmp_path: Path) -> None:
- """Two concurrent downloads with the same base name must not collide on the temp file."""
- fake_content_1 = b"%PDF-1.4 content one"
- fake_content_2 = b"%PDF-1.4 content two"
- def make_http_mock(content: bytes) -> MagicMock:
- async def mock_aiter_bytes():
- yield content
- response_obj = MagicMock()
- response_obj.raise_for_status = MagicMock()
- response_obj.aiter_bytes = MagicMock(return_value=mock_aiter_bytes())
- stream_cm = MagicMock()
- stream_cm.__aenter__ = AsyncMock(return_value=response_obj)
- stream_cm.__aexit__ = AsyncMock(return_value=False)
- client_obj = MagicMock()
- client_obj.stream = MagicMock(return_value=stream_cm)
- client_cm = MagicMock()
- client_cm.__aenter__ = AsyncMock(return_value=client_obj)
- client_cm.__aexit__ = AsyncMock(return_value=False)
- return client_cm
- import agency_swarm.integrations.fastapi_utils.file_handler as fh
- call_count = 0
- contents = [fake_content_1, fake_content_2]
- original_client = fh.httpx.AsyncClient
- def patched_client(**kwargs):
- nonlocal call_count
- mock = make_http_mock(contents[call_count % 2])
- call_count += 1
- return mock
- fh.httpx.AsyncClient = patched_client
- try:
- result1, result2 = await asyncio.gather(
- fh.download_file("https://example.com/f1", "DASDA", str(tmp_path)),
- fh.download_file("https://example.com/f2", "DASDA.pdf", str(tmp_path)),
- )
- finally:
- fh.httpx.AsyncClient = original_client
- assert result1 != result2, "Each download must produce a unique output path"
- assert Path(result1).exists(), "First download result must exist"
- assert Path(result2).exists(), "Second download result must exist"
- contents_found = {Path(result1).read_bytes(), Path(result2).read_bytes()}
- assert contents_found == {fake_content_1, fake_content_2}, "Each download's content must be preserved"
- @pytest.mark.asyncio
- async def test_download_file_uses_shutil_move_for_cross_device_rename(tmp_path: Path) -> None:
- """download_file must use shutil.move instead of Path.replace."""
- import agency_swarm.integrations.fastapi_utils.file_handler as fh
- fake_content = b"%PDF-1.4 fake content"
- pdf_name = "DASDA.pdf"
- async def mock_aiter_bytes():
- yield fake_content
- response_obj = MagicMock()
- response_obj.raise_for_status = MagicMock()
- response_obj.aiter_bytes = MagicMock(return_value=mock_aiter_bytes())
- stream_cm = MagicMock()
- stream_cm.__aenter__ = AsyncMock(return_value=response_obj)
- stream_cm.__aexit__ = AsyncMock(return_value=False)
- client_obj = MagicMock()
- client_obj.stream = MagicMock(return_value=stream_cm)
- client_cm = MagicMock()
- client_cm.__aenter__ = AsyncMock(return_value=client_obj)
- client_cm.__aexit__ = AsyncMock(return_value=False)
- original_client = fh.httpx.AsyncClient
- original_move = fh.shutil.move
- move_was_called = []
- def tracking_move(src: str, dst: str):
- move_was_called.append((src, dst))
- return original_move(src, dst)
- fh.httpx.AsyncClient = MagicMock(return_value=client_cm)
- fh.shutil.move = tracking_move
- try:
- result = await fh.download_file("https://example.com/DASDA.pdf", pdf_name, str(tmp_path))
- finally:
- fh.httpx.AsyncClient = original_client
- fh.shutil.move = original_move
- assert Path(result).suffix == ".pdf"
- assert Path(result).parent == tmp_path
- assert Path(result).exists()
- assert Path(result).read_bytes() == fake_content
- assert len(move_was_called) == 1
- src, dst = move_was_called[0]
- assert src.endswith(".tmp")
- assert dst.endswith(".pdf")
- assert not list(tmp_path.glob("*.tmp")), ".tmp file should be removed after move"
- @pytest.mark.asyncio
- async def test_download_file_long_filename_does_not_crash(tmp_path: Path) -> None:
- """A filename with a ~250-char base must not crash mkstemp with a filesystem limit error."""
- import agency_swarm.integrations.fastapi_utils.file_handler as fh
- long_name = "A" * 250 + ".pdf"
- fake_content = b"%PDF-1.4 long name"
- async def mock_aiter_bytes():
- yield fake_content
- response_obj = MagicMock()
- response_obj.raise_for_status = MagicMock()
- response_obj.aiter_bytes = MagicMock(return_value=mock_aiter_bytes())
- stream_cm = MagicMock()
- stream_cm.__aenter__ = AsyncMock(return_value=response_obj)
- stream_cm.__aexit__ = AsyncMock(return_value=False)
- client_obj = MagicMock()
- client_obj.stream = MagicMock(return_value=stream_cm)
- client_cm = MagicMock()
- client_cm.__aenter__ = AsyncMock(return_value=client_obj)
- client_cm.__aexit__ = AsyncMock(return_value=False)
- original_client = fh.httpx.AsyncClient
- fh.httpx.AsyncClient = MagicMock(return_value=client_cm)
- try:
- result = await fh.download_file("https://example.com/long.pdf", long_name, str(tmp_path))
- finally:
- fh.httpx.AsyncClient = original_client
- assert Path(result).exists()
- assert Path(result).suffix == ".pdf"
- assert len(Path(result).name) <= 255
|