lightrag_llamaindex_direct_demo.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import os
  2. from lightrag import LightRAG, QueryParam
  3. from lightrag.llm.llama_index_impl import (
  4. llama_index_complete_if_cache,
  5. llama_index_embed,
  6. )
  7. from lightrag.utils import EmbeddingFunc
  8. from llama_index.llms.openai import OpenAI
  9. from llama_index.embeddings.openai import OpenAIEmbedding
  10. import asyncio
  11. import nest_asyncio
  12. nest_asyncio.apply()
  13. # Configure working directory
  14. WORKING_DIR = "./index_default"
  15. print(f"WORKING_DIR: {WORKING_DIR}")
  16. # Model configuration
  17. LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4")
  18. print(f"LLM_MODEL: {LLM_MODEL}")
  19. EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
  20. print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
  21. EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
  22. print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
  23. # OpenAI configuration
  24. OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "your-api-key-here")
  25. if not os.path.exists(WORKING_DIR):
  26. print(f"Creating working directory: {WORKING_DIR}")
  27. os.mkdir(WORKING_DIR)
  28. # Initialize LLM function
  29. async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
  30. try:
  31. # Initialize OpenAI if not in kwargs
  32. if "llm_instance" not in kwargs:
  33. llm_instance = OpenAI(
  34. model=LLM_MODEL,
  35. api_key=OPENAI_API_KEY,
  36. temperature=0.7,
  37. )
  38. kwargs["llm_instance"] = llm_instance
  39. response = await llama_index_complete_if_cache(
  40. kwargs["llm_instance"],
  41. prompt,
  42. system_prompt=system_prompt,
  43. history_messages=history_messages,
  44. **kwargs,
  45. )
  46. return response
  47. except Exception as e:
  48. print(f"LLM request failed: {str(e)}")
  49. raise
  50. # Initialize embedding function
  51. async def embedding_func(texts):
  52. try:
  53. embed_model = OpenAIEmbedding(
  54. model=EMBEDDING_MODEL,
  55. api_key=OPENAI_API_KEY,
  56. )
  57. return await llama_index_embed(texts, embed_model=embed_model)
  58. except Exception as e:
  59. print(f"Embedding failed: {str(e)}")
  60. raise
  61. # Get embedding dimension
  62. async def get_embedding_dim():
  63. test_text = ["This is a test sentence."]
  64. embedding = await embedding_func(test_text)
  65. embedding_dim = embedding.shape[1]
  66. print(f"embedding_dim={embedding_dim}")
  67. return embedding_dim
  68. async def initialize_rag():
  69. embedding_dimension = await get_embedding_dim()
  70. rag = LightRAG(
  71. working_dir=WORKING_DIR,
  72. llm_model_func=llm_model_func,
  73. embedding_func=EmbeddingFunc(
  74. embedding_dim=embedding_dimension,
  75. max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
  76. func=embedding_func,
  77. ),
  78. )
  79. await rag.initialize_storages() # Auto-initializes pipeline_status
  80. return rag
  81. def main():
  82. # Initialize RAG instance
  83. rag = asyncio.run(initialize_rag())
  84. # Insert example text
  85. with open("./book.txt", "r", encoding="utf-8") as f:
  86. rag.insert(f.read())
  87. # Test different query modes
  88. print("\nNaive Search:")
  89. print(
  90. rag.query(
  91. "What are the top themes in this story?", param=QueryParam(mode="naive")
  92. )
  93. )
  94. print("\nLocal Search:")
  95. print(
  96. rag.query(
  97. "What are the top themes in this story?", param=QueryParam(mode="local")
  98. )
  99. )
  100. print("\nGlobal Search:")
  101. print(
  102. rag.query(
  103. "What are the top themes in this story?", param=QueryParam(mode="global")
  104. )
  105. )
  106. print("\nHybrid Search:")
  107. print(
  108. rag.query(
  109. "What are the top themes in this story?", param=QueryParam(mode="hybrid")
  110. )
  111. )
  112. if __name__ == "__main__":
  113. main()