| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219 |
- import asyncio
- import os
- import inspect
- import logging
- import logging.config
- from functools import partial
- from lightrag import LightRAG, QueryParam
- from lightrag.llm.ollama import ollama_model_complete, ollama_embed
- from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug
- from dotenv import load_dotenv
- load_dotenv(dotenv_path=".env", override=False)
- WORKING_DIR = "./dickens"
- def configure_logging():
- """Configure logging for the application"""
- # Reset any existing handlers to ensure clean configuration
- for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
- logger_instance = logging.getLogger(logger_name)
- logger_instance.handlers = []
- logger_instance.filters = []
- # Get log directory path from environment variable or use current directory
- log_dir = os.getenv("LOG_DIR", os.getcwd())
- log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag_ollama_demo.log"))
- print(f"\nLightRAG compatible demo log file: {log_file_path}\n")
- os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
- # Get log file max size and backup count from environment variables
- log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
- log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
- logging.config.dictConfig(
- {
- "version": 1,
- "disable_existing_loggers": False,
- "formatters": {
- "default": {
- "format": "%(levelname)s: %(message)s",
- },
- "detailed": {
- "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
- },
- },
- "handlers": {
- "console": {
- "formatter": "default",
- "class": "logging.StreamHandler",
- "stream": "ext://sys.stderr",
- },
- "file": {
- "formatter": "detailed",
- "class": "logging.handlers.RotatingFileHandler",
- "filename": log_file_path,
- "maxBytes": log_max_bytes,
- "backupCount": log_backup_count,
- "encoding": "utf-8",
- },
- },
- "loggers": {
- "lightrag": {
- "handlers": ["console", "file"],
- "level": "INFO",
- "propagate": False,
- },
- },
- }
- )
- # Set the logger level to INFO
- logger.setLevel(logging.INFO)
- # Enable verbose debug if needed
- set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true")
- if not os.path.exists(WORKING_DIR):
- os.mkdir(WORKING_DIR)
- async def initialize_rag():
- rag = LightRAG(
- working_dir=WORKING_DIR,
- llm_model_func=ollama_model_complete,
- llm_model_name=os.getenv("LLM_MODEL", "qwen2.5-coder:7b"),
- summary_max_tokens=8192,
- llm_model_kwargs={
- "host": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
- "options": {"num_ctx": 8192},
- "timeout": int(os.getenv("TIMEOUT", "300")),
- },
- # Note: ollama_embed is decorated with @wrap_embedding_func_with_attrs,
- # which wraps it in an EmbeddingFunc. Using .func accesses the original
- # unwrapped function to avoid double wrapping when we create our own
- # EmbeddingFunc with custom configuration (embedding_dim, max_token_size).
- embedding_func=EmbeddingFunc(
- embedding_dim=int(os.getenv("EMBEDDING_DIM", "1024")),
- max_token_size=int(os.getenv("MAX_EMBED_TOKENS", "8192")),
- func=partial(
- ollama_embed.func, # Access the unwrapped function to avoid double EmbeddingFunc wrapping
- embed_model=os.getenv("EMBEDDING_MODEL", "bge-m3:latest"),
- host=os.getenv("EMBEDDING_BINDING_HOST", "http://localhost:11434"),
- ),
- ),
- )
- await rag.initialize_storages() # Auto-initializes pipeline_status
- return rag
- async def print_stream(stream):
- async for chunk in stream:
- print(chunk, end="", flush=True)
- async def main():
- try:
- # Clear old data files
- files_to_delete = [
- "graph_chunk_entity_relation.graphml",
- "kv_store_doc_status.json",
- "kv_store_full_docs.json",
- "kv_store_text_chunks.json",
- "vdb_chunks.json",
- "vdb_entities.json",
- "vdb_relationships.json",
- ]
- for file in files_to_delete:
- file_path = os.path.join(WORKING_DIR, file)
- if os.path.exists(file_path):
- os.remove(file_path)
- print(f"Deleting old file:: {file_path}")
- # Initialize RAG instance
- rag = await initialize_rag()
- # Test embedding function
- test_text = ["This is a test string for embedding."]
- embedding = await rag.embedding_func(test_text)
- embedding_dim = embedding.shape[1]
- print("\n=======================")
- print("Test embedding function")
- print("========================")
- print(f"Test dict: {test_text}")
- print(f"Detected embedding dimension: {embedding_dim}\n\n")
- with open("./book.txt", "r", encoding="utf-8") as f:
- await rag.ainsert(f.read())
- # Perform naive search
- print("\n=====================")
- print("Query mode: naive")
- print("=====================")
- resp = await rag.aquery(
- "What are the top themes in this story?",
- param=QueryParam(mode="naive", stream=True),
- )
- if inspect.isasyncgen(resp):
- await print_stream(resp)
- else:
- print(resp)
- # Perform local search
- print("\n=====================")
- print("Query mode: local")
- print("=====================")
- resp = await rag.aquery(
- "What are the top themes in this story?",
- param=QueryParam(mode="local", stream=True),
- )
- if inspect.isasyncgen(resp):
- await print_stream(resp)
- else:
- print(resp)
- # Perform global search
- print("\n=====================")
- print("Query mode: global")
- print("=====================")
- resp = await rag.aquery(
- "What are the top themes in this story?",
- param=QueryParam(mode="global", stream=True),
- )
- if inspect.isasyncgen(resp):
- await print_stream(resp)
- else:
- print(resp)
- # Perform hybrid search
- print("\n=====================")
- print("Query mode: hybrid")
- print("=====================")
- resp = await rag.aquery(
- "What are the top themes in this story?",
- param=QueryParam(mode="hybrid", stream=True),
- )
- if inspect.isasyncgen(resp):
- await print_stream(resp)
- else:
- print(resp)
- except Exception as e:
- print(f"An error occurred: {e}")
- finally:
- if rag:
- await rag.llm_response_cache.index_done_callback()
- await rag.finalize_storages()
- if __name__ == "__main__":
- # Configure logging before running the main function
- configure_logging()
- asyncio.run(main())
- print("\nDone!")
|