Step_1_openai_compatible.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import os
  2. import json
  3. import time
  4. import asyncio
  5. import numpy as np
  6. from lightrag import LightRAG
  7. from lightrag.utils import EmbeddingFunc
  8. from lightrag.llm.openai import openai_complete_if_cache, openai_embed
  9. ## For Upstage API
  10. # please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
  11. async def llm_model_func(
  12. prompt, system_prompt=None, history_messages=[], **kwargs
  13. ) -> str:
  14. return await openai_complete_if_cache(
  15. "solar-mini",
  16. prompt,
  17. system_prompt=system_prompt,
  18. history_messages=history_messages,
  19. api_key=os.getenv("UPSTAGE_API_KEY"),
  20. base_url="https://api.upstage.ai/v1/solar",
  21. **kwargs,
  22. )
  23. async def embedding_func(texts: list[str]) -> np.ndarray:
  24. return await openai_embed(
  25. texts,
  26. model="solar-embedding-1-large-query",
  27. api_key=os.getenv("UPSTAGE_API_KEY"),
  28. base_url="https://api.upstage.ai/v1/solar",
  29. )
  30. ## /For Upstage API
  31. def insert_text(rag, file_path):
  32. with open(file_path, mode="r") as f:
  33. unique_contexts = json.load(f)
  34. retries = 0
  35. max_retries = 3
  36. while retries < max_retries:
  37. try:
  38. rag.insert(unique_contexts)
  39. break
  40. except Exception as e:
  41. retries += 1
  42. print(f"Insertion failed, retrying ({retries}/{max_retries}), error: {e}")
  43. time.sleep(10)
  44. if retries == max_retries:
  45. print("Insertion failed after exceeding the maximum number of retries")
  46. cls = "mix"
  47. WORKING_DIR = f"../{cls}"
  48. if not os.path.exists(WORKING_DIR):
  49. os.mkdir(WORKING_DIR)
  50. async def initialize_rag():
  51. rag = LightRAG(
  52. working_dir=WORKING_DIR,
  53. llm_model_func=llm_model_func,
  54. embedding_func=EmbeddingFunc(embedding_dim=4096, func=embedding_func),
  55. )
  56. await rag.initialize_storages() # Auto-initializes pipeline_status
  57. return rag
  58. def main():
  59. # Initialize RAG instance
  60. rag = asyncio.run(initialize_rag())
  61. insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
  62. if __name__ == "__main__":
  63. main()