lightrag_openai_demo.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import os
  2. import asyncio
  3. import logging
  4. import logging.config
  5. from lightrag import LightRAG, QueryParam
  6. from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed
  7. from lightrag.utils import logger, set_verbose_debug
  8. WORKING_DIR = "./dickens"
  9. def configure_logging():
  10. """Configure logging for the application"""
  11. # Reset any existing handlers to ensure clean configuration
  12. for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
  13. logger_instance = logging.getLogger(logger_name)
  14. logger_instance.handlers = []
  15. logger_instance.filters = []
  16. # Get log directory path from environment variable or use current directory
  17. log_dir = os.getenv("LOG_DIR", os.getcwd())
  18. log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag_demo.log"))
  19. print(f"\nLightRAG demo log file: {log_file_path}\n")
  20. os.makedirs(os.path.dirname(log_dir), exist_ok=True)
  21. # Get log file max size and backup count from environment variables
  22. log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
  23. log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
  24. logging.config.dictConfig(
  25. {
  26. "version": 1,
  27. "disable_existing_loggers": False,
  28. "formatters": {
  29. "default": {
  30. "format": "%(levelname)s: %(message)s",
  31. },
  32. "detailed": {
  33. "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
  34. },
  35. },
  36. "handlers": {
  37. "console": {
  38. "formatter": "default",
  39. "class": "logging.StreamHandler",
  40. "stream": "ext://sys.stderr",
  41. },
  42. "file": {
  43. "formatter": "detailed",
  44. "class": "logging.handlers.RotatingFileHandler",
  45. "filename": log_file_path,
  46. "maxBytes": log_max_bytes,
  47. "backupCount": log_backup_count,
  48. "encoding": "utf-8",
  49. },
  50. },
  51. "loggers": {
  52. "lightrag": {
  53. "handlers": ["console", "file"],
  54. "level": "INFO",
  55. "propagate": False,
  56. },
  57. },
  58. }
  59. )
  60. # Set the logger level to INFO
  61. logger.setLevel(logging.INFO)
  62. # Enable verbose debug if needed
  63. set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true")
  64. if not os.path.exists(WORKING_DIR):
  65. os.mkdir(WORKING_DIR)
  66. async def initialize_rag():
  67. rag = LightRAG(
  68. working_dir=WORKING_DIR,
  69. embedding_func=openai_embed,
  70. llm_model_func=gpt_4o_mini_complete,
  71. )
  72. await rag.initialize_storages() # Auto-initializes pipeline_status
  73. return rag
  74. async def main():
  75. # Check if OPENAI_API_KEY environment variable exists
  76. if not os.getenv("OPENAI_API_KEY"):
  77. print(
  78. "Error: OPENAI_API_KEY environment variable is not set. Please set this variable before running the program."
  79. )
  80. print("You can set the environment variable by running:")
  81. print(" export OPENAI_API_KEY='your-openai-api-key'")
  82. return # Exit the async function
  83. try:
  84. # Clear old data files
  85. files_to_delete = [
  86. "graph_chunk_entity_relation.graphml",
  87. "kv_store_doc_status.json",
  88. "kv_store_full_docs.json",
  89. "kv_store_text_chunks.json",
  90. "vdb_chunks.json",
  91. "vdb_entities.json",
  92. "vdb_relationships.json",
  93. ]
  94. for file in files_to_delete:
  95. file_path = os.path.join(WORKING_DIR, file)
  96. if os.path.exists(file_path):
  97. os.remove(file_path)
  98. print(f"Deleting old file:: {file_path}")
  99. # Initialize RAG instance
  100. rag = await initialize_rag()
  101. # Test embedding function
  102. test_text = ["This is a test string for embedding."]
  103. embedding = await rag.embedding_func(test_text)
  104. embedding_dim = embedding.shape[1]
  105. print("\n=======================")
  106. print("Test embedding function")
  107. print("========================")
  108. print(f"Test dict: {test_text}")
  109. print(f"Detected embedding dimension: {embedding_dim}\n\n")
  110. with open("./book.txt", "r", encoding="utf-8") as f:
  111. await rag.ainsert(f.read())
  112. # Perform naive search
  113. print("\n=====================")
  114. print("Query mode: naive")
  115. print("=====================")
  116. print(
  117. await rag.aquery(
  118. "What are the top themes in this story?", param=QueryParam(mode="naive")
  119. )
  120. )
  121. # Perform local search
  122. print("\n=====================")
  123. print("Query mode: local")
  124. print("=====================")
  125. print(
  126. await rag.aquery(
  127. "What are the top themes in this story?", param=QueryParam(mode="local")
  128. )
  129. )
  130. # Perform global search
  131. print("\n=====================")
  132. print("Query mode: global")
  133. print("=====================")
  134. print(
  135. await rag.aquery(
  136. "What are the top themes in this story?",
  137. param=QueryParam(mode="global"),
  138. )
  139. )
  140. # Perform hybrid search
  141. print("\n=====================")
  142. print("Query mode: hybrid")
  143. print("=====================")
  144. print(
  145. await rag.aquery(
  146. "What are the top themes in this story?",
  147. param=QueryParam(mode="hybrid"),
  148. )
  149. )
  150. except Exception as e:
  151. print(f"An error occurred: {e}")
  152. finally:
  153. if rag:
  154. await rag.finalize_storages()
  155. if __name__ == "__main__":
  156. # Configure logging before running the main function
  157. configure_logging()
  158. asyncio.run(main())
  159. print("\nDone!")