main.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import json
  2. import logging
  3. import os
  4. import sys
  5. import traceback
  6. import uuid
  7. from contextvars import ContextVar
  8. from datetime import datetime
  9. import uvicorn
  10. from agency_fin import create_agency
  11. from dotenv import load_dotenv
  12. from fastapi import Request, Response
  13. from starlette.middleware.base import BaseHTTPMiddleware
  14. from agency_swarm import run_fastapi
  15. load_dotenv()
  16. request_id_context: ContextVar[str] = ContextVar("request_id", default="")
  17. log_to_file_context: ContextVar[bool] = ContextVar("log_to_file", default=False)
  18. script_dir = os.path.dirname(os.path.abspath(__file__))
  19. logs_dir = os.path.join(script_dir, "activity-logs")
  20. os.makedirs(logs_dir, exist_ok=True)
  21. class ConsoleFormatter(logging.Formatter):
  22. def format(self, record):
  23. request_id = request_id_context.get("")
  24. if request_id:
  25. request_id_str = f"[{request_id}] "
  26. else:
  27. request_id_str = ""
  28. if hasattr(record, "funcName") and hasattr(record, "module"):
  29. location = f"{record.module}.{record.funcName}:{record.lineno}"
  30. elif hasattr(record, "filename"):
  31. location = f"{record.filename}:{record.lineno}"
  32. else:
  33. location = "unknown"
  34. formatted = f"{request_id_str}[{record.levelname}] {location} - {record.getMessage()}"
  35. if record.exc_info:
  36. formatted += "\n" + self.formatException(record.exc_info)
  37. elif record.levelno >= logging.ERROR:
  38. current_traceback = traceback.format_stack()
  39. if len(current_traceback) > 1:
  40. formatted += "\n" + "-" * 40 + " CALL STACK " + "-" * 40 + "\n" + "".join(current_traceback[:-1])
  41. return formatted
  42. class FileFormatter(logging.Formatter):
  43. def format(self, record):
  44. log_entry = {
  45. "message": record.getMessage(),
  46. "details": {
  47. "timestamp": datetime.fromtimestamp(record.created).isoformat(),
  48. "level": record.levelname,
  49. "location": {
  50. "file": getattr(record, "filename", "unknown"),
  51. "function": getattr(record, "funcName", "unknown"),
  52. "line": getattr(record, "lineno", 0),
  53. },
  54. },
  55. }
  56. if record.exc_info:
  57. log_entry["details"]["exception"] = {
  58. "type": record.exc_info[0].__name__,
  59. "message": str(record.exc_info[1]),
  60. "traceback": self.formatException(record.exc_info).split("\n"),
  61. }
  62. elif record.levelno >= logging.ERROR:
  63. current_traceback = traceback.format_stack()
  64. if len(current_traceback) > 1:
  65. log_entry["details"]["call_stack"] = [line.strip() for line in current_traceback[:-1] if line.strip()]
  66. return json.dumps(log_entry, ensure_ascii=False)
  67. class ConditionalFileHandler(logging.Handler):
  68. def __init__(self):
  69. super().__init__()
  70. def emit(self, record):
  71. if log_to_file_context.get(False):
  72. request_id = request_id_context.get("")
  73. if request_id:
  74. try:
  75. log_file = os.path.join(logs_dir, f"{request_id}.jsonl")
  76. formatted_message = self.format(record)
  77. with open(log_file, "a", encoding="utf-8") as f:
  78. f.write(formatted_message + "\n")
  79. except Exception:
  80. pass
  81. def setup_logging():
  82. """Setup our custom logging configuration and protect against overwrites"""
  83. console_handler = logging.StreamHandler()
  84. console_handler.setFormatter(ConsoleFormatter())
  85. console_handler.name = "custom_console"
  86. file_handler = ConditionalFileHandler()
  87. file_handler.setFormatter(FileFormatter())
  88. file_handler.name = "custom_file"
  89. logger = logging.getLogger()
  90. logger.handlers.clear()
  91. logger.addHandler(console_handler)
  92. logger.addHandler(file_handler)
  93. logger.setLevel(logging.INFO)
  94. logger.propagate = False
  95. return logger
  96. setup_logging()
  97. def handle_exception(exc_type, exc_value, exc_traceback):
  98. if issubclass(exc_type, KeyboardInterrupt):
  99. sys.__excepthook__(exc_type, exc_value, exc_traceback)
  100. return
  101. logging.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))
  102. sys.excepthook = handle_exception
  103. def get_log_id_from_headers(request: Request) -> tuple[str, bool]:
  104. log_id = request.headers.get("x-agency-log-id")
  105. if log_id:
  106. return log_id, True
  107. return str(uuid.uuid4())[:8], False
  108. class RequestTracker(BaseHTTPMiddleware):
  109. async def dispatch(self, request: Request, call_next):
  110. request_id, should_log_to_file = get_log_id_from_headers(request)
  111. request_id_context.set(request_id)
  112. log_to_file_context.set(should_log_to_file)
  113. response = await call_next(request)
  114. return response
  115. async def get_logs(request: Request):
  116. try:
  117. data = await request.json()
  118. log_id = data.get("log_id")
  119. if not log_id:
  120. return Response(
  121. status_code=400,
  122. content='{"error": "Log ID is required"}',
  123. media_type="application/json",
  124. )
  125. log_file = os.path.join(logs_dir, f"{log_id}.jsonl")
  126. if not os.path.exists(log_file):
  127. return Response(
  128. status_code=404,
  129. content='{"error": "Log file not found"}',
  130. media_type="application/json",
  131. )
  132. log_entries = []
  133. with open(log_file, encoding="utf-8") as f:
  134. for line in f:
  135. line = line.strip()
  136. if line:
  137. try:
  138. log_entry = json.loads(line)
  139. log_entries.append(log_entry)
  140. except json.JSONDecodeError:
  141. pass
  142. os.remove(log_file)
  143. return Response(
  144. status_code=200,
  145. content=json.dumps(log_entries, ensure_ascii=False, indent=2),
  146. media_type="application/json",
  147. )
  148. except Exception:
  149. return Response(
  150. status_code=500,
  151. content='{"error": "Internal server error"}',
  152. media_type="application/json",
  153. )
  154. if __name__ == "__main__":
  155. app = run_fastapi(
  156. agencies={
  157. "my-agency": create_agency,
  158. },
  159. port=8080,
  160. return_app=True,
  161. )
  162. setup_logging()
  163. app.add_middleware(RequestTracker)
  164. app.add_route("/get_logs", get_logs, methods=["POST"])
  165. uvicorn.run(app, host="0.0.0.0", port=3088)