| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560 |
- #!/usr/bin/env python3
- """
- LLM Cache Migration Tool for LightRAG
- This tool migrates LLM response cache (default:extract:* and default:summary:*)
- between different KV storage implementations while preserving workspace isolation.
- Usage:
- python -m lightrag.tools.migrate_llm_cache
- # or
- python lightrag/tools/migrate_llm_cache.py
- Supported KV Storage Types:
- - JsonKVStorage
- - RedisKVStorage
- - PGKVStorage
- - MongoKVStorage
- - OpenSearchKVStorage
- """
- import asyncio
- import os
- import sys
- import time
- from typing import Any, Dict, List
- from dataclasses import dataclass, field
- from dotenv import load_dotenv
- # Add project root to path for imports
- sys.path.insert(
- 0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
- )
- from lightrag.kg import STORAGE_ENV_REQUIREMENTS
- from lightrag.namespace import NameSpace
- from lightrag.utils import setup_logger
- # Load environment variables
- # use the .env that is inside the current folder
- # allows to use different .env file for each lightrag instance
- # the OS environment variables take precedence over the .env file
- load_dotenv(dotenv_path=".env", override=False)
- # Setup logger
- setup_logger("lightrag", level="INFO")
- # Storage type configurations
- STORAGE_TYPES = {
- "1": "JsonKVStorage",
- "2": "RedisKVStorage",
- "3": "PGKVStorage",
- "4": "MongoKVStorage",
- "5": "OpenSearchKVStorage",
- }
- # Workspace environment variable mapping
- WORKSPACE_ENV_MAP = {
- "PGKVStorage": "POSTGRES_WORKSPACE",
- "MongoKVStorage": "MONGODB_WORKSPACE",
- "RedisKVStorage": "REDIS_WORKSPACE",
- "OpenSearchKVStorage": "OPENSEARCH_WORKSPACE",
- }
- # Default batch size for migration
- DEFAULT_BATCH_SIZE = 1000
- # Default count batch size for efficient counting
- DEFAULT_COUNT_BATCH_SIZE = 1000
- # ANSI color codes for terminal output
- BOLD_CYAN = "\033[1;36m"
- RESET = "\033[0m"
- @dataclass
- class MigrationStats:
- """Migration statistics and error tracking"""
- total_source_records: int = 0
- total_batches: int = 0
- successful_batches: int = 0
- failed_batches: int = 0
- successful_records: int = 0
- failed_records: int = 0
- errors: List[Dict[str, Any]] = field(default_factory=list)
- def add_error(self, batch_idx: int, error: Exception, batch_size: int):
- """Record batch error"""
- self.errors.append(
- {
- "batch": batch_idx,
- "error_type": type(error).__name__,
- "error_msg": str(error),
- "records_lost": batch_size,
- "timestamp": time.time(),
- }
- )
- self.failed_batches += 1
- self.failed_records += batch_size
- class MigrationTool:
- """LLM Cache Migration Tool"""
- def __init__(self):
- self.source_storage = None
- self.target_storage = None
- self.source_workspace = ""
- self.target_workspace = ""
- self.batch_size = DEFAULT_BATCH_SIZE
- def get_workspace_for_storage(self, storage_name: str) -> str:
- """Get workspace for a specific storage type
- Priority: Storage-specific env var > WORKSPACE env var > empty string
- Args:
- storage_name: Storage implementation name
- Returns:
- Workspace name
- """
- # Check storage-specific workspace
- if storage_name in WORKSPACE_ENV_MAP:
- specific_workspace = os.getenv(WORKSPACE_ENV_MAP[storage_name])
- if specific_workspace:
- return specific_workspace
- # Check generic WORKSPACE
- workspace = os.getenv("WORKSPACE", "")
- return workspace
- def check_config_ini_for_storage(self, storage_name: str) -> bool:
- """Check if config.ini has configuration for the storage type
- Args:
- storage_name: Storage implementation name
- Returns:
- True if config.ini has the necessary configuration
- """
- try:
- import configparser
- config = configparser.ConfigParser()
- config.read("config.ini", "utf-8")
- if storage_name == "RedisKVStorage":
- return config.has_option("redis", "uri")
- elif storage_name == "PGKVStorage":
- return (
- config.has_option("postgres", "user")
- and config.has_option("postgres", "password")
- and config.has_option("postgres", "database")
- )
- elif storage_name == "MongoKVStorage":
- return config.has_option("mongodb", "uri") and config.has_option(
- "mongodb", "database"
- )
- elif storage_name == "OpenSearchKVStorage":
- return config.has_option("opensearch", "hosts")
- return False
- except Exception:
- return False
- def check_env_vars(self, storage_name: str) -> bool:
- """Check environment variables, show warnings if missing but don't fail
- Args:
- storage_name: Storage implementation name
- Returns:
- Always returns True (warnings only, no hard failure)
- """
- required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
- if not required_vars:
- print("✓ No environment variables required")
- return True
- missing_vars = [var for var in required_vars if var not in os.environ]
- if missing_vars:
- print(
- f"⚠️ Warning: Missing environment variables: {', '.join(missing_vars)}"
- )
- # Check if config.ini has configuration
- has_config = self.check_config_ini_for_storage(storage_name)
- if has_config:
- print(" ✓ Found configuration in config.ini")
- else:
- print(f" Will attempt to use defaults for {storage_name}")
- return True
- print("✓ All required environment variables are set")
- return True
- def count_available_storage_types(self) -> int:
- """Count available storage types (with env vars, config.ini, or defaults)
- Returns:
- Number of available storage types
- """
- available_count = 0
- for storage_name in STORAGE_TYPES.values():
- # Check if storage requires configuration
- required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
- if not required_vars:
- # JsonKVStorage, MongoKVStorage etc. - no config needed
- available_count += 1
- else:
- # Check if has environment variables
- has_env = all(var in os.environ for var in required_vars)
- if has_env:
- available_count += 1
- else:
- # Check if has config.ini configuration
- has_config = self.check_config_ini_for_storage(storage_name)
- if has_config:
- available_count += 1
- return available_count
- def get_storage_class(self, storage_name: str):
- """Dynamically import and return storage class
- Args:
- storage_name: Storage implementation name
- Returns:
- Storage class
- """
- if storage_name == "JsonKVStorage":
- from lightrag.kg.json_kv_impl import JsonKVStorage
- return JsonKVStorage
- elif storage_name == "RedisKVStorage":
- from lightrag.kg.redis_impl import RedisKVStorage
- return RedisKVStorage
- elif storage_name == "PGKVStorage":
- from lightrag.kg.postgres_impl import PGKVStorage
- return PGKVStorage
- elif storage_name == "MongoKVStorage":
- from lightrag.kg.mongo_impl import MongoKVStorage
- return MongoKVStorage
- elif storage_name == "OpenSearchKVStorage":
- from lightrag.kg.opensearch_impl import OpenSearchKVStorage
- return OpenSearchKVStorage
- else:
- raise ValueError(f"Unsupported storage type: {storage_name}")
- async def initialize_storage(self, storage_name: str, workspace: str):
- """Initialize storage instance with fallback to config.ini and defaults
- Args:
- storage_name: Storage implementation name
- workspace: Workspace name
- Returns:
- Initialized storage instance
- Raises:
- Exception: If initialization fails
- """
- storage_class = self.get_storage_class(storage_name)
- # Create global config
- global_config = {
- "working_dir": os.getenv("WORKING_DIR", "./rag_storage"),
- "embedding_batch_num": 10,
- }
- # Initialize storage
- storage = storage_class(
- namespace=NameSpace.KV_STORE_LLM_RESPONSE_CACHE,
- workspace=workspace,
- global_config=global_config,
- embedding_func=None,
- )
- # Initialize the storage (may raise exception if connection fails)
- await storage.initialize()
- return storage
- async def get_default_caches_json(self, storage) -> Dict[str, Any]:
- """Get default caches from JsonKVStorage
- Args:
- storage: JsonKVStorage instance
- Returns:
- Dictionary of cache entries with default:extract:* or default:summary:* keys
- """
- # Access _data directly - it's a dict from shared_storage
- async with storage._storage_lock:
- filtered = {}
- for key, value in storage._data.items():
- if key.startswith("default:extract:") or key.startswith(
- "default:summary:"
- ):
- filtered[key] = value.copy()
- return filtered
- async def get_default_caches_redis(
- self, storage, batch_size: int = 1000
- ) -> Dict[str, Any]:
- """Get default caches from RedisKVStorage with pagination
- Args:
- storage: RedisKVStorage instance
- batch_size: Number of keys to process per batch
- Returns:
- Dictionary of cache entries with default:extract:* or default:summary:* keys
- """
- import json
- cache_data = {}
- # Use _get_redis_connection() context manager
- async with storage._get_redis_connection() as redis:
- for pattern in ["default:extract:*", "default:summary:*"]:
- # Add namespace prefix to pattern
- prefixed_pattern = f"{storage.final_namespace}:{pattern}"
- cursor = 0
- while True:
- # SCAN already implements cursor-based pagination
- cursor, keys = await redis.scan(
- cursor, match=prefixed_pattern, count=batch_size
- )
- if keys:
- # Process this batch using pipeline with error handling
- try:
- pipe = redis.pipeline()
- for key in keys:
- pipe.get(key)
- values = await pipe.execute()
- for key, value in zip(keys, values):
- if value:
- key_str = (
- key.decode() if isinstance(key, bytes) else key
- )
- # Remove namespace prefix to get original key
- original_key = key_str.replace(
- f"{storage.final_namespace}:", "", 1
- )
- cache_data[original_key] = json.loads(value)
- except Exception as e:
- # Pipeline execution failed, fall back to individual gets
- print(
- f"⚠️ Pipeline execution failed for batch, using individual gets: {e}"
- )
- for key in keys:
- try:
- value = await redis.get(key)
- if value:
- key_str = (
- key.decode()
- if isinstance(key, bytes)
- else key
- )
- original_key = key_str.replace(
- f"{storage.final_namespace}:", "", 1
- )
- cache_data[original_key] = json.loads(value)
- except Exception as individual_error:
- print(
- f"⚠️ Failed to get individual key {key}: {individual_error}"
- )
- continue
- if cursor == 0:
- break
- # Yield control periodically to avoid blocking
- await asyncio.sleep(0)
- return cache_data
- async def get_default_caches_pg(
- self, storage, batch_size: int = 1000
- ) -> Dict[str, Any]:
- """Get default caches from PGKVStorage with pagination
- Args:
- storage: PGKVStorage instance
- batch_size: Number of records to fetch per batch
- Returns:
- Dictionary of cache entries with default:extract:* or default:summary:* keys
- """
- from lightrag.kg.postgres_impl import namespace_to_table_name
- cache_data = {}
- table_name = namespace_to_table_name(storage.namespace)
- offset = 0
- while True:
- # Use LIMIT and OFFSET for pagination
- query = f"""
- SELECT id as key, original_prompt, return_value, chunk_id, cache_type, queryparam,
- EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
- EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
- FROM {table_name}
- WHERE workspace = $1
- AND (id LIKE 'default:extract:%' OR id LIKE 'default:summary:%')
- ORDER BY id
- LIMIT $2 OFFSET $3
- """
- results = await storage.db.query(
- query, [storage.workspace, batch_size, offset], multirows=True
- )
- if not results:
- break
- for row in results:
- # Map PostgreSQL fields to cache format
- cache_entry = {
- "return": row.get("return_value", ""),
- "cache_type": row.get("cache_type"),
- "original_prompt": row.get("original_prompt", ""),
- "chunk_id": row.get("chunk_id"),
- "queryparam": row.get("queryparam"),
- "create_time": row.get("create_time", 0),
- "update_time": row.get("update_time", 0),
- }
- cache_data[row["key"]] = cache_entry
- # If we got fewer results than batch_size, we're done
- if len(results) < batch_size:
- break
- offset += batch_size
- # Yield control periodically
- await asyncio.sleep(0)
- return cache_data
- async def get_default_caches_mongo(
- self, storage, batch_size: int = 1000
- ) -> Dict[str, Any]:
- """Get default caches from MongoKVStorage with cursor-based pagination
- Args:
- storage: MongoKVStorage instance
- batch_size: Number of documents to process per batch
- Returns:
- Dictionary of cache entries with default:extract:* or default:summary:* keys
- """
- cache_data = {}
- # MongoDB query with regex - use _data not collection
- query = {"_id": {"$regex": "^default:(extract|summary):"}}
- # Use cursor without to_list() - process in batches
- cursor = storage._data.find(query).batch_size(batch_size)
- async for doc in cursor:
- # Process each document as it comes
- doc_copy = doc.copy()
- key = doc_copy.pop("_id")
- # Filter ALL MongoDB/database-specific fields
- # Following .clinerules: "Always filter deprecated/incompatible fields during deserialization"
- for field_name in ["namespace", "workspace", "_id", "content"]:
- doc_copy.pop(field_name, None)
- cache_data[key] = doc_copy.copy()
- # Periodically yield control (every batch_size documents)
- if len(cache_data) % batch_size == 0:
- await asyncio.sleep(0)
- return cache_data
- async def get_default_caches_opensearch(
- self, storage, batch_size: int = 1000
- ) -> Dict[str, Any]:
- """Get default caches from OpenSearchKVStorage."""
- cache_data = {}
- async for hits in storage._iter_raw_docs(batch_size=batch_size):
- for hit in hits:
- key = hit["_id"]
- if key.startswith("default:extract:") or key.startswith(
- "default:summary:"
- ):
- cache_data[key] = hit["_source"].copy()
- return cache_data
- async def get_default_caches(self, storage, storage_name: str) -> Dict[str, Any]:
- """Get default caches from any storage type
- Args:
- storage: Storage instance
- storage_name: Storage type name
- Returns:
- Dictionary of cache entries
- """
- if storage_name == "JsonKVStorage":
- return await self.get_default_caches_json(storage)
- elif storage_name == "RedisKVStorage":
- return await self.get_default_caches_redis(storage)
- elif storage_name == "PGKVStorage":
- return await self.get_default_caches_pg(storage)
- elif storage_name == "MongoKVStorage":
- return await self.get_default_caches_mongo(storage)
- elif storage_name == "OpenSearchKVStorage":
- return await self.get_default_caches_opensearch(storage)
- else:
- raise ValueError(f"Unsupported storage type: {storage_name}")
- async def count_default_caches_json(self, storage) -> int:
- """Count default caches in JsonKVStorage - O(N) but very fast in-memory
- Args:
- storage: JsonKVStorage instance
- Returns:
- Total count of cache records
- """
- async with storage._storage_lock:
- return sum(
- 1
- for key in storage._data.keys()
- if key.startswith("default:extract:")
- or key.startswith("default:summary:")
- )
- async def count_default_caches_redis(self, storage) -> int:
- """Count default caches in RedisKVStorage using SCAN with progress display
- Args:
- storage: RedisKVStorage instance
- Returns:
- Total count of cache records
- """
- count = 0
- print("Scanning Redis keys...", end="", flush=True)
- async with storage._get_redis_connection() as redis:
- for pattern in ["default:extract:*", "default:summary:*"]:
- prefixed_pattern = f"{storage.final_namespace}:{pattern}"
- cursor = 0
- while True:
- cursor, keys = await redis.scan(
- cursor, match=prefixed_pattern, count=DEFAULT_COUNT_BATCH_SIZE
- )
- count += len(keys)
- # Show progress
- print(
- f"\rScanning Redis keys... found {count:,} records",
- end="",
- flush=True,
- )
- if cursor == 0:
- break
- print() # New line after progress
- return count
- async def count_default_caches_pg(self, storage) -> int:
- """Count default caches in PostgreSQL using COUNT(*) with progress indicator
- Args:
- storage: PGKVStorage instance
- Returns:
- Total count of cache records
- """
- from lightrag.kg.postgres_impl import namespace_to_table_name
- table_name = namespace_to_table_name(storage.namespace)
- query = f"""
- SELECT COUNT(*) as count
- FROM {table_name}
- WHERE workspace = $1
- AND (id LIKE 'default:extract:%' OR id LIKE 'default:summary:%')
- """
- print("Counting PostgreSQL records...", end="", flush=True)
- start_time = time.time()
- result = await storage.db.query(query, [storage.workspace])
- elapsed = time.time() - start_time
- if elapsed > 1:
- print(f" (took {elapsed:.1f}s)", end="")
- print() # New line
- return result["count"] if result else 0
- async def count_default_caches_mongo(self, storage) -> int:
- """Count default caches in MongoDB using count_documents with progress indicator
- Args:
- storage: MongoKVStorage instance
- Returns:
- Total count of cache records
- """
- query = {"_id": {"$regex": "^default:(extract|summary):"}}
- print("Counting MongoDB documents...", end="", flush=True)
- start_time = time.time()
- count = await storage._data.count_documents(query)
- elapsed = time.time() - start_time
- if elapsed > 1:
- print(f" (took {elapsed:.1f}s)", end="")
- print() # New line
- return count
- async def count_default_caches_opensearch(self, storage) -> int:
- """Count default caches in OpenSearch using PIT pagination."""
- count = 0
- print("Scanning OpenSearch documents...", end="", flush=True)
- start_time = time.time()
- async for hits in storage._iter_raw_docs(batch_size=DEFAULT_COUNT_BATCH_SIZE):
- for hit in hits:
- key = hit["_id"]
- if key.startswith("default:extract:") or key.startswith(
- "default:summary:"
- ):
- count += 1
- elapsed = time.time() - start_time
- if elapsed > 1:
- print(f" (took {elapsed:.1f}s)", end="")
- print()
- return count
- async def count_default_caches(self, storage, storage_name: str) -> int:
- """Count default caches from any storage type efficiently
- Args:
- storage: Storage instance
- storage_name: Storage type name
- Returns:
- Total count of cache records
- """
- if storage_name == "JsonKVStorage":
- return await self.count_default_caches_json(storage)
- elif storage_name == "RedisKVStorage":
- return await self.count_default_caches_redis(storage)
- elif storage_name == "PGKVStorage":
- return await self.count_default_caches_pg(storage)
- elif storage_name == "MongoKVStorage":
- return await self.count_default_caches_mongo(storage)
- elif storage_name == "OpenSearchKVStorage":
- return await self.count_default_caches_opensearch(storage)
- else:
- raise ValueError(f"Unsupported storage type: {storage_name}")
- async def stream_default_caches_json(self, storage, batch_size: int):
- """Stream default caches from JsonKVStorage - yields batches
- Args:
- storage: JsonKVStorage instance
- batch_size: Size of each batch to yield
- Yields:
- Dictionary batches of cache entries
- Note:
- This method creates a snapshot of matching items while holding the lock,
- then releases the lock before yielding batches. This prevents deadlock
- when the target storage (also JsonKVStorage) tries to acquire the same
- lock during upsert operations.
- """
- # Create a snapshot of matching items while holding the lock
- async with storage._storage_lock:
- matching_items = [
- (key, value)
- for key, value in storage._data.items()
- if key.startswith("default:extract:")
- or key.startswith("default:summary:")
- ]
- # Now iterate over snapshot without holding lock
- batch = {}
- for key, value in matching_items:
- batch[key] = value.copy()
- if len(batch) >= batch_size:
- yield batch
- batch = {}
- # Yield remaining items
- if batch:
- yield batch
- async def stream_default_caches_redis(self, storage, batch_size: int):
- """Stream default caches from RedisKVStorage - yields batches
- Args:
- storage: RedisKVStorage instance
- batch_size: Size of each batch to yield
- Yields:
- Dictionary batches of cache entries
- """
- import json
- async with storage._get_redis_connection() as redis:
- for pattern in ["default:extract:*", "default:summary:*"]:
- prefixed_pattern = f"{storage.final_namespace}:{pattern}"
- cursor = 0
- while True:
- cursor, keys = await redis.scan(
- cursor, match=prefixed_pattern, count=batch_size
- )
- if keys:
- try:
- pipe = redis.pipeline()
- for key in keys:
- pipe.get(key)
- values = await pipe.execute()
- batch = {}
- for key, value in zip(keys, values):
- if value:
- key_str = (
- key.decode() if isinstance(key, bytes) else key
- )
- original_key = key_str.replace(
- f"{storage.final_namespace}:", "", 1
- )
- batch[original_key] = json.loads(value)
- if batch:
- yield batch
- except Exception as e:
- print(f"⚠️ Pipeline execution failed for batch: {e}")
- # Fall back to individual gets
- batch = {}
- for key in keys:
- try:
- value = await redis.get(key)
- if value:
- key_str = (
- key.decode()
- if isinstance(key, bytes)
- else key
- )
- original_key = key_str.replace(
- f"{storage.final_namespace}:", "", 1
- )
- batch[original_key] = json.loads(value)
- except Exception as individual_error:
- print(
- f"⚠️ Failed to get individual key {key}: {individual_error}"
- )
- continue
- if batch:
- yield batch
- if cursor == 0:
- break
- await asyncio.sleep(0)
- async def stream_default_caches_pg(self, storage, batch_size: int):
- """Stream default caches from PostgreSQL - yields batches
- Args:
- storage: PGKVStorage instance
- batch_size: Size of each batch to yield
- Yields:
- Dictionary batches of cache entries
- """
- from lightrag.kg.postgres_impl import namespace_to_table_name
- table_name = namespace_to_table_name(storage.namespace)
- offset = 0
- while True:
- query = f"""
- SELECT id as key, original_prompt, return_value, chunk_id, cache_type, queryparam,
- EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
- EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
- FROM {table_name}
- WHERE workspace = $1
- AND (id LIKE 'default:extract:%' OR id LIKE 'default:summary:%')
- ORDER BY id
- LIMIT $2 OFFSET $3
- """
- results = await storage.db.query(
- query, [storage.workspace, batch_size, offset], multirows=True
- )
- if not results:
- break
- batch = {}
- for row in results:
- cache_entry = {
- "return": row.get("return_value", ""),
- "cache_type": row.get("cache_type"),
- "original_prompt": row.get("original_prompt", ""),
- "chunk_id": row.get("chunk_id"),
- "queryparam": row.get("queryparam"),
- "create_time": row.get("create_time", 0),
- "update_time": row.get("update_time", 0),
- }
- batch[row["key"]] = cache_entry
- if batch:
- yield batch
- if len(results) < batch_size:
- break
- offset += batch_size
- await asyncio.sleep(0)
- async def stream_default_caches_mongo(self, storage, batch_size: int):
- """Stream default caches from MongoDB - yields batches
- Args:
- storage: MongoKVStorage instance
- batch_size: Size of each batch to yield
- Yields:
- Dictionary batches of cache entries
- """
- query = {"_id": {"$regex": "^default:(extract|summary):"}}
- cursor = storage._data.find(query).batch_size(batch_size)
- batch = {}
- async for doc in cursor:
- doc_copy = doc.copy()
- key = doc_copy.pop("_id")
- # Filter MongoDB/database-specific fields
- for field_name in ["namespace", "workspace", "_id", "content"]:
- doc_copy.pop(field_name, None)
- batch[key] = doc_copy.copy()
- if len(batch) >= batch_size:
- yield batch
- batch = {}
- # Yield remaining items
- if batch:
- yield batch
- async def stream_default_caches_opensearch(self, storage, batch_size: int):
- """Stream default caches from OpenSearchKVStorage - yields batches."""
- batch = {}
- async for hits in storage._iter_raw_docs(batch_size=batch_size):
- for hit in hits:
- key = hit["_id"]
- if key.startswith("default:extract:") or key.startswith(
- "default:summary:"
- ):
- batch[key] = hit["_source"].copy()
- if len(batch) >= batch_size:
- yield batch
- batch = {}
- if batch:
- yield batch
- async def stream_default_caches(
- self, storage, storage_name: str, batch_size: int = None
- ):
- """Stream default caches from any storage type - unified interface
- Args:
- storage: Storage instance
- storage_name: Storage type name
- batch_size: Size of each batch to yield (defaults to self.batch_size)
- Yields:
- Dictionary batches of cache entries
- """
- if batch_size is None:
- batch_size = self.batch_size
- if storage_name == "JsonKVStorage":
- async for batch in self.stream_default_caches_json(storage, batch_size):
- yield batch
- elif storage_name == "RedisKVStorage":
- async for batch in self.stream_default_caches_redis(storage, batch_size):
- yield batch
- elif storage_name == "PGKVStorage":
- async for batch in self.stream_default_caches_pg(storage, batch_size):
- yield batch
- elif storage_name == "MongoKVStorage":
- async for batch in self.stream_default_caches_mongo(storage, batch_size):
- yield batch
- elif storage_name == "OpenSearchKVStorage":
- async for batch in self.stream_default_caches_opensearch(
- storage, batch_size
- ):
- yield batch
- else:
- raise ValueError(f"Unsupported storage type: {storage_name}")
- async def count_cache_types(self, cache_data: Dict[str, Any]) -> Dict[str, int]:
- """Count cache entries by type
- Args:
- cache_data: Dictionary of cache entries
- Returns:
- Dictionary with counts for each cache type
- """
- counts = {
- "extract": 0,
- "summary": 0,
- }
- for key in cache_data.keys():
- if key.startswith("default:extract:"):
- counts["extract"] += 1
- elif key.startswith("default:summary:"):
- counts["summary"] += 1
- return counts
- def print_header(self):
- """Print tool header"""
- print("\n" + "=" * 50)
- print("LLM Cache Migration Tool - LightRAG")
- print("=" * 50)
- def print_storage_types(self):
- """Print available storage types"""
- print("\nSupported KV Storage Types:")
- for key, value in STORAGE_TYPES.items():
- print(f"[{key}] {value}")
- def format_workspace(self, workspace: str) -> str:
- """Format workspace name with highlighting
- Args:
- workspace: Workspace name (may be empty)
- Returns:
- Formatted workspace string with ANSI color codes
- """
- if workspace:
- return f"{BOLD_CYAN}{workspace}{RESET}"
- else:
- return f"{BOLD_CYAN}(default){RESET}"
- def format_storage_name(self, storage_name: str) -> str:
- """Format storage type name with highlighting
- Args:
- storage_name: Storage type name
- Returns:
- Formatted storage name string with ANSI color codes
- """
- return f"{BOLD_CYAN}{storage_name}{RESET}"
- async def setup_storage(
- self,
- storage_type: str,
- use_streaming: bool = False,
- exclude_storage_name: str = None,
- ) -> tuple:
- """Setup and initialize storage with config.ini fallback support
- Args:
- storage_type: Type label (source/target)
- use_streaming: If True, only count records without loading. If False, load all data (legacy mode)
- exclude_storage_name: Storage type to exclude from selection (e.g., to prevent selecting same as source)
- Returns:
- Tuple of (storage_instance, storage_name, workspace, total_count)
- Returns (None, None, None, 0) if user chooses to exit
- """
- print(f"\n=== {storage_type} Storage Setup ===")
- # Filter and remap available storage types if exclusion is specified
- if exclude_storage_name:
- # Get available storage types (excluding source)
- available_list = [
- (k, v) for k, v in STORAGE_TYPES.items() if v != exclude_storage_name
- ]
- # Remap to sequential numbering (1, 2, 3...)
- remapped_types = {
- str(i + 1): name for i, (_, name) in enumerate(available_list)
- }
- # Print available types with new sequential numbers
- print(
- f"\nAvailable Storage Types for Target (source: {exclude_storage_name} excluded):"
- )
- for key, value in remapped_types.items():
- print(f"[{key}] {value}")
- available_types = remapped_types
- else:
- # For source storage, use original numbering
- available_types = STORAGE_TYPES.copy()
- self.print_storage_types()
- # Generate dynamic prompt based on number of options
- num_options = len(available_types)
- if num_options == 1:
- prompt_range = "1"
- else:
- prompt_range = f"1-{num_options}"
- # Custom input handling with exit support
- while True:
- choice = input(
- f"\nSelect {storage_type} storage type ({prompt_range}) (Press Enter to exit): "
- ).strip()
- # Check for exit
- if choice == "" or choice == "0":
- print("\n✓ Migration cancelled by user")
- return None, None, None, 0
- # Check if choice is valid
- if choice in available_types:
- break
- print(
- f"✗ Invalid choice. Please enter one of: {', '.join(available_types.keys())}"
- )
- storage_name = available_types[choice]
- # Check configuration (warnings only, doesn't block)
- print("\nChecking configuration...")
- self.check_env_vars(storage_name)
- # Get workspace
- workspace = self.get_workspace_for_storage(storage_name)
- # Initialize storage (real validation point)
- print(f"\nInitializing {storage_type} storage...")
- try:
- storage = await self.initialize_storage(storage_name, workspace)
- workspace = storage.workspace
- print(f"- Storage Type: {storage_name}")
- print(f"- Workspace: {workspace if workspace else '(default)'}")
- print("- Connection Status: ✓ Success")
- # Show configuration source for transparency
- if storage_name == "RedisKVStorage":
- config_source = (
- "environment variable"
- if "REDIS_URI" in os.environ
- else "config.ini or default"
- )
- print(f"- Configuration Source: {config_source}")
- elif storage_name == "PGKVStorage":
- config_source = (
- "environment variables"
- if all(
- var in os.environ
- for var in STORAGE_ENV_REQUIREMENTS[storage_name]
- )
- else "config.ini or defaults"
- )
- print(f"- Configuration Source: {config_source}")
- elif storage_name == "MongoKVStorage":
- config_source = (
- "environment variables"
- if all(
- var in os.environ
- for var in STORAGE_ENV_REQUIREMENTS[storage_name]
- )
- else "config.ini or defaults"
- )
- print(f"- Configuration Source: {config_source}")
- elif storage_name == "OpenSearchKVStorage":
- config_source = (
- "environment variables"
- if all(
- var in os.environ
- for var in STORAGE_ENV_REQUIREMENTS[storage_name]
- )
- else "config.ini or defaults"
- )
- print(f"- Configuration Source: {config_source}")
- except Exception as e:
- print(f"✗ Initialization failed: {e}")
- print(f"\nFor {storage_name}, you can configure using:")
- print(" 1. Environment variables (highest priority)")
- # Show specific environment variable requirements
- if storage_name in STORAGE_ENV_REQUIREMENTS:
- for var in STORAGE_ENV_REQUIREMENTS[storage_name]:
- print(f" - {var}")
- print(" 2. config.ini file (medium priority)")
- if storage_name == "RedisKVStorage":
- print(" [redis]")
- print(" uri = redis://localhost:6379")
- elif storage_name == "PGKVStorage":
- print(" [postgres]")
- print(" host = localhost")
- print(" port = 5432")
- print(" user = postgres")
- print(" password = yourpassword")
- print(" database = lightrag")
- elif storage_name == "MongoKVStorage":
- print(" [mongodb]")
- print(" uri = mongodb://root:root@localhost:27017/")
- print(" database = LightRAG")
- elif storage_name == "OpenSearchKVStorage":
- print(" [opensearch]")
- print(" hosts = localhost:9200")
- return None, None, None, 0
- # Count cache records efficiently
- print(f"\n{'Counting' if use_streaming else 'Loading'} cache records...")
- try:
- if use_streaming:
- # Use efficient counting without loading data
- total_count = await self.count_default_caches(storage, storage_name)
- print(f"- Total: {total_count:,} records")
- else:
- # Legacy mode: load all data
- cache_data = await self.get_default_caches(storage, storage_name)
- counts = await self.count_cache_types(cache_data)
- total_count = len(cache_data)
- print(f"- default:extract: {counts['extract']:,} records")
- print(f"- default:summary: {counts['summary']:,} records")
- print(f"- Total: {total_count:,} records")
- except Exception as e:
- print(f"✗ {'Counting' if use_streaming else 'Loading'} failed: {e}")
- return None, None, None, 0
- return storage, storage_name, workspace, total_count
- async def migrate_caches(
- self, source_data: Dict[str, Any], target_storage, target_storage_name: str
- ) -> MigrationStats:
- """Migrate caches in batches with error tracking (Legacy mode - loads all data)
- Args:
- source_data: Source cache data
- target_storage: Target storage instance
- target_storage_name: Target storage type name
- Returns:
- MigrationStats object with migration results and errors
- """
- stats = MigrationStats()
- stats.total_source_records = len(source_data)
- if stats.total_source_records == 0:
- print("\nNo records to migrate")
- return stats
- # Convert to list for batching
- items = list(source_data.items())
- stats.total_batches = (
- stats.total_source_records + self.batch_size - 1
- ) // self.batch_size
- print("\n=== Starting Migration ===")
- for batch_idx in range(stats.total_batches):
- start_idx = batch_idx * self.batch_size
- end_idx = min((batch_idx + 1) * self.batch_size, stats.total_source_records)
- batch_items = items[start_idx:end_idx]
- batch_data = dict(batch_items)
- # Determine current cache type for display
- current_key = batch_items[0][0]
- cache_type = "extract" if "extract" in current_key else "summary"
- try:
- # Attempt to write batch
- await target_storage.upsert(batch_data)
- # Success - update stats
- stats.successful_batches += 1
- stats.successful_records += len(batch_data)
- # Calculate progress
- progress = (end_idx / stats.total_source_records) * 100
- bar_length = 20
- filled_length = int(bar_length * end_idx // stats.total_source_records)
- bar = "█" * filled_length + "░" * (bar_length - filled_length)
- print(
- f"Batch {batch_idx + 1}/{stats.total_batches}: {bar} "
- f"{end_idx:,}/{stats.total_source_records:,} ({progress:.0f}%) - "
- f"default:{cache_type} ✓"
- )
- except Exception as e:
- # Error - record and continue
- stats.add_error(batch_idx + 1, e, len(batch_data))
- print(
- f"Batch {batch_idx + 1}/{stats.total_batches}: ✗ FAILED - "
- f"{type(e).__name__}: {str(e)}"
- )
- # Final persist
- print("\nPersisting data to disk...")
- try:
- await target_storage.index_done_callback()
- print("✓ Data persisted successfully")
- except Exception as e:
- print(f"✗ Persist failed: {e}")
- stats.add_error(0, e, 0) # batch 0 = persist error
- return stats
- async def migrate_caches_streaming(
- self,
- source_storage,
- source_storage_name: str,
- target_storage,
- target_storage_name: str,
- total_records: int,
- ) -> MigrationStats:
- """Migrate caches using streaming approach - minimal memory footprint
- Args:
- source_storage: Source storage instance
- source_storage_name: Source storage type name
- target_storage: Target storage instance
- target_storage_name: Target storage type name
- total_records: Total number of records to migrate
- Returns:
- MigrationStats object with migration results and errors
- """
- stats = MigrationStats()
- stats.total_source_records = total_records
- if stats.total_source_records == 0:
- print("\nNo records to migrate")
- return stats
- # Calculate total batches
- stats.total_batches = (total_records + self.batch_size - 1) // self.batch_size
- print("\n=== Starting Streaming Migration ===")
- print(
- f"💡 Memory-optimized mode: Processing {self.batch_size:,} records at a time\n"
- )
- batch_idx = 0
- # Stream batches from source and write to target immediately
- async for batch in self.stream_default_caches(
- source_storage, source_storage_name
- ):
- batch_idx += 1
- # Determine current cache type for display
- if batch:
- first_key = next(iter(batch.keys()))
- cache_type = "extract" if "extract" in first_key else "summary"
- else:
- cache_type = "unknown"
- try:
- # Write batch to target storage
- await target_storage.upsert(batch)
- # Success - update stats
- stats.successful_batches += 1
- stats.successful_records += len(batch)
- # Calculate progress with known total
- progress = (stats.successful_records / total_records) * 100
- bar_length = 20
- filled_length = int(
- bar_length * stats.successful_records // total_records
- )
- bar = "█" * filled_length + "░" * (bar_length - filled_length)
- print(
- f"Batch {batch_idx}/{stats.total_batches}: {bar} "
- f"{stats.successful_records:,}/{total_records:,} ({progress:.1f}%) - "
- f"default:{cache_type} ✓"
- )
- except Exception as e:
- # Error - record and continue
- stats.add_error(batch_idx, e, len(batch))
- print(
- f"Batch {batch_idx}/{stats.total_batches}: ✗ FAILED - "
- f"{type(e).__name__}: {str(e)}"
- )
- # Final persist
- print("\nPersisting data to disk...")
- try:
- await target_storage.index_done_callback()
- print("✓ Data persisted successfully")
- except Exception as e:
- print(f"✗ Persist failed: {e}")
- stats.add_error(0, e, 0) # batch 0 = persist error
- return stats
- def print_migration_report(self, stats: MigrationStats):
- """Print comprehensive migration report
- Args:
- stats: MigrationStats object with migration results
- """
- print("\n" + "=" * 60)
- print("Migration Complete - Final Report")
- print("=" * 60)
- # Overall statistics
- print("\n📊 Statistics:")
- print(f" Total source records: {stats.total_source_records:,}")
- print(f" Total batches: {stats.total_batches:,}")
- print(f" Successful batches: {stats.successful_batches:,}")
- print(f" Failed batches: {stats.failed_batches:,}")
- print(f" Successfully migrated: {stats.successful_records:,}")
- print(f" Failed to migrate: {stats.failed_records:,}")
- # Success rate
- success_rate = (
- (stats.successful_records / stats.total_source_records * 100)
- if stats.total_source_records > 0
- else 0
- )
- print(f" Success rate: {success_rate:.2f}%")
- # Error details
- if stats.errors:
- print(f"\n⚠️ Errors encountered: {len(stats.errors)}")
- print("\nError Details:")
- print("-" * 60)
- # Group errors by type
- error_types = {}
- for error in stats.errors:
- err_type = error["error_type"]
- error_types[err_type] = error_types.get(err_type, 0) + 1
- print("\nError Summary:")
- for err_type, count in sorted(error_types.items(), key=lambda x: -x[1]):
- print(f" - {err_type}: {count} occurrence(s)")
- print("\nFirst 5 errors:")
- for i, error in enumerate(stats.errors[:5], 1):
- print(f"\n {i}. Batch {error['batch']}")
- print(f" Type: {error['error_type']}")
- print(f" Message: {error['error_msg']}")
- print(f" Records lost: {error['records_lost']:,}")
- if len(stats.errors) > 5:
- print(f"\n ... and {len(stats.errors) - 5} more errors")
- print("\n" + "=" * 60)
- print("⚠️ WARNING: Migration completed with errors!")
- print(" Please review the error details above.")
- print("=" * 60)
- else:
- print("\n" + "=" * 60)
- print("✓ SUCCESS: All records migrated successfully!")
- print("=" * 60)
- async def run(self):
- """Run the migration tool with streaming approach and early validation"""
- try:
- # Initialize shared storage (REQUIRED for storage classes to work)
- from lightrag.kg.shared_storage import initialize_share_data
- initialize_share_data(workers=1)
- # Print header
- self.print_header()
- # Setup source storage with streaming (only count, don't load all data)
- (
- self.source_storage,
- source_storage_name,
- self.source_workspace,
- source_count,
- ) = await self.setup_storage("Source", use_streaming=True)
- # Check if user cancelled (setup_storage returns None for all fields)
- if self.source_storage is None:
- return
- # Check if there are at least 2 storage types available
- available_count = self.count_available_storage_types()
- if available_count <= 1:
- print("\n" + "=" * 60)
- print("⚠️ Warning: Migration Not Possible")
- print("=" * 60)
- print(f"Only {available_count} storage type(s) available.")
- print("Migration requires at least 2 different storage types.")
- print("\nTo enable migration, configure additional storage:")
- print(" 1. Set environment variables, OR")
- print(" 2. Update config.ini file")
- print("\nSupported storage types:")
- for name in STORAGE_TYPES.values():
- if name != source_storage_name:
- print(f" - {name}")
- if name in STORAGE_ENV_REQUIREMENTS:
- for var in STORAGE_ENV_REQUIREMENTS[name]:
- print(f" Required: {var}")
- print("=" * 60)
- # Cleanup
- await self.source_storage.finalize()
- return
- if source_count == 0:
- print("\n⚠️ Source storage has no cache records to migrate")
- # Cleanup
- await self.source_storage.finalize()
- return
- # Setup target storage with streaming (only count, don't load all data)
- # Exclude source storage type from target selection
- (
- self.target_storage,
- target_storage_name,
- self.target_workspace,
- target_count,
- ) = await self.setup_storage(
- "Target", use_streaming=True, exclude_storage_name=source_storage_name
- )
- if not self.target_storage:
- print("\n✗ Target storage setup failed")
- # Cleanup source
- await self.source_storage.finalize()
- return
- # Show migration summary
- print("\n" + "=" * 50)
- print("Migration Confirmation")
- print("=" * 50)
- print(
- f"Source: {self.format_storage_name(source_storage_name)} (workspace: {self.format_workspace(self.source_workspace)}) - {source_count:,} records"
- )
- print(
- f"Target: {self.format_storage_name(target_storage_name)} (workspace: {self.format_workspace(self.target_workspace)}) - {target_count:,} records"
- )
- print(f"Batch Size: {self.batch_size:,} records/batch")
- print("Memory Mode: Streaming (memory-optimized)")
- if target_count > 0:
- print(
- f"\n⚠️ Warning: Target storage already has {target_count:,} records"
- )
- print("Migration will overwrite records with the same keys")
- # Confirm migration
- confirm = input("\nContinue? (y/n): ").strip().lower()
- if confirm != "y":
- print("\n✗ Migration cancelled")
- # Cleanup
- await self.source_storage.finalize()
- await self.target_storage.finalize()
- return
- # Perform streaming migration with error tracking
- stats = await self.migrate_caches_streaming(
- self.source_storage,
- source_storage_name,
- self.target_storage,
- target_storage_name,
- source_count,
- )
- # Print comprehensive migration report
- self.print_migration_report(stats)
- # Cleanup
- await self.source_storage.finalize()
- await self.target_storage.finalize()
- except KeyboardInterrupt:
- print("\n\n✗ Migration interrupted by user")
- except Exception as e:
- print(f"\n✗ Migration failed: {e}")
- import traceback
- traceback.print_exc()
- finally:
- # Ensure cleanup
- if self.source_storage:
- try:
- await self.source_storage.finalize()
- except Exception:
- pass
- if self.target_storage:
- try:
- await self.target_storage.finalize()
- except Exception:
- pass
- # Finalize shared storage
- try:
- from lightrag.kg.shared_storage import finalize_share_data
- finalize_share_data()
- except Exception:
- pass
- async def main():
- """Main entry point"""
- tool = MigrationTool()
- await tool.run()
- if __name__ == "__main__":
- asyncio.run(main())
|