| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235 |
- import warnings
- import pipmaster as pm
- from llama_index.core.llms import (
- ChatMessage,
- MessageRole,
- ChatResponse,
- )
- from typing import Any, List, Optional
- from lightrag.utils import logger
- # Install required dependencies
- if not pm.is_installed("llama-index"):
- pm.install("llama-index")
- from llama_index.core.embeddings import BaseEmbedding
- from llama_index.core.settings import Settings as LlamaIndexSettings
- from tenacity import (
- retry,
- stop_after_attempt,
- wait_exponential,
- retry_if_exception_type,
- )
- from lightrag.utils import (
- wrap_embedding_func_with_attrs,
- )
- from lightrag.exceptions import (
- APIConnectionError,
- RateLimitError,
- APITimeoutError,
- )
- import numpy as np
- def configure_llama_index(settings: Any = None, **kwargs):
- """
- Configure LlamaIndex settings.
- Args:
- settings: LlamaIndex Settings instance. If None, uses default settings.
- **kwargs: Additional settings to override/configure
- """
- if settings is None:
- settings = LlamaIndexSettings()
- # Update settings with any provided kwargs
- for key, value in kwargs.items():
- if hasattr(settings, key):
- setattr(settings, key, value)
- else:
- logger.warning(f"Unknown LlamaIndex setting: {key}")
- # Set as global settings
- LlamaIndexSettings.set_global(settings)
- return settings
- def format_chat_messages(messages):
- """Format chat messages into LlamaIndex format."""
- formatted_messages = []
- for msg in messages:
- role = msg.get("role", "user")
- content = msg.get("content", "")
- if role == "system":
- formatted_messages.append(
- ChatMessage(role=MessageRole.SYSTEM, content=content)
- )
- elif role == "assistant":
- formatted_messages.append(
- ChatMessage(role=MessageRole.ASSISTANT, content=content)
- )
- elif role == "user":
- formatted_messages.append(
- ChatMessage(role=MessageRole.USER, content=content)
- )
- else:
- logger.warning(f"Unknown role {role}, treating as user message")
- formatted_messages.append(
- ChatMessage(role=MessageRole.USER, content=content)
- )
- return formatted_messages
- @retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=60),
- retry=retry_if_exception_type(
- (RateLimitError, APIConnectionError, APITimeoutError)
- ),
- )
- async def llama_index_complete_if_cache(
- model: str,
- prompt: str,
- system_prompt: Optional[str] = None,
- history_messages: List[dict] = [],
- enable_cot: bool = False,
- chat_kwargs={},
- ) -> str:
- """Complete the prompt using LlamaIndex."""
- if enable_cot:
- logger.debug(
- "enable_cot=True is not supported for LlamaIndex implementation and will be ignored."
- )
- try:
- # Format messages for chat
- formatted_messages = []
- # Add system message if provided
- if system_prompt:
- formatted_messages.append(
- ChatMessage(role=MessageRole.SYSTEM, content=system_prompt)
- )
- # Add history messages
- for msg in history_messages:
- formatted_messages.append(
- ChatMessage(
- role=MessageRole.USER
- if msg["role"] == "user"
- else MessageRole.ASSISTANT,
- content=msg["content"],
- )
- )
- # Add current prompt
- formatted_messages.append(ChatMessage(role=MessageRole.USER, content=prompt))
- response: ChatResponse = await model.achat(
- messages=formatted_messages, **chat_kwargs
- )
- # In newer versions, the response is in message.content
- content = response.message.content
- return content
- except Exception as e:
- logger.error(f"Error in llama_index_complete_if_cache: {str(e)}")
- raise
- async def llama_index_complete(
- prompt,
- system_prompt=None,
- history_messages=None,
- enable_cot: bool = False,
- keyword_extraction=False,
- entity_extraction=False,
- settings: Any = None,
- **kwargs,
- ) -> str:
- """
- Main completion function for LlamaIndex.
- Args:
- prompt: Input prompt
- system_prompt: Optional system prompt
- history_messages: Optional chat history
- keyword_extraction: Deprecated compatibility shim. Emits a warning and
- is ignored.
- entity_extraction: Deprecated compatibility shim. Emits a warning and
- is ignored.
- settings: Optional LlamaIndex settings
- **kwargs: Additional arguments. ``response_format`` is not supported by
- this adapter and is stripped before calling LlamaIndex.
- Structured output note:
- - This adapter does not support OpenAI-style ``response_format`` JSON mode.
- - If callers pass ``response_format``, it is stripped before generation.
- """
- if history_messages is None:
- history_messages = []
- # LlamaIndex adapters have no JSON mode; drop response_format and warn
- # when legacy boolean shim flags are set.
- if kwargs.pop("keyword_extraction", False) or keyword_extraction:
- warnings.warn(
- "llama_index_complete(keyword_extraction=True) is deprecated; "
- "pass response_format={'type': 'json_object'} instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- if kwargs.pop("entity_extraction", False) or entity_extraction:
- warnings.warn(
- "llama_index_complete(entity_extraction=True) is deprecated; "
- "pass response_format={'type': 'json_object'} instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- kwargs.pop("response_format", None)
- result = await llama_index_complete_if_cache(
- kwargs.get("llm_instance"),
- prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- enable_cot=enable_cot,
- **kwargs,
- )
- return result
- @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
- @retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=60),
- retry=retry_if_exception_type(
- (RateLimitError, APIConnectionError, APITimeoutError)
- ),
- )
- async def llama_index_embed(
- texts: list[str],
- embed_model: BaseEmbedding = None,
- settings: Any = None,
- **kwargs,
- ) -> np.ndarray:
- """
- Generate embeddings using LlamaIndex
- Args:
- texts: List of texts to embed
- embed_model: LlamaIndex embedding model
- settings: Optional LlamaIndex settings
- **kwargs: Additional arguments
- """
- if settings:
- configure_llama_index(settings)
- if embed_model is None:
- raise ValueError("embed_model must be provided")
- # Use _get_text_embeddings for batch processing
- embeddings = embed_model._get_text_embeddings(texts)
- return np.array(embeddings)
|