lightrag_openai_compatible_demo.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. import os
  2. import asyncio
  3. import inspect
  4. import logging
  5. import logging.config
  6. from functools import partial
  7. from lightrag import LightRAG, QueryParam
  8. from lightrag.llm.openai import openai_complete_if_cache
  9. from lightrag.llm.ollama import ollama_embed
  10. from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug
  11. from dotenv import load_dotenv
  12. load_dotenv(dotenv_path=".env", override=False)
  13. WORKING_DIR = "./dickens"
  14. def configure_logging():
  15. """Configure logging for the application"""
  16. # Reset any existing handlers to ensure clean configuration
  17. for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
  18. logger_instance = logging.getLogger(logger_name)
  19. logger_instance.handlers = []
  20. logger_instance.filters = []
  21. # Get log directory path from environment variable or use current directory
  22. log_dir = os.getenv("LOG_DIR", os.getcwd())
  23. log_file_path = os.path.abspath(
  24. os.path.join(log_dir, "lightrag_compatible_demo.log")
  25. )
  26. print(f"\nLightRAG compatible demo log file: {log_file_path}\n")
  27. os.makedirs(os.path.dirname(log_dir), exist_ok=True)
  28. # Get log file max size and backup count from environment variables
  29. log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
  30. log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
  31. logging.config.dictConfig(
  32. {
  33. "version": 1,
  34. "disable_existing_loggers": False,
  35. "formatters": {
  36. "default": {
  37. "format": "%(levelname)s: %(message)s",
  38. },
  39. "detailed": {
  40. "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
  41. },
  42. },
  43. "handlers": {
  44. "console": {
  45. "formatter": "default",
  46. "class": "logging.StreamHandler",
  47. "stream": "ext://sys.stderr",
  48. },
  49. "file": {
  50. "formatter": "detailed",
  51. "class": "logging.handlers.RotatingFileHandler",
  52. "filename": log_file_path,
  53. "maxBytes": log_max_bytes,
  54. "backupCount": log_backup_count,
  55. "encoding": "utf-8",
  56. },
  57. },
  58. "loggers": {
  59. "lightrag": {
  60. "handlers": ["console", "file"],
  61. "level": "INFO",
  62. "propagate": False,
  63. },
  64. },
  65. }
  66. )
  67. # Set the logger level to INFO
  68. logger.setLevel(logging.INFO)
  69. # Enable verbose debug if needed
  70. set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true")
  71. if not os.path.exists(WORKING_DIR):
  72. os.mkdir(WORKING_DIR)
  73. async def llm_model_func(
  74. prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
  75. ) -> str:
  76. return await openai_complete_if_cache(
  77. os.getenv("LLM_MODEL", "deepseek-chat"),
  78. prompt,
  79. system_prompt=system_prompt,
  80. history_messages=history_messages,
  81. api_key=os.getenv("LLM_BINDING_API_KEY") or os.getenv("OPENAI_API_KEY"),
  82. base_url=os.getenv("LLM_BINDING_HOST", "https://api.deepseek.com"),
  83. **kwargs,
  84. )
  85. async def print_stream(stream):
  86. async for chunk in stream:
  87. if chunk:
  88. print(chunk, end="", flush=True)
  89. async def initialize_rag():
  90. rag = LightRAG(
  91. working_dir=WORKING_DIR,
  92. llm_model_func=llm_model_func,
  93. # Note: ollama_embed is decorated with @wrap_embedding_func_with_attrs,
  94. # which wraps it in an EmbeddingFunc. Using .func accesses the original
  95. # unwrapped function to avoid double wrapping when we create our own
  96. # EmbeddingFunc with custom configuration (embedding_dim, max_token_size).
  97. embedding_func=EmbeddingFunc(
  98. embedding_dim=int(os.getenv("EMBEDDING_DIM", "1024")),
  99. max_token_size=int(os.getenv("MAX_EMBED_TOKENS", "8192")),
  100. func=partial(
  101. ollama_embed.func, # Access the unwrapped function to avoid double EmbeddingFunc wrapping
  102. embed_model=os.getenv("EMBEDDING_MODEL", "bge-m3:latest"),
  103. host=os.getenv("EMBEDDING_BINDING_HOST", "http://localhost:11434"),
  104. ),
  105. ),
  106. )
  107. await rag.initialize_storages() # Auto-initializes pipeline_status
  108. return rag
  109. async def main():
  110. try:
  111. # Clear old data files
  112. files_to_delete = [
  113. "graph_chunk_entity_relation.graphml",
  114. "kv_store_doc_status.json",
  115. "kv_store_full_docs.json",
  116. "kv_store_text_chunks.json",
  117. "vdb_chunks.json",
  118. "vdb_entities.json",
  119. "vdb_relationships.json",
  120. ]
  121. for file in files_to_delete:
  122. file_path = os.path.join(WORKING_DIR, file)
  123. if os.path.exists(file_path):
  124. os.remove(file_path)
  125. print(f"Deleting old file:: {file_path}")
  126. # Initialize RAG instance
  127. rag = await initialize_rag()
  128. # Test embedding function
  129. test_text = ["This is a test string for embedding."]
  130. embedding = await rag.embedding_func(test_text)
  131. embedding_dim = embedding.shape[1]
  132. print("\n=======================")
  133. print("Test embedding function")
  134. print("========================")
  135. print(f"Test dict: {test_text}")
  136. print(f"Detected embedding dimension: {embedding_dim}\n\n")
  137. with open("./book.txt", "r", encoding="utf-8") as f:
  138. await rag.ainsert(f.read())
  139. # Perform naive search
  140. print("\n=====================")
  141. print("Query mode: naive")
  142. print("=====================")
  143. resp = await rag.aquery(
  144. "What are the top themes in this story?",
  145. param=QueryParam(mode="naive", stream=True),
  146. )
  147. if inspect.isasyncgen(resp):
  148. await print_stream(resp)
  149. else:
  150. print(resp)
  151. # Perform local search
  152. print("\n=====================")
  153. print("Query mode: local")
  154. print("=====================")
  155. resp = await rag.aquery(
  156. "What are the top themes in this story?",
  157. param=QueryParam(mode="local", stream=True),
  158. )
  159. if inspect.isasyncgen(resp):
  160. await print_stream(resp)
  161. else:
  162. print(resp)
  163. # Perform global search
  164. print("\n=====================")
  165. print("Query mode: global")
  166. print("=====================")
  167. resp = await rag.aquery(
  168. "What are the top themes in this story?",
  169. param=QueryParam(mode="global", stream=True),
  170. )
  171. if inspect.isasyncgen(resp):
  172. await print_stream(resp)
  173. else:
  174. print(resp)
  175. # Perform hybrid search
  176. print("\n=====================")
  177. print("Query mode: hybrid")
  178. print("=====================")
  179. resp = await rag.aquery(
  180. "What are the top themes in this story?",
  181. param=QueryParam(mode="hybrid", stream=True),
  182. )
  183. if inspect.isasyncgen(resp):
  184. await print_stream(resp)
  185. else:
  186. print(resp)
  187. except Exception as e:
  188. print(f"An error occurred: {e}")
  189. finally:
  190. if rag:
  191. await rag.finalize_storages()
  192. if __name__ == "__main__":
  193. # Configure logging before running the main function
  194. configure_logging()
  195. asyncio.run(main())
  196. print("\nDone!")