ollama.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. from collections.abc import AsyncIterator
  2. import os
  3. import re
  4. import warnings
  5. import pipmaster as pm
  6. # install specific modules
  7. if not pm.is_installed("ollama"):
  8. pm.install("ollama")
  9. import ollama
  10. from tenacity import (
  11. retry,
  12. stop_after_attempt,
  13. wait_exponential,
  14. retry_if_exception_type,
  15. )
  16. from lightrag.exceptions import (
  17. APIConnectionError,
  18. RateLimitError,
  19. APITimeoutError,
  20. )
  21. from lightrag.api import __api_version__
  22. import numpy as np
  23. from typing import Any, Optional, Union
  24. from lightrag.utils import (
  25. wrap_embedding_func_with_attrs,
  26. logger,
  27. )
  28. _OLLAMA_CLOUD_HOST = "https://ollama.com"
  29. _CLOUD_MODEL_SUFFIX_PATTERN = re.compile(r"(?:-cloud|:cloud)$")
  30. def _coerce_host_for_cloud_model(host: Optional[str], model: object) -> Optional[str]:
  31. if host:
  32. return host
  33. try:
  34. model_name_str = str(model) if model is not None else ""
  35. except (TypeError, ValueError, AttributeError) as e:
  36. logger.warning(f"Failed to convert model to string: {e}, using empty string")
  37. model_name_str = ""
  38. if _CLOUD_MODEL_SUFFIX_PATTERN.search(model_name_str):
  39. logger.debug(
  40. f"Detected cloud model '{model_name_str}', using Ollama Cloud host"
  41. )
  42. return _OLLAMA_CLOUD_HOST
  43. return host
  44. def _normalize_ollama_response_format(kwargs: dict) -> None:
  45. """Translate OpenAI-style response_format into Ollama's native format field.
  46. Precedence: an explicit ``format`` value (Ollama's native field) wins over
  47. ``response_format`` — if ``format`` is already set, ``response_format`` is
  48. dropped silently. Otherwise, ``{"type": "json_object"}`` maps to
  49. ``format="json"`` and any other payload is passed through unchanged so
  50. callers can supply JSON schemas directly.
  51. """
  52. response_format = kwargs.pop("response_format", None)
  53. if kwargs.get("format") is not None or response_format is None:
  54. return
  55. if isinstance(response_format, dict):
  56. if response_format.get("type") == "json_object":
  57. kwargs["format"] = "json"
  58. return
  59. if response_format.get("type") == "json_schema":
  60. json_schema = response_format.get("json_schema")
  61. if isinstance(json_schema, dict):
  62. kwargs["format"] = json_schema.get("schema", json_schema)
  63. return
  64. # Fall back to passing through schema-like payloads for native Ollama support.
  65. kwargs["format"] = response_format
  66. @retry(
  67. stop=stop_after_attempt(3),
  68. wait=wait_exponential(multiplier=1, min=4, max=10),
  69. retry=retry_if_exception_type(
  70. (RateLimitError, APIConnectionError, APITimeoutError)
  71. ),
  72. )
  73. async def _ollama_model_if_cache(
  74. model,
  75. prompt,
  76. system_prompt=None,
  77. history_messages=[],
  78. enable_cot: bool = False,
  79. image_inputs: list[Any] | None = None,
  80. **kwargs,
  81. ) -> Union[str, AsyncIterator[str]]:
  82. """Call Ollama chat API with OpenAI-style structured-output compatibility.
  83. Structured output note:
  84. - This adapter accepts OpenAI-style ``response_format`` and translates it
  85. to Ollama's native ``format`` field.
  86. - ``response_format={"type": "json_object"}`` maps to ``format="json"``.
  87. - Deprecated ``keyword_extraction`` and ``entity_extraction`` booleans are
  88. compatibility shims; when no explicit ``response_format`` is supplied,
  89. they are mapped to ``{"type": "json_object"}``.
  90. """
  91. if enable_cot:
  92. logger.debug("enable_cot=True is not supported for ollama and will be ignored.")
  93. stream = True if kwargs.get("stream") else False
  94. kwargs.pop("max_tokens", None)
  95. # Deprecation shims: map legacy boolean flags to response_format only when
  96. # an explicit response_format was not supplied by the caller.
  97. if kwargs.get("response_format") is None:
  98. if kwargs.pop("entity_extraction", False):
  99. warnings.warn(
  100. "_ollama_model_if_cache(entity_extraction=True) is deprecated; "
  101. "pass response_format={'type': 'json_object'} instead.",
  102. DeprecationWarning,
  103. stacklevel=2,
  104. )
  105. kwargs["response_format"] = {"type": "json_object"}
  106. elif kwargs.pop("keyword_extraction", False):
  107. warnings.warn(
  108. "_ollama_model_if_cache(keyword_extraction=True) is deprecated; "
  109. "pass response_format={'type': 'json_object'} instead.",
  110. DeprecationWarning,
  111. stacklevel=2,
  112. )
  113. kwargs["response_format"] = {"type": "json_object"}
  114. else:
  115. # response_format was supplied explicitly; drop legacy flags silently.
  116. kwargs.pop("entity_extraction", None)
  117. kwargs.pop("keyword_extraction", None)
  118. _normalize_ollama_response_format(kwargs)
  119. host = kwargs.pop("host", None)
  120. timeout = kwargs.pop("timeout", None)
  121. if timeout == 0:
  122. timeout = None
  123. kwargs.pop("hashing_kv", None)
  124. api_key = kwargs.pop("api_key", None)
  125. # fallback to environment variable when not provided explicitly
  126. if not api_key:
  127. api_key = os.getenv("OLLAMA_API_KEY")
  128. headers = {
  129. "Content-Type": "application/json",
  130. "User-Agent": f"LightRAG/{__api_version__}",
  131. }
  132. if api_key:
  133. headers["Authorization"] = f"Bearer {api_key}"
  134. host = _coerce_host_for_cloud_model(host, model)
  135. ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
  136. try:
  137. messages = []
  138. if system_prompt:
  139. messages.append({"role": "system", "content": system_prompt})
  140. messages.extend(history_messages)
  141. user_message: dict[str, Any] = {"role": "user", "content": prompt}
  142. if image_inputs:
  143. from lightrag.llm._vision_utils import normalize_image_inputs
  144. normalized_images = normalize_image_inputs(image_inputs)
  145. user_message["images"] = [img.base64_str for img in normalized_images]
  146. messages.append(user_message)
  147. response = await ollama_client.chat(model=model, messages=messages, **kwargs)
  148. if stream:
  149. """cannot cache stream response and process reasoning"""
  150. async def inner():
  151. try:
  152. async for chunk in response:
  153. yield chunk["message"]["content"]
  154. except Exception as e:
  155. logger.error(f"Error in stream response: {str(e)}")
  156. raise
  157. finally:
  158. try:
  159. await ollama_client._client.aclose()
  160. logger.debug("Successfully closed Ollama client for streaming")
  161. except Exception as close_error:
  162. logger.warning(f"Failed to close Ollama client: {close_error}")
  163. return inner()
  164. else:
  165. model_response = response["message"]["content"]
  166. """
  167. If the model also wraps its thoughts in a specific tag,
  168. this information is not needed for the final
  169. response and can simply be trimmed.
  170. """
  171. return model_response
  172. except Exception as e:
  173. try:
  174. await ollama_client._client.aclose()
  175. logger.debug("Successfully closed Ollama client after exception")
  176. except Exception as close_error:
  177. logger.warning(
  178. f"Failed to close Ollama client after exception: {close_error}"
  179. )
  180. raise e
  181. finally:
  182. if not stream:
  183. try:
  184. await ollama_client._client.aclose()
  185. logger.debug(
  186. "Successfully closed Ollama client for non-streaming response"
  187. )
  188. except Exception as close_error:
  189. logger.warning(
  190. f"Failed to close Ollama client in finally block: {close_error}"
  191. )
  192. async def ollama_model_complete(
  193. prompt,
  194. system_prompt=None,
  195. history_messages=[],
  196. enable_cot: bool = False,
  197. keyword_extraction=False,
  198. entity_extraction=False,
  199. **kwargs,
  200. ) -> Union[str, AsyncIterator[str]]:
  201. # Forward legacy extraction flags as kwargs so _ollama_model_if_cache can
  202. # emit a single DeprecationWarning with the correct stack frame.
  203. if keyword_extraction:
  204. kwargs.setdefault("keyword_extraction", True)
  205. if entity_extraction:
  206. kwargs.setdefault("entity_extraction", True)
  207. model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
  208. return await _ollama_model_if_cache(
  209. model_name,
  210. prompt,
  211. system_prompt=system_prompt,
  212. history_messages=history_messages,
  213. enable_cot=enable_cot,
  214. **kwargs,
  215. )
  216. @wrap_embedding_func_with_attrs(
  217. embedding_dim=1024,
  218. max_token_size=8192,
  219. model_name="bge-m3:latest",
  220. supports_asymmetric=True,
  221. )
  222. async def ollama_embed(
  223. texts: list[str],
  224. embed_model: str = "bge-m3:latest",
  225. max_token_size: int | None = None,
  226. context: str = "document",
  227. query_prefix: str | None = None,
  228. document_prefix: str | None = None,
  229. **kwargs,
  230. ) -> np.ndarray:
  231. """Generate embeddings using Ollama's API.
  232. Args:
  233. texts: List of texts to embed.
  234. embed_model: The Ollama embedding model to use. Default is "bge-m3:latest".
  235. max_token_size: Maximum tokens per text. This parameter is automatically
  236. injected by the EmbeddingFunc wrapper when the underlying function
  237. signature supports it (via inspect.signature check). Ollama will
  238. automatically truncate texts exceeding the model's context length
  239. (num_ctx), so no client-side truncation is needed.
  240. context: The embedding context - "query" for search queries, "document" for indexed content.
  241. **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper
  242. when supports_asymmetric=True. Default is "document".
  243. query_prefix: Optional prefix to prepend to texts when context="query" (e.g., "search_query: ").
  244. document_prefix: Optional prefix to prepend to texts when context="document" (e.g., "search_document: ").
  245. **kwargs: Additional arguments passed to the Ollama client.
  246. Returns:
  247. A numpy array of embeddings, one per input text.
  248. Note:
  249. - Ollama API automatically truncates texts exceeding the model's context length
  250. - The max_token_size parameter is received but not used for client-side truncation
  251. """
  252. # Apply context-based prefixes if provided
  253. if context == "query" and query_prefix:
  254. texts = [query_prefix + text for text in texts]
  255. elif context == "document" and document_prefix:
  256. texts = [document_prefix + text for text in texts]
  257. # Note: max_token_size is received but not used for client-side truncation.
  258. # Ollama API handles truncation automatically based on the model's num_ctx setting.
  259. _ = max_token_size # Acknowledge parameter to avoid unused variable warning
  260. api_key = kwargs.pop("api_key", None)
  261. if not api_key:
  262. api_key = os.getenv("OLLAMA_API_KEY")
  263. headers = {
  264. "Content-Type": "application/json",
  265. "User-Agent": f"LightRAG/{__api_version__}",
  266. }
  267. if api_key:
  268. headers["Authorization"] = f"Bearer {api_key}"
  269. host = kwargs.pop("host", None)
  270. timeout = kwargs.pop("timeout", None)
  271. host = _coerce_host_for_cloud_model(host, embed_model)
  272. ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
  273. try:
  274. options = kwargs.pop("options", {})
  275. data = await ollama_client.embed(
  276. model=embed_model, input=texts, options=options
  277. )
  278. return np.array(data["embeddings"])
  279. except Exception as e:
  280. logger.error(f"Error in ollama_embed: {str(e)}")
  281. try:
  282. await ollama_client._client.aclose()
  283. logger.debug("Successfully closed Ollama client after exception in embed")
  284. except Exception as close_error:
  285. logger.warning(
  286. f"Failed to close Ollama client after exception in embed: {close_error}"
  287. )
  288. raise e
  289. finally:
  290. try:
  291. await ollama_client._client.aclose()
  292. logger.debug("Successfully closed Ollama client after embed")
  293. except Exception as close_error:
  294. logger.warning(f"Failed to close Ollama client after embed: {close_error}")