| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227 |
- from ..utils import verbose_debug, VERBOSE_DEBUG
- import os
- import logging
- import warnings
- from collections.abc import AsyncIterator
- import pipmaster as pm
- import tiktoken
- # install specific modules
- if not pm.is_installed("openai"):
- pm.install("openai")
- from openai import (
- APIConnectionError,
- RateLimitError,
- APITimeoutError,
- InternalServerError,
- BadRequestError,
- )
- from tenacity import (
- retry,
- stop_after_attempt,
- wait_exponential,
- retry_if_exception_type,
- )
- from lightrag.utils import (
- wrap_embedding_func_with_attrs,
- safe_unicode_decode,
- logger,
- )
- from lightrag.api import __api_version__
- import numpy as np
- import base64
- from typing import Any, Union
- from dotenv import load_dotenv
- # Try to import Langfuse for LLM observability (optional)
- # Falls back to standard OpenAI client if not available
- # Langfuse requires proper configuration to work correctly
- LANGFUSE_ENABLED = False
- try:
- # Check if required Langfuse environment variables are set
- langfuse_public_key = os.environ.get("LANGFUSE_PUBLIC_KEY")
- langfuse_secret_key = os.environ.get("LANGFUSE_SECRET_KEY")
- # Only enable Langfuse if both keys are configured
- if langfuse_public_key and langfuse_secret_key:
- from langfuse.openai import AsyncOpenAI # type: ignore[import-untyped]
- LANGFUSE_ENABLED = True
- logger.info("Langfuse observability enabled for OpenAI client")
- else:
- from openai import AsyncOpenAI
- logger.debug(
- "Langfuse environment variables not configured, using standard OpenAI client"
- )
- except ImportError:
- from openai import AsyncOpenAI
- logger.debug("Langfuse not available, using standard OpenAI client")
- # use the .env that is inside the current folder
- # allows to use different .env file for each lightrag instance
- # the OS environment variables take precedence over the .env file
- load_dotenv(dotenv_path=".env", override=False)
- class InvalidResponseError(Exception):
- """Custom exception class for triggering retry mechanism"""
- pass
- class TransientBadRequestError(Exception):
- """Wrapper to trigger retry on transient HTTP 400 errors.
- Some 400s are not genuine client errors: the OpenAI API (or a proxy in
- front of it) intermittently returns "We could not parse the JSON body of
- your request" when the request body is corrupted/truncated in transit.
- These succeed on retry, so we re-raise them as this retryable type while
- letting genuine 400s (bad params, content policy, etc.) fail fast.
- """
- pass
- def _validate_openai_response_format(response_format: Any | None) -> None:
- """Reject typed structured-output helpers; only wire-format dicts are supported."""
- if response_format is None or isinstance(response_format, dict):
- return
- raise TypeError(
- "openai_complete_if_cache only supports dict response_format payloads; "
- "typed/Pydantic response_format values are not supported."
- )
- # Module-level cache for tiktoken encodings
- _TIKTOKEN_ENCODING_CACHE: dict[str, Any] = {}
- # Whether to request base64-encoded embeddings from the API.
- # Base64 is more efficient over the wire; set EMBEDDING_USE_BASE64=false for
- # providers that don't support it (e.g. Yandex Cloud).
- EMBEDDING_USE_BASE64: bool = os.getenv("EMBEDDING_USE_BASE64", "true").lower() in (
- "true",
- "1",
- "yes",
- )
- def _get_tiktoken_encoding_for_model(model: str) -> Any:
- """Get tiktoken encoding for the specified model with caching.
- Args:
- model: The model name to get encoding for.
- Returns:
- The tiktoken encoding for the model.
- """
- if model not in _TIKTOKEN_ENCODING_CACHE:
- try:
- _TIKTOKEN_ENCODING_CACHE[model] = tiktoken.encoding_for_model(model)
- except KeyError:
- logger.debug(
- f"Encoding for model '{model}' not found, falling back to cl100k_base"
- )
- _TIKTOKEN_ENCODING_CACHE[model] = tiktoken.get_encoding("cl100k_base")
- return _TIKTOKEN_ENCODING_CACHE[model]
- def create_openai_async_client(
- api_key: str | None = None,
- base_url: str | None = None,
- use_azure: bool = False,
- azure_deployment: str | None = None,
- api_version: str | None = None,
- timeout: int | None = None,
- client_configs: dict[str, Any] | None = None,
- ) -> AsyncOpenAI:
- """Create an AsyncOpenAI or AsyncAzureOpenAI client with the given configuration.
- Args:
- api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
- base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL.
- use_azure: Whether to create an Azure OpenAI client. Default is False.
- azure_deployment: Azure OpenAI deployment name (only used when use_azure=True).
- api_version: Azure OpenAI API version (only used when use_azure=True).
- timeout: Request timeout in seconds.
- client_configs: Additional configuration options for the AsyncOpenAI client.
- These will override any default configurations but will be overridden by
- explicit parameters (api_key, base_url).
- Returns:
- An AsyncOpenAI or AsyncAzureOpenAI client instance.
- """
- if use_azure:
- from openai import AsyncAzureOpenAI
- if not api_key:
- api_key = os.environ.get("AZURE_OPENAI_API_KEY") or os.environ.get(
- "LLM_BINDING_API_KEY"
- )
- if client_configs is None:
- client_configs = {}
- # Create a merged config dict with precedence: explicit params > client_configs
- merged_configs = {
- **client_configs,
- "api_key": api_key,
- }
- # Add explicit parameters (override client_configs)
- if base_url is not None:
- merged_configs["azure_endpoint"] = base_url
- if azure_deployment is not None:
- merged_configs["azure_deployment"] = azure_deployment
- if api_version is not None:
- merged_configs["api_version"] = api_version
- if timeout is not None:
- merged_configs["timeout"] = timeout
- return AsyncAzureOpenAI(**merged_configs)
- else:
- if not api_key:
- api_key = os.environ["OPENAI_API_KEY"]
- default_headers = {
- "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
- "Content-Type": "application/json",
- }
- dashscope_workspace_id = os.getenv("DASHSCOPE_WORKSPACE_ID", "").strip()
- if dashscope_workspace_id:
- default_headers["X-DashScope-Workspace"] = dashscope_workspace_id
- if client_configs is None:
- client_configs = {}
- # Create a merged config dict with precedence: explicit params > client_configs > defaults
- merged_configs = {
- **client_configs,
- "default_headers": default_headers,
- "api_key": api_key,
- }
- if base_url is not None:
- merged_configs["base_url"] = base_url
- else:
- merged_configs["base_url"] = os.environ.get(
- "OPENAI_API_BASE", "https://api.openai.com/v1"
- )
- if timeout is not None:
- merged_configs["timeout"] = timeout
- return AsyncOpenAI(**merged_configs)
- # TODO LengthFinishReasonError should not persist into LLM cache
- @retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=10),
- retry=(
- retry_if_exception_type(RateLimitError)
- | retry_if_exception_type(APIConnectionError)
- | retry_if_exception_type(APITimeoutError)
- | retry_if_exception_type(InvalidResponseError)
- # Retry transient HTTP 5xx (OpenAI "500 server_error", proxy "upstream
- # connect error"). InternalServerError covers all status >= 500.
- | retry_if_exception_type(InternalServerError)
- # Retry transient "could not parse JSON body" 400s (see handler below).
- | retry_if_exception_type(TransientBadRequestError)
- ),
- )
- async def openai_complete_if_cache(
- model: str,
- prompt: str,
- system_prompt: str | None = None,
- history_messages: list[dict[str, Any]] | None = None,
- enable_cot: bool = False,
- base_url: str | None = None,
- api_key: str | None = None,
- token_tracker: Any | None = None,
- stream: bool | None = None,
- timeout: int | None = None,
- keyword_extraction: bool = False,
- use_azure: bool = False,
- azure_deployment: str | None = None,
- api_version: str | None = None,
- image_inputs: list[Any] | None = None,
- **kwargs: Any,
- ) -> str:
- """Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration.
- This function supports automatic integration of reasoning content from models that provide
- Chain of Thought capabilities. The reasoning content is seamlessly integrated into the response
- using <think>...</think> tags.
- Structured output design note:
- - This adapter supports dict-based OpenAI response_format payloads,
- including ``{"type": "json_object"}`` and dict-form ``json_schema``.
- - Typed/Pydantic ``response_format`` helpers are rejected explicitly.
- - Structured responses are returned as raw text from ``message.content``
- and are not locally schema-validated here.
- - ``keyword_extraction`` is deprecated; prefer
- ``response_format={"type": "json_object"}`` instead.
- Note on truncated structured output: when the OpenAI SDK raises
- `LengthFinishReasonError`, callers may still receive partial raw JSON from
- `completion.choices[0].message.content`. That payload should be treated as
- best-effort recovery only. If the JSON was truncated or repaired after
- truncation, it is safer not to persist it into the LLM cache because later
- runs with a higher token budget could otherwise keep reusing incomplete data.
- Note on `reasoning_content`: This feature relies on a Deepseek Style `reasoning_content`
- in the API response, which may be provided by OpenAI-compatible endpoints that support
- Chain of Thought.
- COT Integration Rules:
- 1. COT content is accepted only when regular content is empty and `reasoning_content` has content.
- 2. COT processing stops when regular content becomes available.
- 3. If both `content` and `reasoning_content` are present simultaneously, reasoning is ignored.
- 4. If both fields have content from the start, COT is never activated.
- 5. For streaming: COT content is inserted into the content stream with <think> tags.
- 6. For non-streaming: COT content is prepended to regular content with <think> tags.
- Args:
- model: The OpenAI model to use. For Azure, this can be the deployment name.
- prompt: The prompt to complete.
- system_prompt: Optional system prompt to include.
- history_messages: Optional list of previous messages in the conversation.
- enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False.
- base_url: Optional base URL for the OpenAI API. For Azure, this should be the
- Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/).
- api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment
- variable if None. For Azure, uses AZURE_OPENAI_API_KEY if None.
- token_tracker: Optional token usage tracker for monitoring API usage.
- stream: Whether to stream the response. Default is False.
- timeout: Request timeout in seconds. Default is None.
- keyword_extraction: Deprecated compatibility shim. When True and no
- explicit ``response_format`` is supplied, it is mapped to
- ``{"type": "json_object"}``. Prefer passing ``response_format``
- directly. Default is False.
- use_azure: Whether to use Azure OpenAI service instead of standard OpenAI.
- When True, creates an AsyncAzureOpenAI client. Default is False.
- azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True.
- If not specified, falls back to AZURE_OPENAI_DEPLOYMENT environment variable.
- api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used
- when use_azure=True. If not specified, falls back to AZURE_OPENAI_API_VERSION
- environment variable.
- **kwargs: Additional keyword arguments to pass to the OpenAI API.
- Special kwargs:
- - response_format: Structured output control forwarded to the OpenAI
- chat completions API. This adapter accepts dict payloads such
- as ``{"type": "json_object"}`` and dict-form ``json_schema``,
- but rejects typed/Pydantic response_format values.
- - openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
- These will be passed to the client constructor but will be overridden by
- explicit parameters (api_key, base_url). Supports proxy configuration,
- custom headers, retry policies, etc.
- Returns:
- The completed text (with integrated COT content if available) or an async iterator
- of text chunks if streaming. COT content is wrapped in <think>...</think> tags.
- Raises:
- InvalidResponseError: If the response from OpenAI is invalid or empty.
- APIConnectionError: If there is a connection error with the OpenAI API.
- RateLimitError: If the OpenAI API rate limit is exceeded.
- APITimeoutError: If the OpenAI API request times out.
- """
- if history_messages is None:
- history_messages = []
- # Set openai logger level to INFO when VERBOSE_DEBUG is off
- if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
- logging.getLogger("openai").setLevel(logging.INFO)
- # Remove special kwargs that shouldn't be passed to OpenAI
- kwargs.pop("hashing_kv", None)
- # Extract client configuration options
- client_configs = kwargs.pop("openai_client_configs", {})
- # Deprecation shims: map legacy boolean flags to response_format only when
- # an explicit response_format was not supplied by the caller. Prefer passing
- # response_format directly.
- entity_extraction = kwargs.pop("entity_extraction", False)
- if entity_extraction and kwargs.get("response_format") is None:
- warnings.warn(
- "openai_complete_if_cache(entity_extraction=True) is deprecated; "
- "pass response_format={'type': 'json_object'} instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- kwargs["response_format"] = {"type": "json_object"}
- if keyword_extraction and kwargs.get("response_format") is None:
- warnings.warn(
- "openai_complete_if_cache(keyword_extraction=True) is deprecated; "
- "pass response_format={'type': 'json_object'} instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- kwargs["response_format"] = {"type": "json_object"}
- _validate_openai_response_format(kwargs.get("response_format"))
- if kwargs.get("response_format") is not None:
- enable_cot = False
- # Create the OpenAI client (supports both OpenAI and Azure)
- openai_async_client = create_openai_async_client(
- api_key=api_key,
- base_url=base_url,
- use_azure=use_azure,
- azure_deployment=azure_deployment,
- api_version=api_version,
- timeout=timeout,
- client_configs=client_configs,
- )
- # Prepare messages
- messages: list[dict[str, Any]] = []
- if system_prompt:
- messages.append({"role": "system", "content": system_prompt})
- messages.extend(history_messages)
- if image_inputs:
- from lightrag.llm._vision_utils import normalize_image_inputs
- normalized_images = normalize_image_inputs(image_inputs)
- user_content: list[dict[str, Any]] = [{"type": "text", "text": prompt}]
- for img in normalized_images:
- user_content.append(
- {
- "type": "image_url",
- "image_url": {
- "url": f"data:{img.mime_type};base64,{img.base64_str}"
- },
- }
- )
- messages.append({"role": "user", "content": user_content})
- else:
- messages.append({"role": "user", "content": prompt})
- logger.debug("===== Entering func of LLM =====")
- logger.debug(f"Model: {model} Base URL: {base_url}")
- logger.debug(f"Client Configs: {client_configs}")
- logger.debug(f"Additional kwargs: {kwargs}")
- logger.debug(f"Num of history messages: {len(history_messages)}")
- verbose_debug(f"System prompt: {system_prompt}")
- verbose_debug(f"Query: {prompt}")
- logger.debug("===== Sending Query to LLM =====")
- messages = kwargs.pop("messages", messages)
- # Add explicit parameters back to kwargs so they're passed to OpenAI API
- if stream is not None:
- kwargs["stream"] = stream
- if timeout is not None:
- kwargs["timeout"] = timeout
- # Determine the correct model identifier to use
- # For Azure OpenAI, we must use the deployment name instead of the model name
- api_model = azure_deployment if use_azure and azure_deployment else model
- try:
- # Single dispatch: create() covers the dict-based response_format
- # payloads used by this project. Typed/Pydantic helpers are rejected
- # above. Length-truncation is detected via finish_reason below and the
- # raw content is returned unchanged so upstream tolerant JSON parsing
- # can still salvage it.
- response = await openai_async_client.chat.completions.create(
- model=api_model, messages=messages, **kwargs
- )
- except APITimeoutError as e:
- logger.error(f"OpenAI API Timeout Error: {e}")
- try:
- await openai_async_client.close()
- except Exception as close_error:
- logger.warning(f"Failed to close OpenAI client: {close_error}")
- raise
- except APIConnectionError as e:
- logger.error(f"OpenAI API Connection Error: {e}")
- try:
- await openai_async_client.close()
- except Exception as close_error:
- logger.warning(f"Failed to close OpenAI client: {close_error}")
- raise
- except RateLimitError as e:
- logger.error(f"OpenAI API Rate Limit Error: {e}")
- try:
- await openai_async_client.close()
- except Exception as close_error:
- logger.warning(f"Failed to close OpenAI client: {close_error}")
- raise
- except BadRequestError as e:
- # A "could not parse JSON body" 400 is transient (corrupted/truncated
- # request body in transit) and succeeds on retry; re-raise it as a
- # retryable type. Genuine 400s (bad params, content policy) fail fast.
- # Either way we must close the client before re-raising, matching the
- # other except branches above — otherwise non-transient 400s would
- # leak httpx connections in validation-heavy/misconfigured runs.
- try:
- await openai_async_client.close()
- except Exception as close_error:
- logger.warning(f"Failed to close OpenAI client: {close_error}")
- # Heuristic: match on the provider's error wording. It can drift across
- # providers/proxies or localization, and a genuinely malformed request
- # body (e.g. invalid user-supplied JSON) could also surface this text —
- # in that case we simply retry 3x and still fail fast. We accept that
- # "retry too much" trade-off to recover the common transient case.
- if "could not parse" in str(e).lower():
- logger.warning(f"Transient JSON-parse 400 from OpenAI, will retry: {e}")
- raise TransientBadRequestError(str(e)) from e
- raise
- except Exception as e:
- body = getattr(e, "body", None)
- request_id = getattr(e, "request_id", None)
- req = getattr(e, "request", None)
- extra_parts = []
- if body:
- extra_parts.append(f"Response body: {body}")
- if request_id:
- extra_parts.append(f"Request ID: {request_id}")
- if req is not None:
- extra_parts.append(f"Request URL: {req.url}")
- extra = ("\n" + "\n".join(extra_parts)) if extra_parts else ""
- logger.error(
- f"OpenAI API Call Failed,\nModel: {model},\nParams: {kwargs}, Got: {e}{extra}"
- )
- try:
- await openai_async_client.close()
- except Exception as close_error:
- logger.warning(f"Failed to close OpenAI client: {close_error}")
- raise
- if hasattr(response, "__aiter__"):
- async def inner():
- # Track if we've started iterating
- iteration_started = False
- final_chunk_usage = None
- # COT (Chain of Thought) state tracking
- cot_active = False
- cot_started = False
- initial_content_seen = False
- try:
- iteration_started = True
- async for chunk in response:
- # Check if this chunk has usage information (final chunk)
- if hasattr(chunk, "usage") and chunk.usage:
- final_chunk_usage = chunk.usage
- logger.debug(
- f"Received usage info in streaming chunk: {chunk.usage}"
- )
- # Check if choices exists and is not empty
- if not hasattr(chunk, "choices") or not chunk.choices:
- # Azure OpenAI sends content filter results in first chunk without choices
- logger.debug(
- f"Received chunk without choices (likely Azure content filter): {chunk}"
- )
- continue
- # Check if delta exists
- if not hasattr(chunk.choices[0], "delta"):
- # This might be the final chunk, continue to check for usage
- continue
- delta = chunk.choices[0].delta
- content = getattr(delta, "content", None)
- reasoning_content = getattr(delta, "reasoning_content", "")
- # Handle COT logic for streaming (only if enabled)
- if enable_cot:
- if content:
- # Regular content is present
- if not initial_content_seen:
- initial_content_seen = True
- # If both content and reasoning_content are present initially, don't start COT
- if reasoning_content:
- cot_active = False
- cot_started = False
- # If COT was active, end it
- if cot_active:
- yield "</think>"
- cot_active = False
- # Process regular content
- if r"\u" in content:
- content = safe_unicode_decode(content.encode("utf-8"))
- yield content
- elif reasoning_content:
- # Only reasoning content is present
- if not initial_content_seen and not cot_started:
- # Start COT if we haven't seen initial content yet
- if not cot_active:
- yield "<think>"
- cot_active = True
- cot_started = True
- # Process reasoning content if COT is active
- if cot_active:
- if r"\u" in reasoning_content:
- reasoning_content = safe_unicode_decode(
- reasoning_content.encode("utf-8")
- )
- yield reasoning_content
- else:
- # COT disabled, only process regular content
- if content:
- if r"\u" in content:
- content = safe_unicode_decode(content.encode("utf-8"))
- yield content
- # If neither content nor reasoning_content, continue to next chunk
- if content is None and reasoning_content is None:
- continue
- # Ensure COT is properly closed if still active after stream ends
- if enable_cot and cot_active:
- yield "</think>"
- cot_active = False
- # After streaming is complete, track token usage
- if token_tracker and final_chunk_usage:
- # Use actual usage from the API
- token_counts = {
- "prompt_tokens": getattr(final_chunk_usage, "prompt_tokens", 0),
- "completion_tokens": getattr(
- final_chunk_usage, "completion_tokens", 0
- ),
- "total_tokens": getattr(final_chunk_usage, "total_tokens", 0),
- }
- token_tracker.add_usage(token_counts)
- logger.debug(f"Streaming token usage (from API): {token_counts}")
- elif token_tracker:
- logger.debug("No usage information available in streaming response")
- except Exception as e:
- # Ensure COT is properly closed before handling exception
- if enable_cot and cot_active:
- try:
- yield "</think>"
- cot_active = False
- except Exception as close_error:
- logger.warning(
- f"Failed to close COT tag during exception handling: {close_error}"
- )
- logger.error(f"Error in stream response: {str(e)}")
- # Try to clean up resources if possible
- if (
- iteration_started
- and hasattr(response, "aclose")
- and callable(getattr(response, "aclose", None))
- ):
- try:
- await response.aclose()
- logger.debug("Successfully closed stream response after error")
- except Exception as close_error:
- logger.warning(
- f"Failed to close stream response: {close_error}"
- )
- # Ensure client is closed in case of exception
- try:
- await openai_async_client.close()
- except Exception as client_close_error:
- logger.warning(
- f"Failed to close OpenAI client after stream error: {client_close_error}"
- )
- raise
- finally:
- # Final safety check for unclosed COT tags
- if enable_cot and cot_active:
- try:
- yield "</think>"
- cot_active = False
- except Exception as final_close_error:
- logger.warning(
- f"Failed to close COT tag in finally block: {final_close_error}"
- )
- # Ensure resources are released even if no exception occurs
- # Note: Some wrapped clients (e.g., Langfuse) may not implement aclose() properly
- if iteration_started and hasattr(response, "aclose"):
- aclose_method = getattr(response, "aclose", None)
- if callable(aclose_method):
- try:
- await response.aclose()
- logger.debug("Successfully closed stream response")
- except (AttributeError, TypeError) as close_error:
- # Some wrapper objects may report hasattr(aclose) but fail when called
- # This is expected behavior for certain client wrappers
- logger.debug(
- f"Stream response cleanup not supported by client wrapper: {close_error}"
- )
- except Exception as close_error:
- logger.warning(
- f"Unexpected error during stream response cleanup: {close_error}"
- )
- # This prevents resource leaks since the caller doesn't handle closing
- try:
- await openai_async_client.close()
- logger.debug(
- "Successfully closed OpenAI client for streaming response"
- )
- except Exception as client_close_error:
- logger.warning(
- f"Failed to close OpenAI client in streaming finally block: {client_close_error}"
- )
- return inner()
- else:
- try:
- if (
- not response
- or not response.choices
- or not hasattr(response.choices[0], "message")
- ):
- logger.error("Invalid response from OpenAI API")
- try:
- await openai_async_client.close()
- except Exception as close_error:
- logger.warning(f"Failed to close OpenAI client: {close_error}")
- raise InvalidResponseError("Invalid response from OpenAI API")
- message = response.choices[0].message
- # Handle parsed responses (structured output via response_format)
- # When using beta.chat.completions.parse(), the response is in message.parsed
- if hasattr(message, "parsed") and message.parsed is not None:
- # Serialize the parsed structured response to JSON
- final_content = message.parsed.model_dump_json()
- logger.debug("Using parsed structured response from API")
- else:
- # Handle regular content responses
- content = getattr(message, "content", None)
- reasoning_content = getattr(message, "reasoning_content", "")
- # Handle COT logic for non-streaming responses (only if enabled)
- final_content = ""
- if enable_cot:
- # Check if we should include reasoning content
- should_include_reasoning = False
- if reasoning_content and reasoning_content.strip():
- if not content or content.strip() == "":
- # Case 1: Only reasoning content, should include COT
- should_include_reasoning = True
- final_content = (
- content or ""
- ) # Use empty string if content is None
- else:
- # Case 3: Both content and reasoning_content present, ignore reasoning
- should_include_reasoning = False
- final_content = content
- else:
- # No reasoning content, use regular content
- final_content = content or ""
- # Apply COT wrapping if needed
- if should_include_reasoning:
- if r"\u" in reasoning_content:
- reasoning_content = safe_unicode_decode(
- reasoning_content.encode("utf-8")
- )
- final_content = (
- f"<think>{reasoning_content}</think>{final_content}"
- )
- else:
- # COT disabled, only use regular content
- final_content = content or ""
- # Validate final content
- if not final_content or final_content.strip() == "":
- logger.error("Received empty content from OpenAI API")
- try:
- await openai_async_client.close()
- except Exception as close_error:
- logger.warning(f"Failed to close OpenAI client: {close_error}")
- raise InvalidResponseError("Received empty content from OpenAI API")
- # Apply Unicode decoding to final content if needed
- if r"\u" in final_content:
- final_content = safe_unicode_decode(final_content.encode("utf-8"))
- if token_tracker and hasattr(response, "usage"):
- token_counts = {
- "prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
- "completion_tokens": getattr(
- response.usage, "completion_tokens", 0
- ),
- "total_tokens": getattr(response.usage, "total_tokens", 0),
- }
- token_tracker.add_usage(token_counts)
- logger.debug(f"Response content len: {len(final_content)}")
- verbose_debug(f"Response: {response}")
- return final_content
- finally:
- # Ensure client is closed in all cases for non-streaming responses
- try:
- await openai_async_client.close()
- except Exception as close_error:
- logger.warning(
- f"Failed to close OpenAI client in non-streaming finally block: {close_error}"
- )
- async def openai_complete(
- prompt,
- system_prompt=None,
- history_messages=None,
- keyword_extraction=False,
- entity_extraction=False,
- **kwargs,
- ) -> Union[str, AsyncIterator[str]]:
- if history_messages is None:
- history_messages = []
- # Pop entity_extraction from kwargs if also passed there (avoid duplication)
- entity_extraction = kwargs.pop("entity_extraction", entity_extraction)
- model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
- return await openai_complete_if_cache(
- model_name,
- prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- keyword_extraction=keyword_extraction,
- entity_extraction=entity_extraction,
- **kwargs,
- )
- async def gpt_4o_complete(
- prompt,
- system_prompt=None,
- history_messages=None,
- enable_cot: bool = False,
- keyword_extraction=False,
- entity_extraction=False,
- **kwargs,
- ) -> str:
- if history_messages is None:
- history_messages = []
- entity_extraction = kwargs.pop("entity_extraction", entity_extraction)
- return await openai_complete_if_cache(
- "gpt-4o",
- prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- enable_cot=enable_cot,
- keyword_extraction=keyword_extraction,
- entity_extraction=entity_extraction,
- **kwargs,
- )
- async def gpt_4o_mini_complete(
- prompt,
- system_prompt=None,
- history_messages=None,
- enable_cot: bool = False,
- keyword_extraction=False,
- entity_extraction=False,
- **kwargs,
- ) -> str:
- if history_messages is None:
- history_messages = []
- entity_extraction = kwargs.pop("entity_extraction", entity_extraction)
- return await openai_complete_if_cache(
- "gpt-4o-mini",
- prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- enable_cot=enable_cot,
- keyword_extraction=keyword_extraction,
- entity_extraction=entity_extraction,
- **kwargs,
- )
- async def nvidia_openai_complete(
- prompt,
- system_prompt=None,
- history_messages=None,
- enable_cot: bool = False,
- keyword_extraction=False,
- entity_extraction=False,
- **kwargs,
- ) -> str:
- if history_messages is None:
- history_messages = []
- entity_extraction = kwargs.pop("entity_extraction", entity_extraction)
- result = await openai_complete_if_cache(
- "nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k
- prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- enable_cot=enable_cot,
- keyword_extraction=keyword_extraction,
- entity_extraction=entity_extraction,
- base_url="https://integrate.api.nvidia.com/v1",
- **kwargs,
- )
- return result
- @wrap_embedding_func_with_attrs(
- embedding_dim=1536,
- max_token_size=8192,
- model_name="text-embedding-3-small",
- supports_asymmetric=True,
- )
- @retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=60),
- retry=(
- retry_if_exception_type(RateLimitError)
- | retry_if_exception_type(APIConnectionError)
- | retry_if_exception_type(APITimeoutError)
- # Retry transient HTTP 5xx (OpenAI 500 / proxy upstream errors).
- | retry_if_exception_type(InternalServerError)
- ),
- )
- async def openai_embed(
- texts: list[str],
- model: str = "text-embedding-3-small",
- base_url: str | None = None,
- api_key: str | None = None,
- embedding_dim: int | None = None,
- max_token_size: int | None = None,
- client_configs: dict[str, Any] | None = None,
- token_tracker: Any | None = None,
- use_azure: bool = False,
- azure_deployment: str | None = None,
- api_version: str | None = None,
- context: str = "document",
- query_prefix: str | None = None,
- document_prefix: str | None = None,
- ) -> np.ndarray:
- """Generate embeddings for a list of texts using OpenAI's API with automatic text truncation.
- This function supports both standard OpenAI and Azure OpenAI services. It automatically
- truncates texts that exceed the model's token limit to prevent API errors.
- Args:
- texts: List of texts to embed.
- model: The embedding model to use. For standard OpenAI (e.g., "text-embedding-3-small").
- For Azure, this can be the deployment name.
- base_url: Optional base URL for the API. For standard OpenAI, uses default OpenAI endpoint.
- For Azure, this should be the Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/).
- api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment variable if None.
- For Azure, uses AZURE_EMBEDDING_API_KEY environment variable if None.
- embedding_dim: Optional embedding dimension for dynamic dimension reduction.
- **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
- Do NOT manually pass this parameter when calling the function directly.
- The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
- Manually passing a different value will trigger a warning and be ignored.
- When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction.
- max_token_size: Maximum tokens per text. Texts exceeding this limit will be truncated.
- **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper
- when the underlying function signature supports it (via inspect.signature check).
- The value is controlled by the @wrap_embedding_func_with_attrs decorator.
- Set max_token_size=0 to disable truncation.
- client_configs: Additional configuration options for the AsyncOpenAI/AsyncAzureOpenAI client.
- These will override any default configurations but will be overridden by
- explicit parameters (api_key, base_url). Supports proxy configuration,
- custom headers, retry policies, etc.
- token_tracker: Optional token usage tracker for monitoring API usage.
- use_azure: Whether to use Azure OpenAI service instead of standard OpenAI.
- When True, creates an AsyncAzureOpenAI client. Default is False.
- azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True.
- If not specified, falls back to AZURE_EMBEDDING_DEPLOYMENT environment variable.
- api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used
- when use_azure=True. If not specified, falls back to AZURE_EMBEDDING_API_VERSION
- environment variable.
- context: The embedding context - "query" for search queries, "document" for indexed content.
- **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper
- when supports_asymmetric=True. Default is "document".
- query_prefix: Optional prefix to prepend to texts when context="query" (e.g., "search_query: ").
- document_prefix: Optional prefix to prepend to texts when context="document" (e.g., "search_document: ").
- Returns:
- A numpy array of embeddings, one per input text.
- Raises:
- APIConnectionError: If there is a connection error with the OpenAI API.
- RateLimitError: If the OpenAI API rate limit is exceeded.
- APITimeoutError: If the OpenAI API request times out.
- """
- # Apply context-based prefixes if provided
- if context == "query" and query_prefix:
- texts = [query_prefix + text for text in texts]
- elif context == "document" and document_prefix:
- texts = [document_prefix + text for text in texts]
- # Apply text truncation if max_token_size is provided
- if max_token_size is not None and max_token_size > 0:
- encoding = _get_tiktoken_encoding_for_model(model)
- truncated_texts = []
- truncation_count = 0
- for text in texts:
- if not text:
- truncated_texts.append(text)
- continue
- tokens = encoding.encode(text)
- if len(tokens) > max_token_size:
- truncated_tokens = tokens[:max_token_size]
- truncated_texts.append(encoding.decode(truncated_tokens))
- truncation_count += 1
- logger.debug(
- f"Text truncated from {len(tokens)} to {max_token_size} tokens"
- )
- else:
- truncated_texts.append(text)
- if truncation_count > 0:
- logger.info(
- f"Truncated {truncation_count}/{len(texts)} texts to fit token limit ({max_token_size})"
- )
- texts = truncated_texts
- # Create the OpenAI client (supports both OpenAI and Azure)
- openai_async_client = create_openai_async_client(
- api_key=api_key,
- base_url=base_url,
- use_azure=use_azure,
- azure_deployment=azure_deployment,
- api_version=api_version,
- client_configs=client_configs,
- )
- async with openai_async_client:
- # Determine the correct model identifier to use
- # For Azure OpenAI, we must use the deployment name instead of the model name
- api_model = azure_deployment if use_azure and azure_deployment else model
- # Prepare API call parameters
- api_params = {
- "model": api_model,
- "input": texts,
- }
- # Add encoding_format parameter (some providers like Yandex don't support base64)
- # OpenAI client defaults to base64, so we must explicitly set it to "float" if disabled
- api_params["encoding_format"] = "base64" if EMBEDDING_USE_BASE64 else "float"
- # Add dimensions parameter only if embedding_dim is provided
- if embedding_dim is not None:
- api_params["dimensions"] = embedding_dim
- # Make API call
- response = await openai_async_client.embeddings.create(**api_params)
- if token_tracker and hasattr(response, "usage"):
- token_counts = {
- "prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
- "total_tokens": getattr(response.usage, "total_tokens", 0),
- }
- token_tracker.add_usage(token_counts)
- return np.array(
- [
- np.array(dp.embedding, dtype=np.float32)
- if isinstance(dp.embedding, list)
- else np.frombuffer(base64.b64decode(dp.embedding), dtype=np.float32)
- for dp in response.data
- ]
- )
- # Azure OpenAI wrapper functions for backward compatibility
- async def azure_openai_complete_if_cache(
- model,
- prompt,
- system_prompt: str | None = None,
- history_messages: list[dict[str, Any]] | None = None,
- enable_cot: bool = False,
- base_url: str | None = None,
- api_key: str | None = None,
- token_tracker: Any | None = None,
- stream: bool | None = None,
- timeout: int | None = None,
- api_version: str | None = None,
- keyword_extraction: bool = False,
- **kwargs,
- ):
- """Azure OpenAI completion wrapper function.
- This function provides backward compatibility by wrapping the unified
- openai_complete_if_cache implementation with Azure-specific parameter handling.
- All parameters from the underlying openai_complete_if_cache are exposed to ensure
- full feature parity and API consistency.
- """
- # Handle Azure-specific environment variables and parameters
- deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL")
- base_url = (
- base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST")
- )
- api_key = (
- api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY")
- )
- api_version = (
- api_version
- or os.getenv("AZURE_OPENAI_API_VERSION")
- or os.getenv("OPENAI_API_VERSION")
- or "2024-08-01-preview"
- )
- # Call the unified implementation with Azure-specific parameters
- return await openai_complete_if_cache(
- model=deployment,
- prompt=prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- enable_cot=enable_cot,
- base_url=base_url,
- api_key=api_key,
- token_tracker=token_tracker,
- stream=stream,
- timeout=timeout,
- use_azure=True,
- azure_deployment=deployment,
- api_version=api_version,
- keyword_extraction=keyword_extraction,
- **kwargs,
- )
- async def azure_openai_complete(
- prompt,
- system_prompt=None,
- history_messages=None,
- keyword_extraction=False,
- entity_extraction=False,
- **kwargs,
- ) -> str:
- """Azure OpenAI complete wrapper function.
- Provides backward compatibility for azure_openai_complete calls.
- """
- if history_messages is None:
- history_messages = []
- entity_extraction = kwargs.pop("entity_extraction", entity_extraction)
- result = await azure_openai_complete_if_cache(
- os.getenv("LLM_MODEL", "gpt-4o-mini"),
- prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- keyword_extraction=keyword_extraction,
- entity_extraction=entity_extraction,
- **kwargs,
- )
- return result
- @wrap_embedding_func_with_attrs(
- embedding_dim=1536,
- max_token_size=8192,
- model_name="my-text-embedding-3-large-deployment",
- supports_asymmetric=True,
- )
- async def azure_openai_embed(
- texts: list[str],
- model: str | None = None,
- base_url: str | None = None,
- api_key: str | None = None,
- embedding_dim: int | None = None,
- token_tracker: Any | None = None,
- client_configs: dict[str, Any] | None = None,
- api_version: str | None = None,
- context: str = "document",
- query_prefix: str | None = None,
- document_prefix: str | None = None,
- ) -> np.ndarray:
- """Azure OpenAI embedding wrapper function.
- This function provides backward compatibility by wrapping the unified
- openai_embed implementation with Azure-specific parameter handling.
- All parameters from the underlying openai_embed are exposed to ensure
- full feature parity and API consistency.
- IMPORTANT - Decorator Usage:
- 1. This function is decorated with @wrap_embedding_func_with_attrs to provide
- the EmbeddingFunc interface for users who need to access embedding_dim
- and other attributes.
- 2. This function does NOT use @retry decorator to avoid double-wrapping,
- since the underlying openai_embed.func already has retry logic.
- 3. This function calls openai_embed.func (the unwrapped function) instead of
- openai_embed (the EmbeddingFunc instance) to avoid double decoration issues:
- ✅ Correct: await openai_embed.func(...) # Calls unwrapped function with retry
- ❌ Wrong: await openai_embed(...) # Would cause double EmbeddingFunc wrapping
- Double decoration causes:
- - Double injection of embedding_dim parameter
- - Incorrect parameter passing to the underlying implementation
- - Runtime errors due to parameter conflicts
- The call chain with correct implementation:
- azure_openai_embed(texts)
- → EmbeddingFunc.__call__(texts) # azure's decorator
- → azure_openai_embed_impl(texts, embedding_dim=1536)
- → openai_embed.func(texts, ...)
- → @retry_wrapper(texts, ...) # openai's retry (only one layer)
- → openai_embed_impl(texts, ...)
- → actual embedding computation
- """
- # Handle Azure-specific environment variables and parameters
- deployment = (
- os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
- or model
- or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
- )
- base_url = (
- base_url
- or os.getenv("AZURE_EMBEDDING_ENDPOINT")
- or os.getenv("EMBEDDING_BINDING_HOST")
- )
- api_key = (
- api_key
- or os.getenv("AZURE_EMBEDDING_API_KEY")
- or os.getenv("EMBEDDING_BINDING_API_KEY")
- )
- api_version = (
- api_version
- or os.getenv("AZURE_EMBEDDING_API_VERSION")
- or os.getenv("AZURE_OPENAI_API_VERSION")
- or os.getenv("OPENAI_API_VERSION")
- or "2024-08-01-preview"
- )
- # CRITICAL: Call openai_embed.func (unwrapped) to avoid double decoration
- # openai_embed is an EmbeddingFunc instance, .func accesses the underlying function
- return await openai_embed.func(
- texts=texts,
- model=deployment,
- base_url=base_url,
- api_key=api_key,
- embedding_dim=embedding_dim,
- token_tracker=token_tracker,
- client_configs=client_configs,
- use_azure=True,
- azure_deployment=deployment,
- api_version=api_version,
- context=context,
- query_prefix=query_prefix,
- document_prefix=document_prefix,
- )
|