lightrag_embedding_prefixes.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. import os
  2. import asyncio
  3. import inspect
  4. import logging
  5. import logging.config
  6. from functools import partial
  7. from lightrag import LightRAG, QueryParam
  8. from lightrag.llm.openai import openai_complete_if_cache
  9. from lightrag.llm.ollama import ollama_embed
  10. from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug
  11. from dotenv import load_dotenv
  12. load_dotenv(dotenv_path=".env", override=False)
  13. WORKING_DIR = "./dickens"
  14. def configure_logging():
  15. """Configure logging for the application"""
  16. # Reset any existing handlers to ensure clean configuration
  17. for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
  18. logger_instance = logging.getLogger(logger_name)
  19. logger_instance.handlers = []
  20. logger_instance.filters = []
  21. # Get log directory path from environment variable or use current directory
  22. log_dir = os.getenv("LOG_DIR", os.getcwd())
  23. log_file_path = os.path.abspath(
  24. os.path.join(log_dir, "lightrag_compatible_demo.log")
  25. )
  26. print(f"\nLightRAG compatible demo log file: {log_file_path}\n")
  27. os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
  28. # Get log file max size and backup count from environment variables
  29. log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
  30. log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
  31. logging.config.dictConfig(
  32. {
  33. "version": 1,
  34. "disable_existing_loggers": False,
  35. "formatters": {
  36. "default": {
  37. "format": "%(levelname)s: %(message)s",
  38. },
  39. "detailed": {
  40. "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
  41. },
  42. },
  43. "handlers": {
  44. "console": {
  45. "formatter": "default",
  46. "class": "logging.StreamHandler",
  47. "stream": "ext://sys.stderr",
  48. },
  49. "file": {
  50. "formatter": "detailed",
  51. "class": "logging.handlers.RotatingFileHandler",
  52. "filename": log_file_path,
  53. "maxBytes": log_max_bytes,
  54. "backupCount": log_backup_count,
  55. "encoding": "utf-8",
  56. },
  57. },
  58. "loggers": {
  59. "lightrag": {
  60. "handlers": ["console", "file"],
  61. "level": "INFO",
  62. "propagate": False,
  63. },
  64. },
  65. }
  66. )
  67. # Set the logger level to INFO
  68. logger.setLevel(logging.INFO)
  69. # Enable verbose debug if needed
  70. set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true")
  71. if not os.path.exists(WORKING_DIR):
  72. os.mkdir(WORKING_DIR)
  73. async def llm_model_func(
  74. prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
  75. ) -> str:
  76. return await openai_complete_if_cache(
  77. os.getenv("LLM_MODEL", "deepseek-chat"),
  78. prompt,
  79. system_prompt=system_prompt,
  80. history_messages=history_messages,
  81. api_key=os.getenv("LLM_BINDING_API_KEY") or os.getenv("OPENAI_API_KEY"),
  82. base_url=os.getenv("LLM_BINDING_HOST", "https://api.deepseek.com"),
  83. **kwargs,
  84. )
  85. async def print_stream(stream):
  86. async for chunk in stream:
  87. if chunk:
  88. print(chunk, end="", flush=True)
  89. async def initialize_rag():
  90. rag = LightRAG(
  91. working_dir=WORKING_DIR,
  92. llm_model_func=llm_model_func,
  93. # Note: ollama_embed is decorated with @wrap_embedding_func_with_attrs,
  94. # which wraps it in an EmbeddingFunc. Using .func accesses the original
  95. # unwrapped function to avoid double wrapping when we create our own
  96. # EmbeddingFunc with custom configuration (embedding_dim, max_token_size and prefixes).
  97. embedding_func=EmbeddingFunc(
  98. embedding_dim=int(os.getenv("EMBEDDING_DIM", "1024")),
  99. max_token_size=int(os.getenv("MAX_EMBED_TOKENS", "8192")),
  100. supports_asymmetric=True,
  101. func=partial(
  102. ollama_embed.func, # Access the unwrapped function to avoid double EmbeddingFunc wrapping
  103. embed_model=os.getenv("EMBEDDING_MODEL", "FRIDA:latest"),
  104. host=os.getenv("EMBEDDING_BINDING_HOST", "http://localhost:11434"),
  105. query_prefix=os.getenv("EMBEDDING_QUERY_PREFIX", "search_query: "),
  106. document_prefix=os.getenv(
  107. "EMBEDDING_DOCUMENT_PREFIX", "search_document: "
  108. ),
  109. ),
  110. ),
  111. )
  112. await rag.initialize_storages() # Auto-initializes pipeline_status
  113. return rag
  114. async def main():
  115. rag = None
  116. try:
  117. # Clear old data files
  118. files_to_delete = [
  119. "graph_chunk_entity_relation.graphml",
  120. "kv_store_doc_status.json",
  121. "kv_store_full_docs.json",
  122. "kv_store_text_chunks.json",
  123. "vdb_chunks.json",
  124. "vdb_entities.json",
  125. "vdb_relationships.json",
  126. ]
  127. for file in files_to_delete:
  128. file_path = os.path.join(WORKING_DIR, file)
  129. if os.path.exists(file_path):
  130. os.remove(file_path)
  131. print(f"Deleting old file:: {file_path}")
  132. # Initialize RAG instance
  133. rag = await initialize_rag()
  134. # Test embedding function
  135. test_text = ["This is a test string for embedding."]
  136. embedding = await rag.embedding_func(test_text)
  137. embedding_dim = embedding.shape[1]
  138. print("\n=======================")
  139. print("Test embedding function")
  140. print("========================")
  141. print(f"Test dict: {test_text}")
  142. print(f"Detected embedding dimension: {embedding_dim}\n\n")
  143. with open("./book.txt", "r", encoding="utf-8") as f:
  144. await rag.ainsert(f.read())
  145. # Perform naive search
  146. print("\n=====================")
  147. print("Query mode: naive")
  148. print("=====================")
  149. resp = await rag.aquery(
  150. "What are the top themes in this story?",
  151. param=QueryParam(mode="naive", stream=True),
  152. )
  153. if inspect.isasyncgen(resp):
  154. await print_stream(resp)
  155. else:
  156. print(resp)
  157. # Perform local search
  158. print("\n=====================")
  159. print("Query mode: local")
  160. print("=====================")
  161. resp = await rag.aquery(
  162. "What are the top themes in this story?",
  163. param=QueryParam(mode="local", stream=True),
  164. )
  165. if inspect.isasyncgen(resp):
  166. await print_stream(resp)
  167. else:
  168. print(resp)
  169. # Perform global search
  170. print("\n=====================")
  171. print("Query mode: global")
  172. print("=====================")
  173. resp = await rag.aquery(
  174. "What are the top themes in this story?",
  175. param=QueryParam(mode="global", stream=True),
  176. )
  177. if inspect.isasyncgen(resp):
  178. await print_stream(resp)
  179. else:
  180. print(resp)
  181. # Perform hybrid search
  182. print("\n=====================")
  183. print("Query mode: hybrid")
  184. print("=====================")
  185. resp = await rag.aquery(
  186. "What are the top themes in this story?",
  187. param=QueryParam(mode="hybrid", stream=True),
  188. )
  189. if inspect.isasyncgen(resp):
  190. await print_stream(resp)
  191. else:
  192. print(resp)
  193. except Exception as e:
  194. print(f"An error occurred: {e}")
  195. finally:
  196. if rag:
  197. await rag.finalize_storages()
  198. if __name__ == "__main__":
  199. # Configure logging before running the main function
  200. configure_logging()
  201. asyncio.run(main())
  202. print("\nDone!")