bedrock.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. import copy
  2. import inspect
  3. import json
  4. import logging
  5. import warnings
  6. import pipmaster as pm # Pipmaster for dynamic library install
  7. if not pm.is_installed("aioboto3"):
  8. pm.install("aioboto3")
  9. import aioboto3
  10. import numpy as np
  11. from tenacity import (
  12. retry,
  13. stop_after_attempt,
  14. wait_exponential,
  15. retry_if_exception_type,
  16. )
  17. from collections.abc import AsyncIterator
  18. from typing import Any, Union
  19. from lightrag.utils import wrap_embedding_func_with_attrs
  20. # Import botocore exceptions for proper exception handling
  21. try:
  22. from botocore.exceptions import (
  23. ClientError,
  24. ConnectionError as BotocoreConnectionError,
  25. ReadTimeoutError,
  26. )
  27. except ImportError:
  28. # If botocore is not installed, define placeholders
  29. ClientError = Exception
  30. BotocoreConnectionError = Exception
  31. ReadTimeoutError = Exception
  32. class BedrockError(Exception):
  33. """Generic error for issues related to Amazon Bedrock"""
  34. class BedrockRateLimitError(BedrockError):
  35. """Error for rate limiting and throttling issues"""
  36. class BedrockConnectionError(BedrockError):
  37. """Error for network and connection issues"""
  38. class BedrockTimeoutError(BedrockError):
  39. """Error for timeout issues"""
  40. def _normalize_bedrock_endpoint_url(endpoint_url: str | None) -> str | None:
  41. """Return a usable Bedrock endpoint override or None for SDK defaults."""
  42. if endpoint_url is None:
  43. return None
  44. normalized = endpoint_url.strip()
  45. if not normalized or normalized == "DEFAULT_BEDROCK_ENDPOINT":
  46. return None
  47. return normalized
  48. def _bedrock_client_kwargs(
  49. region: str | None,
  50. endpoint_url: str | None,
  51. aws_access_key_id: str | None = None,
  52. aws_secret_access_key: str | None = None,
  53. aws_session_token: str | None = None,
  54. ) -> dict:
  55. """Build kwargs for aioboto3 ``session.client("bedrock-runtime", ...)``."""
  56. client_kwargs: dict = {"region_name": region}
  57. if endpoint_url is not None:
  58. client_kwargs["endpoint_url"] = endpoint_url
  59. if aws_access_key_id:
  60. client_kwargs["aws_access_key_id"] = aws_access_key_id
  61. if aws_secret_access_key:
  62. client_kwargs["aws_secret_access_key"] = aws_secret_access_key
  63. if aws_session_token:
  64. client_kwargs["aws_session_token"] = aws_session_token
  65. return client_kwargs
  66. def _handle_bedrock_exception(e: Exception, operation: str = "Bedrock API") -> None:
  67. """Convert AWS Bedrock exceptions to appropriate custom exceptions.
  68. Args:
  69. e: The exception to handle
  70. operation: Description of the operation for error messages
  71. Raises:
  72. BedrockRateLimitError: For rate limiting and throttling issues (retryable)
  73. BedrockConnectionError: For network and server issues (retryable)
  74. BedrockTimeoutError: For timeout issues (retryable)
  75. BedrockError: For validation and other non-retryable errors
  76. """
  77. error_message = str(e)
  78. # Handle botocore ClientError with specific error codes
  79. if isinstance(e, ClientError):
  80. error_code = e.response.get("Error", {}).get("Code", "")
  81. error_msg = e.response.get("Error", {}).get("Message", error_message)
  82. # Rate limiting and throttling errors (retryable)
  83. if error_code in [
  84. "ThrottlingException",
  85. "ProvisionedThroughputExceededException",
  86. ]:
  87. logging.error(f"{operation} rate limit error: {error_msg}")
  88. raise BedrockRateLimitError(f"Rate limit error: {error_msg}")
  89. # Server errors (retryable)
  90. elif error_code in ["ServiceUnavailableException", "InternalServerException"]:
  91. logging.error(f"{operation} connection error: {error_msg}")
  92. raise BedrockConnectionError(f"Service error: {error_msg}")
  93. # Check for 5xx HTTP status codes (retryable)
  94. elif e.response.get("ResponseMetadata", {}).get("HTTPStatusCode", 0) >= 500:
  95. logging.error(f"{operation} server error: {error_msg}")
  96. raise BedrockConnectionError(f"Server error: {error_msg}")
  97. # Validation and other client errors (non-retryable)
  98. else:
  99. logging.error(f"{operation} client error: {error_msg}")
  100. raise BedrockError(f"Client error: {error_msg}")
  101. # Connection errors (retryable)
  102. elif isinstance(e, BotocoreConnectionError):
  103. logging.error(f"{operation} connection error: {error_message}")
  104. raise BedrockConnectionError(f"Connection error: {error_message}")
  105. # Timeout errors (retryable)
  106. elif isinstance(e, (ReadTimeoutError, TimeoutError)):
  107. logging.error(f"{operation} timeout error: {error_message}")
  108. raise BedrockTimeoutError(f"Timeout error: {error_message}")
  109. # Custom Bedrock errors (already properly typed)
  110. elif isinstance(
  111. e,
  112. (
  113. BedrockRateLimitError,
  114. BedrockConnectionError,
  115. BedrockTimeoutError,
  116. BedrockError,
  117. ),
  118. ):
  119. raise
  120. # Unknown errors (non-retryable)
  121. else:
  122. logging.error(f"{operation} unexpected error: {error_message}")
  123. raise BedrockError(f"Unexpected error: {error_message}")
  124. @retry(
  125. stop=stop_after_attempt(5),
  126. wait=wait_exponential(multiplier=1, min=4, max=60),
  127. retry=(
  128. retry_if_exception_type(BedrockRateLimitError)
  129. | retry_if_exception_type(BedrockConnectionError)
  130. | retry_if_exception_type(BedrockTimeoutError)
  131. ),
  132. )
  133. async def bedrock_complete_if_cache(
  134. model,
  135. prompt,
  136. system_prompt=None,
  137. history_messages=[],
  138. enable_cot: bool = False,
  139. aws_access_key_id=None,
  140. aws_secret_access_key=None,
  141. aws_session_token=None,
  142. aws_region: str | None = None,
  143. api_key: str | None = None,
  144. endpoint_url: str | None = None,
  145. image_inputs: list[Any] | None = None,
  146. **kwargs,
  147. ) -> Union[str, AsyncIterator[str]]:
  148. """Call Amazon Bedrock Converse API with LightRAG-compatible shims.
  149. Structured output note:
  150. - This adapter does not support OpenAI-style ``response_format`` JSON mode.
  151. - If callers pass ``response_format``, it is stripped before the request.
  152. - Deprecated ``keyword_extraction`` and ``entity_extraction`` booleans are
  153. accepted only as compatibility shims; they emit warnings and are ignored.
  154. Authentication note:
  155. - Bedrock does not use LightRAG's generic ``api_key`` fields.
  156. - ``LLM_BINDING_API_KEY`` and ``EMBEDDING_BINDING_API_KEY`` are ignored for
  157. Bedrock.
  158. - To use Bedrock API key / bearer-token auth, set
  159. ``AWS_BEARER_TOKEN_BEDROCK`` before starting the process; this is a
  160. process-level AWS SDK setting.
  161. - For role-specific Bedrock LLMs, use explicit SigV4 parameters
  162. (``aws_access_key_id``, ``aws_secret_access_key``, ``aws_session_token``,
  163. ``aws_region``). Per-role bearer-token overrides are not supported.
  164. Endpoint note:
  165. - ``endpoint_url`` overrides the default regional Bedrock endpoint. Pass
  166. ``None``, an empty string, or the sentinel ``DEFAULT_BEDROCK_ENDPOINT``
  167. to let the AWS SDK select its default endpoint.
  168. """
  169. if enable_cot:
  170. logging.debug(
  171. "enable_cot=True is not supported for Bedrock and will be ignored."
  172. )
  173. # Bedrock Converse API has no JSON mode; drop legacy extraction flags and
  174. # response_format below and rely on the prompt template plus downstream
  175. # tolerant JSON parsing.
  176. keyword_extraction = kwargs.pop("keyword_extraction", False)
  177. entity_extraction = kwargs.pop("entity_extraction", False)
  178. if keyword_extraction:
  179. warnings.warn(
  180. "bedrock_complete_if_cache(keyword_extraction=True) is deprecated; "
  181. "pass response_format={'type': 'json_object'} instead.",
  182. DeprecationWarning,
  183. stacklevel=2,
  184. )
  185. if entity_extraction:
  186. warnings.warn(
  187. "bedrock_complete_if_cache(entity_extraction=True) is deprecated; "
  188. "pass response_format={'type': 'json_object'} instead.",
  189. DeprecationWarning,
  190. stacklevel=2,
  191. )
  192. if api_key:
  193. warnings.warn(
  194. "bedrock_complete_if_cache(api_key=...) is ignored; use SigV4 "
  195. "parameters or set AWS_BEARER_TOKEN_BEDROCK before process start.",
  196. DeprecationWarning,
  197. stacklevel=2,
  198. )
  199. region = aws_region or kwargs.pop("aws_region", None)
  200. endpoint_url = _normalize_bedrock_endpoint_url(endpoint_url)
  201. kwargs.pop("hashing_kv", None)
  202. # Capture stream flag (if provided) and remove from kwargs since it's not a Bedrock API parameter
  203. # We'll use this to determine whether to call converse_stream or converse
  204. stream = bool(kwargs.pop("stream", False))
  205. # Remove unsupported args for Bedrock Converse API
  206. for k in [
  207. "response_format",
  208. "tools",
  209. "tool_choice",
  210. "seed",
  211. "presence_penalty",
  212. "frequency_penalty",
  213. "n",
  214. "logprobs",
  215. "top_logprobs",
  216. "max_completion_tokens",
  217. ]:
  218. kwargs.pop(k, None)
  219. # Fix message history format
  220. messages = []
  221. for history_message in history_messages:
  222. message = copy.copy(history_message)
  223. message["content"] = [{"text": message["content"]}]
  224. messages.append(message)
  225. # Add user prompt
  226. if image_inputs:
  227. from lightrag.llm._vision_utils import normalize_image_inputs
  228. normalized_images = normalize_image_inputs(image_inputs)
  229. user_content: list[dict[str, Any]] = [{"text": prompt}]
  230. for img in normalized_images:
  231. fmt = img.mime_type.split("/", 1)[1] if "/" in img.mime_type else "png"
  232. user_content.append(
  233. {"image": {"format": fmt, "source": {"bytes": img.raw_bytes}}}
  234. )
  235. messages.append({"role": "user", "content": user_content})
  236. if stream:
  237. logging.getLogger(__name__).debug(
  238. "[bedrock] image_inputs provided; forcing non-stream Converse "
  239. "(stream + image combination has SDK limitations)"
  240. )
  241. stream = False
  242. else:
  243. messages.append({"role": "user", "content": [{"text": prompt}]})
  244. # Initialize Converse API arguments
  245. args = {"modelId": model, "messages": messages}
  246. # Define system prompt
  247. if system_prompt:
  248. args["system"] = [{"text": system_prompt}]
  249. # Map and set up inference parameters
  250. inference_params_map = {
  251. "max_tokens": "maxTokens",
  252. "top_p": "topP",
  253. "stop_sequences": "stopSequences",
  254. }
  255. inference_config: dict[str, Any] = {}
  256. for param in ("max_tokens", "temperature", "top_p", "stop_sequences"):
  257. if param not in kwargs:
  258. continue
  259. value = kwargs.pop(param)
  260. # Bedrock rejects None; a None default means "inherit provider default"
  261. if value is None:
  262. continue
  263. inference_config[inference_params_map.get(param, param)] = value
  264. if inference_config:
  265. args["inferenceConfig"] = inference_config
  266. # Pass-through for model-specific parameters (e.g. Anthropic reasoning_config,
  267. # Nova inferenceConfig extensions). Mirrors OpenAI's `extra_body`.
  268. extra_fields = kwargs.pop("extra_fields", None)
  269. if extra_fields:
  270. args["additionalModelRequestFields"] = extra_fields
  271. # For streaming responses, we need a different approach to keep the connection open
  272. if stream:
  273. # Create a session that will be used throughout the streaming process
  274. session = aioboto3.Session()
  275. client_kwargs = _bedrock_client_kwargs(
  276. region,
  277. endpoint_url,
  278. aws_access_key_id=aws_access_key_id,
  279. aws_secret_access_key=aws_secret_access_key,
  280. aws_session_token=aws_session_token,
  281. )
  282. # Define the generator function that will manage the client lifecycle
  283. async def stream_generator():
  284. # async with ensures the aioboto3 client is closed even under
  285. # task cancellation, avoiding aiohttp "Unclosed connection" warnings.
  286. async with session.client("bedrock-runtime", **client_kwargs) as client:
  287. event_stream = None
  288. try:
  289. # Make the API call
  290. response = await client.converse_stream(**args, **kwargs)
  291. event_stream = response.get("stream")
  292. # Process the stream
  293. async for event in event_stream:
  294. # Validate event structure
  295. if not event or not isinstance(event, dict):
  296. continue
  297. if "contentBlockDelta" in event:
  298. delta = event["contentBlockDelta"].get("delta", {})
  299. text = delta.get("text")
  300. if text:
  301. yield text
  302. # Handle other event types that might indicate stream end
  303. elif "messageStop" in event:
  304. break
  305. except Exception as e:
  306. # Convert to appropriate exception type
  307. _handle_bedrock_exception(e, "Bedrock streaming")
  308. finally:
  309. # Close the event stream once; client cleanup is handled by async with.
  310. # aiobotocore's EventStream exposes sync `close()`, while generic
  311. # async iterators expose async `aclose()` — handle both and dispatch
  312. # awaitable results accordingly.
  313. if event_stream is not None:
  314. close_fn = getattr(event_stream, "close", None) or getattr(
  315. event_stream, "aclose", None
  316. )
  317. if callable(close_fn):
  318. try:
  319. result = close_fn()
  320. if inspect.isawaitable(result):
  321. await result
  322. except Exception as close_error:
  323. logging.warning(
  324. f"Failed to close Bedrock event stream: {close_error}"
  325. )
  326. # Return the generator that manages its own lifecycle
  327. return stream_generator()
  328. # For non-streaming responses, use the standard async context manager pattern
  329. session = aioboto3.Session()
  330. async with session.client(
  331. "bedrock-runtime",
  332. **_bedrock_client_kwargs(
  333. region,
  334. endpoint_url,
  335. aws_access_key_id=aws_access_key_id,
  336. aws_secret_access_key=aws_secret_access_key,
  337. aws_session_token=aws_session_token,
  338. ),
  339. ) as bedrock_async_client:
  340. try:
  341. # Use converse for non-streaming responses
  342. response = await bedrock_async_client.converse(**args, **kwargs)
  343. # Validate response structure
  344. if (
  345. not response
  346. or "output" not in response
  347. or "message" not in response["output"]
  348. or "content" not in response["output"]["message"]
  349. or not response["output"]["message"]["content"]
  350. ):
  351. raise BedrockError("Invalid response structure from Bedrock API")
  352. # When thinking/reasoning is enabled, the first content block is a
  353. # `reasoningContent` block and the visible text follows in a later
  354. # block. Pick the first block that carries a text payload.
  355. content = next(
  356. (
  357. block["text"]
  358. for block in response["output"]["message"]["content"]
  359. if isinstance(block, dict) and block.get("text")
  360. ),
  361. None,
  362. )
  363. if not content or content.strip() == "":
  364. raise BedrockError("Received empty content from Bedrock API")
  365. return content
  366. except Exception as e:
  367. # Convert to appropriate exception type
  368. _handle_bedrock_exception(e, "Bedrock converse")
  369. # Generic Bedrock completion function
  370. async def bedrock_complete(
  371. prompt,
  372. system_prompt=None,
  373. history_messages=[],
  374. keyword_extraction=False,
  375. entity_extraction=False,
  376. **kwargs,
  377. ) -> Union[str, AsyncIterator[str]]:
  378. # Bedrock Converse API has no JSON mode; the shim booleans are absorbed
  379. # and forwarded so bedrock_complete_if_cache can emit DeprecationWarnings
  380. # with accurate stack frames.
  381. model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
  382. result = await bedrock_complete_if_cache(
  383. model_name,
  384. prompt,
  385. system_prompt=system_prompt,
  386. history_messages=history_messages,
  387. keyword_extraction=keyword_extraction,
  388. entity_extraction=entity_extraction,
  389. **kwargs,
  390. )
  391. return result
  392. @wrap_embedding_func_with_attrs(
  393. embedding_dim=1024, max_token_size=8192, model_name="amazon.titan-embed-text-v2:0"
  394. )
  395. @retry(
  396. stop=stop_after_attempt(5),
  397. wait=wait_exponential(multiplier=1, min=4, max=60),
  398. retry=(
  399. retry_if_exception_type(BedrockRateLimitError)
  400. | retry_if_exception_type(BedrockConnectionError)
  401. | retry_if_exception_type(BedrockTimeoutError)
  402. ),
  403. )
  404. async def bedrock_embed(
  405. texts: list[str],
  406. model: str = "amazon.titan-embed-text-v2:0",
  407. aws_access_key_id=None,
  408. aws_secret_access_key=None,
  409. aws_session_token=None,
  410. aws_region: str | None = None,
  411. api_key: str | None = None,
  412. endpoint_url: str | None = None,
  413. ) -> np.ndarray:
  414. """Generate embeddings with Amazon Bedrock Runtime.
  415. Authentication note:
  416. - Bedrock does not use LightRAG's generic ``api_key`` fields.
  417. - ``LLM_BINDING_API_KEY`` and ``EMBEDDING_BINDING_API_KEY`` are ignored for
  418. Bedrock.
  419. - To use Bedrock API key / bearer-token auth, set
  420. ``AWS_BEARER_TOKEN_BEDROCK`` before starting the process; this is a
  421. process-level AWS SDK setting.
  422. - For role-specific Bedrock configuration, use explicit SigV4 parameters
  423. (``aws_access_key_id``, ``aws_secret_access_key``, ``aws_session_token``,
  424. ``aws_region``). Per-role bearer-token overrides are not supported.
  425. """
  426. if api_key:
  427. warnings.warn(
  428. "bedrock_embed(api_key=...) is ignored; use SigV4 parameters or "
  429. "set AWS_BEARER_TOKEN_BEDROCK before process start.",
  430. DeprecationWarning,
  431. stacklevel=2,
  432. )
  433. region = aws_region
  434. endpoint_url = _normalize_bedrock_endpoint_url(endpoint_url)
  435. session = aioboto3.Session()
  436. async with session.client(
  437. "bedrock-runtime",
  438. **_bedrock_client_kwargs(
  439. region,
  440. endpoint_url,
  441. aws_access_key_id=aws_access_key_id,
  442. aws_secret_access_key=aws_secret_access_key,
  443. aws_session_token=aws_session_token,
  444. ),
  445. ) as bedrock_async_client:
  446. try:
  447. if (model_provider := model.split(".")[0]) == "amazon":
  448. embed_texts = []
  449. for text in texts:
  450. try:
  451. if "v2" in model:
  452. body = json.dumps(
  453. {
  454. "inputText": text,
  455. # 'dimensions': embedding_dim,
  456. "embeddingTypes": ["float"],
  457. }
  458. )
  459. elif "v1" in model:
  460. body = json.dumps({"inputText": text})
  461. else:
  462. raise BedrockError(f"Model {model} is not supported!")
  463. response = await bedrock_async_client.invoke_model(
  464. modelId=model,
  465. body=body,
  466. accept="application/json",
  467. contentType="application/json",
  468. )
  469. response_body = await response.get("body").json()
  470. # Validate response structure
  471. if not response_body or "embedding" not in response_body:
  472. raise BedrockError(
  473. f"Invalid embedding response structure for text: {text[:50]}..."
  474. )
  475. embedding = response_body["embedding"]
  476. if not embedding:
  477. raise BedrockError(
  478. f"Received empty embedding for text: {text[:50]}..."
  479. )
  480. embed_texts.append(embedding)
  481. except Exception as e:
  482. # Convert to appropriate exception type
  483. _handle_bedrock_exception(
  484. e, "Bedrock embedding (amazon, text chunk)"
  485. )
  486. elif model_provider == "cohere":
  487. try:
  488. body = json.dumps(
  489. {
  490. "texts": texts,
  491. "input_type": "search_document",
  492. "truncate": "NONE",
  493. }
  494. )
  495. response = await bedrock_async_client.invoke_model(
  496. model=model,
  497. body=body,
  498. accept="application/json",
  499. contentType="application/json",
  500. )
  501. response_body = json.loads(response.get("body").read())
  502. # Validate response structure
  503. if not response_body or "embeddings" not in response_body:
  504. raise BedrockError(
  505. "Invalid embedding response structure from Cohere"
  506. )
  507. embeddings = response_body["embeddings"]
  508. if not embeddings or len(embeddings) != len(texts):
  509. raise BedrockError(
  510. f"Invalid embeddings count: expected {len(texts)}, got {len(embeddings) if embeddings else 0}"
  511. )
  512. embed_texts = embeddings
  513. except Exception as e:
  514. # Convert to appropriate exception type
  515. _handle_bedrock_exception(e, "Bedrock embedding (cohere)")
  516. else:
  517. raise BedrockError(
  518. f"Model provider '{model_provider}' is not supported!"
  519. )
  520. # Final validation
  521. if not embed_texts:
  522. raise BedrockError("No embeddings generated")
  523. return np.array(embed_texts)
  524. except Exception as e:
  525. # Convert to appropriate exception type
  526. _handle_bedrock_exception(e, "Bedrock embedding")