lollms.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import sys
  2. import warnings
  3. if sys.version_info < (3, 9):
  4. from typing import AsyncIterator
  5. else:
  6. from collections.abc import AsyncIterator
  7. import pipmaster as pm # Pipmaster for dynamic library install
  8. if not pm.is_installed("aiohttp"):
  9. pm.install("aiohttp")
  10. import aiohttp
  11. from tenacity import (
  12. retry,
  13. stop_after_attempt,
  14. wait_exponential,
  15. retry_if_exception_type,
  16. )
  17. from lightrag.exceptions import (
  18. APIConnectionError,
  19. RateLimitError,
  20. APITimeoutError,
  21. )
  22. from typing import Any, List, Union
  23. import numpy as np
  24. from lightrag.utils import (
  25. wrap_embedding_func_with_attrs,
  26. )
  27. @retry(
  28. stop=stop_after_attempt(3),
  29. wait=wait_exponential(multiplier=1, min=4, max=10),
  30. retry=retry_if_exception_type(
  31. (RateLimitError, APIConnectionError, APITimeoutError)
  32. ),
  33. )
  34. async def lollms_model_if_cache(
  35. model,
  36. prompt,
  37. system_prompt=None,
  38. history_messages=[],
  39. enable_cot: bool = False,
  40. base_url="http://localhost:9600",
  41. image_inputs: list[Any] | None = None,
  42. **kwargs,
  43. ) -> Union[str, AsyncIterator[str]]:
  44. """Client implementation for lollms generation.
  45. Structured output note:
  46. - This adapter does not support OpenAI-style ``response_format`` JSON mode.
  47. - If callers pass ``response_format``, it is stripped before the request.
  48. - Deprecated ``keyword_extraction`` and ``entity_extraction`` booleans are
  49. accepted only as compatibility shims; they emit warnings and are ignored.
  50. Vision note:
  51. - lollms does not support image inputs. Passing a non-empty
  52. ``image_inputs`` raises :class:`NotImplementedError`.
  53. """
  54. if image_inputs:
  55. raise NotImplementedError(
  56. "lollms binding does not support image_inputs; configure a "
  57. "vision-capable VLM provider (openai/azure_openai/gemini/bedrock/"
  58. "ollama/anthropic) for VLM_LLM_BINDING."
  59. )
  60. if enable_cot:
  61. from lightrag.utils import logger
  62. logger.debug("enable_cot=True is not supported for lollms and will be ignored.")
  63. # lollms has no JSON mode; drop response_format and warn when legacy
  64. # boolean shim flags are set.
  65. if kwargs.pop("keyword_extraction", False):
  66. warnings.warn(
  67. "lollms_model_if_cache(keyword_extraction=True) is deprecated; "
  68. "pass response_format={'type': 'json_object'} instead.",
  69. DeprecationWarning,
  70. stacklevel=2,
  71. )
  72. if kwargs.pop("entity_extraction", False):
  73. warnings.warn(
  74. "lollms_model_if_cache(entity_extraction=True) is deprecated; "
  75. "pass response_format={'type': 'json_object'} instead.",
  76. DeprecationWarning,
  77. stacklevel=2,
  78. )
  79. kwargs.pop("response_format", None)
  80. stream = True if kwargs.get("stream") else False
  81. api_key = kwargs.pop("api_key", None)
  82. headers = (
  83. {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
  84. if api_key
  85. else {"Content-Type": "application/json"}
  86. )
  87. # Extract lollms specific parameters
  88. request_data = {
  89. "prompt": prompt,
  90. "model_name": model,
  91. "personality": kwargs.get("personality", -1),
  92. "n_predict": kwargs.get("n_predict", None),
  93. "stream": stream,
  94. "temperature": kwargs.get("temperature", 1.0),
  95. "top_k": kwargs.get("top_k", 50),
  96. "top_p": kwargs.get("top_p", 0.95),
  97. "repeat_penalty": kwargs.get("repeat_penalty", 0.8),
  98. "repeat_last_n": kwargs.get("repeat_last_n", 40),
  99. "seed": kwargs.get("seed", None),
  100. "n_threads": kwargs.get("n_threads", 8),
  101. }
  102. # Prepare the full prompt including history
  103. full_prompt = ""
  104. if system_prompt:
  105. full_prompt += f"{system_prompt}\n"
  106. for msg in history_messages:
  107. full_prompt += f"{msg['role']}: {msg['content']}\n"
  108. full_prompt += prompt
  109. request_data["prompt"] = full_prompt
  110. timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", None))
  111. async with aiohttp.ClientSession(timeout=timeout, headers=headers) as session:
  112. if stream:
  113. async def inner():
  114. async with session.post(
  115. f"{base_url}/lollms_generate", json=request_data
  116. ) as response:
  117. async for line in response.content:
  118. yield line.decode().strip()
  119. return inner()
  120. else:
  121. async with session.post(
  122. f"{base_url}/lollms_generate", json=request_data
  123. ) as response:
  124. return await response.text()
  125. async def lollms_model_complete(
  126. prompt,
  127. system_prompt=None,
  128. history_messages=[],
  129. enable_cot: bool = False,
  130. keyword_extraction=False,
  131. entity_extraction=False,
  132. **kwargs,
  133. ) -> Union[str, AsyncIterator[str]]:
  134. """Complete function for lollms model generation."""
  135. # Forward legacy extraction flags as kwargs so lollms_model_if_cache can
  136. # emit a single DeprecationWarning with the correct stack frame.
  137. if keyword_extraction:
  138. kwargs.setdefault("keyword_extraction", True)
  139. if entity_extraction:
  140. kwargs.setdefault("entity_extraction", True)
  141. model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
  142. return await lollms_model_if_cache(
  143. model_name,
  144. prompt,
  145. system_prompt=system_prompt,
  146. history_messages=history_messages,
  147. enable_cot=enable_cot,
  148. **kwargs,
  149. )
  150. @wrap_embedding_func_with_attrs(
  151. embedding_dim=1024, max_token_size=8192, model_name="lollms_embedding_model"
  152. )
  153. async def lollms_embed(
  154. texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
  155. ) -> np.ndarray:
  156. """
  157. Generate embeddings for a list of texts using lollms server.
  158. Args:
  159. texts: List of strings to embed
  160. embed_model: Model name (not used directly as lollms uses configured vectorizer)
  161. base_url: URL of the lollms server
  162. **kwargs: Additional arguments passed to the request
  163. Returns:
  164. np.ndarray: Array of embeddings
  165. """
  166. api_key = kwargs.pop("api_key", None)
  167. headers = (
  168. {"Content-Type": "application/json", "Authorization": api_key}
  169. if api_key
  170. else {"Content-Type": "application/json"}
  171. )
  172. async with aiohttp.ClientSession(headers=headers) as session:
  173. embeddings = []
  174. for text in texts:
  175. request_data = {"text": text}
  176. async with session.post(
  177. f"{base_url}/lollms_embed",
  178. json=request_data,
  179. ) as response:
  180. result = await response.json()
  181. embeddings.append(result["vector"])
  182. return np.array(embeddings)