llama_index_impl.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. import warnings
  2. import pipmaster as pm
  3. from llama_index.core.llms import (
  4. ChatMessage,
  5. MessageRole,
  6. ChatResponse,
  7. )
  8. from typing import Any, List, Optional
  9. from lightrag.utils import logger
  10. # Install required dependencies
  11. if not pm.is_installed("llama-index"):
  12. pm.install("llama-index")
  13. from llama_index.core.embeddings import BaseEmbedding
  14. from llama_index.core.settings import Settings as LlamaIndexSettings
  15. from tenacity import (
  16. retry,
  17. stop_after_attempt,
  18. wait_exponential,
  19. retry_if_exception_type,
  20. )
  21. from lightrag.utils import (
  22. wrap_embedding_func_with_attrs,
  23. )
  24. from lightrag.exceptions import (
  25. APIConnectionError,
  26. RateLimitError,
  27. APITimeoutError,
  28. )
  29. import numpy as np
  30. def configure_llama_index(settings: Any = None, **kwargs):
  31. """
  32. Configure LlamaIndex settings.
  33. Args:
  34. settings: LlamaIndex Settings instance. If None, uses default settings.
  35. **kwargs: Additional settings to override/configure
  36. """
  37. if settings is None:
  38. settings = LlamaIndexSettings()
  39. # Update settings with any provided kwargs
  40. for key, value in kwargs.items():
  41. if hasattr(settings, key):
  42. setattr(settings, key, value)
  43. else:
  44. logger.warning(f"Unknown LlamaIndex setting: {key}")
  45. # Set as global settings
  46. LlamaIndexSettings.set_global(settings)
  47. return settings
  48. def format_chat_messages(messages):
  49. """Format chat messages into LlamaIndex format."""
  50. formatted_messages = []
  51. for msg in messages:
  52. role = msg.get("role", "user")
  53. content = msg.get("content", "")
  54. if role == "system":
  55. formatted_messages.append(
  56. ChatMessage(role=MessageRole.SYSTEM, content=content)
  57. )
  58. elif role == "assistant":
  59. formatted_messages.append(
  60. ChatMessage(role=MessageRole.ASSISTANT, content=content)
  61. )
  62. elif role == "user":
  63. formatted_messages.append(
  64. ChatMessage(role=MessageRole.USER, content=content)
  65. )
  66. else:
  67. logger.warning(f"Unknown role {role}, treating as user message")
  68. formatted_messages.append(
  69. ChatMessage(role=MessageRole.USER, content=content)
  70. )
  71. return formatted_messages
  72. @retry(
  73. stop=stop_after_attempt(3),
  74. wait=wait_exponential(multiplier=1, min=4, max=60),
  75. retry=retry_if_exception_type(
  76. (RateLimitError, APIConnectionError, APITimeoutError)
  77. ),
  78. )
  79. async def llama_index_complete_if_cache(
  80. model: str,
  81. prompt: str,
  82. system_prompt: Optional[str] = None,
  83. history_messages: List[dict] = [],
  84. enable_cot: bool = False,
  85. chat_kwargs={},
  86. ) -> str:
  87. """Complete the prompt using LlamaIndex."""
  88. if enable_cot:
  89. logger.debug(
  90. "enable_cot=True is not supported for LlamaIndex implementation and will be ignored."
  91. )
  92. try:
  93. # Format messages for chat
  94. formatted_messages = []
  95. # Add system message if provided
  96. if system_prompt:
  97. formatted_messages.append(
  98. ChatMessage(role=MessageRole.SYSTEM, content=system_prompt)
  99. )
  100. # Add history messages
  101. for msg in history_messages:
  102. formatted_messages.append(
  103. ChatMessage(
  104. role=MessageRole.USER
  105. if msg["role"] == "user"
  106. else MessageRole.ASSISTANT,
  107. content=msg["content"],
  108. )
  109. )
  110. # Add current prompt
  111. formatted_messages.append(ChatMessage(role=MessageRole.USER, content=prompt))
  112. response: ChatResponse = await model.achat(
  113. messages=formatted_messages, **chat_kwargs
  114. )
  115. # In newer versions, the response is in message.content
  116. content = response.message.content
  117. return content
  118. except Exception as e:
  119. logger.error(f"Error in llama_index_complete_if_cache: {str(e)}")
  120. raise
  121. async def llama_index_complete(
  122. prompt,
  123. system_prompt=None,
  124. history_messages=None,
  125. enable_cot: bool = False,
  126. keyword_extraction=False,
  127. entity_extraction=False,
  128. settings: Any = None,
  129. **kwargs,
  130. ) -> str:
  131. """
  132. Main completion function for LlamaIndex.
  133. Args:
  134. prompt: Input prompt
  135. system_prompt: Optional system prompt
  136. history_messages: Optional chat history
  137. keyword_extraction: Deprecated compatibility shim. Emits a warning and
  138. is ignored.
  139. entity_extraction: Deprecated compatibility shim. Emits a warning and
  140. is ignored.
  141. settings: Optional LlamaIndex settings
  142. **kwargs: Additional arguments. ``response_format`` is not supported by
  143. this adapter and is stripped before calling LlamaIndex.
  144. Structured output note:
  145. - This adapter does not support OpenAI-style ``response_format`` JSON mode.
  146. - If callers pass ``response_format``, it is stripped before generation.
  147. """
  148. if history_messages is None:
  149. history_messages = []
  150. # LlamaIndex adapters have no JSON mode; drop response_format and warn
  151. # when legacy boolean shim flags are set.
  152. if kwargs.pop("keyword_extraction", False) or keyword_extraction:
  153. warnings.warn(
  154. "llama_index_complete(keyword_extraction=True) is deprecated; "
  155. "pass response_format={'type': 'json_object'} instead.",
  156. DeprecationWarning,
  157. stacklevel=2,
  158. )
  159. if kwargs.pop("entity_extraction", False) or entity_extraction:
  160. warnings.warn(
  161. "llama_index_complete(entity_extraction=True) is deprecated; "
  162. "pass response_format={'type': 'json_object'} instead.",
  163. DeprecationWarning,
  164. stacklevel=2,
  165. )
  166. kwargs.pop("response_format", None)
  167. result = await llama_index_complete_if_cache(
  168. kwargs.get("llm_instance"),
  169. prompt,
  170. system_prompt=system_prompt,
  171. history_messages=history_messages,
  172. enable_cot=enable_cot,
  173. **kwargs,
  174. )
  175. return result
  176. @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
  177. @retry(
  178. stop=stop_after_attempt(3),
  179. wait=wait_exponential(multiplier=1, min=4, max=60),
  180. retry=retry_if_exception_type(
  181. (RateLimitError, APIConnectionError, APITimeoutError)
  182. ),
  183. )
  184. async def llama_index_embed(
  185. texts: list[str],
  186. embed_model: BaseEmbedding = None,
  187. settings: Any = None,
  188. **kwargs,
  189. ) -> np.ndarray:
  190. """
  191. Generate embeddings using LlamaIndex
  192. Args:
  193. texts: List of texts to embed
  194. embed_model: LlamaIndex embedding model
  195. settings: Optional LlamaIndex settings
  196. **kwargs: Additional arguments
  197. """
  198. if settings:
  199. configure_llama_index(settings)
  200. if embed_model is None:
  201. raise ValueError("embed_model must be provided")
  202. # Use _get_text_embeddings for batch processing
  203. embeddings = embed_model._get_text_embeddings(texts)
  204. return np.array(embeddings)