nvidia_openai.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import sys
  2. import os
  3. if sys.version_info < (3, 9):
  4. pass
  5. else:
  6. pass
  7. import pipmaster as pm # Pipmaster for dynamic library install
  8. # install specific modules
  9. if not pm.is_installed("openai"):
  10. pm.install("openai")
  11. from openai import (
  12. AsyncOpenAI,
  13. APIConnectionError,
  14. RateLimitError,
  15. APITimeoutError,
  16. )
  17. from tenacity import (
  18. retry,
  19. stop_after_attempt,
  20. wait_exponential,
  21. retry_if_exception_type,
  22. )
  23. from lightrag.utils import (
  24. wrap_embedding_func_with_attrs,
  25. )
  26. import numpy as np
  27. @wrap_embedding_func_with_attrs(
  28. embedding_dim=2048, max_token_size=8192, model_name="nvidia_embedding_model"
  29. )
  30. @retry(
  31. stop=stop_after_attempt(3),
  32. wait=wait_exponential(multiplier=1, min=4, max=60),
  33. retry=retry_if_exception_type(
  34. (RateLimitError, APIConnectionError, APITimeoutError)
  35. ),
  36. )
  37. async def nvidia_openai_embed(
  38. texts: list[str],
  39. model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
  40. # refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
  41. base_url: str = "https://integrate.api.nvidia.com/v1",
  42. api_key: str = None,
  43. input_type: str = "passage", # query for retrieval, passage for embedding
  44. trunc: str = "NONE", # NONE or START or END
  45. encode: str = "float", # float or base64
  46. ) -> np.ndarray:
  47. if api_key:
  48. os.environ["OPENAI_API_KEY"] = api_key
  49. openai_async_client = (
  50. AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
  51. )
  52. response = await openai_async_client.embeddings.create(
  53. model=model,
  54. input=texts,
  55. encoding_format=encode,
  56. extra_body={"input_type": input_type, "truncate": trunc},
  57. )
  58. return np.array([dp.embedding for dp in response.data])