lightrag_azure_openai_demo.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import os
  2. import asyncio
  3. from lightrag import LightRAG, QueryParam
  4. from lightrag.utils import EmbeddingFunc
  5. import numpy as np
  6. from dotenv import load_dotenv
  7. import logging
  8. from openai import AzureOpenAI
  9. logging.basicConfig(level=logging.INFO)
  10. load_dotenv()
  11. AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")
  12. AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT")
  13. AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
  14. AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
  15. AZURE_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
  16. AZURE_EMBEDDING_API_VERSION = os.getenv("AZURE_EMBEDDING_API_VERSION")
  17. WORKING_DIR = "./dickens"
  18. if os.path.exists(WORKING_DIR):
  19. import shutil
  20. shutil.rmtree(WORKING_DIR)
  21. os.mkdir(WORKING_DIR)
  22. async def llm_model_func(
  23. prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
  24. ) -> str:
  25. client = AzureOpenAI(
  26. api_key=AZURE_OPENAI_API_KEY,
  27. api_version=AZURE_OPENAI_API_VERSION,
  28. azure_endpoint=AZURE_OPENAI_ENDPOINT,
  29. )
  30. messages = []
  31. if system_prompt:
  32. messages.append({"role": "system", "content": system_prompt})
  33. if history_messages:
  34. messages.extend(history_messages)
  35. messages.append({"role": "user", "content": prompt})
  36. chat_completion = client.chat.completions.create(
  37. model=AZURE_OPENAI_DEPLOYMENT, # model = "deployment_name".
  38. messages=messages,
  39. temperature=kwargs.get("temperature", 0),
  40. top_p=kwargs.get("top_p", 1),
  41. n=kwargs.get("n", 1),
  42. )
  43. if not chat_completion.choices or chat_completion.choices[0].message is None:
  44. return ""
  45. return chat_completion.choices[0].message.content
  46. async def embedding_func(texts: list[str]) -> np.ndarray:
  47. client = AzureOpenAI(
  48. api_key=AZURE_OPENAI_API_KEY,
  49. api_version=AZURE_EMBEDDING_API_VERSION,
  50. azure_endpoint=AZURE_OPENAI_ENDPOINT,
  51. )
  52. embedding = client.embeddings.create(model=AZURE_EMBEDDING_DEPLOYMENT, input=texts)
  53. embeddings = [item.embedding for item in embedding.data]
  54. return np.array(embeddings)
  55. async def test_funcs():
  56. result = await llm_model_func("How are you?")
  57. print("Resposta do llm_model_func: ", result)
  58. result = await embedding_func(["How are you?"])
  59. print("Resultado do embedding_func: ", result.shape)
  60. print("Dimensão da embedding: ", result.shape[1])
  61. asyncio.run(test_funcs())
  62. embedding_dimension = 3072
  63. async def initialize_rag():
  64. rag = LightRAG(
  65. working_dir=WORKING_DIR,
  66. llm_model_func=llm_model_func,
  67. embedding_func=EmbeddingFunc(
  68. embedding_dim=embedding_dimension,
  69. max_token_size=8192,
  70. func=embedding_func,
  71. ),
  72. )
  73. await rag.initialize_storages() # Auto-initializes pipeline_status
  74. return rag
  75. def main():
  76. rag = asyncio.run(initialize_rag())
  77. book1 = open("./book_1.txt", encoding="utf-8")
  78. book2 = open("./book_2.txt", encoding="utf-8")
  79. rag.insert([book1.read(), book2.read()])
  80. query_text = "What are the main themes?"
  81. print("Result (Naive):")
  82. print(rag.query(query_text, param=QueryParam(mode="naive")))
  83. print("\nResult (Local):")
  84. print(rag.query(query_text, param=QueryParam(mode="local")))
  85. print("\nResult (Global):")
  86. print(rag.query(query_text, param=QueryParam(mode="global")))
  87. print("\nResult (Hybrid):")
  88. print(rag.query(query_text, param=QueryParam(mode="hybrid")))
  89. if __name__ == "__main__":
  90. main()