lightrag_llamaindex_litellm_opik_demo.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import os
  2. from lightrag import LightRAG, QueryParam
  3. from lightrag.llm.llama_index_impl import (
  4. llama_index_complete_if_cache,
  5. llama_index_embed,
  6. )
  7. from lightrag.utils import EmbeddingFunc
  8. from llama_index.llms.litellm import LiteLLM
  9. from llama_index.embeddings.litellm import LiteLLMEmbedding
  10. import asyncio
  11. import nest_asyncio
  12. nest_asyncio.apply()
  13. # Configure working directory
  14. WORKING_DIR = "./index_default"
  15. print(f"WORKING_DIR: {WORKING_DIR}")
  16. # Model configuration
  17. LLM_MODEL = os.environ.get("LLM_MODEL", "gemma-3-4b")
  18. print(f"LLM_MODEL: {LLM_MODEL}")
  19. EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "arctic-embed")
  20. print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
  21. EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
  22. print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
  23. # LiteLLM configuration
  24. LITELLM_URL = os.environ.get("LITELLM_URL", "http://localhost:4000")
  25. print(f"LITELLM_URL: {LITELLM_URL}")
  26. LITELLM_KEY = os.environ.get("LITELLM_KEY", "sk-4JdvGFKqSA3S0k_5p0xufw")
  27. if not os.path.exists(WORKING_DIR):
  28. os.mkdir(WORKING_DIR)
  29. # Initialize LLM function
  30. async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
  31. try:
  32. # Initialize LiteLLM if not in kwargs
  33. if "llm_instance" not in kwargs:
  34. llm_instance = LiteLLM(
  35. model=f"openai/{LLM_MODEL}", # Format: "provider/model_name"
  36. api_base=LITELLM_URL,
  37. api_key=LITELLM_KEY,
  38. temperature=0.7,
  39. )
  40. kwargs["llm_instance"] = llm_instance
  41. chat_kwargs = {}
  42. chat_kwargs["litellm_params"] = {
  43. "metadata": {
  44. "opik": {
  45. "project_name": "lightrag_llamaindex_litellm_opik_demo",
  46. "tags": ["lightrag", "litellm"],
  47. }
  48. }
  49. }
  50. response = await llama_index_complete_if_cache(
  51. kwargs["llm_instance"],
  52. prompt,
  53. system_prompt=system_prompt,
  54. history_messages=history_messages,
  55. chat_kwargs=chat_kwargs,
  56. )
  57. return response
  58. except Exception as e:
  59. print(f"LLM request failed: {str(e)}")
  60. raise
  61. # Initialize embedding function
  62. async def embedding_func(texts):
  63. try:
  64. embed_model = LiteLLMEmbedding(
  65. model_name=f"openai/{EMBEDDING_MODEL}",
  66. api_base=LITELLM_URL,
  67. api_key=LITELLM_KEY,
  68. )
  69. return await llama_index_embed(texts, embed_model=embed_model)
  70. except Exception as e:
  71. print(f"Embedding failed: {str(e)}")
  72. raise
  73. # Get embedding dimension
  74. async def get_embedding_dim():
  75. test_text = ["This is a test sentence."]
  76. embedding = await embedding_func(test_text)
  77. embedding_dim = embedding.shape[1]
  78. print(f"embedding_dim={embedding_dim}")
  79. return embedding_dim
  80. async def initialize_rag():
  81. embedding_dimension = await get_embedding_dim()
  82. rag = LightRAG(
  83. working_dir=WORKING_DIR,
  84. llm_model_func=llm_model_func,
  85. embedding_func=EmbeddingFunc(
  86. embedding_dim=embedding_dimension,
  87. max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
  88. func=embedding_func,
  89. ),
  90. )
  91. await rag.initialize_storages() # Auto-initializes pipeline_status
  92. return rag
  93. def main():
  94. # Initialize RAG instance
  95. rag = asyncio.run(initialize_rag())
  96. # Insert example text
  97. with open("./book.txt", "r", encoding="utf-8") as f:
  98. rag.insert(f.read())
  99. # Test different query modes
  100. print("\nNaive Search:")
  101. print(
  102. rag.query(
  103. "What are the top themes in this story?", param=QueryParam(mode="naive")
  104. )
  105. )
  106. print("\nLocal Search:")
  107. print(
  108. rag.query(
  109. "What are the top themes in this story?", param=QueryParam(mode="local")
  110. )
  111. )
  112. print("\nGlobal Search:")
  113. print(
  114. rag.query(
  115. "What are the top themes in this story?", param=QueryParam(mode="global")
  116. )
  117. )
  118. print("\nHybrid Search:")
  119. print(
  120. rag.query(
  121. "What are the top themes in this story?", param=QueryParam(mode="hybrid")
  122. )
  123. )
  124. if __name__ == "__main__":
  125. main()