lightrag_ollama_demo.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import asyncio
  2. import os
  3. import inspect
  4. import logging
  5. import logging.config
  6. from functools import partial
  7. from lightrag import LightRAG, QueryParam
  8. from lightrag.llm.ollama import ollama_model_complete, ollama_embed
  9. from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug
  10. from dotenv import load_dotenv
  11. load_dotenv(dotenv_path=".env", override=False)
  12. WORKING_DIR = "./dickens"
  13. def configure_logging():
  14. """Configure logging for the application"""
  15. # Reset any existing handlers to ensure clean configuration
  16. for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
  17. logger_instance = logging.getLogger(logger_name)
  18. logger_instance.handlers = []
  19. logger_instance.filters = []
  20. # Get log directory path from environment variable or use current directory
  21. log_dir = os.getenv("LOG_DIR", os.getcwd())
  22. log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag_ollama_demo.log"))
  23. print(f"\nLightRAG compatible demo log file: {log_file_path}\n")
  24. os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
  25. # Get log file max size and backup count from environment variables
  26. log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
  27. log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
  28. logging.config.dictConfig(
  29. {
  30. "version": 1,
  31. "disable_existing_loggers": False,
  32. "formatters": {
  33. "default": {
  34. "format": "%(levelname)s: %(message)s",
  35. },
  36. "detailed": {
  37. "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
  38. },
  39. },
  40. "handlers": {
  41. "console": {
  42. "formatter": "default",
  43. "class": "logging.StreamHandler",
  44. "stream": "ext://sys.stderr",
  45. },
  46. "file": {
  47. "formatter": "detailed",
  48. "class": "logging.handlers.RotatingFileHandler",
  49. "filename": log_file_path,
  50. "maxBytes": log_max_bytes,
  51. "backupCount": log_backup_count,
  52. "encoding": "utf-8",
  53. },
  54. },
  55. "loggers": {
  56. "lightrag": {
  57. "handlers": ["console", "file"],
  58. "level": "INFO",
  59. "propagate": False,
  60. },
  61. },
  62. }
  63. )
  64. # Set the logger level to INFO
  65. logger.setLevel(logging.INFO)
  66. # Enable verbose debug if needed
  67. set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true")
  68. if not os.path.exists(WORKING_DIR):
  69. os.mkdir(WORKING_DIR)
  70. async def initialize_rag():
  71. rag = LightRAG(
  72. working_dir=WORKING_DIR,
  73. llm_model_func=ollama_model_complete,
  74. llm_model_name=os.getenv("LLM_MODEL", "qwen2.5-coder:7b"),
  75. summary_max_tokens=8192,
  76. llm_model_kwargs={
  77. "host": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
  78. "options": {"num_ctx": 8192},
  79. "timeout": int(os.getenv("TIMEOUT", "300")),
  80. },
  81. # Note: ollama_embed is decorated with @wrap_embedding_func_with_attrs,
  82. # which wraps it in an EmbeddingFunc. Using .func accesses the original
  83. # unwrapped function to avoid double wrapping when we create our own
  84. # EmbeddingFunc with custom configuration (embedding_dim, max_token_size).
  85. embedding_func=EmbeddingFunc(
  86. embedding_dim=int(os.getenv("EMBEDDING_DIM", "1024")),
  87. max_token_size=int(os.getenv("MAX_EMBED_TOKENS", "8192")),
  88. func=partial(
  89. ollama_embed.func, # Access the unwrapped function to avoid double EmbeddingFunc wrapping
  90. embed_model=os.getenv("EMBEDDING_MODEL", "bge-m3:latest"),
  91. host=os.getenv("EMBEDDING_BINDING_HOST", "http://localhost:11434"),
  92. ),
  93. ),
  94. )
  95. await rag.initialize_storages() # Auto-initializes pipeline_status
  96. return rag
  97. async def print_stream(stream):
  98. async for chunk in stream:
  99. print(chunk, end="", flush=True)
  100. async def main():
  101. try:
  102. # Clear old data files
  103. files_to_delete = [
  104. "graph_chunk_entity_relation.graphml",
  105. "kv_store_doc_status.json",
  106. "kv_store_full_docs.json",
  107. "kv_store_text_chunks.json",
  108. "vdb_chunks.json",
  109. "vdb_entities.json",
  110. "vdb_relationships.json",
  111. ]
  112. for file in files_to_delete:
  113. file_path = os.path.join(WORKING_DIR, file)
  114. if os.path.exists(file_path):
  115. os.remove(file_path)
  116. print(f"Deleting old file:: {file_path}")
  117. # Initialize RAG instance
  118. rag = await initialize_rag()
  119. # Test embedding function
  120. test_text = ["This is a test string for embedding."]
  121. embedding = await rag.embedding_func(test_text)
  122. embedding_dim = embedding.shape[1]
  123. print("\n=======================")
  124. print("Test embedding function")
  125. print("========================")
  126. print(f"Test dict: {test_text}")
  127. print(f"Detected embedding dimension: {embedding_dim}\n\n")
  128. with open("./book.txt", "r", encoding="utf-8") as f:
  129. await rag.ainsert(f.read())
  130. # Perform naive search
  131. print("\n=====================")
  132. print("Query mode: naive")
  133. print("=====================")
  134. resp = await rag.aquery(
  135. "What are the top themes in this story?",
  136. param=QueryParam(mode="naive", stream=True),
  137. )
  138. if inspect.isasyncgen(resp):
  139. await print_stream(resp)
  140. else:
  141. print(resp)
  142. # Perform local search
  143. print("\n=====================")
  144. print("Query mode: local")
  145. print("=====================")
  146. resp = await rag.aquery(
  147. "What are the top themes in this story?",
  148. param=QueryParam(mode="local", stream=True),
  149. )
  150. if inspect.isasyncgen(resp):
  151. await print_stream(resp)
  152. else:
  153. print(resp)
  154. # Perform global search
  155. print("\n=====================")
  156. print("Query mode: global")
  157. print("=====================")
  158. resp = await rag.aquery(
  159. "What are the top themes in this story?",
  160. param=QueryParam(mode="global", stream=True),
  161. )
  162. if inspect.isasyncgen(resp):
  163. await print_stream(resp)
  164. else:
  165. print(resp)
  166. # Perform hybrid search
  167. print("\n=====================")
  168. print("Query mode: hybrid")
  169. print("=====================")
  170. resp = await rag.aquery(
  171. "What are the top themes in this story?",
  172. param=QueryParam(mode="hybrid", stream=True),
  173. )
  174. if inspect.isasyncgen(resp):
  175. await print_stream(resp)
  176. else:
  177. print(resp)
  178. except Exception as e:
  179. print(f"An error occurred: {e}")
  180. finally:
  181. if rag:
  182. await rag.llm_response_cache.index_done_callback()
  183. await rag.finalize_storages()
  184. if __name__ == "__main__":
  185. # Configure logging before running the main function
  186. configure_logging()
  187. asyncio.run(main())
  188. print("\nDone!")