test_logging_middleware.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. """
  2. Integration tests for FastAPI logging middleware.
  3. Tests the actual logging behavior, request tracking, file operations,
  4. and HTTP middleware functionality.
  5. """
  6. import asyncio
  7. import json
  8. import logging
  9. import os
  10. import tempfile
  11. from contextlib import contextmanager
  12. from pathlib import Path
  13. from unittest.mock import MagicMock, patch
  14. import httpx
  15. import pytest
  16. os.environ.setdefault("OPENAI_AGENTS_DISABLE_TRACING", "1")
  17. from agents import ModelSettings
  18. from agents.tracing import set_tracing_disabled
  19. from agency_swarm import Agency, Agent, run_fastapi
  20. from agency_swarm.integrations.fastapi_utils.logging_middleware import (
  21. ConditionalFileHandler,
  22. ConsoleFormatter,
  23. FileFormatter,
  24. RequestTracker,
  25. get_log_id_from_headers,
  26. get_logs_endpoint_impl,
  27. log_to_file_context,
  28. request_id_context,
  29. setup_enhanced_logging,
  30. )
  31. @contextmanager
  32. def set_context(var, value):
  33. """Temporarily set a ContextVar value."""
  34. token = var.set(value)
  35. try:
  36. yield
  37. finally:
  38. var.reset(token)
  39. @pytest.fixture(autouse=True)
  40. def ensure_clean_logging_context():
  41. """Ensure logging ContextVars start clean and reset after each test."""
  42. # Detect leakage from prior tests before forcing a clean baseline.
  43. assert request_id_context.get() == ""
  44. assert log_to_file_context.get() is False
  45. request_token = request_id_context.set("")
  46. log_token = log_to_file_context.set(False)
  47. try:
  48. yield
  49. finally:
  50. request_id_context.reset(request_token)
  51. log_to_file_context.reset(log_token)
  52. @pytest.fixture
  53. def temp_logs_dir():
  54. """Create a temporary directory for test logs."""
  55. with tempfile.TemporaryDirectory() as temp_dir:
  56. yield temp_dir
  57. @pytest.fixture
  58. def agency_factory():
  59. """Create an agency factory for testing."""
  60. def create_agency(load_threads_callback=None):
  61. agent = Agent(
  62. name="LogTestAgent",
  63. instructions="You are a test agent for logging middleware testing.",
  64. model_settings=ModelSettings(temperature=0),
  65. )
  66. return Agency(
  67. agent,
  68. load_threads_callback=load_threads_callback,
  69. )
  70. return create_agency
  71. class TestConsoleFormatter:
  72. """Test console log formatting with request tracking."""
  73. def test_format_with_request_id(self):
  74. """Test that console formatter includes request ID when present."""
  75. formatter = ConsoleFormatter()
  76. record = logging.LogRecord(
  77. name="test", level=logging.INFO, pathname="test.py", lineno=42, msg="Test message", args=(), exc_info=None
  78. )
  79. record.funcName = "test_func"
  80. record.module = "test_module"
  81. # Set request ID in context
  82. with set_context(request_id_context, "req-123"):
  83. formatted = formatter.format(record)
  84. assert "[req-123]" in formatted
  85. assert "[INFO]" in formatted
  86. assert "test_module.test_func:42" in formatted
  87. assert "Test message" in formatted
  88. def test_format_without_request_id(self):
  89. """Test console formatting when no request ID is set."""
  90. formatter = ConsoleFormatter()
  91. record = logging.LogRecord(
  92. name="test",
  93. level=logging.WARNING,
  94. pathname="test.py",
  95. lineno=10,
  96. msg="Warning message",
  97. args=(),
  98. exc_info=None,
  99. )
  100. record.filename = "test.py"
  101. record.funcName = "test_function"
  102. record.module = "test_module"
  103. # Clear request ID context
  104. with set_context(request_id_context, ""):
  105. formatted = formatter.format(record)
  106. assert "[req-" not in formatted # No request ID prefix
  107. assert "[WARNING]" in formatted
  108. assert "test_module.test_function:10" in formatted
  109. assert "Warning message" in formatted
  110. def test_format_with_exception(self):
  111. """Test console formatting includes exception details."""
  112. formatter = ConsoleFormatter()
  113. try:
  114. raise ValueError("Test exception")
  115. except ValueError:
  116. import sys
  117. exc_info = sys.exc_info()
  118. record = logging.LogRecord(
  119. name="test",
  120. level=logging.ERROR,
  121. pathname="test.py",
  122. lineno=20,
  123. msg="Error occurred",
  124. args=(),
  125. exc_info=exc_info,
  126. )
  127. formatted = formatter.format(record)
  128. assert "Error occurred" in formatted
  129. assert "ValueError: Test exception" in formatted
  130. assert "Traceback" in formatted
  131. class TestFileFormatter:
  132. """Test JSON file log formatting."""
  133. def test_format_basic_log(self):
  134. """Test JSON formatting for basic log entries."""
  135. formatter = FileFormatter()
  136. record = logging.LogRecord(
  137. name="test",
  138. level=logging.INFO,
  139. pathname="test.py",
  140. lineno=100,
  141. msg="JSON test message",
  142. args=(),
  143. exc_info=None,
  144. )
  145. record.funcName = "json_func"
  146. record.filename = "test.py"
  147. formatted = formatter.format(record)
  148. log_entry = json.loads(formatted)
  149. assert log_entry["message"] == "JSON test message"
  150. assert log_entry["details"]["level"] == "INFO"
  151. assert log_entry["details"]["location"]["file"] == "test.py"
  152. assert log_entry["details"]["location"]["function"] == "json_func"
  153. assert log_entry["details"]["location"]["line"] == 100
  154. assert "timestamp" in log_entry["details"]
  155. def test_format_with_exception_info(self):
  156. """Test JSON formatting includes structured exception data."""
  157. formatter = FileFormatter()
  158. try:
  159. raise RuntimeError("JSON test exception")
  160. except RuntimeError:
  161. import sys
  162. exc_info = sys.exc_info()
  163. record = logging.LogRecord(
  164. name="test",
  165. level=logging.ERROR,
  166. pathname="test.py",
  167. lineno=50,
  168. msg="Exception in JSON",
  169. args=(),
  170. exc_info=exc_info,
  171. )
  172. formatted = formatter.format(record)
  173. log_entry = json.loads(formatted)
  174. assert "exception" in log_entry["details"]
  175. assert log_entry["details"]["exception"]["type"] == "RuntimeError"
  176. assert log_entry["details"]["exception"]["message"] == "JSON test exception"
  177. assert isinstance(log_entry["details"]["exception"]["traceback"], list)
  178. class TestConditionalFileHandler:
  179. """Test conditional file logging based on context."""
  180. def test_logs_when_enabled(self, temp_logs_dir):
  181. """Test that handler writes to file when context is enabled."""
  182. handler = ConditionalFileHandler(temp_logs_dir)
  183. handler.setFormatter(FileFormatter())
  184. record = logging.LogRecord(
  185. name="test",
  186. level=logging.INFO,
  187. pathname="test.py",
  188. lineno=1,
  189. msg="Should be logged",
  190. args=(),
  191. exc_info=None,
  192. )
  193. # Enable file logging and set request ID
  194. with set_context(log_to_file_context, True), set_context(request_id_context, "test-id-123"):
  195. handler.emit(record)
  196. # Check that log file was created
  197. log_file = Path(temp_logs_dir) / "test-id-123.jsonl"
  198. assert log_file.exists()
  199. # Verify content
  200. content = log_file.read_text(encoding="utf-8")
  201. log_entry = json.loads(content.strip())
  202. assert log_entry["message"] == "Should be logged"
  203. def test_skips_when_disabled(self, temp_logs_dir):
  204. """Test that handler doesn't write when file logging is disabled."""
  205. handler = ConditionalFileHandler(temp_logs_dir)
  206. handler.setFormatter(FileFormatter())
  207. record = logging.LogRecord(
  208. name="test",
  209. level=logging.INFO,
  210. pathname="test.py",
  211. lineno=1,
  212. msg="Should not be logged",
  213. args=(),
  214. exc_info=None,
  215. )
  216. # Disable file logging
  217. with set_context(log_to_file_context, False), set_context(request_id_context, "test-id-456"):
  218. handler.emit(record)
  219. # Check that no log file was created
  220. log_file = Path(temp_logs_dir) / "test-id-456.jsonl"
  221. assert not log_file.exists()
  222. def test_handles_write_errors_gracefully(self, temp_logs_dir):
  223. """Test that handler doesn't crash when file writing fails."""
  224. handler = ConditionalFileHandler(temp_logs_dir)
  225. handler.setFormatter(FileFormatter())
  226. record = logging.LogRecord(
  227. name="test",
  228. level=logging.INFO,
  229. pathname="test.py",
  230. lineno=1,
  231. msg="This should fail silently",
  232. args=(),
  233. exc_info=None,
  234. )
  235. with set_context(log_to_file_context, True), set_context(request_id_context, "error-test"):
  236. # Should not raise exception
  237. with patch("builtins.open", side_effect=OSError("write error")):
  238. handler.emit(record)
  239. class TestSetupEnhancedLogging:
  240. """Test the logging setup function."""
  241. def test_creates_logs_directory(self, temp_logs_dir):
  242. """Test that setup creates the logs directory."""
  243. non_existent_dir = os.path.join(temp_logs_dir, "new_logs")
  244. assert not os.path.exists(non_existent_dir)
  245. logger = setup_enhanced_logging(non_existent_dir)
  246. assert os.path.exists(non_existent_dir)
  247. assert isinstance(logger, logging.Logger)
  248. def test_configures_handlers_correctly(self, temp_logs_dir):
  249. """Test that setup configures console and file handlers."""
  250. logger = setup_enhanced_logging(temp_logs_dir)
  251. # Should have exactly 2 handlers
  252. assert len(logger.handlers) == 2
  253. # Check handler names and types
  254. handler_names = [h.name for h in logger.handlers]
  255. assert "custom_console" in handler_names
  256. assert "custom_file" in handler_names
  257. # Check formatters
  258. console_handler = next(h for h in logger.handlers if h.name == "custom_console")
  259. file_handler = next(h for h in logger.handlers if h.name == "custom_file")
  260. assert isinstance(console_handler.formatter, ConsoleFormatter)
  261. assert isinstance(file_handler.formatter, FileFormatter)
  262. class TestGetLogIdFromHeaders:
  263. """Test request header processing for log IDs."""
  264. def test_extracts_existing_log_id(self):
  265. """Test that function extracts log ID from headers when present."""
  266. mock_request = MagicMock()
  267. mock_request.headers.get.return_value = "custom-log-id-789"
  268. log_id, should_log = get_log_id_from_headers(mock_request)
  269. assert log_id == "custom-log-id-789"
  270. assert should_log is True
  271. def test_generates_new_log_id(self):
  272. """Test that function generates new log ID when header is missing."""
  273. mock_request = MagicMock()
  274. mock_request.headers.get.return_value = None
  275. log_id, should_log = get_log_id_from_headers(mock_request)
  276. assert len(log_id) == 8 # Should be 8-character UUID prefix
  277. assert should_log is False
  278. class TestRequestTracker:
  279. """Test the HTTP middleware for request tracking."""
  280. @pytest.mark.asyncio
  281. async def test_sets_context_variables(self):
  282. """Test that middleware sets request ID and logging context."""
  283. middleware = RequestTracker(MagicMock())
  284. mock_request = MagicMock()
  285. mock_request.headers.get.return_value = "middleware-test-id"
  286. async def mock_call_next(request):
  287. # Verify context is set during request processing
  288. assert request_id_context.get() == "middleware-test-id"
  289. assert log_to_file_context.get() is True
  290. return MagicMock()
  291. await middleware.dispatch(mock_request, mock_call_next)
  292. # Context variables should be reset after the request completes
  293. assert request_id_context.get() == ""
  294. assert log_to_file_context.get() is False
  295. @pytest.mark.asyncio
  296. async def test_resets_context_on_exception(self):
  297. """Middleware must reset context when downstream handlers fail."""
  298. middleware = RequestTracker(MagicMock())
  299. mock_request = MagicMock()
  300. mock_request.headers.get.return_value = "middleware-error-test"
  301. async def mock_call_next(request):
  302. assert request_id_context.get() == "middleware-error-test"
  303. assert log_to_file_context.get() is True
  304. raise RuntimeError("downstream failure")
  305. with pytest.raises(RuntimeError):
  306. await middleware.dispatch(mock_request, mock_call_next)
  307. assert request_id_context.get() == ""
  308. assert log_to_file_context.get() is False
  309. @pytest.mark.asyncio
  310. async def test_run_fastapi_logging_integration(self, agency_factory, temp_logs_dir):
  311. """Test logging middleware with actual run_fastapi method."""
  312. set_tracing_disabled(True)
  313. # Build FastAPI app with logging enabled
  314. app = run_fastapi(
  315. agencies={"test_agency": agency_factory},
  316. port=8099,
  317. logs_dir=temp_logs_dir,
  318. return_app=True,
  319. enable_agui=False,
  320. enable_logging=True, # Enable logging to test the middleware
  321. )
  322. transport = httpx.ASGITransport(app=app)
  323. try:
  324. # Make request with log ID header against in-process app
  325. async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
  326. response = await client.post(
  327. "/test_agency/get_response",
  328. json={"message": "Test logging middleware"},
  329. headers={"x-agency-log-id": "fastapi-integration-test"},
  330. timeout=30.0,
  331. )
  332. assert response.status_code == 200
  333. # Wait for log file to be written
  334. log_file = Path(temp_logs_dir) / "fastapi-integration-test.jsonl"
  335. for _ in range(20):
  336. if log_file.exists():
  337. break
  338. await asyncio.sleep(0.1)
  339. assert log_file.exists()
  340. # Verify log content
  341. content = log_file.read_text(encoding="utf-8")
  342. log_lines = [line for line in content.strip().split("\n") if line.strip()]
  343. assert len(log_lines) >= 1 # Should have at least some logs
  344. # Parse and verify log entries
  345. for line in log_lines:
  346. log_entry = json.loads(line)
  347. assert "message" in log_entry
  348. assert "details" in log_entry
  349. assert "timestamp" in log_entry["details"]
  350. finally:
  351. await transport.aclose()
  352. class TestGetLogsEndpointImpl:
  353. """Test the logs retrieval endpoint implementation."""
  354. @pytest.mark.asyncio
  355. async def test_retrieves_and_deletes_log_file(self, temp_logs_dir):
  356. """Test that endpoint returns logs and deletes the file."""
  357. # Create a test log file
  358. log_file = Path(temp_logs_dir) / "endpoint-test.jsonl"
  359. test_logs = [
  360. {"message": "Log entry 1", "details": {"level": "INFO"}},
  361. {"message": "Log entry 2", "details": {"level": "ERROR"}},
  362. ]
  363. with log_file.open("w", encoding="utf-8") as f:
  364. for log_entry in test_logs:
  365. f.write(json.dumps(log_entry) + "\n")
  366. # Call the endpoint
  367. response = await get_logs_endpoint_impl("endpoint-test", temp_logs_dir)
  368. assert response.status_code == 200
  369. assert response.media_type == "application/json"
  370. # Parse response content
  371. response_data = json.loads(response.body)
  372. assert len(response_data) == 2
  373. assert response_data[0]["message"] == "Log entry 1"
  374. assert response_data[1]["message"] == "Log entry 2"
  375. # Verify file was deleted
  376. assert not log_file.exists()
  377. @pytest.mark.asyncio
  378. async def test_returns_404_for_missing_file(self, temp_logs_dir):
  379. """Test that endpoint returns 404 for non-existent log files."""
  380. response = await get_logs_endpoint_impl("non-existent", temp_logs_dir)
  381. assert response.status_code == 404
  382. assert "Log file not found" in response.body.decode()
  383. @pytest.mark.asyncio
  384. async def test_returns_400_for_empty_log_id(self, temp_logs_dir):
  385. """Test that endpoint returns 400 for empty log ID."""
  386. response = await get_logs_endpoint_impl("", temp_logs_dir)
  387. assert response.status_code == 400
  388. assert "Log ID is required" in response.body.decode()
  389. @pytest.mark.asyncio
  390. async def test_handles_invalid_json_gracefully(self, temp_logs_dir):
  391. """Test that endpoint skips invalid JSON lines."""
  392. log_file = Path(temp_logs_dir) / "invalid-json-test.jsonl"
  393. with log_file.open("w", encoding="utf-8") as f:
  394. f.write('{"valid": "json"}\n')
  395. f.write("invalid json line\n")
  396. f.write('{"another": "valid"}\n')
  397. response = await get_logs_endpoint_impl("invalid-json-test", temp_logs_dir)
  398. assert response.status_code == 200
  399. response_data = json.loads(response.body)
  400. assert len(response_data) == 2 # Only valid JSON entries
  401. assert response_data[0]["valid"] == "json"
  402. assert response_data[1]["another"] == "valid"
  403. @pytest.mark.asyncio
  404. async def test_handles_file_system_errors(self):
  405. """Test that endpoint handles file system errors gracefully."""
  406. # Use invalid directory to trigger file system error
  407. with patch("os.path.exists", side_effect=OSError("File system error")):
  408. response = await get_logs_endpoint_impl("test-id", "/invalid/path")
  409. assert response.status_code == 500
  410. assert "Internal server error" in response.body.decode()