| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609 |
- import copy
- import inspect
- import json
- import logging
- import warnings
- import pipmaster as pm # Pipmaster for dynamic library install
- if not pm.is_installed("aioboto3"):
- pm.install("aioboto3")
- import aioboto3
- import numpy as np
- from tenacity import (
- retry,
- stop_after_attempt,
- wait_exponential,
- retry_if_exception_type,
- )
- from collections.abc import AsyncIterator
- from typing import Any, Union
- from lightrag.utils import wrap_embedding_func_with_attrs
- # Import botocore exceptions for proper exception handling
- try:
- from botocore.exceptions import (
- ClientError,
- ConnectionError as BotocoreConnectionError,
- ReadTimeoutError,
- )
- except ImportError:
- # If botocore is not installed, define placeholders
- ClientError = Exception
- BotocoreConnectionError = Exception
- ReadTimeoutError = Exception
- class BedrockError(Exception):
- """Generic error for issues related to Amazon Bedrock"""
- class BedrockRateLimitError(BedrockError):
- """Error for rate limiting and throttling issues"""
- class BedrockConnectionError(BedrockError):
- """Error for network and connection issues"""
- class BedrockTimeoutError(BedrockError):
- """Error for timeout issues"""
- def _normalize_bedrock_endpoint_url(endpoint_url: str | None) -> str | None:
- """Return a usable Bedrock endpoint override or None for SDK defaults."""
- if endpoint_url is None:
- return None
- normalized = endpoint_url.strip()
- if not normalized or normalized == "DEFAULT_BEDROCK_ENDPOINT":
- return None
- return normalized
- def _bedrock_client_kwargs(
- region: str | None,
- endpoint_url: str | None,
- aws_access_key_id: str | None = None,
- aws_secret_access_key: str | None = None,
- aws_session_token: str | None = None,
- ) -> dict:
- """Build kwargs for aioboto3 ``session.client("bedrock-runtime", ...)``."""
- client_kwargs: dict = {"region_name": region}
- if endpoint_url is not None:
- client_kwargs["endpoint_url"] = endpoint_url
- if aws_access_key_id:
- client_kwargs["aws_access_key_id"] = aws_access_key_id
- if aws_secret_access_key:
- client_kwargs["aws_secret_access_key"] = aws_secret_access_key
- if aws_session_token:
- client_kwargs["aws_session_token"] = aws_session_token
- return client_kwargs
- def _handle_bedrock_exception(e: Exception, operation: str = "Bedrock API") -> None:
- """Convert AWS Bedrock exceptions to appropriate custom exceptions.
- Args:
- e: The exception to handle
- operation: Description of the operation for error messages
- Raises:
- BedrockRateLimitError: For rate limiting and throttling issues (retryable)
- BedrockConnectionError: For network and server issues (retryable)
- BedrockTimeoutError: For timeout issues (retryable)
- BedrockError: For validation and other non-retryable errors
- """
- error_message = str(e)
- # Handle botocore ClientError with specific error codes
- if isinstance(e, ClientError):
- error_code = e.response.get("Error", {}).get("Code", "")
- error_msg = e.response.get("Error", {}).get("Message", error_message)
- # Rate limiting and throttling errors (retryable)
- if error_code in [
- "ThrottlingException",
- "ProvisionedThroughputExceededException",
- ]:
- logging.error(f"{operation} rate limit error: {error_msg}")
- raise BedrockRateLimitError(f"Rate limit error: {error_msg}")
- # Server errors (retryable)
- elif error_code in ["ServiceUnavailableException", "InternalServerException"]:
- logging.error(f"{operation} connection error: {error_msg}")
- raise BedrockConnectionError(f"Service error: {error_msg}")
- # Check for 5xx HTTP status codes (retryable)
- elif e.response.get("ResponseMetadata", {}).get("HTTPStatusCode", 0) >= 500:
- logging.error(f"{operation} server error: {error_msg}")
- raise BedrockConnectionError(f"Server error: {error_msg}")
- # Validation and other client errors (non-retryable)
- else:
- logging.error(f"{operation} client error: {error_msg}")
- raise BedrockError(f"Client error: {error_msg}")
- # Connection errors (retryable)
- elif isinstance(e, BotocoreConnectionError):
- logging.error(f"{operation} connection error: {error_message}")
- raise BedrockConnectionError(f"Connection error: {error_message}")
- # Timeout errors (retryable)
- elif isinstance(e, (ReadTimeoutError, TimeoutError)):
- logging.error(f"{operation} timeout error: {error_message}")
- raise BedrockTimeoutError(f"Timeout error: {error_message}")
- # Custom Bedrock errors (already properly typed)
- elif isinstance(
- e,
- (
- BedrockRateLimitError,
- BedrockConnectionError,
- BedrockTimeoutError,
- BedrockError,
- ),
- ):
- raise
- # Unknown errors (non-retryable)
- else:
- logging.error(f"{operation} unexpected error: {error_message}")
- raise BedrockError(f"Unexpected error: {error_message}")
- @retry(
- stop=stop_after_attempt(5),
- wait=wait_exponential(multiplier=1, min=4, max=60),
- retry=(
- retry_if_exception_type(BedrockRateLimitError)
- | retry_if_exception_type(BedrockConnectionError)
- | retry_if_exception_type(BedrockTimeoutError)
- ),
- )
- async def bedrock_complete_if_cache(
- model,
- prompt,
- system_prompt=None,
- history_messages=[],
- enable_cot: bool = False,
- aws_access_key_id=None,
- aws_secret_access_key=None,
- aws_session_token=None,
- aws_region: str | None = None,
- api_key: str | None = None,
- endpoint_url: str | None = None,
- image_inputs: list[Any] | None = None,
- **kwargs,
- ) -> Union[str, AsyncIterator[str]]:
- """Call Amazon Bedrock Converse API with LightRAG-compatible shims.
- Structured output note:
- - This adapter does not support OpenAI-style ``response_format`` JSON mode.
- - If callers pass ``response_format``, it is stripped before the request.
- - Deprecated ``keyword_extraction`` and ``entity_extraction`` booleans are
- accepted only as compatibility shims; they emit warnings and are ignored.
- Authentication note:
- - Bedrock does not use LightRAG's generic ``api_key`` fields.
- - ``LLM_BINDING_API_KEY`` and ``EMBEDDING_BINDING_API_KEY`` are ignored for
- Bedrock.
- - To use Bedrock API key / bearer-token auth, set
- ``AWS_BEARER_TOKEN_BEDROCK`` before starting the process; this is a
- process-level AWS SDK setting.
- - For role-specific Bedrock LLMs, use explicit SigV4 parameters
- (``aws_access_key_id``, ``aws_secret_access_key``, ``aws_session_token``,
- ``aws_region``). Per-role bearer-token overrides are not supported.
- Endpoint note:
- - ``endpoint_url`` overrides the default regional Bedrock endpoint. Pass
- ``None``, an empty string, or the sentinel ``DEFAULT_BEDROCK_ENDPOINT``
- to let the AWS SDK select its default endpoint.
- """
- if enable_cot:
- logging.debug(
- "enable_cot=True is not supported for Bedrock and will be ignored."
- )
- # Bedrock Converse API has no JSON mode; drop legacy extraction flags and
- # response_format below and rely on the prompt template plus downstream
- # tolerant JSON parsing.
- keyword_extraction = kwargs.pop("keyword_extraction", False)
- entity_extraction = kwargs.pop("entity_extraction", False)
- if keyword_extraction:
- warnings.warn(
- "bedrock_complete_if_cache(keyword_extraction=True) is deprecated; "
- "pass response_format={'type': 'json_object'} instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- if entity_extraction:
- warnings.warn(
- "bedrock_complete_if_cache(entity_extraction=True) is deprecated; "
- "pass response_format={'type': 'json_object'} instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- if api_key:
- warnings.warn(
- "bedrock_complete_if_cache(api_key=...) is ignored; use SigV4 "
- "parameters or set AWS_BEARER_TOKEN_BEDROCK before process start.",
- DeprecationWarning,
- stacklevel=2,
- )
- region = aws_region or kwargs.pop("aws_region", None)
- endpoint_url = _normalize_bedrock_endpoint_url(endpoint_url)
- kwargs.pop("hashing_kv", None)
- # Capture stream flag (if provided) and remove from kwargs since it's not a Bedrock API parameter
- # We'll use this to determine whether to call converse_stream or converse
- stream = bool(kwargs.pop("stream", False))
- # Remove unsupported args for Bedrock Converse API
- for k in [
- "response_format",
- "tools",
- "tool_choice",
- "seed",
- "presence_penalty",
- "frequency_penalty",
- "n",
- "logprobs",
- "top_logprobs",
- "max_completion_tokens",
- ]:
- kwargs.pop(k, None)
- # Fix message history format
- messages = []
- for history_message in history_messages:
- message = copy.copy(history_message)
- message["content"] = [{"text": message["content"]}]
- messages.append(message)
- # Add user prompt
- if image_inputs:
- from lightrag.llm._vision_utils import normalize_image_inputs
- normalized_images = normalize_image_inputs(image_inputs)
- user_content: list[dict[str, Any]] = [{"text": prompt}]
- for img in normalized_images:
- fmt = img.mime_type.split("/", 1)[1] if "/" in img.mime_type else "png"
- user_content.append(
- {"image": {"format": fmt, "source": {"bytes": img.raw_bytes}}}
- )
- messages.append({"role": "user", "content": user_content})
- if stream:
- logging.getLogger(__name__).debug(
- "[bedrock] image_inputs provided; forcing non-stream Converse "
- "(stream + image combination has SDK limitations)"
- )
- stream = False
- else:
- messages.append({"role": "user", "content": [{"text": prompt}]})
- # Initialize Converse API arguments
- args = {"modelId": model, "messages": messages}
- # Define system prompt
- if system_prompt:
- args["system"] = [{"text": system_prompt}]
- # Map and set up inference parameters
- inference_params_map = {
- "max_tokens": "maxTokens",
- "top_p": "topP",
- "stop_sequences": "stopSequences",
- }
- inference_config: dict[str, Any] = {}
- for param in ("max_tokens", "temperature", "top_p", "stop_sequences"):
- if param not in kwargs:
- continue
- value = kwargs.pop(param)
- # Bedrock rejects None; a None default means "inherit provider default"
- if value is None:
- continue
- inference_config[inference_params_map.get(param, param)] = value
- if inference_config:
- args["inferenceConfig"] = inference_config
- # Pass-through for model-specific parameters (e.g. Anthropic reasoning_config,
- # Nova inferenceConfig extensions). Mirrors OpenAI's `extra_body`.
- extra_fields = kwargs.pop("extra_fields", None)
- if extra_fields:
- args["additionalModelRequestFields"] = extra_fields
- # For streaming responses, we need a different approach to keep the connection open
- if stream:
- # Create a session that will be used throughout the streaming process
- session = aioboto3.Session()
- client_kwargs = _bedrock_client_kwargs(
- region,
- endpoint_url,
- aws_access_key_id=aws_access_key_id,
- aws_secret_access_key=aws_secret_access_key,
- aws_session_token=aws_session_token,
- )
- # Define the generator function that will manage the client lifecycle
- async def stream_generator():
- # async with ensures the aioboto3 client is closed even under
- # task cancellation, avoiding aiohttp "Unclosed connection" warnings.
- async with session.client("bedrock-runtime", **client_kwargs) as client:
- event_stream = None
- try:
- # Make the API call
- response = await client.converse_stream(**args, **kwargs)
- event_stream = response.get("stream")
- # Process the stream
- async for event in event_stream:
- # Validate event structure
- if not event or not isinstance(event, dict):
- continue
- if "contentBlockDelta" in event:
- delta = event["contentBlockDelta"].get("delta", {})
- text = delta.get("text")
- if text:
- yield text
- # Handle other event types that might indicate stream end
- elif "messageStop" in event:
- break
- except Exception as e:
- # Convert to appropriate exception type
- _handle_bedrock_exception(e, "Bedrock streaming")
- finally:
- # Close the event stream once; client cleanup is handled by async with.
- # aiobotocore's EventStream exposes sync `close()`, while generic
- # async iterators expose async `aclose()` — handle both and dispatch
- # awaitable results accordingly.
- if event_stream is not None:
- close_fn = getattr(event_stream, "close", None) or getattr(
- event_stream, "aclose", None
- )
- if callable(close_fn):
- try:
- result = close_fn()
- if inspect.isawaitable(result):
- await result
- except Exception as close_error:
- logging.warning(
- f"Failed to close Bedrock event stream: {close_error}"
- )
- # Return the generator that manages its own lifecycle
- return stream_generator()
- # For non-streaming responses, use the standard async context manager pattern
- session = aioboto3.Session()
- async with session.client(
- "bedrock-runtime",
- **_bedrock_client_kwargs(
- region,
- endpoint_url,
- aws_access_key_id=aws_access_key_id,
- aws_secret_access_key=aws_secret_access_key,
- aws_session_token=aws_session_token,
- ),
- ) as bedrock_async_client:
- try:
- # Use converse for non-streaming responses
- response = await bedrock_async_client.converse(**args, **kwargs)
- # Validate response structure
- if (
- not response
- or "output" not in response
- or "message" not in response["output"]
- or "content" not in response["output"]["message"]
- or not response["output"]["message"]["content"]
- ):
- raise BedrockError("Invalid response structure from Bedrock API")
- # When thinking/reasoning is enabled, the first content block is a
- # `reasoningContent` block and the visible text follows in a later
- # block. Pick the first block that carries a text payload.
- content = next(
- (
- block["text"]
- for block in response["output"]["message"]["content"]
- if isinstance(block, dict) and block.get("text")
- ),
- None,
- )
- if not content or content.strip() == "":
- raise BedrockError("Received empty content from Bedrock API")
- return content
- except Exception as e:
- # Convert to appropriate exception type
- _handle_bedrock_exception(e, "Bedrock converse")
- # Generic Bedrock completion function
- async def bedrock_complete(
- prompt,
- system_prompt=None,
- history_messages=[],
- keyword_extraction=False,
- entity_extraction=False,
- **kwargs,
- ) -> Union[str, AsyncIterator[str]]:
- # Bedrock Converse API has no JSON mode; the shim booleans are absorbed
- # and forwarded so bedrock_complete_if_cache can emit DeprecationWarnings
- # with accurate stack frames.
- model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
- result = await bedrock_complete_if_cache(
- model_name,
- prompt,
- system_prompt=system_prompt,
- history_messages=history_messages,
- keyword_extraction=keyword_extraction,
- entity_extraction=entity_extraction,
- **kwargs,
- )
- return result
- @wrap_embedding_func_with_attrs(
- embedding_dim=1024, max_token_size=8192, model_name="amazon.titan-embed-text-v2:0"
- )
- @retry(
- stop=stop_after_attempt(5),
- wait=wait_exponential(multiplier=1, min=4, max=60),
- retry=(
- retry_if_exception_type(BedrockRateLimitError)
- | retry_if_exception_type(BedrockConnectionError)
- | retry_if_exception_type(BedrockTimeoutError)
- ),
- )
- async def bedrock_embed(
- texts: list[str],
- model: str = "amazon.titan-embed-text-v2:0",
- aws_access_key_id=None,
- aws_secret_access_key=None,
- aws_session_token=None,
- aws_region: str | None = None,
- api_key: str | None = None,
- endpoint_url: str | None = None,
- ) -> np.ndarray:
- """Generate embeddings with Amazon Bedrock Runtime.
- Authentication note:
- - Bedrock does not use LightRAG's generic ``api_key`` fields.
- - ``LLM_BINDING_API_KEY`` and ``EMBEDDING_BINDING_API_KEY`` are ignored for
- Bedrock.
- - To use Bedrock API key / bearer-token auth, set
- ``AWS_BEARER_TOKEN_BEDROCK`` before starting the process; this is a
- process-level AWS SDK setting.
- - For role-specific Bedrock configuration, use explicit SigV4 parameters
- (``aws_access_key_id``, ``aws_secret_access_key``, ``aws_session_token``,
- ``aws_region``). Per-role bearer-token overrides are not supported.
- """
- if api_key:
- warnings.warn(
- "bedrock_embed(api_key=...) is ignored; use SigV4 parameters or "
- "set AWS_BEARER_TOKEN_BEDROCK before process start.",
- DeprecationWarning,
- stacklevel=2,
- )
- region = aws_region
- endpoint_url = _normalize_bedrock_endpoint_url(endpoint_url)
- session = aioboto3.Session()
- async with session.client(
- "bedrock-runtime",
- **_bedrock_client_kwargs(
- region,
- endpoint_url,
- aws_access_key_id=aws_access_key_id,
- aws_secret_access_key=aws_secret_access_key,
- aws_session_token=aws_session_token,
- ),
- ) as bedrock_async_client:
- try:
- if (model_provider := model.split(".")[0]) == "amazon":
- embed_texts = []
- for text in texts:
- try:
- if "v2" in model:
- body = json.dumps(
- {
- "inputText": text,
- # 'dimensions': embedding_dim,
- "embeddingTypes": ["float"],
- }
- )
- elif "v1" in model:
- body = json.dumps({"inputText": text})
- else:
- raise BedrockError(f"Model {model} is not supported!")
- response = await bedrock_async_client.invoke_model(
- modelId=model,
- body=body,
- accept="application/json",
- contentType="application/json",
- )
- response_body = await response.get("body").json()
- # Validate response structure
- if not response_body or "embedding" not in response_body:
- raise BedrockError(
- f"Invalid embedding response structure for text: {text[:50]}..."
- )
- embedding = response_body["embedding"]
- if not embedding:
- raise BedrockError(
- f"Received empty embedding for text: {text[:50]}..."
- )
- embed_texts.append(embedding)
- except Exception as e:
- # Convert to appropriate exception type
- _handle_bedrock_exception(
- e, "Bedrock embedding (amazon, text chunk)"
- )
- elif model_provider == "cohere":
- try:
- body = json.dumps(
- {
- "texts": texts,
- "input_type": "search_document",
- "truncate": "NONE",
- }
- )
- response = await bedrock_async_client.invoke_model(
- model=model,
- body=body,
- accept="application/json",
- contentType="application/json",
- )
- response_body = json.loads(response.get("body").read())
- # Validate response structure
- if not response_body or "embeddings" not in response_body:
- raise BedrockError(
- "Invalid embedding response structure from Cohere"
- )
- embeddings = response_body["embeddings"]
- if not embeddings or len(embeddings) != len(texts):
- raise BedrockError(
- f"Invalid embeddings count: expected {len(texts)}, got {len(embeddings) if embeddings else 0}"
- )
- embed_texts = embeddings
- except Exception as e:
- # Convert to appropriate exception type
- _handle_bedrock_exception(e, "Bedrock embedding (cohere)")
- else:
- raise BedrockError(
- f"Model provider '{model_provider}' is not supported!"
- )
- # Final validation
- if not embed_texts:
- raise BedrockError("No embeddings generated")
- return np.array(embed_texts)
- except Exception as e:
- # Convert to appropriate exception type
- _handle_bedrock_exception(e, "Bedrock embedding")
|