gemini.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751
  1. """
  2. Gemini LLM binding for LightRAG.
  3. This module provides asynchronous helpers that adapt Google's Gemini models
  4. to the same interface used by the rest of the LightRAG LLM bindings. The
  5. implementation mirrors the OpenAI helpers while relying on the official
  6. ``google-genai`` client under the hood.
  7. """
  8. from __future__ import annotations
  9. import os
  10. import warnings
  11. from collections.abc import AsyncIterator
  12. from functools import lru_cache
  13. from typing import Any
  14. import numpy as np
  15. from tenacity import (
  16. retry,
  17. stop_after_attempt,
  18. wait_exponential,
  19. retry_if_exception_type,
  20. )
  21. from lightrag.utils import (
  22. logger,
  23. remove_think_tags,
  24. safe_unicode_decode,
  25. wrap_embedding_func_with_attrs,
  26. )
  27. import pipmaster as pm
  28. # Install the Google Gemini client and its dependencies on demand
  29. if not pm.is_installed("google-genai"):
  30. pm.install("google-genai")
  31. if not pm.is_installed("google-api-core"):
  32. pm.install("google-api-core")
  33. from google import genai # type: ignore
  34. from google.genai import types # type: ignore
  35. from google.api_core import exceptions as google_api_exceptions # type: ignore
  36. class InvalidResponseError(Exception):
  37. """Custom exception class for triggering retry mechanism when Gemini returns empty responses"""
  38. pass
  39. _DEFAULT_GEMINI_BASE_URLS = {
  40. "https://generativelanguage.googleapis.com",
  41. "https://generativelanguage.googleapis.com/",
  42. "https://generativelanguage.googleapis.com/v1beta",
  43. "https://generativelanguage.googleapis.com/v1beta/",
  44. "https://generativelanguage.googleapis.com/v1",
  45. "https://generativelanguage.googleapis.com/v1/",
  46. }
  47. def _normalize_gemini_base_url(base_url: str | None) -> str | None:
  48. """Treat Google's default Gemini API service roots as SDK defaults."""
  49. if not base_url:
  50. return None
  51. normalized = base_url.strip()
  52. if not normalized or normalized == "DEFAULT_GEMINI_ENDPOINT":
  53. return None
  54. if normalized.rstrip("/") in {
  55. service_root.rstrip("/") for service_root in _DEFAULT_GEMINI_BASE_URLS
  56. }:
  57. return None
  58. return normalized
  59. @lru_cache(maxsize=8)
  60. def _get_gemini_client(
  61. api_key: str, base_url: str | None, timeout: int | None = None
  62. ) -> genai.Client:
  63. """
  64. Create (or fetch cached) Gemini client.
  65. Args:
  66. api_key: Google Gemini API key (not used in Vertex AI mode).
  67. base_url: Optional custom API endpoint.
  68. timeout: Optional request timeout in milliseconds.
  69. Returns:
  70. genai.Client: Configured Gemini client instance.
  71. """
  72. client_kwargs: dict[str, Any] = {}
  73. normalized_base_url = _normalize_gemini_base_url(base_url)
  74. # Add Vertex AI support
  75. use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
  76. if use_vertexai:
  77. # Vertex AI mode: use project/location, NOT api_key
  78. client_kwargs["vertexai"] = True
  79. project = os.getenv("GOOGLE_CLOUD_PROJECT")
  80. if project:
  81. location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1")
  82. client_kwargs["project"] = project
  83. if location:
  84. client_kwargs["location"] = location
  85. else:
  86. raise ValueError(
  87. "GOOGLE_CLOUD_PROJECT must be set when using Vertex AI mode"
  88. )
  89. else:
  90. # Standard Gemini API mode: use api_key
  91. client_kwargs["api_key"] = api_key
  92. if normalized_base_url is not None or timeout is not None:
  93. try:
  94. http_options_kwargs = {}
  95. if normalized_base_url is not None:
  96. http_options_kwargs["base_url"] = normalized_base_url
  97. if timeout is not None:
  98. http_options_kwargs["timeout"] = timeout
  99. client_kwargs["http_options"] = types.HttpOptions(**http_options_kwargs)
  100. except Exception as e:
  101. logger.error("Failed to apply custom Gemini http_options: %s", e)
  102. raise e
  103. return genai.Client(**client_kwargs)
  104. def _ensure_api_key(api_key: str | None) -> str:
  105. # In Vertex AI mode, API key is not required
  106. use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
  107. if use_vertexai:
  108. # Return empty string for Vertex AI mode (not used)
  109. return ""
  110. key = api_key or os.getenv("LLM_BINDING_API_KEY") or os.getenv("GEMINI_API_KEY")
  111. if not key:
  112. raise ValueError(
  113. "Gemini API key not provided. "
  114. "Set LLM_BINDING_API_KEY or GEMINI_API_KEY in the environment."
  115. )
  116. return key
  117. def _build_generation_config(
  118. base_config: dict[str, Any] | None,
  119. system_prompt: str | None,
  120. response_format: Any | None,
  121. ) -> types.GenerateContentConfig | None:
  122. config_data = dict(base_config or {})
  123. if system_prompt:
  124. if config_data.get("system_instruction"):
  125. config_data["system_instruction"] = (
  126. f"{config_data['system_instruction']}\n{system_prompt}"
  127. )
  128. else:
  129. config_data["system_instruction"] = system_prompt
  130. # Translate response_format to Gemini's native generation config fields.
  131. if response_format is not None:
  132. config_data.setdefault("response_mime_type", "application/json")
  133. schema = _normalize_gemini_response_schema(response_format)
  134. if schema is not None and "response_json_schema" not in config_data:
  135. config_data["response_json_schema"] = schema
  136. # Remove entries that are explicitly set to None to avoid type errors
  137. sanitized = {
  138. key: value
  139. for key, value in config_data.items()
  140. if value is not None and value != ""
  141. }
  142. if not sanitized:
  143. return None
  144. return types.GenerateContentConfig(**sanitized)
  145. def _normalize_gemini_response_schema(response_format: Any) -> Any | None:
  146. """Extract a Gemini-compatible JSON schema from LightRAG/OpenAI inputs."""
  147. if response_format is None:
  148. return None
  149. if isinstance(response_format, dict):
  150. if response_format.get("type") == "json_object":
  151. return None
  152. if response_format.get("type") == "json_schema":
  153. json_schema = response_format.get("json_schema")
  154. if isinstance(json_schema, dict):
  155. schema = json_schema.get("schema")
  156. if isinstance(schema, dict):
  157. return schema
  158. return json_schema
  159. return response_format
  160. return response_format
  161. def _validate_gemini_response_format(response_format: Any | None) -> None:
  162. """Reject typed structured-output helpers; only dict payloads are supported."""
  163. if response_format is None or isinstance(response_format, dict):
  164. return
  165. raise TypeError(
  166. "gemini_complete_if_cache only supports dict response_format payloads; "
  167. "typed/Pydantic response_format values are not supported."
  168. )
  169. def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> str:
  170. if not history_messages:
  171. return ""
  172. history_lines: list[str] = []
  173. for message in history_messages:
  174. role = message.get("role", "user")
  175. content = message.get("content", "")
  176. history_lines.append(f"[{role}] {content}")
  177. return "\n".join(history_lines)
  178. def _extract_response_text(
  179. response: Any, extract_thoughts: bool = False
  180. ) -> tuple[str, str]:
  181. """
  182. Extract text content from Gemini response, separating regular content from thoughts.
  183. Args:
  184. response: Gemini API response object
  185. extract_thoughts: Whether to extract thought content separately
  186. Returns:
  187. Tuple of (regular_text, thought_text)
  188. """
  189. candidates = getattr(response, "candidates", None)
  190. if not candidates:
  191. return ("", "")
  192. regular_parts: list[str] = []
  193. thought_parts: list[str] = []
  194. for candidate in candidates:
  195. if not getattr(candidate, "content", None):
  196. continue
  197. # Use 'or []' to handle None values from parts attribute
  198. for part in getattr(candidate.content, "parts", None) or []:
  199. text = getattr(part, "text", None)
  200. if not text:
  201. continue
  202. # Check if this part is thought content using the 'thought' attribute
  203. is_thought = getattr(part, "thought", False)
  204. if is_thought and extract_thoughts:
  205. thought_parts.append(text)
  206. elif not is_thought:
  207. regular_parts.append(text)
  208. return ("\n".join(regular_parts), "\n".join(thought_parts))
  209. @retry(
  210. stop=stop_after_attempt(3),
  211. wait=wait_exponential(multiplier=1, min=4, max=60),
  212. retry=(
  213. retry_if_exception_type(google_api_exceptions.InternalServerError)
  214. | retry_if_exception_type(google_api_exceptions.ServiceUnavailable)
  215. | retry_if_exception_type(google_api_exceptions.ResourceExhausted)
  216. | retry_if_exception_type(google_api_exceptions.GatewayTimeout)
  217. | retry_if_exception_type(google_api_exceptions.BadGateway)
  218. | retry_if_exception_type(google_api_exceptions.DeadlineExceeded)
  219. | retry_if_exception_type(google_api_exceptions.Aborted)
  220. | retry_if_exception_type(google_api_exceptions.Unknown)
  221. | retry_if_exception_type(InvalidResponseError)
  222. ),
  223. )
  224. async def gemini_complete_if_cache(
  225. model: str,
  226. prompt: str,
  227. system_prompt: str | None = None,
  228. history_messages: list[dict[str, Any]] | None = None,
  229. enable_cot: bool = False,
  230. base_url: str | None = None,
  231. api_key: str | None = None,
  232. token_tracker: Any | None = None,
  233. stream: bool | None = None,
  234. response_format: Any | None = None,
  235. keyword_extraction: bool = False,
  236. entity_extraction: bool = False,
  237. generation_config: dict[str, Any] | None = None,
  238. timeout: int | None = None,
  239. image_inputs: list[Any] | None = None,
  240. **_: Any,
  241. ) -> str | AsyncIterator[str]:
  242. """
  243. Complete a prompt using Gemini's API with Chain of Thought (COT) support.
  244. This function supports automatic integration of reasoning content from Gemini models
  245. that provide Chain of Thought capabilities via the thinking_config API feature.
  246. Structured output note:
  247. - This adapter accepts OpenAI-style ``response_format`` and translates it
  248. to Gemini's native generation config fields.
  249. - ``response_format={"type": "json_object"}`` maps to
  250. ``response_mime_type="application/json"``.
  251. - Dict-form ``json_schema`` payloads map to
  252. ``response_mime_type="application/json"`` plus
  253. ``response_json_schema=<schema>``.
  254. - Typed/Pydantic ``response_format`` helpers are rejected explicitly.
  255. - Deprecated ``keyword_extraction`` and ``entity_extraction`` booleans are
  256. compatibility shims; when no explicit ``response_format`` is supplied,
  257. they are mapped to ``{"type": "json_object"}``.
  258. COT Integration:
  259. - When enable_cot=True: Thought content is wrapped in <think>...</think> tags
  260. - When enable_cot=False: Thought content is filtered out, only regular content returned
  261. - Thought content is identified by the 'thought' attribute on response parts
  262. - Requires thinking_config to be enabled in generation_config for API to return thoughts
  263. Args:
  264. model: The Gemini model to use.
  265. prompt: The prompt to complete.
  266. system_prompt: Optional system prompt to include.
  267. history_messages: Optional list of previous messages in the conversation.
  268. api_key: Optional Gemini API key. If None, uses environment variable.
  269. base_url: Optional custom API endpoint.
  270. generation_config: Optional generation configuration dict.
  271. response_format: OpenAI-style structured output control translated to
  272. Gemini generation config. ``{"type": "json_object"}`` maps to
  273. ``response_mime_type="application/json"``; dict-form
  274. ``json_schema`` payloads map to ``response_json_schema``.
  275. Typed/Pydantic response_format values are rejected.
  276. token_tracker: Optional token usage tracker for monitoring API usage.
  277. stream: Whether to stream the response.
  278. hashing_kv: Storage interface (for interface parity with other bindings).
  279. enable_cot: Whether to include Chain of Thought content in the response.
  280. timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
  281. **_: Additional keyword arguments (ignored).
  282. Returns:
  283. The completed text (with COT content if enable_cot=True) or an async iterator
  284. of text chunks if streaming. COT content is wrapped in <think>...</think> tags.
  285. Raises:
  286. RuntimeError: If the response from Gemini is empty.
  287. ValueError: If API key is not provided or configured.
  288. """
  289. key = _ensure_api_key(api_key)
  290. # Convert timeout from seconds to milliseconds for Gemini API
  291. timeout_ms = timeout * 1000 if timeout else None
  292. client = _get_gemini_client(key, base_url, timeout_ms)
  293. # Deprecation shims: map legacy boolean flags to response_format only when
  294. # an explicit response_format was not supplied.
  295. if response_format is None:
  296. if entity_extraction:
  297. warnings.warn(
  298. "gemini_complete_if_cache(entity_extraction=True) is deprecated; "
  299. "pass response_format={'type': 'json_object'} instead.",
  300. DeprecationWarning,
  301. stacklevel=2,
  302. )
  303. response_format = {"type": "json_object"}
  304. elif keyword_extraction:
  305. warnings.warn(
  306. "gemini_complete_if_cache(keyword_extraction=True) is deprecated; "
  307. "pass response_format={'type': 'json_object'} instead.",
  308. DeprecationWarning,
  309. stacklevel=2,
  310. )
  311. response_format = {"type": "json_object"}
  312. _validate_gemini_response_format(response_format)
  313. if response_format is not None:
  314. enable_cot = False
  315. history_block = _format_history_messages(history_messages)
  316. prompt_sections = []
  317. if history_block:
  318. prompt_sections.append(history_block)
  319. prompt_sections.append(f"[user] {prompt}")
  320. combined_prompt = "\n".join(prompt_sections)
  321. config_obj = _build_generation_config(
  322. generation_config,
  323. system_prompt=system_prompt,
  324. response_format=response_format,
  325. )
  326. if image_inputs:
  327. from lightrag.llm._vision_utils import normalize_image_inputs
  328. normalized_images = normalize_image_inputs(image_inputs)
  329. parts: list[Any] = [combined_prompt]
  330. parts.extend(
  331. types.Part.from_bytes(data=img.raw_bytes, mime_type=img.mime_type)
  332. for img in normalized_images
  333. )
  334. contents: list[Any] = [parts]
  335. else:
  336. contents = [combined_prompt]
  337. request_kwargs: dict[str, Any] = {
  338. "model": model,
  339. "contents": contents,
  340. }
  341. if config_obj is not None:
  342. request_kwargs["config"] = config_obj
  343. if stream:
  344. async def _async_stream() -> AsyncIterator[str]:
  345. # COT state tracking for streaming
  346. cot_active = False
  347. cot_started = False
  348. initial_content_seen = False
  349. usage_metadata = None
  350. try:
  351. # Use native async streaming from genai SDK
  352. # Note: generate_content_stream returns Awaitable[AsyncIterator], need to await first
  353. stream_iter = await client.aio.models.generate_content_stream(
  354. **request_kwargs
  355. )
  356. async for chunk in stream_iter:
  357. usage = getattr(chunk, "usage_metadata", None)
  358. if usage is not None:
  359. usage_metadata = usage
  360. # Extract both regular and thought content
  361. regular_text, thought_text = _extract_response_text(
  362. chunk, extract_thoughts=True
  363. )
  364. if enable_cot:
  365. # Process regular content
  366. if regular_text:
  367. if not initial_content_seen:
  368. initial_content_seen = True
  369. # Close COT section if it was active
  370. if cot_active:
  371. yield "</think>"
  372. cot_active = False
  373. # Process and yield regular content
  374. if "\\u" in regular_text:
  375. regular_text = safe_unicode_decode(
  376. regular_text.encode("utf-8")
  377. )
  378. yield regular_text
  379. # Process thought content
  380. if thought_text:
  381. if not initial_content_seen and not cot_started:
  382. # Start COT section
  383. yield "<think>"
  384. cot_active = True
  385. cot_started = True
  386. # Yield thought content if COT is active
  387. if cot_active:
  388. if "\\u" in thought_text:
  389. thought_text = safe_unicode_decode(
  390. thought_text.encode("utf-8")
  391. )
  392. yield thought_text
  393. else:
  394. # COT disabled - only yield regular content
  395. if regular_text:
  396. if "\\u" in regular_text:
  397. regular_text = safe_unicode_decode(
  398. regular_text.encode("utf-8")
  399. )
  400. yield regular_text
  401. # Ensure COT is properly closed if still active
  402. if cot_active:
  403. yield "</think>"
  404. cot_active = False
  405. except Exception:
  406. # Try to close COT tag before re-raising
  407. if cot_active:
  408. try:
  409. yield "</think>"
  410. except Exception:
  411. pass
  412. raise
  413. finally:
  414. # Track token usage after streaming completes
  415. if token_tracker and usage_metadata:
  416. token_tracker.add_usage(
  417. {
  418. "prompt_tokens": getattr(
  419. usage_metadata, "prompt_token_count", 0
  420. ),
  421. "completion_tokens": getattr(
  422. usage_metadata, "candidates_token_count", 0
  423. ),
  424. "total_tokens": getattr(
  425. usage_metadata, "total_token_count", 0
  426. ),
  427. }
  428. )
  429. return _async_stream()
  430. # Non-streaming: use native async client
  431. response = await client.aio.models.generate_content(**request_kwargs)
  432. # Extract both regular text and thought text
  433. regular_text, thought_text = _extract_response_text(response, extract_thoughts=True)
  434. # Apply COT filtering logic based on enable_cot parameter
  435. if enable_cot:
  436. # Include thought content wrapped in <think> tags
  437. if thought_text and thought_text.strip():
  438. if not regular_text or regular_text.strip() == "":
  439. # Only thought content available
  440. final_text = f"<think>{thought_text}</think>"
  441. else:
  442. # Both content types present: prepend thought to regular content
  443. final_text = f"<think>{thought_text}</think>{regular_text}"
  444. else:
  445. # No thought content, use regular content only
  446. final_text = regular_text or ""
  447. else:
  448. # Filter out thought content, return only regular content
  449. final_text = regular_text or ""
  450. if not final_text:
  451. raise InvalidResponseError("Gemini response did not contain any text content.")
  452. if "\\u" in final_text:
  453. final_text = safe_unicode_decode(final_text.encode("utf-8"))
  454. final_text = remove_think_tags(final_text)
  455. usage = getattr(response, "usage_metadata", None)
  456. if token_tracker and usage:
  457. token_tracker.add_usage(
  458. {
  459. "prompt_tokens": getattr(usage, "prompt_token_count", 0),
  460. "completion_tokens": getattr(usage, "candidates_token_count", 0),
  461. "total_tokens": getattr(usage, "total_token_count", 0),
  462. }
  463. )
  464. logger.debug("Gemini response length: %s", len(final_text))
  465. return final_text
  466. async def gemini_model_complete(
  467. prompt: str,
  468. system_prompt: str | None = None,
  469. history_messages: list[dict[str, Any]] | None = None,
  470. response_format: Any | None = None,
  471. keyword_extraction: bool = False,
  472. entity_extraction: bool = False,
  473. **kwargs: Any,
  474. ) -> str | AsyncIterator[str]:
  475. # Accept legacy keyword if passed via kwargs to preserve backwards compat.
  476. entity_extraction = kwargs.pop("entity_extraction", entity_extraction)
  477. hashing_kv = kwargs.get("hashing_kv")
  478. model_name = None
  479. if hashing_kv is not None:
  480. model_name = hashing_kv.global_config.get("llm_model_name")
  481. if model_name is None:
  482. model_name = kwargs.pop("model_name", None)
  483. if model_name is None:
  484. raise ValueError("Gemini model name not provided in configuration.")
  485. return await gemini_complete_if_cache(
  486. model_name,
  487. prompt,
  488. system_prompt=system_prompt,
  489. history_messages=history_messages,
  490. response_format=response_format,
  491. keyword_extraction=keyword_extraction,
  492. entity_extraction=entity_extraction,
  493. **kwargs,
  494. )
  495. @wrap_embedding_func_with_attrs(
  496. embedding_dim=1536,
  497. max_token_size=2048,
  498. model_name="gemini-embedding-001",
  499. supports_asymmetric=True,
  500. )
  501. @retry(
  502. stop=stop_after_attempt(3),
  503. wait=wait_exponential(multiplier=1, min=4, max=60),
  504. retry=(
  505. retry_if_exception_type(google_api_exceptions.InternalServerError)
  506. | retry_if_exception_type(google_api_exceptions.ServiceUnavailable)
  507. | retry_if_exception_type(google_api_exceptions.ResourceExhausted)
  508. | retry_if_exception_type(google_api_exceptions.GatewayTimeout)
  509. | retry_if_exception_type(google_api_exceptions.BadGateway)
  510. | retry_if_exception_type(google_api_exceptions.DeadlineExceeded)
  511. | retry_if_exception_type(google_api_exceptions.Aborted)
  512. | retry_if_exception_type(google_api_exceptions.Unknown)
  513. ),
  514. )
  515. async def gemini_embed(
  516. texts: list[str],
  517. model: str = "gemini-embedding-001",
  518. base_url: str | None = None,
  519. api_key: str | None = None,
  520. embedding_dim: int | None = None,
  521. max_token_size: int | None = None,
  522. task_type: str | None = None,
  523. timeout: int | None = None,
  524. token_tracker: Any | None = None,
  525. context: str = "document",
  526. ) -> np.ndarray:
  527. """Generate embeddings for a list of texts using Gemini's API.
  528. This function uses Google's Gemini embedding model to generate text embeddings.
  529. It supports dynamic dimension control and automatic normalization for dimensions
  530. less than 3072.
  531. Args:
  532. texts: List of texts to embed.
  533. model: The Gemini embedding model to use. Default is "gemini-embedding-001".
  534. base_url: Optional custom API endpoint.
  535. api_key: Optional Gemini API key. If None, uses environment variables.
  536. embedding_dim: Optional embedding dimension for dynamic dimension reduction.
  537. **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
  538. Do NOT manually pass this parameter when calling the function directly.
  539. The dimension is controlled by the @wrap_embedding_func_with_attrs decorator
  540. or the EMBEDDING_DIM environment variable.
  541. Supported range: 128-3072. Recommended values: 768, 1536, 3072.
  542. max_token_size: Maximum tokens per text. This parameter is automatically
  543. injected by the EmbeddingFunc wrapper when the underlying function
  544. signature supports it (via inspect.signature check). Gemini API will
  545. automatically truncate texts exceeding this limit (autoTruncate=True
  546. by default), so no client-side truncation is needed.
  547. task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT".
  548. Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING,
  549. RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY,
  550. QUESTION_ANSWERING, FACT_VERIFICATION.
  551. timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
  552. token_tracker: Optional token usage tracker for monitoring API usage.
  553. context: The embedding context - "query" for search queries, "document" for indexed content.
  554. **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper
  555. when supports_asymmetric=True. Default is "document".
  556. Returns:
  557. A numpy array of embeddings, one per input text. For dimensions < 3072,
  558. the embeddings are L2-normalized to ensure optimal semantic similarity performance.
  559. Raises:
  560. ValueError: If API key is not provided or configured.
  561. RuntimeError: If the response from Gemini is invalid or empty.
  562. Note:
  563. - For dimension 3072: Embeddings are already normalized by the API
  564. - For dimensions < 3072: Embeddings are L2-normalized after retrieval
  565. - Normalization ensures accurate semantic similarity via cosine distance
  566. - Gemini API automatically truncates texts exceeding max_token_size (autoTruncate=True)
  567. """
  568. # Note: max_token_size is received but not used for client-side truncation.
  569. # Gemini API handles truncation automatically with autoTruncate=True (default).
  570. _ = max_token_size # Acknowledge parameter to avoid unused variable warning
  571. key = _ensure_api_key(api_key)
  572. # Convert timeout from seconds to milliseconds for Gemini API
  573. timeout_ms = timeout * 1000 if timeout else None
  574. client = _get_gemini_client(key, base_url, timeout_ms)
  575. # Prepare embedding configuration
  576. config_kwargs: dict[str, Any] = {}
  577. # Add task_type to config
  578. if task_type is None:
  579. if context == "query":
  580. task_type = "RETRIEVAL_QUERY"
  581. elif context == "document":
  582. task_type = "RETRIEVAL_DOCUMENT"
  583. else:
  584. task_type = "RETRIEVAL_DOCUMENT" # Default for backward compatibility
  585. config_kwargs["task_type"] = task_type
  586. # Add output_dimensionality if embedding_dim is provided
  587. if embedding_dim is not None:
  588. config_kwargs["output_dimensionality"] = embedding_dim
  589. # Create config object if we have parameters
  590. config_obj = types.EmbedContentConfig(**config_kwargs) if config_kwargs else None
  591. request_kwargs: dict[str, Any] = {
  592. "model": model,
  593. "contents": texts,
  594. }
  595. if config_obj is not None:
  596. request_kwargs["config"] = config_obj
  597. # Use native async client for embedding
  598. response = await client.aio.models.embed_content(**request_kwargs)
  599. # Extract embeddings from response
  600. if not hasattr(response, "embeddings") or not response.embeddings:
  601. raise RuntimeError("Gemini response did not contain embeddings.")
  602. # Convert embeddings to numpy array
  603. embeddings = np.array(
  604. [np.array(e.values, dtype=np.float32) for e in response.embeddings]
  605. )
  606. # Apply L2 normalization for dimensions < 3072
  607. # The 3072 dimension embedding is already normalized by Gemini API
  608. if embedding_dim and embedding_dim < 3072:
  609. # Normalize each embedding vector to unit length
  610. norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
  611. # Avoid division by zero
  612. norms = np.where(norms == 0, 1, norms)
  613. embeddings = embeddings / norms
  614. logger.debug(
  615. f"Applied L2 normalization to {len(embeddings)} embeddings of dimension {embedding_dim}"
  616. )
  617. # Track token usage if tracker is provided
  618. # Note: Gemini embedding API may not provide usage metadata
  619. if token_tracker and hasattr(response, "usage_metadata"):
  620. usage = response.usage_metadata
  621. token_counts = {
  622. "prompt_tokens": getattr(usage, "prompt_token_count", 0),
  623. "total_tokens": getattr(usage, "total_token_count", 0),
  624. }
  625. token_tracker.add_usage(token_counts)
  626. logger.debug(
  627. f"Generated {len(embeddings)} Gemini embeddings with dimension {embeddings.shape[1]}"
  628. )
  629. return embeddings
  630. __all__ = [
  631. "gemini_complete_if_cache",
  632. "gemini_model_complete",
  633. "gemini_embed",
  634. ]