| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527 |
- """
- Integration tests for FastAPI logging middleware.
- Tests the actual logging behavior, request tracking, file operations,
- and HTTP middleware functionality.
- """
- import asyncio
- import json
- import logging
- import os
- import tempfile
- from contextlib import contextmanager
- from pathlib import Path
- from unittest.mock import MagicMock, patch
- import httpx
- import pytest
- os.environ.setdefault("OPENAI_AGENTS_DISABLE_TRACING", "1")
- from agents import ModelSettings
- from agents.tracing import set_tracing_disabled
- from agency_swarm import Agency, Agent, run_fastapi
- from agency_swarm.integrations.fastapi_utils.logging_middleware import (
- ConditionalFileHandler,
- ConsoleFormatter,
- FileFormatter,
- RequestTracker,
- get_log_id_from_headers,
- get_logs_endpoint_impl,
- log_to_file_context,
- request_id_context,
- setup_enhanced_logging,
- )
- @contextmanager
- def set_context(var, value):
- """Temporarily set a ContextVar value."""
- token = var.set(value)
- try:
- yield
- finally:
- var.reset(token)
- @pytest.fixture(autouse=True)
- def ensure_clean_logging_context():
- """Ensure logging ContextVars start clean and reset after each test."""
- # Detect leakage from prior tests before forcing a clean baseline.
- assert request_id_context.get() == ""
- assert log_to_file_context.get() is False
- request_token = request_id_context.set("")
- log_token = log_to_file_context.set(False)
- try:
- yield
- finally:
- request_id_context.reset(request_token)
- log_to_file_context.reset(log_token)
- @pytest.fixture
- def temp_logs_dir():
- """Create a temporary directory for test logs."""
- with tempfile.TemporaryDirectory() as temp_dir:
- yield temp_dir
- @pytest.fixture
- def agency_factory():
- """Create an agency factory for testing."""
- def create_agency(load_threads_callback=None):
- agent = Agent(
- name="LogTestAgent",
- instructions="You are a test agent for logging middleware testing.",
- model_settings=ModelSettings(temperature=0),
- )
- return Agency(
- agent,
- load_threads_callback=load_threads_callback,
- )
- return create_agency
- class TestConsoleFormatter:
- """Test console log formatting with request tracking."""
- def test_format_with_request_id(self):
- """Test that console formatter includes request ID when present."""
- formatter = ConsoleFormatter()
- record = logging.LogRecord(
- name="test", level=logging.INFO, pathname="test.py", lineno=42, msg="Test message", args=(), exc_info=None
- )
- record.funcName = "test_func"
- record.module = "test_module"
- # Set request ID in context
- with set_context(request_id_context, "req-123"):
- formatted = formatter.format(record)
- assert "[req-123]" in formatted
- assert "[INFO]" in formatted
- assert "test_module.test_func:42" in formatted
- assert "Test message" in formatted
- def test_format_without_request_id(self):
- """Test console formatting when no request ID is set."""
- formatter = ConsoleFormatter()
- record = logging.LogRecord(
- name="test",
- level=logging.WARNING,
- pathname="test.py",
- lineno=10,
- msg="Warning message",
- args=(),
- exc_info=None,
- )
- record.filename = "test.py"
- record.funcName = "test_function"
- record.module = "test_module"
- # Clear request ID context
- with set_context(request_id_context, ""):
- formatted = formatter.format(record)
- assert "[req-" not in formatted # No request ID prefix
- assert "[WARNING]" in formatted
- assert "test_module.test_function:10" in formatted
- assert "Warning message" in formatted
- def test_format_with_exception(self):
- """Test console formatting includes exception details."""
- formatter = ConsoleFormatter()
- try:
- raise ValueError("Test exception")
- except ValueError:
- import sys
- exc_info = sys.exc_info()
- record = logging.LogRecord(
- name="test",
- level=logging.ERROR,
- pathname="test.py",
- lineno=20,
- msg="Error occurred",
- args=(),
- exc_info=exc_info,
- )
- formatted = formatter.format(record)
- assert "Error occurred" in formatted
- assert "ValueError: Test exception" in formatted
- assert "Traceback" in formatted
- class TestFileFormatter:
- """Test JSON file log formatting."""
- def test_format_basic_log(self):
- """Test JSON formatting for basic log entries."""
- formatter = FileFormatter()
- record = logging.LogRecord(
- name="test",
- level=logging.INFO,
- pathname="test.py",
- lineno=100,
- msg="JSON test message",
- args=(),
- exc_info=None,
- )
- record.funcName = "json_func"
- record.filename = "test.py"
- formatted = formatter.format(record)
- log_entry = json.loads(formatted)
- assert log_entry["message"] == "JSON test message"
- assert log_entry["details"]["level"] == "INFO"
- assert log_entry["details"]["location"]["file"] == "test.py"
- assert log_entry["details"]["location"]["function"] == "json_func"
- assert log_entry["details"]["location"]["line"] == 100
- assert "timestamp" in log_entry["details"]
- def test_format_with_exception_info(self):
- """Test JSON formatting includes structured exception data."""
- formatter = FileFormatter()
- try:
- raise RuntimeError("JSON test exception")
- except RuntimeError:
- import sys
- exc_info = sys.exc_info()
- record = logging.LogRecord(
- name="test",
- level=logging.ERROR,
- pathname="test.py",
- lineno=50,
- msg="Exception in JSON",
- args=(),
- exc_info=exc_info,
- )
- formatted = formatter.format(record)
- log_entry = json.loads(formatted)
- assert "exception" in log_entry["details"]
- assert log_entry["details"]["exception"]["type"] == "RuntimeError"
- assert log_entry["details"]["exception"]["message"] == "JSON test exception"
- assert isinstance(log_entry["details"]["exception"]["traceback"], list)
- class TestConditionalFileHandler:
- """Test conditional file logging based on context."""
- def test_logs_when_enabled(self, temp_logs_dir):
- """Test that handler writes to file when context is enabled."""
- handler = ConditionalFileHandler(temp_logs_dir)
- handler.setFormatter(FileFormatter())
- record = logging.LogRecord(
- name="test",
- level=logging.INFO,
- pathname="test.py",
- lineno=1,
- msg="Should be logged",
- args=(),
- exc_info=None,
- )
- # Enable file logging and set request ID
- with set_context(log_to_file_context, True), set_context(request_id_context, "test-id-123"):
- handler.emit(record)
- # Check that log file was created
- log_file = Path(temp_logs_dir) / "test-id-123.jsonl"
- assert log_file.exists()
- # Verify content
- content = log_file.read_text(encoding="utf-8")
- log_entry = json.loads(content.strip())
- assert log_entry["message"] == "Should be logged"
- def test_skips_when_disabled(self, temp_logs_dir):
- """Test that handler doesn't write when file logging is disabled."""
- handler = ConditionalFileHandler(temp_logs_dir)
- handler.setFormatter(FileFormatter())
- record = logging.LogRecord(
- name="test",
- level=logging.INFO,
- pathname="test.py",
- lineno=1,
- msg="Should not be logged",
- args=(),
- exc_info=None,
- )
- # Disable file logging
- with set_context(log_to_file_context, False), set_context(request_id_context, "test-id-456"):
- handler.emit(record)
- # Check that no log file was created
- log_file = Path(temp_logs_dir) / "test-id-456.jsonl"
- assert not log_file.exists()
- def test_handles_write_errors_gracefully(self, temp_logs_dir):
- """Test that handler doesn't crash when file writing fails."""
- handler = ConditionalFileHandler(temp_logs_dir)
- handler.setFormatter(FileFormatter())
- record = logging.LogRecord(
- name="test",
- level=logging.INFO,
- pathname="test.py",
- lineno=1,
- msg="This should fail silently",
- args=(),
- exc_info=None,
- )
- with set_context(log_to_file_context, True), set_context(request_id_context, "error-test"):
- # Should not raise exception
- with patch("builtins.open", side_effect=OSError("write error")):
- handler.emit(record)
- class TestSetupEnhancedLogging:
- """Test the logging setup function."""
- def test_creates_logs_directory(self, temp_logs_dir):
- """Test that setup creates the logs directory."""
- non_existent_dir = os.path.join(temp_logs_dir, "new_logs")
- assert not os.path.exists(non_existent_dir)
- logger = setup_enhanced_logging(non_existent_dir)
- assert os.path.exists(non_existent_dir)
- assert isinstance(logger, logging.Logger)
- def test_configures_handlers_correctly(self, temp_logs_dir):
- """Test that setup configures console and file handlers."""
- logger = setup_enhanced_logging(temp_logs_dir)
- # Should have exactly 2 handlers
- assert len(logger.handlers) == 2
- # Check handler names and types
- handler_names = [h.name for h in logger.handlers]
- assert "custom_console" in handler_names
- assert "custom_file" in handler_names
- # Check formatters
- console_handler = next(h for h in logger.handlers if h.name == "custom_console")
- file_handler = next(h for h in logger.handlers if h.name == "custom_file")
- assert isinstance(console_handler.formatter, ConsoleFormatter)
- assert isinstance(file_handler.formatter, FileFormatter)
- class TestGetLogIdFromHeaders:
- """Test request header processing for log IDs."""
- def test_extracts_existing_log_id(self):
- """Test that function extracts log ID from headers when present."""
- mock_request = MagicMock()
- mock_request.headers.get.return_value = "custom-log-id-789"
- log_id, should_log = get_log_id_from_headers(mock_request)
- assert log_id == "custom-log-id-789"
- assert should_log is True
- def test_generates_new_log_id(self):
- """Test that function generates new log ID when header is missing."""
- mock_request = MagicMock()
- mock_request.headers.get.return_value = None
- log_id, should_log = get_log_id_from_headers(mock_request)
- assert len(log_id) == 8 # Should be 8-character UUID prefix
- assert should_log is False
- class TestRequestTracker:
- """Test the HTTP middleware for request tracking."""
- @pytest.mark.asyncio
- async def test_sets_context_variables(self):
- """Test that middleware sets request ID and logging context."""
- middleware = RequestTracker(MagicMock())
- mock_request = MagicMock()
- mock_request.headers.get.return_value = "middleware-test-id"
- async def mock_call_next(request):
- # Verify context is set during request processing
- assert request_id_context.get() == "middleware-test-id"
- assert log_to_file_context.get() is True
- return MagicMock()
- await middleware.dispatch(mock_request, mock_call_next)
- # Context variables should be reset after the request completes
- assert request_id_context.get() == ""
- assert log_to_file_context.get() is False
- @pytest.mark.asyncio
- async def test_resets_context_on_exception(self):
- """Middleware must reset context when downstream handlers fail."""
- middleware = RequestTracker(MagicMock())
- mock_request = MagicMock()
- mock_request.headers.get.return_value = "middleware-error-test"
- async def mock_call_next(request):
- assert request_id_context.get() == "middleware-error-test"
- assert log_to_file_context.get() is True
- raise RuntimeError("downstream failure")
- with pytest.raises(RuntimeError):
- await middleware.dispatch(mock_request, mock_call_next)
- assert request_id_context.get() == ""
- assert log_to_file_context.get() is False
- @pytest.mark.asyncio
- async def test_run_fastapi_logging_integration(self, agency_factory, temp_logs_dir):
- """Test logging middleware with actual run_fastapi method."""
- set_tracing_disabled(True)
- # Build FastAPI app with logging enabled
- app = run_fastapi(
- agencies={"test_agency": agency_factory},
- port=8099,
- logs_dir=temp_logs_dir,
- return_app=True,
- enable_agui=False,
- enable_logging=True, # Enable logging to test the middleware
- )
- transport = httpx.ASGITransport(app=app)
- try:
- # Make request with log ID header against in-process app
- async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
- response = await client.post(
- "/test_agency/get_response",
- json={"message": "Test logging middleware"},
- headers={"x-agency-log-id": "fastapi-integration-test"},
- timeout=30.0,
- )
- assert response.status_code == 200
- # Wait for log file to be written
- log_file = Path(temp_logs_dir) / "fastapi-integration-test.jsonl"
- for _ in range(20):
- if log_file.exists():
- break
- await asyncio.sleep(0.1)
- assert log_file.exists()
- # Verify log content
- content = log_file.read_text(encoding="utf-8")
- log_lines = [line for line in content.strip().split("\n") if line.strip()]
- assert len(log_lines) >= 1 # Should have at least some logs
- # Parse and verify log entries
- for line in log_lines:
- log_entry = json.loads(line)
- assert "message" in log_entry
- assert "details" in log_entry
- assert "timestamp" in log_entry["details"]
- finally:
- await transport.aclose()
- class TestGetLogsEndpointImpl:
- """Test the logs retrieval endpoint implementation."""
- @pytest.mark.asyncio
- async def test_retrieves_and_deletes_log_file(self, temp_logs_dir):
- """Test that endpoint returns logs and deletes the file."""
- # Create a test log file
- log_file = Path(temp_logs_dir) / "endpoint-test.jsonl"
- test_logs = [
- {"message": "Log entry 1", "details": {"level": "INFO"}},
- {"message": "Log entry 2", "details": {"level": "ERROR"}},
- ]
- with log_file.open("w", encoding="utf-8") as f:
- for log_entry in test_logs:
- f.write(json.dumps(log_entry) + "\n")
- # Call the endpoint
- response = await get_logs_endpoint_impl("endpoint-test", temp_logs_dir)
- assert response.status_code == 200
- assert response.media_type == "application/json"
- # Parse response content
- response_data = json.loads(response.body)
- assert len(response_data) == 2
- assert response_data[0]["message"] == "Log entry 1"
- assert response_data[1]["message"] == "Log entry 2"
- # Verify file was deleted
- assert not log_file.exists()
- @pytest.mark.asyncio
- async def test_returns_404_for_missing_file(self, temp_logs_dir):
- """Test that endpoint returns 404 for non-existent log files."""
- response = await get_logs_endpoint_impl("non-existent", temp_logs_dir)
- assert response.status_code == 404
- assert "Log file not found" in response.body.decode()
- @pytest.mark.asyncio
- async def test_returns_400_for_empty_log_id(self, temp_logs_dir):
- """Test that endpoint returns 400 for empty log ID."""
- response = await get_logs_endpoint_impl("", temp_logs_dir)
- assert response.status_code == 400
- assert "Log ID is required" in response.body.decode()
- @pytest.mark.asyncio
- async def test_handles_invalid_json_gracefully(self, temp_logs_dir):
- """Test that endpoint skips invalid JSON lines."""
- log_file = Path(temp_logs_dir) / "invalid-json-test.jsonl"
- with log_file.open("w", encoding="utf-8") as f:
- f.write('{"valid": "json"}\n')
- f.write("invalid json line\n")
- f.write('{"another": "valid"}\n')
- response = await get_logs_endpoint_impl("invalid-json-test", temp_logs_dir)
- assert response.status_code == 200
- response_data = json.loads(response.body)
- assert len(response_data) == 2 # Only valid JSON entries
- assert response_data[0]["valid"] == "json"
- assert response_data[1]["another"] == "valid"
- @pytest.mark.asyncio
- async def test_handles_file_system_errors(self):
- """Test that endpoint handles file system errors gracefully."""
- # Use invalid directory to trigger file system error
- with patch("os.path.exists", side_effect=OSError("File system error")):
- response = await get_logs_endpoint_impl("test-id", "/invalid/path")
- assert response.status_code == 500
- assert "Internal server error" in response.body.decode()
|