| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751 |
- """
- Gemini LLM binding for LightRAG.
- This module provides asynchronous helpers that adapt Google's Gemini models
- to the same interface used by the rest of the LightRAG LLM bindings. The
- implementation mirrors the OpenAI helpers while relying on the official
- ``google-genai`` client under the hood.
- """
- from __future__ import annotations
- import os
- import warnings
- from collections.abc import AsyncIterator
- from functools import lru_cache
- from typing import Any
- import numpy as np
- from tenacity import (
- retry,
- stop_after_attempt,
- wait_exponential,
- retry_if_exception_type,
- )
- from lightrag.utils import (
- logger,
- remove_think_tags,
- safe_unicode_decode,
- wrap_embedding_func_with_attrs,
- )
- import pipmaster as pm
- # Install the Google Gemini client and its dependencies on demand
- if not pm.is_installed("google-genai"):
- pm.install("google-genai")
- if not pm.is_installed("google-api-core"):
- pm.install("google-api-core")
- from google import genai # type: ignore
- from google.genai import types # type: ignore
- from google.api_core import exceptions as google_api_exceptions # type: ignore
- class InvalidResponseError(Exception):
- """Custom exception class for triggering retry mechanism when Gemini returns empty responses"""
- pass
- _DEFAULT_GEMINI_BASE_URLS = {
- "https://generativelanguage.googleapis.com",
- "https://generativelanguage.googleapis.com/",
- "https://generativelanguage.googleapis.com/v1beta",
- "https://generativelanguage.googleapis.com/v1beta/",
- "https://generativelanguage.googleapis.com/v1",
- "https://generativelanguage.googleapis.com/v1/",
- }
- def _normalize_gemini_base_url(base_url: str | None) -> str | None:
- """Treat Google's default Gemini API service roots as SDK defaults."""
- if not base_url:
- return None
- normalized = base_url.strip()
- if not normalized or normalized == "DEFAULT_GEMINI_ENDPOINT":
- return None
- if normalized.rstrip("/") in {
- service_root.rstrip("/") for service_root in _DEFAULT_GEMINI_BASE_URLS
- }:
- return None
- return normalized
- @lru_cache(maxsize=8)
- def _get_gemini_client(
- api_key: str, base_url: str | None, timeout: int | None = None
- ) -> genai.Client:
- """
- Create (or fetch cached) Gemini client.
- Args:
- api_key: Google Gemini API key (not used in Vertex AI mode).
- base_url: Optional custom API endpoint.
- timeout: Optional request timeout in milliseconds.
- Returns:
- genai.Client: Configured Gemini client instance.
- """
- client_kwargs: dict[str, Any] = {}
- normalized_base_url = _normalize_gemini_base_url(base_url)
- # Add Vertex AI support
- use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
- if use_vertexai:
- # Vertex AI mode: use project/location, NOT api_key
- client_kwargs["vertexai"] = True
- project = os.getenv("GOOGLE_CLOUD_PROJECT")
- if project:
- location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1")
- client_kwargs["project"] = project
- if location:
- client_kwargs["location"] = location
- else:
- raise ValueError(
- "GOOGLE_CLOUD_PROJECT must be set when using Vertex AI mode"
- )
- else:
- # Standard Gemini API mode: use api_key
- client_kwargs["api_key"] = api_key
- if normalized_base_url is not None or timeout is not None:
- try:
- http_options_kwargs = {}
- if normalized_base_url is not None:
- http_options_kwargs["base_url"] = normalized_base_url
- if timeout is not None:
- http_options_kwargs["timeout"] = timeout
- client_kwargs["http_options"] = types.HttpOptions(**http_options_kwargs)
- except Exception as e:
- logger.error("Failed to apply custom Gemini http_options: %s", e)
- raise e
- return genai.Client(**client_kwargs)
- def _ensure_api_key(api_key: str | None) -> str:
- # In Vertex AI mode, API key is not required
- use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
- if use_vertexai:
- # Return empty string for Vertex AI mode (not used)
- return ""
- key = api_key or os.getenv("LLM_BINDING_API_KEY") or os.getenv("GEMINI_API_KEY")
- if not key:
- raise ValueError(
- "Gemini API key not provided. "
- "Set LLM_BINDING_API_KEY or GEMINI_API_KEY in the environment."
- )
- return key
- def _build_generation_config(
- base_config: dict[str, Any] | None,
- system_prompt: str | None,
- response_format: Any | None,
- ) -> types.GenerateContentConfig | None:
- config_data = dict(base_config or {})
- if system_prompt:
- if config_data.get("system_instruction"):
- config_data["system_instruction"] = (
- f"{config_data['system_instruction']}\n{system_prompt}"
- )
- else:
- config_data["system_instruction"] = system_prompt
- # Translate response_format to Gemini's native generation config fields.
- if response_format is not None:
- config_data.setdefault("response_mime_type", "application/json")
- schema = _normalize_gemini_response_schema(response_format)
- if schema is not None and "response_json_schema" not in config_data:
- config_data["response_json_schema"] = schema
- # Remove entries that are explicitly set to None to avoid type errors
- sanitized = {
- key: value
- for key, value in config_data.items()
- if value is not None and value != ""
- }
- if not sanitized:
- return None
- return types.GenerateContentConfig(**sanitized)
- def _normalize_gemini_response_schema(response_format: Any) -> Any | None:
- """Extract a Gemini-compatible JSON schema from LightRAG/OpenAI inputs."""
- if response_format is None:
- return None
- if isinstance(response_format, dict):
- if response_format.get("type") == "json_object":
- return None
- if response_format.get("type") == "json_schema":
- json_schema = response_format.get("json_schema")
- if isinstance(json_schema, dict):
- schema = json_schema.get("schema")
- if isinstance(schema, dict):
- return schema
- return json_schema
- return response_format
- return response_format
- def _validate_gemini_response_format(response_format: Any | None) -> None:
- """Reject typed structured-output helpers; only dict payloads are supported."""
- if response_format is None or isinstance(response_format, dict):
- return
- raise TypeError(
- "gemini_complete_if_cache only supports dict response_format payloads; "
- "typed/Pydantic response_format values are not supported."
- )
- def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> str:
- if not history_messages:
- return ""
- history_lines: list[str] = []
- for message in history_messages:
- role = message.get("role", "user")
- content = message.get("content", "")
- history_lines.append(f"[{role}] {content}")
- return "\n".join(history_lines)
- def _extract_response_text(
- response: Any, extract_thoughts: bool = False
- ) -> tuple[str, str]:
- """
- Extract text content from Gemini response, separating regular content from thoughts.
- Args:
- response: Gemini API response object
- extract_thoughts: Whether to extract thought content separately
- Returns:
- Tuple of (regular_text, thought_text)
- """
- candidates = getattr(response, "candidates", None)
- if not candidates:
- return ("", "")
- regular_parts: list[str] = []
- thought_parts: list[str] = []
- for candidate in candidates:
- if not getattr(candidate, "content", None):
- continue
- # Use 'or []' to handle None values from parts attribute
- for part in getattr(candidate.content, "parts", None) or []:
- text = getattr(part, "text", None)
- if not text:
- continue
- # Check if this part is thought content using the 'thought' attribute
- is_thought = getattr(part, "thought", False)
- if is_thought and extract_thoughts:
- thought_parts.append(text)
- elif not is_thought:
- regular_parts.append(text)
- return ("\n".join(regular_parts), "\n".join(thought_parts))
- @retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=60),
- retry=(
- retry_if_exception_type(google_api_exceptions.InternalServerError)
- | retry_if_exception_type(google_api_exceptions.ServiceUnavailable)
- | retry_if_exception_type(google_api_exceptions.ResourceExhausted)
- | retry_if_exception_type(google_api_exceptions.GatewayTimeout)
- | retry_if_exception_type(google_api_exceptions.BadGateway)
- | retry_if_exception_type(google_api_exceptions.DeadlineExceeded)
- | retry_if_exception_type(google_api_exceptions.Aborted)
- | retry_if_exception_type(google_api_exceptions.Unknown)
- | retry_if_exception_type(InvalidResponseError)
- ),
- )
- async def gemini_complete_if_cache(
- model: str,
- prompt: str,
- system_prompt: str | None = None,
- history_messages: list[dict[str, Any]] | None = None,
- enable_cot: bool = False,
- base_url: str | None = None,
- api_key: str | None = None,
- token_tracker: Any | None = None,
- stream: bool | None = None,
- response_format: Any | None = None,
- keyword_extraction: bool = False,
- entity_extraction: bool = False,
- generation_config: dict[str, Any] | None = None,
- timeout: int | None = None,
- image_inputs: list[Any] | None = None,
- **_: Any,
- ) -> str | AsyncIterator[str]:
- """
- Complete a prompt using Gemini's API with Chain of Thought (COT) support.
- This function supports automatic integration of reasoning content from Gemini models
- that provide Chain of Thought capabilities via the thinking_config API feature.
- Structured output note:
- - This adapter accepts OpenAI-style ``response_format`` and translates it
- to Gemini's native generation config fields.
- - ``response_format={"type": "json_object"}`` maps to
- ``response_mime_type="application/json"``.
- - Dict-form ``json_schema`` payloads map to
- ``response_mime_type="application/json"`` plus
- ``response_json_schema=<schema>``.
- - Typed/Pydantic ``response_format`` helpers are rejected explicitly.
- - Deprecated ``keyword_extraction`` and ``entity_extraction`` booleans are
- compatibility shims; when no explicit ``response_format`` is supplied,
- they are mapped to ``{"type": "json_object"}``.
- COT Integration:
- - When enable_cot=True: Thought content is wrapped in <think>...</think> tags
- - When enable_cot=False: Thought content is filtered out, only regular content returned
- - Thought content is identified by the 'thought' attribute on response parts
- - Requires thinking_config to be enabled in generation_config for API to return thoughts
- Args:
- model: The Gemini model to use.
- prompt: The prompt to complete.
- system_prompt: Optional system prompt to include.
- history_messages: Optional list of previous messages in the conversation.
- api_key: Optional Gemini API key. If None, uses environment variable.
- base_url: Optional custom API endpoint.
- generation_config: Optional generation configuration dict.
- response_format: OpenAI-style structured output control translated to
- Gemini generation config. ``{"type": "json_object"}`` maps to
- ``response_mime_type="application/json"``; dict-form
- ``json_schema`` payloads map to ``response_json_schema``.
- Typed/Pydantic response_format values are rejected.
- token_tracker: Optional token usage tracker for monitoring API usage.
- stream: Whether to stream the response.
- hashing_kv: Storage interface (for interface parity with other bindings).
- enable_cot: Whether to include Chain of Thought content in the response.
- timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
- **_: Additional keyword arguments (ignored).
- Returns:
- The completed text (with COT content if enable_cot=True) or an async iterator
- of text chunks if streaming. COT content is wrapped in <think>...</think> tags.
- Raises:
- RuntimeError: If the response from Gemini is empty.
- ValueError: If API key is not provided or configured.
- """
- key = _ensure_api_key(api_key)
- # Convert timeout from seconds to milliseconds for Gemini API
- timeout_ms = timeout * 1000 if timeout else None
- client = _get_gemini_client(key, base_url, timeout_ms)
- # Deprecation shims: map legacy boolean flags to response_format only when
- # an explicit response_format was not supplied.
- if response_format is None:
- if entity_extraction:
- warnings.warn(
- "gemini_complete_if_cache(entity_extraction=True) is deprecated; "
- "pass response_format={'type': 'json_object'} instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- response_format = {"type": "json_object"}
- elif keyword_extraction:
- warnings.warn(
- "gemini_complete_if_cache(keyword_extraction=True) is deprecated; "
- "pass response_format={'type': 'json_object'} instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- response_format = {"type": "json_object"}
- _validate_gemini_response_format(response_format)
- if response_format is not None:
- enable_cot = False
- history_block = _format_history_messages(history_messages)
- prompt_sections = []
- if history_block:
- prompt_sections.append(history_block)
- prompt_sections.append(f"[user] {prompt}")
- combined_prompt = "\n".join(prompt_sections)
- config_obj = _build_generation_config(
- generation_config,
- system_prompt=system_prompt,
- response_format=response_format,
- )
- if image_inputs:
- from lightrag.llm._vision_utils import normalize_image_inputs
- normalized_images = normalize_image_inputs(image_inputs)
- parts: list[Any] = [combined_prompt]
- parts.extend(
- types.Part.from_bytes(data=img.raw_bytes, mime_type=img.mime_type)
- for img in normalized_images
- )
- contents: list[Any] = [parts]
- else:
- contents = [combined_prompt]
- request_kwargs: dict[str, Any] = {
- "model": model,
- "contents": contents,
- }
- if config_obj is not None:
- request_kwargs["config"] = config_obj
- if stream:
- async def _async_stream() -> AsyncIterator[str]:
- # COT state tracking for streaming
- cot_active = False
- cot_started = False
- initial_content_seen = False
- usage_metadata = None
- try:
- # Use native async streaming from genai SDK
- # Note: generate_content_stream returns Awaitable[AsyncIterator], need to await first
- stream_iter = await client.aio.models.generate_content_stream(
- **request_kwargs
- )
- async for chunk in stream_iter:
- usage = getattr(chunk, "usage_metadata", None)
- if usage is not None:
- usage_metadata = usage
- # Extract both regular and thought content
- regular_text, thought_text = _extract_response_text(
- chunk, extract_thoughts=True
- )
- if enable_cot:
- # Process regular content
- if regular_text:
- if not initial_content_seen:
- initial_content_seen = True
- # Close COT section if it was active
- if cot_active:
- yield "</think>"
- cot_active = False
- # Process and yield regular content
- if "\\u" in regular_text:
- regular_text = safe_unicode_decode(
- regular_text.encode("utf-8")
- )
- yield regular_text
- # Process thought content
- if thought_text:
- if not initial_content_seen and not cot_started:
- # Start COT section
- yield "<think>"
- cot_active = True
- cot_started = True
- # Yield thought content if COT is active
- if cot_active:
- if "\\u" in thought_text:
- thought_text = safe_unicode_decode(
- thought_text.encode("utf-8")
- )
- yield thought_text
- else:
- # COT disabled - only yield regular content
- if regular_text:
- if "\\u" in regular_text:
- regular_text = safe_unicode_decode(
- regular_text.encode("utf-8")
- )
- yield regular_text
- # Ensure COT is properly closed if still active
- if cot_active:
- yield "</think>"
- cot_active = False
- except Exception:
- # Try to close COT tag before re-raising
- if cot_active:
- try:
- yield "</think>"
- except Exception:
- pass
- raise
- finally:
- # Track token usage after streaming completes
- if token_tracker and usage_metadata:
- token_tracker.add_usage(
- {
- "prompt_tokens": getattr(
- usage_metadata, "prompt_token_count", 0
- ),
- "completion_tokens": getattr(
- usage_metadata, "candidates_token_count", 0
- ),
- "total_tokens": getattr(
- usage_metadata, "total_token_count", 0
- ),
- }
- )
- return _async_stream()
- # Non-streaming: use native async client
- response = await client.aio.models.generate_content(**request_kwargs)
- # Extract both regular text and thought text
- regular_text, thought_text = _extract_response_text(response, extract_thoughts=True)
- # Apply COT filtering logic based on enable_cot parameter
- if enable_cot:
- # Include thought content wrapped in <think> tags
- if thought_text and thought_text.strip():
- if not regular_text or regular_text.strip() == "":
- # Only thought content available
- final_text = f"<think>{thought_text}</think>"
- else:
- # Both content types present: prepend thought to regular content
- final_text = f"<think>{thought_text}</think>{regular_text}"
- else:
- # No thought content, use regular content only
- final_text = regular_text or ""
- else:
- # Filter out thought content, return only regular content
- final_text = regular_text or ""
- if not final_text:
- raise InvalidResponseError("Gemini response did not contain any text content.")
- if "\\u" in final_text:
- final_text = safe_unicode_decode(final_text.encode("utf-8"))
- final_text = remove_think_tags(final_text)
- usage = getattr(response, "usage_metadata", None)
- if token_tracker and usage:
- token_tracker.add_usage(
- {
- "prompt_tokens": getattr(usage, "prompt_token_count", 0),
- "completion_tokens": getattr(usage, "candidates_token_count", 0),
- "total_tokens": getattr(usage, "total_token_count", 0),
- }
- )
- logger.debug("Gemini response length: %s", len(final_text))
- return final_text
- async def gemini_model_complete(
- prompt: str,
- system_prompt: str | None = None,
- history_messages: list[dict[str, Any]] | None = None,
- response_format: Any | None = None,
- keyword_extraction: bool = False,
- entity_extraction: bool = False,
- **kwargs: Any,
- ) -> str | AsyncIterator[str]:
- # Accept legacy keyword if passed via kwargs to preserve backwards compat.
- entity_extraction = kwargs.pop("entity_extraction", entity_extraction)
- hashing_kv = kwargs.get("hashing_kv")
- model_name = None
- if hashing_kv is not None:
- model_name = hashing_kv.global_config.get("llm_model_name")
- if model_name is None:
- model_name = kwargs.pop("model_name", None)
- if model_name is None:
- raise ValueError("Gemini model name not provided in configuration.")
- return await gemini_complete_if_cache(
- model_name,
- prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- response_format=response_format,
- keyword_extraction=keyword_extraction,
- entity_extraction=entity_extraction,
- **kwargs,
- )
- @wrap_embedding_func_with_attrs(
- embedding_dim=1536,
- max_token_size=2048,
- model_name="gemini-embedding-001",
- supports_asymmetric=True,
- )
- @retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=60),
- retry=(
- retry_if_exception_type(google_api_exceptions.InternalServerError)
- | retry_if_exception_type(google_api_exceptions.ServiceUnavailable)
- | retry_if_exception_type(google_api_exceptions.ResourceExhausted)
- | retry_if_exception_type(google_api_exceptions.GatewayTimeout)
- | retry_if_exception_type(google_api_exceptions.BadGateway)
- | retry_if_exception_type(google_api_exceptions.DeadlineExceeded)
- | retry_if_exception_type(google_api_exceptions.Aborted)
- | retry_if_exception_type(google_api_exceptions.Unknown)
- ),
- )
- async def gemini_embed(
- texts: list[str],
- model: str = "gemini-embedding-001",
- base_url: str | None = None,
- api_key: str | None = None,
- embedding_dim: int | None = None,
- max_token_size: int | None = None,
- task_type: str | None = None,
- timeout: int | None = None,
- token_tracker: Any | None = None,
- context: str = "document",
- ) -> np.ndarray:
- """Generate embeddings for a list of texts using Gemini's API.
- This function uses Google's Gemini embedding model to generate text embeddings.
- It supports dynamic dimension control and automatic normalization for dimensions
- less than 3072.
- Args:
- texts: List of texts to embed.
- model: The Gemini embedding model to use. Default is "gemini-embedding-001".
- base_url: Optional custom API endpoint.
- api_key: Optional Gemini API key. If None, uses environment variables.
- embedding_dim: Optional embedding dimension for dynamic dimension reduction.
- **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
- Do NOT manually pass this parameter when calling the function directly.
- The dimension is controlled by the @wrap_embedding_func_with_attrs decorator
- or the EMBEDDING_DIM environment variable.
- Supported range: 128-3072. Recommended values: 768, 1536, 3072.
- max_token_size: Maximum tokens per text. This parameter is automatically
- injected by the EmbeddingFunc wrapper when the underlying function
- signature supports it (via inspect.signature check). Gemini API will
- automatically truncate texts exceeding this limit (autoTruncate=True
- by default), so no client-side truncation is needed.
- task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT".
- Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING,
- RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY,
- QUESTION_ANSWERING, FACT_VERIFICATION.
- timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
- token_tracker: Optional token usage tracker for monitoring API usage.
- context: The embedding context - "query" for search queries, "document" for indexed content.
- **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper
- when supports_asymmetric=True. Default is "document".
- Returns:
- A numpy array of embeddings, one per input text. For dimensions < 3072,
- the embeddings are L2-normalized to ensure optimal semantic similarity performance.
- Raises:
- ValueError: If API key is not provided or configured.
- RuntimeError: If the response from Gemini is invalid or empty.
- Note:
- - For dimension 3072: Embeddings are already normalized by the API
- - For dimensions < 3072: Embeddings are L2-normalized after retrieval
- - Normalization ensures accurate semantic similarity via cosine distance
- - Gemini API automatically truncates texts exceeding max_token_size (autoTruncate=True)
- """
- # Note: max_token_size is received but not used for client-side truncation.
- # Gemini API handles truncation automatically with autoTruncate=True (default).
- _ = max_token_size # Acknowledge parameter to avoid unused variable warning
- key = _ensure_api_key(api_key)
- # Convert timeout from seconds to milliseconds for Gemini API
- timeout_ms = timeout * 1000 if timeout else None
- client = _get_gemini_client(key, base_url, timeout_ms)
- # Prepare embedding configuration
- config_kwargs: dict[str, Any] = {}
- # Add task_type to config
- if task_type is None:
- if context == "query":
- task_type = "RETRIEVAL_QUERY"
- elif context == "document":
- task_type = "RETRIEVAL_DOCUMENT"
- else:
- task_type = "RETRIEVAL_DOCUMENT" # Default for backward compatibility
- config_kwargs["task_type"] = task_type
- # Add output_dimensionality if embedding_dim is provided
- if embedding_dim is not None:
- config_kwargs["output_dimensionality"] = embedding_dim
- # Create config object if we have parameters
- config_obj = types.EmbedContentConfig(**config_kwargs) if config_kwargs else None
- request_kwargs: dict[str, Any] = {
- "model": model,
- "contents": texts,
- }
- if config_obj is not None:
- request_kwargs["config"] = config_obj
- # Use native async client for embedding
- response = await client.aio.models.embed_content(**request_kwargs)
- # Extract embeddings from response
- if not hasattr(response, "embeddings") or not response.embeddings:
- raise RuntimeError("Gemini response did not contain embeddings.")
- # Convert embeddings to numpy array
- embeddings = np.array(
- [np.array(e.values, dtype=np.float32) for e in response.embeddings]
- )
- # Apply L2 normalization for dimensions < 3072
- # The 3072 dimension embedding is already normalized by Gemini API
- if embedding_dim and embedding_dim < 3072:
- # Normalize each embedding vector to unit length
- norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
- # Avoid division by zero
- norms = np.where(norms == 0, 1, norms)
- embeddings = embeddings / norms
- logger.debug(
- f"Applied L2 normalization to {len(embeddings)} embeddings of dimension {embedding_dim}"
- )
- # Track token usage if tracker is provided
- # Note: Gemini embedding API may not provide usage metadata
- if token_tracker and hasattr(response, "usage_metadata"):
- usage = response.usage_metadata
- token_counts = {
- "prompt_tokens": getattr(usage, "prompt_token_count", 0),
- "total_tokens": getattr(usage, "total_token_count", 0),
- }
- token_tracker.add_usage(token_counts)
- logger.debug(
- f"Generated {len(embeddings)} Gemini embeddings with dimension {embeddings.shape[1]}"
- )
- return embeddings
- __all__ = [
- "gemini_complete_if_cache",
- "gemini_model_complete",
- "gemini_embed",
- ]
|