sse.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. """SSE (Server-Sent Events) 相关工具。"""
  2. import json
  3. from typing import AsyncGenerator, Any, Dict, Optional
  4. from fastapi.responses import StreamingResponse
  5. DEFAULT_SSE_HEADERS: Dict[str, str] = {
  6. "Cache-Control": "no-cache",
  7. "Connection": "keep-alive",
  8. "Content-Type": "text/event-stream",
  9. }
  10. def sse_response(
  11. generator: AsyncGenerator[str, Any],
  12. media_type: str = "text/event-stream",
  13. extra_headers: Optional[Dict[str, str]] = None,
  14. ) -> StreamingResponse:
  15. """
  16. 包装 SSE 异步生成器为 StreamingResponse,统一 headers 和 media_type。
  17. Args:
  18. generator: 异步生成器,yield 已经带好 "data: ..." 和 "\n\n" 的字符串
  19. media_type: 响应的 media_type,默认使用 text/event-stream
  20. extra_headers: 额外需要添加或覆盖的响应头
  21. """
  22. headers = DEFAULT_SSE_HEADERS.copy()
  23. if extra_headers:
  24. headers.update(extra_headers)
  25. return StreamingResponse(
  26. generator,
  27. media_type=media_type,
  28. headers=headers,
  29. )
  30. def sse_data(payload: Dict[str, Any]) -> str:
  31. """将 payload 包装为标准 SSE data 事件。"""
  32. return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
  33. def sse_chunk(chunk: str) -> str:
  34. """输出增量文本块。"""
  35. return sse_data({"chunk": chunk})
  36. def sse_progress(message: str) -> str:
  37. """输出进度事件。"""
  38. return sse_data({"type": "progress", "message": message})
  39. def sse_result(payload: Dict[str, Any]) -> str:
  40. """输出结果事件。"""
  41. return sse_data({"type": "result", **payload})
  42. def sse_error(message: str) -> str:
  43. """输出统一错误事件。"""
  44. return sse_data({"error": True, "message": message})
  45. def sse_done() -> str:
  46. """输出结束标记。"""
  47. return "data: [DONE]\n\n"