test_file_handler_downloads.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. """Unit tests for fastapi_utils download_file behavior."""
  2. import asyncio
  3. from pathlib import Path
  4. from unittest.mock import AsyncMock, MagicMock
  5. import pytest
  6. @pytest.mark.asyncio
  7. async def test_download_file_cleans_up_tmp_on_http_error(tmp_path: Path) -> None:
  8. """Temp file must be deleted and no fd must be leaked when the HTTP request fails."""
  9. import gc
  10. import agency_swarm.integrations.fastapi_utils.file_handler as fh
  11. response_obj = MagicMock()
  12. response_obj.raise_for_status = MagicMock(side_effect=Exception("HTTP 500"))
  13. response_obj.aiter_bytes = MagicMock()
  14. stream_cm = MagicMock()
  15. stream_cm.__aenter__ = AsyncMock(return_value=response_obj)
  16. stream_cm.__aexit__ = AsyncMock(return_value=False)
  17. client_obj = MagicMock()
  18. client_obj.stream = MagicMock(return_value=stream_cm)
  19. client_cm = MagicMock()
  20. client_cm.__aenter__ = AsyncMock(return_value=client_obj)
  21. client_cm.__aexit__ = AsyncMock(return_value=False)
  22. original_client = fh.httpx.AsyncClient
  23. fh.httpx.AsyncClient = MagicMock(return_value=client_cm)
  24. try:
  25. with pytest.raises(Exception, match="HTTP 500"):
  26. await fh.download_file("https://example.com/file.pdf", "file.pdf", str(tmp_path))
  27. finally:
  28. fh.httpx.AsyncClient = original_client
  29. gc.collect()
  30. leftover = list(tmp_path.glob("*.tmp"))
  31. assert leftover == [], f"Temp file was not cleaned up: {leftover}"
  32. open_fds_in_dir = [fd for fd in range(3, 1024) if _fd_points_to_dir(fd, str(tmp_path))]
  33. assert open_fds_in_dir == [], f"Leaked file descriptors pointing to tmp_path: {open_fds_in_dir}"
  34. def _fd_points_to_dir(fd: int, directory: str) -> bool:
  35. try:
  36. import os
  37. stat = os.fstat(fd)
  38. path = Path(directory)
  39. for file_path in path.iterdir():
  40. try:
  41. file_stat = file_path.stat()
  42. if file_stat.st_ino == stat.st_ino and file_stat.st_dev == stat.st_dev:
  43. return True
  44. except OSError:
  45. pass
  46. except OSError:
  47. pass
  48. return False
  49. @pytest.mark.asyncio
  50. async def test_download_file_concurrent_same_base_name(tmp_path: Path) -> None:
  51. """Two concurrent downloads with the same base name must not collide on the temp file."""
  52. fake_content_1 = b"%PDF-1.4 content one"
  53. fake_content_2 = b"%PDF-1.4 content two"
  54. def make_http_mock(content: bytes) -> MagicMock:
  55. async def mock_aiter_bytes():
  56. yield content
  57. response_obj = MagicMock()
  58. response_obj.raise_for_status = MagicMock()
  59. response_obj.aiter_bytes = MagicMock(return_value=mock_aiter_bytes())
  60. stream_cm = MagicMock()
  61. stream_cm.__aenter__ = AsyncMock(return_value=response_obj)
  62. stream_cm.__aexit__ = AsyncMock(return_value=False)
  63. client_obj = MagicMock()
  64. client_obj.stream = MagicMock(return_value=stream_cm)
  65. client_cm = MagicMock()
  66. client_cm.__aenter__ = AsyncMock(return_value=client_obj)
  67. client_cm.__aexit__ = AsyncMock(return_value=False)
  68. return client_cm
  69. import agency_swarm.integrations.fastapi_utils.file_handler as fh
  70. call_count = 0
  71. contents = [fake_content_1, fake_content_2]
  72. original_client = fh.httpx.AsyncClient
  73. def patched_client(**kwargs):
  74. nonlocal call_count
  75. mock = make_http_mock(contents[call_count % 2])
  76. call_count += 1
  77. return mock
  78. fh.httpx.AsyncClient = patched_client
  79. try:
  80. result1, result2 = await asyncio.gather(
  81. fh.download_file("https://example.com/f1", "DASDA", str(tmp_path)),
  82. fh.download_file("https://example.com/f2", "DASDA.pdf", str(tmp_path)),
  83. )
  84. finally:
  85. fh.httpx.AsyncClient = original_client
  86. assert result1 != result2, "Each download must produce a unique output path"
  87. assert Path(result1).exists(), "First download result must exist"
  88. assert Path(result2).exists(), "Second download result must exist"
  89. contents_found = {Path(result1).read_bytes(), Path(result2).read_bytes()}
  90. assert contents_found == {fake_content_1, fake_content_2}, "Each download's content must be preserved"
  91. @pytest.mark.asyncio
  92. async def test_download_file_uses_shutil_move_for_cross_device_rename(tmp_path: Path) -> None:
  93. """download_file must use shutil.move instead of Path.replace."""
  94. import agency_swarm.integrations.fastapi_utils.file_handler as fh
  95. fake_content = b"%PDF-1.4 fake content"
  96. pdf_name = "DASDA.pdf"
  97. async def mock_aiter_bytes():
  98. yield fake_content
  99. response_obj = MagicMock()
  100. response_obj.raise_for_status = MagicMock()
  101. response_obj.aiter_bytes = MagicMock(return_value=mock_aiter_bytes())
  102. stream_cm = MagicMock()
  103. stream_cm.__aenter__ = AsyncMock(return_value=response_obj)
  104. stream_cm.__aexit__ = AsyncMock(return_value=False)
  105. client_obj = MagicMock()
  106. client_obj.stream = MagicMock(return_value=stream_cm)
  107. client_cm = MagicMock()
  108. client_cm.__aenter__ = AsyncMock(return_value=client_obj)
  109. client_cm.__aexit__ = AsyncMock(return_value=False)
  110. original_client = fh.httpx.AsyncClient
  111. original_move = fh.shutil.move
  112. move_was_called = []
  113. def tracking_move(src: str, dst: str):
  114. move_was_called.append((src, dst))
  115. return original_move(src, dst)
  116. fh.httpx.AsyncClient = MagicMock(return_value=client_cm)
  117. fh.shutil.move = tracking_move
  118. try:
  119. result = await fh.download_file("https://example.com/DASDA.pdf", pdf_name, str(tmp_path))
  120. finally:
  121. fh.httpx.AsyncClient = original_client
  122. fh.shutil.move = original_move
  123. assert Path(result).suffix == ".pdf"
  124. assert Path(result).parent == tmp_path
  125. assert Path(result).exists()
  126. assert Path(result).read_bytes() == fake_content
  127. assert len(move_was_called) == 1
  128. src, dst = move_was_called[0]
  129. assert src.endswith(".tmp")
  130. assert dst.endswith(".pdf")
  131. assert not list(tmp_path.glob("*.tmp")), ".tmp file should be removed after move"
  132. @pytest.mark.asyncio
  133. async def test_download_file_long_filename_does_not_crash(tmp_path: Path) -> None:
  134. """A filename with a ~250-char base must not crash mkstemp with a filesystem limit error."""
  135. import agency_swarm.integrations.fastapi_utils.file_handler as fh
  136. long_name = "A" * 250 + ".pdf"
  137. fake_content = b"%PDF-1.4 long name"
  138. async def mock_aiter_bytes():
  139. yield fake_content
  140. response_obj = MagicMock()
  141. response_obj.raise_for_status = MagicMock()
  142. response_obj.aiter_bytes = MagicMock(return_value=mock_aiter_bytes())
  143. stream_cm = MagicMock()
  144. stream_cm.__aenter__ = AsyncMock(return_value=response_obj)
  145. stream_cm.__aexit__ = AsyncMock(return_value=False)
  146. client_obj = MagicMock()
  147. client_obj.stream = MagicMock(return_value=stream_cm)
  148. client_cm = MagicMock()
  149. client_cm.__aenter__ = AsyncMock(return_value=client_obj)
  150. client_cm.__aexit__ = AsyncMock(return_value=False)
  151. original_client = fh.httpx.AsyncClient
  152. fh.httpx.AsyncClient = MagicMock(return_value=client_cm)
  153. try:
  154. result = await fh.download_file("https://example.com/long.pdf", long_name, str(tmp_path))
  155. finally:
  156. fh.httpx.AsyncClient = original_client
  157. assert Path(result).exists()
  158. assert Path(result).suffix == ".pdf"
  159. assert len(Path(result).name) <= 255