lightrag_gemini_demo.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. """
  2. LightRAG Demo with Google Gemini Models
  3. This example demonstrates how to use LightRAG with Google's Gemini 2.0 Flash model
  4. for text generation and the text-embedding-004 model for embeddings.
  5. Prerequisites:
  6. 1. Set GEMINI_API_KEY environment variable:
  7. export GEMINI_API_KEY='your-actual-api-key'
  8. 2. Prepare a text file named 'book.txt' in the current directory
  9. (or modify BOOK_FILE constant to point to your text file)
  10. Usage:
  11. python examples/lightrag_gemini_demo.py
  12. """
  13. import os
  14. import asyncio
  15. import nest_asyncio
  16. import numpy as np
  17. from lightrag import LightRAG, QueryParam
  18. from lightrag.llm.gemini import gemini_model_complete, gemini_embed
  19. from lightrag.utils import wrap_embedding_func_with_attrs
  20. nest_asyncio.apply()
  21. WORKING_DIR = "./rag_storage"
  22. BOOK_FILE = "./book.txt"
  23. # Validate API key
  24. GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
  25. if not GEMINI_API_KEY:
  26. raise ValueError(
  27. "GEMINI_API_KEY environment variable is not set. "
  28. "Please set it with: export GEMINI_API_KEY='your-api-key'"
  29. )
  30. if not os.path.exists(WORKING_DIR):
  31. os.mkdir(WORKING_DIR)
  32. # --------------------------------------------------
  33. # LLM function
  34. # --------------------------------------------------
  35. async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
  36. return await gemini_model_complete(
  37. prompt,
  38. system_prompt=system_prompt,
  39. history_messages=history_messages,
  40. api_key=GEMINI_API_KEY,
  41. model_name="gemini-2.0-flash",
  42. **kwargs,
  43. )
  44. # --------------------------------------------------
  45. # Embedding function
  46. # --------------------------------------------------
  47. @wrap_embedding_func_with_attrs(
  48. embedding_dim=768,
  49. send_dimensions=True,
  50. max_token_size=2048,
  51. model_name="models/text-embedding-004",
  52. )
  53. async def embedding_func(texts: list[str]) -> np.ndarray:
  54. return await gemini_embed.func(
  55. texts, api_key=GEMINI_API_KEY, model="models/text-embedding-004"
  56. )
  57. # --------------------------------------------------
  58. # Initialize RAG
  59. # --------------------------------------------------
  60. async def initialize_rag():
  61. rag = LightRAG(
  62. working_dir=WORKING_DIR,
  63. llm_model_func=llm_model_func,
  64. embedding_func=embedding_func,
  65. llm_model_name="gemini-2.0-flash",
  66. )
  67. # 🔑 REQUIRED
  68. await rag.initialize_storages()
  69. return rag
  70. # --------------------------------------------------
  71. # Main
  72. # --------------------------------------------------
  73. def main():
  74. # Validate book file exists
  75. if not os.path.exists(BOOK_FILE):
  76. raise FileNotFoundError(
  77. f"'{BOOK_FILE}' not found. "
  78. "Please provide a text file to index in the current directory."
  79. )
  80. rag = asyncio.run(initialize_rag())
  81. # Insert text
  82. with open(BOOK_FILE, "r", encoding="utf-8") as f:
  83. rag.insert(f.read())
  84. query = "What are the top themes?"
  85. print("\nNaive Search:")
  86. print(rag.query(query, param=QueryParam(mode="naive")))
  87. print("\nLocal Search:")
  88. print(rag.query(query, param=QueryParam(mode="local")))
  89. print("\nGlobal Search:")
  90. print(rag.query(query, param=QueryParam(mode="global")))
  91. print("\nHybrid Search:")
  92. print(rag.query(query, param=QueryParam(mode="hybrid")))
  93. if __name__ == "__main__":
  94. main()