jina.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import os
  2. import pipmaster as pm # Pipmaster for dynamic library install
  3. # install specific modules
  4. if not pm.is_installed("aiohttp"):
  5. pm.install("aiohttp")
  6. if not pm.is_installed("tenacity"):
  7. pm.install("tenacity")
  8. import numpy as np
  9. import base64
  10. import aiohttp
  11. from tenacity import (
  12. retry,
  13. stop_after_attempt,
  14. wait_exponential,
  15. retry_if_exception_type,
  16. )
  17. from lightrag.utils import wrap_embedding_func_with_attrs, logger
  18. async def fetch_data(url, headers, data):
  19. async with aiohttp.ClientSession() as session:
  20. async with session.post(url, headers=headers, json=data) as response:
  21. if response.status != 200:
  22. error_text = await response.text()
  23. # Check if the error response is HTML (common for 502, 503, etc.)
  24. content_type = response.headers.get("content-type", "").lower()
  25. is_html_error = (
  26. error_text.strip().startswith("<!DOCTYPE html>")
  27. or "text/html" in content_type
  28. )
  29. if is_html_error:
  30. # Provide clean, user-friendly error messages for HTML error pages
  31. if response.status == 502:
  32. clean_error = "Bad Gateway (502) - Jina AI service temporarily unavailable. Please try again in a few minutes."
  33. elif response.status == 503:
  34. clean_error = "Service Unavailable (503) - Jina AI service is temporarily overloaded. Please try again later."
  35. elif response.status == 504:
  36. clean_error = "Gateway Timeout (504) - Jina AI service request timed out. Please try again."
  37. else:
  38. clean_error = f"HTTP {response.status} - Jina AI service error. Please try again later."
  39. else:
  40. # Use original error text if it's not HTML
  41. clean_error = error_text
  42. logger.error(f"Jina API error {response.status}: {clean_error}")
  43. raise aiohttp.ClientResponseError(
  44. request_info=response.request_info,
  45. history=response.history,
  46. status=response.status,
  47. message=f"Jina API error: {clean_error}",
  48. )
  49. response_json = await response.json()
  50. data_list = response_json.get("data", [])
  51. return data_list
  52. @wrap_embedding_func_with_attrs(
  53. embedding_dim=2048,
  54. max_token_size=8192,
  55. model_name="jina-embeddings-v4",
  56. supports_asymmetric=True,
  57. )
  58. @retry(
  59. stop=stop_after_attempt(3),
  60. wait=wait_exponential(multiplier=1, min=4, max=60),
  61. retry=(
  62. retry_if_exception_type(aiohttp.ClientError)
  63. | retry_if_exception_type(aiohttp.ClientResponseError)
  64. ),
  65. )
  66. async def jina_embed(
  67. texts: list[str],
  68. model: str = "jina-embeddings-v4",
  69. embedding_dim: int = 2048,
  70. late_chunking: bool = False,
  71. base_url: str = None,
  72. api_key: str = None,
  73. context: str | None = None,
  74. task: str | None = None,
  75. ) -> np.ndarray:
  76. """Generate embeddings for a list of texts using Jina AI's API.
  77. Args:
  78. texts: List of texts to embed.
  79. model: The Jina embedding model to use (default: jina-embeddings-v4).
  80. Supported models: jina-embeddings-v3, jina-embeddings-v4, etc.
  81. embedding_dim: The embedding dimensions (default: 2048 for jina-embeddings-v4).
  82. **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
  83. Do NOT manually pass this parameter when calling the function directly.
  84. The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
  85. Manually passing a different value will trigger a warning and be ignored.
  86. When provided (by EmbeddingFunc), it will be passed to the Jina API for dimension reduction.
  87. late_chunking: Whether to use late chunking.
  88. base_url: Optional base URL for the Jina API.
  89. api_key: Optional Jina API key. If None, uses the JINA_API_KEY environment variable.
  90. context: The embedding context - "query" for search queries, "document" for indexed content.
  91. **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper
  92. when supports_asymmetric=True. When ``task`` is left at its default of None,
  93. ``context`` drives the task selection.
  94. task: Embedding task mode. Default is None so that ``context`` (when present)
  95. picks the right Jina task:
  96. - "retrieval.query" for context="query"
  97. - "retrieval.passage" for context="document"
  98. - "text-matching" otherwise (true backward-compatible default)
  99. Any explicit non-None task value overrides context-based selection.
  100. Returns:
  101. A numpy array of embeddings, one per input text.
  102. Raises:
  103. aiohttp.ClientError: If there is a connection error with the Jina API.
  104. aiohttp.ClientResponseError: If the Jina API returns an error response.
  105. """
  106. if api_key:
  107. os.environ["JINA_API_KEY"] = api_key
  108. if "JINA_API_KEY" not in os.environ:
  109. raise ValueError("JINA_API_KEY environment variable is required")
  110. url = base_url or "https://api.jina.ai/v1/embeddings"
  111. headers = {
  112. "Content-Type": "application/json",
  113. "Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
  114. }
  115. # Determine task based on context if not explicitly provided
  116. if task is None:
  117. if context == "query":
  118. task = "retrieval.query"
  119. elif context == "document":
  120. task = "retrieval.passage"
  121. else:
  122. task = "text-matching" # Default for backward compatibility
  123. data = {
  124. "model": model,
  125. "task": task,
  126. "dimensions": embedding_dim,
  127. "embedding_type": "base64",
  128. "input": texts,
  129. }
  130. # Only add optional parameters if they have non-default values
  131. if late_chunking:
  132. data["late_chunking"] = late_chunking
  133. logger.debug(
  134. f"Jina embedding request: {len(texts)} texts, dimensions: {embedding_dim}"
  135. )
  136. try:
  137. data_list = await fetch_data(url, headers, data)
  138. if not data_list:
  139. logger.error("Jina API returned empty data list")
  140. raise ValueError("Jina API returned empty data list")
  141. if len(data_list) != len(texts):
  142. logger.error(
  143. f"Jina API returned {len(data_list)} embeddings for {len(texts)} texts"
  144. )
  145. raise ValueError(
  146. f"Jina API returned {len(data_list)} embeddings for {len(texts)} texts"
  147. )
  148. embeddings = np.array(
  149. [
  150. np.frombuffer(base64.b64decode(dp["embedding"]), dtype=np.float32)
  151. for dp in data_list
  152. ]
  153. )
  154. logger.debug(f"Jina embeddings generated: shape {embeddings.shape}")
  155. return embeddings
  156. except Exception as e:
  157. logger.error(f"Jina embedding error: {e}")
  158. raise