lightrag_vllm_demo.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. """
  2. LightRAG Demo with vLLM (LLM, Embeddings, and Reranker)
  3. This example demonstrates how to use LightRAG with:
  4. - vLLM-served LLM (OpenAI-compatible API)
  5. - vLLM-served embedding model
  6. - Jina-compatible reranker (also vLLM-served)
  7. Prerequisites:
  8. 1. Create a .env file or export environment variables:
  9. - LLM_MODEL
  10. - LLM_BINDING_HOST
  11. - LLM_BINDING_API_KEY
  12. - EMBEDDING_MODEL
  13. - EMBEDDING_BINDING_HOST
  14. - EMBEDDING_BINDING_API_KEY
  15. - EMBEDDING_DIM
  16. - EMBEDDING_TOKEN_LIMIT
  17. - RERANK_MODEL
  18. - RERANK_BINDING_HOST
  19. - RERANK_BINDING_API_KEY
  20. 2. Prepare a text file to index (default: Data/book-small.txt)
  21. 3. Configure storage backends via environment variables or modify
  22. the storage parameters in initialize_rag() below.
  23. Usage:
  24. python examples/lightrag_vllm_demo.py
  25. """
  26. import os
  27. import asyncio
  28. from functools import partial
  29. from dotenv import load_dotenv
  30. from lightrag import LightRAG, QueryParam
  31. from lightrag.llm.openai import openai_complete_if_cache, openai_embed
  32. from lightrag.utils import EmbeddingFunc
  33. from lightrag.rerank import jina_rerank
  34. load_dotenv()
  35. # --------------------------------------------------
  36. # Constants
  37. # --------------------------------------------------
  38. WORKING_DIR = "./LightRAG_Data"
  39. BOOK_FILE = "Data/book-small.txt"
  40. # --------------------------------------------------
  41. # LLM function (vLLM, OpenAI-compatible)
  42. # --------------------------------------------------
  43. async def llm_model_func(
  44. prompt, system_prompt=None, history_messages=[], **kwargs
  45. ) -> str:
  46. return await openai_complete_if_cache(
  47. model=os.getenv("LLM_MODEL", "Qwen/Qwen3-14B-AWQ"),
  48. prompt=prompt,
  49. system_prompt=system_prompt,
  50. history_messages=history_messages,
  51. base_url=os.getenv("LLM_BINDING_HOST", "http://0.0.0.0:4646/v1"),
  52. api_key=os.getenv("LLM_BINDING_API_KEY", "not_needed"),
  53. timeout=600,
  54. **kwargs,
  55. )
  56. # --------------------------------------------------
  57. # Embedding function (vLLM)
  58. # --------------------------------------------------
  59. vLLM_emb_func = EmbeddingFunc(
  60. model_name=os.getenv("EMBEDDING_MODEL", "Qwen/Qwen3-Embedding-0.6B"),
  61. send_dimensions=False,
  62. embedding_dim=int(os.getenv("EMBEDDING_DIM", 1024)),
  63. max_token_size=int(os.getenv("EMBEDDING_TOKEN_LIMIT", 4096)),
  64. func=partial(
  65. openai_embed.func,
  66. model=os.getenv("EMBEDDING_MODEL", "Qwen/Qwen3-Embedding-0.6B"),
  67. base_url=os.getenv(
  68. "EMBEDDING_BINDING_HOST",
  69. "http://0.0.0.0:1234/v1",
  70. ),
  71. api_key=os.getenv("EMBEDDING_BINDING_API_KEY", "not_needed"),
  72. ),
  73. )
  74. # --------------------------------------------------
  75. # Reranker (Jina-compatible, vLLM-served)
  76. # --------------------------------------------------
  77. jina_rerank_model_func = partial(
  78. jina_rerank,
  79. model=os.getenv("RERANK_MODEL", "Qwen/Qwen3-Reranker-0.6B"),
  80. api_key=os.getenv("RERANK_BINDING_API_KEY"),
  81. base_url=os.getenv(
  82. "RERANK_BINDING_HOST",
  83. "http://0.0.0.0:3535/v1/rerank",
  84. ),
  85. )
  86. # --------------------------------------------------
  87. # Initialize RAG
  88. # --------------------------------------------------
  89. async def initialize_rag():
  90. rag = LightRAG(
  91. working_dir=WORKING_DIR,
  92. llm_model_func=llm_model_func,
  93. embedding_func=vLLM_emb_func,
  94. rerank_model_func=jina_rerank_model_func,
  95. # Storage backends (configurable via environment or modify here)
  96. kv_storage=os.getenv("KV_STORAGE", "PGKVStorage"),
  97. doc_status_storage=os.getenv("DOC_STATUS_STORAGE", "PGDocStatusStorage"),
  98. vector_storage=os.getenv("VECTOR_STORAGE", "PGVectorStorage"),
  99. graph_storage=os.getenv("GRAPH_STORAGE", "Neo4JStorage"),
  100. )
  101. await rag.initialize_storages()
  102. return rag
  103. # --------------------------------------------------
  104. # Main
  105. # --------------------------------------------------
  106. async def main():
  107. rag = None
  108. try:
  109. # Validate book file exists
  110. if not os.path.exists(BOOK_FILE):
  111. raise FileNotFoundError(
  112. f"'{BOOK_FILE}' not found. Please provide a text file to index."
  113. )
  114. rag = await initialize_rag()
  115. # --------------------------------------------------
  116. # Data Ingestion
  117. # --------------------------------------------------
  118. print(f"Indexing {BOOK_FILE}...")
  119. with open(BOOK_FILE, "r", encoding="utf-8") as f:
  120. await rag.ainsert(f.read())
  121. print("Indexing complete.")
  122. # --------------------------------------------------
  123. # Query
  124. # --------------------------------------------------
  125. query = (
  126. "What are the main themes of the book, and how do the key characters "
  127. "evolve throughout the story?"
  128. )
  129. print("\nHybrid Search with Reranking:")
  130. result = await rag.aquery(
  131. query,
  132. param=QueryParam(
  133. mode="hybrid",
  134. stream=False,
  135. enable_rerank=True,
  136. ),
  137. )
  138. print("\nResult:\n", result)
  139. except Exception as e:
  140. print(f"An error occurred: {e}")
  141. finally:
  142. if rag:
  143. await rag.finalize_storages()
  144. if __name__ == "__main__":
  145. asyncio.run(main())
  146. print("\nDone!")