outline.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. """目录相关 API 路由。"""
  2. import asyncio
  3. import logging
  4. from fastapi import APIRouter, HTTPException
  5. from ..models.schemas import OutlineRequest, OutlineResponse
  6. from ..services.outline_service import OutlineService
  7. from ..utils.errors import AppError
  8. from ..utils.sse import (
  9. sse_done,
  10. sse_error,
  11. sse_progress,
  12. sse_response,
  13. sse_result,
  14. )
  15. logger = logging.getLogger(__name__)
  16. router = APIRouter(prefix="/api/outline", tags=["目录管理"])
  17. @router.post("/generate", response_model=OutlineResponse)
  18. async def generate_outline(request: OutlineRequest):
  19. """生成完整目录结构。"""
  20. try:
  21. outline_service = OutlineService()
  22. return await outline_service.generate_outline(
  23. overview=request.overview,
  24. requirements=request.requirements,
  25. mode=request.mode,
  26. uploaded_expand=bool(request.uploaded_expand),
  27. old_outline=request.old_outline,
  28. )
  29. except AppError as exc:
  30. raise HTTPException(status_code=exc.status_code, detail=exc.message) from exc
  31. except Exception as exc:
  32. logger.exception("目录生成失败")
  33. raise HTTPException(status_code=500, detail=f"目录生成失败: {exc}") from exc
  34. @router.post("/generate-stream")
  35. async def generate_outline_stream(request: OutlineRequest):
  36. """流式生成目录结构。"""
  37. try:
  38. outline_service = OutlineService()
  39. except AppError as exc:
  40. raise HTTPException(status_code=exc.status_code, detail=exc.message) from exc
  41. async def generate():
  42. queue: asyncio.Queue[str | None] = asyncio.Queue()
  43. client_disconnected = False
  44. async def progress_callback(message: str) -> None:
  45. await queue.put(sse_progress(message))
  46. async def run_workflow() -> None:
  47. try:
  48. outline = await outline_service.generate_outline(
  49. overview=request.overview,
  50. requirements=request.requirements,
  51. mode=request.mode,
  52. uploaded_expand=bool(request.uploaded_expand),
  53. old_outline=request.old_outline,
  54. progress_callback=progress_callback,
  55. )
  56. await queue.put(sse_result({"outline": outline}))
  57. except AppError as exc:
  58. await queue.put(sse_error(exc.message))
  59. except Exception:
  60. logger.exception("目录流式生成失败")
  61. await queue.put(sse_error("目录生成失败,请稍后重试"))
  62. finally:
  63. await queue.put(None)
  64. task = asyncio.create_task(run_workflow())
  65. try:
  66. while True:
  67. event = await queue.get()
  68. if event is None:
  69. break
  70. yield event
  71. except asyncio.CancelledError:
  72. client_disconnected = True
  73. raise
  74. finally:
  75. if not task.done():
  76. task.cancel()
  77. try:
  78. await task
  79. except asyncio.CancelledError:
  80. pass
  81. finally:
  82. if not client_disconnected:
  83. yield sse_done()
  84. return sse_response(generate())