| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738 |
- from fastapi import APIRouter, HTTPException, Request
- from pydantic import BaseModel
- from typing import List, Dict, Any, Optional, Type
- from lightrag.utils import logger
- import time
- import json
- import re
- from enum import Enum
- from fastapi.responses import StreamingResponse
- import asyncio
- from lightrag import LightRAG, QueryParam
- from lightrag.utils import TiktokenTokenizer
- from lightrag.api.utils_api import get_combined_auth_dependency
- from fastapi import Depends
- # query mode according to query prefix (bypass is not LightRAG quer mode)
- class SearchMode(str, Enum):
- naive = "naive"
- local = "local"
- global_ = "global"
- hybrid = "hybrid"
- mix = "mix"
- bypass = "bypass"
- context = "context"
- class OllamaMessage(BaseModel):
- role: str
- content: str
- images: Optional[List[str]] = None
- class OllamaChatRequest(BaseModel):
- model: str
- messages: List[OllamaMessage]
- stream: bool = True
- options: Optional[Dict[str, Any]] = None
- system: Optional[str] = None
- class OllamaChatResponse(BaseModel):
- model: str
- created_at: str
- message: OllamaMessage
- done: bool
- class OllamaGenerateRequest(BaseModel):
- model: str
- prompt: str
- system: Optional[str] = None
- stream: bool = False
- options: Optional[Dict[str, Any]] = None
- class OllamaGenerateResponse(BaseModel):
- model: str
- created_at: str
- response: str
- done: bool
- context: Optional[List[int]]
- total_duration: Optional[int]
- load_duration: Optional[int]
- prompt_eval_count: Optional[int]
- prompt_eval_duration: Optional[int]
- eval_count: Optional[int]
- eval_duration: Optional[int]
- class OllamaVersionResponse(BaseModel):
- version: str
- class OllamaModelDetails(BaseModel):
- parent_model: str
- format: str
- family: str
- families: List[str]
- parameter_size: str
- quantization_level: str
- class OllamaModel(BaseModel):
- name: str
- model: str
- size: int
- digest: str
- modified_at: str
- details: OllamaModelDetails
- class OllamaTagResponse(BaseModel):
- models: List[OllamaModel]
- class OllamaRunningModelDetails(BaseModel):
- parent_model: str
- format: str
- family: str
- families: List[str]
- parameter_size: str
- quantization_level: str
- class OllamaRunningModel(BaseModel):
- name: str
- model: str
- size: int
- digest: str
- details: OllamaRunningModelDetails
- expires_at: str
- size_vram: int
- class OllamaPsResponse(BaseModel):
- models: List[OllamaRunningModel]
- async def parse_request_body(
- request: Request, model_class: Type[BaseModel]
- ) -> BaseModel:
- """
- Parse request body based on Content-Type header.
- Supports both application/json and application/octet-stream.
- Args:
- request: The FastAPI Request object
- model_class: The Pydantic model class to parse the request into
- Returns:
- An instance of the provided model_class
- """
- content_type = request.headers.get("content-type", "").lower()
- try:
- if content_type.startswith("application/json"):
- # FastAPI already handles JSON parsing for us
- body = await request.json()
- elif content_type.startswith("application/octet-stream"):
- # Manually parse octet-stream as JSON
- body_bytes = await request.body()
- body = json.loads(body_bytes.decode("utf-8"))
- else:
- # Try to parse as JSON for any other content type
- body_bytes = await request.body()
- body = json.loads(body_bytes.decode("utf-8"))
- # Create an instance of the model
- return model_class(**body)
- except json.JSONDecodeError:
- raise HTTPException(status_code=400, detail="Invalid JSON in request body")
- except Exception as e:
- raise HTTPException(
- status_code=400, detail=f"Error parsing request body: {str(e)}"
- )
- def estimate_tokens(text: str) -> int:
- """Estimate the number of tokens in text using tiktoken"""
- tokens = TiktokenTokenizer().encode(text)
- return len(tokens)
- def parse_query_mode(query: str) -> tuple[str, SearchMode, bool, Optional[str]]:
- """Parse query prefix to determine search mode
- Returns tuple of (cleaned_query, search_mode, only_need_context, user_prompt)
- Examples:
- - "/local[use mermaid format for diagrams] query string" -> (cleaned_query, SearchMode.local, False, "use mermaid format for diagrams")
- - "/[use mermaid format for diagrams] query string" -> (cleaned_query, SearchMode.hybrid, False, "use mermaid format for diagrams")
- - "/local query string" -> (cleaned_query, SearchMode.local, False, None)
- """
- # Initialize user_prompt as None
- user_prompt = None
- # First check if there's a bracket format for user prompt
- bracket_pattern = r"^/([a-z]*)\[(.*?)\](.*)"
- bracket_match = re.match(bracket_pattern, query)
- if bracket_match:
- mode_prefix = bracket_match.group(1)
- user_prompt = bracket_match.group(2)
- remaining_query = bracket_match.group(3).lstrip()
- # Reconstruct query, removing the bracket part
- query = f"/{mode_prefix} {remaining_query}".strip()
- # Unified handling of mode and only_need_context determination
- mode_map = {
- "/local ": (SearchMode.local, False),
- "/global ": (
- SearchMode.global_,
- False,
- ), # global_ is used because 'global' is a Python keyword
- "/naive ": (SearchMode.naive, False),
- "/hybrid ": (SearchMode.hybrid, False),
- "/mix ": (SearchMode.mix, False),
- "/bypass ": (SearchMode.bypass, False),
- "/context": (
- SearchMode.mix,
- True,
- ),
- "/localcontext": (SearchMode.local, True),
- "/globalcontext": (SearchMode.global_, True),
- "/hybridcontext": (SearchMode.hybrid, True),
- "/naivecontext": (SearchMode.naive, True),
- "/mixcontext": (SearchMode.mix, True),
- }
- for prefix, (mode, only_need_context) in mode_map.items():
- if query.startswith(prefix):
- # After removing prefix and leading spaces
- cleaned_query = query[len(prefix) :].lstrip()
- return cleaned_query, mode, only_need_context, user_prompt
- return query, SearchMode.mix, False, user_prompt
- class OllamaAPI:
- def __init__(self, rag: LightRAG, top_k: int = 60, api_key: Optional[str] = None):
- self.rag = rag
- self.ollama_server_infos = rag.ollama_server_infos
- self.top_k = top_k
- self.api_key = api_key
- self.router = APIRouter(tags=["ollama"])
- self.setup_routes()
- def setup_routes(self):
- # Create combined auth dependency for Ollama API routes
- combined_auth = get_combined_auth_dependency(self.api_key)
- @self.router.get("/version", dependencies=[Depends(combined_auth)])
- async def get_version():
- """Get Ollama version information"""
- return OllamaVersionResponse(version="0.9.3")
- @self.router.get("/tags", dependencies=[Depends(combined_auth)])
- async def get_tags():
- """Return available models acting as an Ollama server"""
- return OllamaTagResponse(
- models=[
- {
- "name": self.ollama_server_infos.LIGHTRAG_MODEL,
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
- "modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
- "size": self.ollama_server_infos.LIGHTRAG_SIZE,
- "digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
- "details": {
- "parent_model": "",
- "format": "gguf",
- "family": self.ollama_server_infos.LIGHTRAG_NAME,
- "families": [self.ollama_server_infos.LIGHTRAG_NAME],
- "parameter_size": "13B",
- "quantization_level": "Q4_0",
- },
- }
- ]
- )
- @self.router.get("/ps", dependencies=[Depends(combined_auth)])
- async def get_running_models():
- """List Running Models - returns currently running models"""
- return OllamaPsResponse(
- models=[
- {
- "name": self.ollama_server_infos.LIGHTRAG_MODEL,
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
- "size": self.ollama_server_infos.LIGHTRAG_SIZE,
- "digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
- "details": {
- "parent_model": "",
- "format": "gguf",
- "family": "llama",
- "families": ["llama"],
- "parameter_size": "7.2B",
- "quantization_level": "Q4_0",
- },
- "expires_at": "2050-12-31T14:38:31.83753-07:00",
- "size_vram": self.ollama_server_infos.LIGHTRAG_SIZE,
- }
- ]
- )
- @self.router.post(
- "/generate", dependencies=[Depends(combined_auth)], include_in_schema=True
- )
- async def generate(raw_request: Request):
- """Handle generate completion requests acting as an Ollama model
- For compatibility purpose, the request is not processed by LightRAG,
- and will be handled by underlying LLM model.
- Supports both application/json and application/octet-stream Content-Types.
- """
- try:
- # Parse the request body manually
- request = await parse_request_body(raw_request, OllamaGenerateRequest)
- query = request.prompt
- start_time = time.time_ns()
- prompt_tokens = estimate_tokens(query)
- role_kwargs = (
- dict(self.rag.role_llm_kwargs["query"])
- if self.rag.role_llm_kwargs["query"] is not None
- else dict(self.rag.llm_model_kwargs)
- )
- if request.system:
- role_kwargs["system_prompt"] = request.system
- if request.stream:
- response = await (self.rag.role_llm_funcs["query"])(
- query, stream=True, **role_kwargs
- )
- async def stream_generator():
- first_chunk_time = None
- last_chunk_time = time.time_ns()
- total_response = ""
- # Ensure response is an async generator
- if isinstance(response, str):
- # If it's a string, send in two parts
- first_chunk_time = start_time
- last_chunk_time = time.time_ns()
- total_response = response
- data = {
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
- "response": response,
- "done": False,
- }
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
- completion_tokens = estimate_tokens(total_response)
- total_time = last_chunk_time - start_time
- prompt_eval_time = first_chunk_time - start_time
- eval_time = last_chunk_time - first_chunk_time
- data = {
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
- "response": "",
- "done": True,
- "done_reason": "stop",
- "context": [],
- "total_duration": total_time,
- "load_duration": 0,
- "prompt_eval_count": prompt_tokens,
- "prompt_eval_duration": prompt_eval_time,
- "eval_count": completion_tokens,
- "eval_duration": eval_time,
- }
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
- else:
- try:
- async for chunk in response:
- if chunk:
- if first_chunk_time is None:
- first_chunk_time = time.time_ns()
- last_chunk_time = time.time_ns()
- total_response += chunk
- data = {
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
- "response": chunk,
- "done": False,
- }
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
- except (asyncio.CancelledError, Exception) as e:
- error_msg = str(e)
- if isinstance(e, asyncio.CancelledError):
- error_msg = "Stream was cancelled by server"
- else:
- error_msg = f"Provider error: {error_msg}"
- logger.error(f"Stream error: {error_msg}")
- # Send error message to client
- error_data = {
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
- "response": f"\n\nError: {error_msg}",
- "error": f"\n\nError: {error_msg}",
- "done": False,
- }
- yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
- # Send final message to close the stream
- final_data = {
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
- "response": "",
- "done": True,
- }
- yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
- return
- if first_chunk_time is None:
- first_chunk_time = start_time
- completion_tokens = estimate_tokens(total_response)
- total_time = last_chunk_time - start_time
- prompt_eval_time = first_chunk_time - start_time
- eval_time = last_chunk_time - first_chunk_time
- data = {
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
- "response": "",
- "done": True,
- "done_reason": "stop",
- "context": [],
- "total_duration": total_time,
- "load_duration": 0,
- "prompt_eval_count": prompt_tokens,
- "prompt_eval_duration": prompt_eval_time,
- "eval_count": completion_tokens,
- "eval_duration": eval_time,
- }
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
- return
- return StreamingResponse(
- stream_generator(),
- media_type="application/x-ndjson",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "Content-Type": "application/x-ndjson",
- "X-Accel-Buffering": "no", # Ensure proper handling of streaming responses in Nginx proxy
- },
- )
- else:
- first_chunk_time = time.time_ns()
- response_text = await (self.rag.role_llm_funcs["query"])(
- query, stream=False, **role_kwargs
- )
- last_chunk_time = time.time_ns()
- if not response_text:
- response_text = "No response generated"
- completion_tokens = estimate_tokens(str(response_text))
- total_time = last_chunk_time - start_time
- prompt_eval_time = first_chunk_time - start_time
- eval_time = last_chunk_time - first_chunk_time
- return {
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
- "response": str(response_text),
- "done": True,
- "done_reason": "stop",
- "context": [],
- "total_duration": total_time,
- "load_duration": 0,
- "prompt_eval_count": prompt_tokens,
- "prompt_eval_duration": prompt_eval_time,
- "eval_count": completion_tokens,
- "eval_duration": eval_time,
- }
- except Exception as e:
- logger.error(f"Ollama generate error: {str(e)}", exc_info=True)
- raise HTTPException(status_code=500, detail=str(e))
- @self.router.post(
- "/chat", dependencies=[Depends(combined_auth)], include_in_schema=True
- )
- async def chat(raw_request: Request):
- """Process chat completion requests by acting as an Ollama model.
- Routes user queries through LightRAG by selecting query mode based on query prefix.
- Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
- Supports both application/json and application/octet-stream Content-Types.
- """
- try:
- # Parse the request body manually
- request = await parse_request_body(raw_request, OllamaChatRequest)
- # Get all messages
- messages = request.messages
- if not messages:
- raise HTTPException(status_code=400, detail="No messages provided")
- # Validate that the last message is from a user
- if messages[-1].role != "user":
- raise HTTPException(
- status_code=400, detail="Last message must be from user role"
- )
- # Get the last message as query and previous messages as history
- query = messages[-1].content
- # Convert OllamaMessage objects to dictionaries
- conversation_history = [
- {"role": msg.role, "content": msg.content} for msg in messages[:-1]
- ]
- # Check for query prefix
- cleaned_query, mode, only_need_context, user_prompt = parse_query_mode(
- query
- )
- start_time = time.time_ns()
- prompt_tokens = estimate_tokens(cleaned_query)
- param_dict = {
- "mode": mode.value,
- "stream": request.stream,
- "only_need_context": only_need_context,
- "conversation_history": conversation_history,
- "top_k": self.top_k,
- }
- # Add user_prompt to param_dict
- if user_prompt is not None:
- param_dict["user_prompt"] = user_prompt
- query_param = QueryParam(**param_dict)
- if request.stream:
- # Determine if the request is prefix with "/bypass"
- if mode == SearchMode.bypass:
- role_kwargs = (
- dict(self.rag.role_llm_kwargs["query"])
- if self.rag.role_llm_kwargs["query"] is not None
- else dict(self.rag.llm_model_kwargs)
- )
- if request.system:
- role_kwargs["system_prompt"] = request.system
- response = await (self.rag.role_llm_funcs["query"])(
- cleaned_query,
- stream=True,
- history_messages=conversation_history,
- **role_kwargs,
- )
- else:
- response = await self.rag.aquery(
- cleaned_query, param=query_param
- )
- async def stream_generator():
- first_chunk_time = None
- last_chunk_time = time.time_ns()
- total_response = ""
- # Ensure response is an async generator
- if isinstance(response, str):
- # If it's a string, send in two parts
- first_chunk_time = start_time
- last_chunk_time = time.time_ns()
- total_response = response
- data = {
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
- "message": {
- "role": "assistant",
- "content": response,
- "images": None,
- },
- "done": False,
- }
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
- completion_tokens = estimate_tokens(total_response)
- total_time = last_chunk_time - start_time
- prompt_eval_time = first_chunk_time - start_time
- eval_time = last_chunk_time - first_chunk_time
- data = {
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
- "message": {
- "role": "assistant",
- "content": "",
- "images": None,
- },
- "done_reason": "stop",
- "done": True,
- "total_duration": total_time,
- "load_duration": 0,
- "prompt_eval_count": prompt_tokens,
- "prompt_eval_duration": prompt_eval_time,
- "eval_count": completion_tokens,
- "eval_duration": eval_time,
- }
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
- else:
- try:
- async for chunk in response:
- if chunk:
- if first_chunk_time is None:
- first_chunk_time = time.time_ns()
- last_chunk_time = time.time_ns()
- total_response += chunk
- data = {
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
- "message": {
- "role": "assistant",
- "content": chunk,
- "images": None,
- },
- "done": False,
- }
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
- except (asyncio.CancelledError, Exception) as e:
- error_msg = str(e)
- if isinstance(e, asyncio.CancelledError):
- error_msg = "Stream was cancelled by server"
- else:
- error_msg = f"Provider error: {error_msg}"
- logger.error(f"Stream error: {error_msg}")
- # Send error message to client
- error_data = {
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
- "message": {
- "role": "assistant",
- "content": f"\n\nError: {error_msg}",
- "images": None,
- },
- "error": f"\n\nError: {error_msg}",
- "done": False,
- }
- yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
- # Send final message to close the stream
- final_data = {
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
- "message": {
- "role": "assistant",
- "content": "",
- "images": None,
- },
- "done": True,
- }
- yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
- return
- if first_chunk_time is None:
- first_chunk_time = start_time
- completion_tokens = estimate_tokens(total_response)
- total_time = last_chunk_time - start_time
- prompt_eval_time = first_chunk_time - start_time
- eval_time = last_chunk_time - first_chunk_time
- data = {
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
- "message": {
- "role": "assistant",
- "content": "",
- "images": None,
- },
- "done_reason": "stop",
- "done": True,
- "total_duration": total_time,
- "load_duration": 0,
- "prompt_eval_count": prompt_tokens,
- "prompt_eval_duration": prompt_eval_time,
- "eval_count": completion_tokens,
- "eval_duration": eval_time,
- }
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
- return StreamingResponse(
- stream_generator(),
- media_type="application/x-ndjson",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "Content-Type": "application/x-ndjson",
- "X-Accel-Buffering": "no", # Ensure proper handling of streaming responses in Nginx proxy
- },
- )
- else:
- first_chunk_time = time.time_ns()
- # Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task
- match_result = re.search(
- r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
- )
- if match_result or mode == SearchMode.bypass:
- role_kwargs = (
- dict(self.rag.role_llm_kwargs["query"])
- if self.rag.role_llm_kwargs["query"] is not None
- else dict(self.rag.llm_model_kwargs)
- )
- if request.system:
- role_kwargs["system_prompt"] = request.system
- response_text = await (self.rag.role_llm_funcs["query"])(
- cleaned_query,
- stream=False,
- history_messages=conversation_history,
- **role_kwargs,
- )
- else:
- response_text = await self.rag.aquery(
- cleaned_query, param=query_param
- )
- last_chunk_time = time.time_ns()
- if not response_text:
- response_text = "No response generated"
- completion_tokens = estimate_tokens(str(response_text))
- total_time = last_chunk_time - start_time
- prompt_eval_time = first_chunk_time - start_time
- eval_time = last_chunk_time - first_chunk_time
- return {
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
- "message": {
- "role": "assistant",
- "content": str(response_text),
- "images": None,
- },
- "done_reason": "stop",
- "done": True,
- "total_duration": total_time,
- "load_duration": 0,
- "prompt_eval_count": prompt_tokens,
- "prompt_eval_duration": prompt_eval_time,
- "eval_count": completion_tokens,
- "eval_duration": eval_time,
- }
- except Exception as e:
- logger.error(f"Ollama chat error: {str(e)}", exc_info=True)
- raise HTTPException(status_code=500, detail=str(e))
|