lightrag_llamaindex_litellm_demo.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import os
  2. from lightrag import LightRAG, QueryParam
  3. from lightrag.llm.llama_index_impl import (
  4. llama_index_complete_if_cache,
  5. llama_index_embed,
  6. )
  7. from lightrag.utils import EmbeddingFunc
  8. from llama_index.llms.litellm import LiteLLM
  9. from llama_index.embeddings.litellm import LiteLLMEmbedding
  10. import asyncio
  11. import nest_asyncio
  12. nest_asyncio.apply()
  13. # Configure working directory
  14. WORKING_DIR = "./index_default"
  15. print(f"WORKING_DIR: {WORKING_DIR}")
  16. # Model configuration
  17. LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4")
  18. print(f"LLM_MODEL: {LLM_MODEL}")
  19. EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
  20. print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
  21. EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
  22. print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
  23. # LiteLLM configuration
  24. LITELLM_URL = os.environ.get("LITELLM_URL", "http://localhost:4000")
  25. print(f"LITELLM_URL: {LITELLM_URL}")
  26. LITELLM_KEY = os.environ.get("LITELLM_KEY", "sk-1234")
  27. if not os.path.exists(WORKING_DIR):
  28. os.mkdir(WORKING_DIR)
  29. # Initialize LLM function
  30. async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
  31. try:
  32. # Initialize LiteLLM if not in kwargs
  33. if "llm_instance" not in kwargs:
  34. llm_instance = LiteLLM(
  35. model=f"openai/{LLM_MODEL}", # Format: "provider/model_name"
  36. api_base=LITELLM_URL,
  37. api_key=LITELLM_KEY,
  38. temperature=0.7,
  39. )
  40. kwargs["llm_instance"] = llm_instance
  41. response = await llama_index_complete_if_cache(
  42. kwargs["llm_instance"],
  43. prompt,
  44. system_prompt=system_prompt,
  45. history_messages=history_messages,
  46. )
  47. return response
  48. except Exception as e:
  49. print(f"LLM request failed: {str(e)}")
  50. raise
  51. # Initialize embedding function
  52. async def embedding_func(texts):
  53. try:
  54. embed_model = LiteLLMEmbedding(
  55. model_name=f"openai/{EMBEDDING_MODEL}",
  56. api_base=LITELLM_URL,
  57. api_key=LITELLM_KEY,
  58. )
  59. return await llama_index_embed(texts, embed_model=embed_model)
  60. except Exception as e:
  61. print(f"Embedding failed: {str(e)}")
  62. raise
  63. # Get embedding dimension
  64. async def get_embedding_dim():
  65. test_text = ["This is a test sentence."]
  66. embedding = await embedding_func(test_text)
  67. embedding_dim = embedding.shape[1]
  68. print(f"embedding_dim={embedding_dim}")
  69. return embedding_dim
  70. async def initialize_rag():
  71. embedding_dimension = await get_embedding_dim()
  72. rag = LightRAG(
  73. working_dir=WORKING_DIR,
  74. llm_model_func=llm_model_func,
  75. embedding_func=EmbeddingFunc(
  76. embedding_dim=embedding_dimension,
  77. max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
  78. func=embedding_func,
  79. ),
  80. )
  81. await rag.initialize_storages() # Auto-initializes pipeline_status
  82. return rag
  83. def main():
  84. # Initialize RAG instance
  85. rag = asyncio.run(initialize_rag())
  86. # Insert example text
  87. with open("./book.txt", "r", encoding="utf-8") as f:
  88. rag.insert(f.read())
  89. # Test different query modes
  90. print("\nNaive Search:")
  91. print(
  92. rag.query(
  93. "What are the top themes in this story?", param=QueryParam(mode="naive")
  94. )
  95. )
  96. print("\nLocal Search:")
  97. print(
  98. rag.query(
  99. "What are the top themes in this story?", param=QueryParam(mode="local")
  100. )
  101. )
  102. print("\nGlobal Search:")
  103. print(
  104. rag.query(
  105. "What are the top themes in this story?", param=QueryParam(mode="global")
  106. )
  107. )
  108. print("\nHybrid Search:")
  109. print(
  110. rag.query(
  111. "What are the top themes in this story?", param=QueryParam(mode="hybrid")
  112. )
  113. )
  114. if __name__ == "__main__":
  115. main()