redis_impl.py 50 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238
  1. import os
  2. import logging
  3. from typing import Any, final, Union
  4. from dataclasses import dataclass
  5. import pipmaster as pm
  6. import configparser
  7. from contextlib import asynccontextmanager
  8. import threading
  9. if not pm.is_installed("redis"):
  10. pm.install("redis")
  11. # aioredis is a depricated library, replaced with redis
  12. from redis.asyncio import Redis, ConnectionPool # type: ignore
  13. from redis.exceptions import RedisError, ConnectionError, TimeoutError # type: ignore
  14. from lightrag.utils import logger, get_pinyin_sort_key, _cooperative_yield
  15. from lightrag.base import (
  16. BaseKVStorage,
  17. DocStatusStorage,
  18. DocStatus,
  19. DocProcessingStatus,
  20. )
  21. from ..kg.shared_storage import get_data_init_lock
  22. import json
  23. # Import tenacity for retry logic
  24. from tenacity import (
  25. retry,
  26. stop_after_attempt,
  27. wait_exponential,
  28. retry_if_exception_type,
  29. before_sleep_log,
  30. )
  31. config = configparser.ConfigParser()
  32. config.read("config.ini", "utf-8")
  33. # Constants for Redis connection pool with environment variable support
  34. MAX_CONNECTIONS = int(os.getenv("REDIS_MAX_CONNECTIONS", "200"))
  35. SOCKET_TIMEOUT = float(os.getenv("REDIS_SOCKET_TIMEOUT", "30.0"))
  36. SOCKET_CONNECT_TIMEOUT = float(os.getenv("REDIS_CONNECT_TIMEOUT", "10.0"))
  37. RETRY_ATTEMPTS = int(os.getenv("REDIS_RETRY_ATTEMPTS", "3"))
  38. # Tenacity retry decorator for Redis operations
  39. redis_retry = retry(
  40. stop=stop_after_attempt(RETRY_ATTEMPTS),
  41. wait=wait_exponential(multiplier=1, min=1, max=8),
  42. retry=(
  43. retry_if_exception_type(ConnectionError)
  44. | retry_if_exception_type(TimeoutError)
  45. | retry_if_exception_type(RedisError)
  46. ),
  47. before_sleep=before_sleep_log(logger, logging.WARNING),
  48. )
  49. class RedisConnectionManager:
  50. """Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI"""
  51. _pools = {}
  52. _pool_refs = {} # Track reference count for each pool
  53. _lock = threading.Lock()
  54. @classmethod
  55. def get_pool(cls, redis_url: str) -> ConnectionPool:
  56. """Get or create a connection pool for the given Redis URL"""
  57. with cls._lock:
  58. if redis_url not in cls._pools:
  59. cls._pools[redis_url] = ConnectionPool.from_url(
  60. redis_url,
  61. max_connections=MAX_CONNECTIONS,
  62. decode_responses=True,
  63. socket_timeout=SOCKET_TIMEOUT,
  64. socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
  65. )
  66. cls._pool_refs[redis_url] = 0
  67. logger.info(f"Created shared Redis connection pool for {redis_url}")
  68. # Increment reference count
  69. cls._pool_refs[redis_url] += 1
  70. logger.debug(
  71. f"Redis pool {redis_url} reference count: {cls._pool_refs[redis_url]}"
  72. )
  73. return cls._pools[redis_url]
  74. @classmethod
  75. def release_pool(cls, redis_url: str):
  76. """Release a reference to the connection pool"""
  77. with cls._lock:
  78. if redis_url in cls._pool_refs:
  79. cls._pool_refs[redis_url] -= 1
  80. logger.debug(
  81. f"Redis pool {redis_url} reference count: {cls._pool_refs[redis_url]}"
  82. )
  83. # If no more references, close the pool
  84. if cls._pool_refs[redis_url] <= 0:
  85. try:
  86. cls._pools[redis_url].disconnect()
  87. logger.info(
  88. f"Closed Redis connection pool for {redis_url} (no more references)"
  89. )
  90. except Exception as e:
  91. logger.error(f"Error closing Redis pool for {redis_url}: {e}")
  92. finally:
  93. del cls._pools[redis_url]
  94. del cls._pool_refs[redis_url]
  95. @classmethod
  96. def close_all_pools(cls):
  97. """Close all connection pools (for cleanup)"""
  98. with cls._lock:
  99. for url, pool in cls._pools.items():
  100. try:
  101. pool.disconnect()
  102. logger.info(f"Closed Redis connection pool for {url}")
  103. except Exception as e:
  104. logger.error(f"Error closing Redis pool for {url}: {e}")
  105. cls._pools.clear()
  106. cls._pool_refs.clear()
  107. @final
  108. @dataclass
  109. class RedisKVStorage(BaseKVStorage):
  110. def __post_init__(self):
  111. # Check for REDIS_WORKSPACE environment variable first (higher priority)
  112. # This allows administrators to force a specific workspace for all Redis storage instances
  113. redis_workspace = os.environ.get("REDIS_WORKSPACE")
  114. if redis_workspace and redis_workspace.strip():
  115. # Use environment variable value, overriding the passed workspace parameter
  116. effective_workspace = redis_workspace.strip()
  117. logger.info(
  118. f"Using REDIS_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')"
  119. )
  120. else:
  121. # Use the workspace parameter passed during initialization
  122. effective_workspace = self.workspace
  123. if effective_workspace:
  124. logger.debug(
  125. f"Using passed workspace parameter: '{effective_workspace}'"
  126. )
  127. # Build final_namespace with workspace prefix for data isolation
  128. # Keep original namespace unchanged for type detection logic
  129. if effective_workspace:
  130. self.final_namespace = f"{effective_workspace}_{self.namespace}"
  131. logger.debug(
  132. f"Final namespace with workspace prefix: '{self.final_namespace}'"
  133. )
  134. else:
  135. # When workspace is empty, final_namespace equals original namespace
  136. self.final_namespace = self.namespace
  137. self.workspace = ""
  138. logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
  139. self._redis_url = os.environ.get(
  140. "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
  141. )
  142. self._pool = None
  143. self._redis = None
  144. self._initialized = False
  145. try:
  146. # Use shared connection pool
  147. self._pool = RedisConnectionManager.get_pool(self._redis_url)
  148. self._redis = Redis(connection_pool=self._pool)
  149. logger.info(
  150. f"[{self.workspace}] Initialized Redis KV storage for {self.namespace} using shared connection pool"
  151. )
  152. except Exception as e:
  153. # Clean up on initialization failure
  154. if self._redis_url:
  155. RedisConnectionManager.release_pool(self._redis_url)
  156. logger.error(
  157. f"[{self.workspace}] Failed to initialize Redis KV storage: {e}"
  158. )
  159. raise
  160. async def initialize(self):
  161. """Initialize Redis connection and migrate legacy cache structure if needed"""
  162. async with get_data_init_lock():
  163. if self._initialized:
  164. return
  165. # Test connection
  166. try:
  167. async with self._get_redis_connection() as redis:
  168. await redis.ping()
  169. logger.info(
  170. f"[{self.workspace}] Connected to Redis for namespace {self.namespace}"
  171. )
  172. self._initialized = True
  173. except Exception as e:
  174. logger.error(f"[{self.workspace}] Failed to connect to Redis: {e}")
  175. # Clean up on connection failure
  176. await self.close()
  177. raise
  178. # Migrate legacy cache structure if this is a cache namespace
  179. if self.namespace.endswith("_cache"):
  180. try:
  181. await self._migrate_legacy_cache_structure()
  182. except Exception as e:
  183. logger.error(
  184. f"[{self.workspace}] Failed to migrate legacy cache structure: {e}"
  185. )
  186. # Don't fail initialization for migration errors, just log them
  187. @asynccontextmanager
  188. async def _get_redis_connection(self):
  189. """Safe context manager for Redis operations."""
  190. if not self._redis:
  191. raise ConnectionError("Redis connection not initialized")
  192. try:
  193. # Use the existing Redis instance with shared pool
  194. yield self._redis
  195. except ConnectionError as e:
  196. logger.error(
  197. f"[{self.workspace}] Redis connection error in {self.namespace}: {e}"
  198. )
  199. raise
  200. except RedisError as e:
  201. logger.error(
  202. f"[{self.workspace}] Redis operation error in {self.namespace}: {e}"
  203. )
  204. raise
  205. except Exception as e:
  206. logger.error(
  207. f"[{self.workspace}] Unexpected error in Redis operation for {self.namespace}: {e}"
  208. )
  209. raise
  210. async def close(self):
  211. """Close the Redis connection and release pool reference to prevent resource leaks."""
  212. if hasattr(self, "_redis") and self._redis:
  213. try:
  214. await self._redis.close()
  215. logger.debug(
  216. f"[{self.workspace}] Closed Redis connection for {self.namespace}"
  217. )
  218. except Exception as e:
  219. logger.error(f"[{self.workspace}] Error closing Redis connection: {e}")
  220. finally:
  221. self._redis = None
  222. # Release the pool reference (will auto-close pool if no more references)
  223. if hasattr(self, "_redis_url") and self._redis_url:
  224. RedisConnectionManager.release_pool(self._redis_url)
  225. self._pool = None
  226. logger.debug(
  227. f"[{self.workspace}] Released Redis connection pool reference for {self.namespace}"
  228. )
  229. async def __aenter__(self):
  230. """Support for async context manager."""
  231. return self
  232. async def __aexit__(self, exc_type, exc_val, exc_tb):
  233. """Ensure Redis resources are cleaned up when exiting context."""
  234. await self.close()
  235. @redis_retry
  236. async def get_by_id(self, id: str) -> dict[str, Any] | None:
  237. async with self._get_redis_connection() as redis:
  238. try:
  239. data = await redis.get(f"{self.final_namespace}:{id}")
  240. if data:
  241. result = json.loads(data)
  242. # Ensure time fields are present, provide default values for old data
  243. result.setdefault("create_time", 0)
  244. result.setdefault("update_time", 0)
  245. return result
  246. return None
  247. except json.JSONDecodeError as e:
  248. logger.error(f"[{self.workspace}] JSON decode error for id {id}: {e}")
  249. raise
  250. @redis_retry
  251. async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
  252. async with self._get_redis_connection() as redis:
  253. try:
  254. pipe = redis.pipeline()
  255. for id in ids:
  256. pipe.get(f"{self.final_namespace}:{id}")
  257. results = await pipe.execute()
  258. processed_results = []
  259. for result in results:
  260. if result:
  261. data = json.loads(result)
  262. # Ensure time fields are present for all documents
  263. data.setdefault("create_time", 0)
  264. data.setdefault("update_time", 0)
  265. processed_results.append(data)
  266. else:
  267. processed_results.append(None)
  268. return processed_results
  269. except json.JSONDecodeError as e:
  270. logger.error(f"[{self.workspace}] JSON decode error in batch get: {e}")
  271. raise
  272. async def filter_keys(self, keys: set[str]) -> set[str]:
  273. async with self._get_redis_connection() as redis:
  274. pipe = redis.pipeline()
  275. keys_list = list(keys) # Convert set to list for indexing
  276. for key in keys_list:
  277. pipe.exists(f"{self.final_namespace}:{key}")
  278. results = await pipe.execute()
  279. existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
  280. return set(keys) - existing_ids
  281. @redis_retry
  282. async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
  283. if not data:
  284. return
  285. import time
  286. current_time = int(time.time()) # Get current Unix timestamp
  287. async with self._get_redis_connection() as redis:
  288. try:
  289. # Check which keys already exist to determine create vs update
  290. pipe = redis.pipeline()
  291. for i, k in enumerate(data.keys(), start=1):
  292. pipe.exists(f"{self.final_namespace}:{k}")
  293. await _cooperative_yield(i)
  294. exists_results = await pipe.execute()
  295. # Add timestamps to data
  296. for i, (k, v) in enumerate(data.items(), start=1):
  297. # For text_chunks namespace, ensure llm_cache_list field exists
  298. if self.namespace.endswith("text_chunks"):
  299. if "llm_cache_list" not in v:
  300. v["llm_cache_list"] = []
  301. # Add timestamps based on whether key exists
  302. if exists_results[i - 1]: # Key exists, only update update_time
  303. v["update_time"] = current_time
  304. else: # New key, set both create_time and update_time
  305. v["create_time"] = current_time
  306. v["update_time"] = current_time
  307. v["_id"] = k
  308. await _cooperative_yield(i)
  309. # Store the data
  310. pipe = redis.pipeline()
  311. for i, (k, v) in enumerate(data.items(), start=1):
  312. pipe.set(f"{self.final_namespace}:{k}", json.dumps(v))
  313. await _cooperative_yield(i)
  314. await pipe.execute()
  315. except json.JSONDecodeError as e:
  316. logger.error(f"[{self.workspace}] JSON decode error during upsert: {e}")
  317. raise
  318. async def index_done_callback(self) -> None:
  319. # Redis handles persistence automatically
  320. pass
  321. async def is_empty(self) -> bool:
  322. """Check if the storage is empty for the current workspace and namespace
  323. Returns:
  324. bool: True if storage is empty, False otherwise
  325. """
  326. pattern = f"{self.final_namespace}:*"
  327. try:
  328. async with self._get_redis_connection() as redis:
  329. # Use scan to check if any keys exist
  330. async for key in redis.scan_iter(match=pattern, count=1):
  331. return False # Found at least one key
  332. return True # No keys found
  333. except Exception as e:
  334. logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}")
  335. return True
  336. async def delete(self, ids: list[str]) -> None:
  337. """Delete specific records from storage by their IDs"""
  338. if not ids:
  339. return
  340. async with self._get_redis_connection() as redis:
  341. pipe = redis.pipeline()
  342. for id in ids:
  343. pipe.delete(f"{self.final_namespace}:{id}")
  344. results = await pipe.execute()
  345. deleted_count = sum(results)
  346. logger.info(
  347. f"[{self.workspace}] Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
  348. )
  349. async def drop(self) -> dict[str, str]:
  350. """Drop the storage by removing all keys under the current namespace.
  351. Returns:
  352. dict[str, str]: Status of the operation with keys 'status' and 'message'
  353. """
  354. async with self._get_redis_connection() as redis:
  355. try:
  356. # Use SCAN to find all keys with the namespace prefix
  357. pattern = f"{self.final_namespace}:*"
  358. cursor = 0
  359. deleted_count = 0
  360. while True:
  361. cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
  362. if keys:
  363. # Delete keys in batches
  364. pipe = redis.pipeline()
  365. for key in keys:
  366. pipe.delete(key)
  367. results = await pipe.execute()
  368. deleted_count += sum(results)
  369. if cursor == 0:
  370. break
  371. logger.info(
  372. f"[{self.workspace}] Dropped {deleted_count} keys from {self.namespace}"
  373. )
  374. return {
  375. "status": "success",
  376. "message": f"{deleted_count} keys dropped",
  377. }
  378. except Exception as e:
  379. logger.error(
  380. f"[{self.workspace}] Error dropping keys from {self.namespace}: {e}"
  381. )
  382. return {"status": "error", "message": str(e)}
  383. async def _migrate_legacy_cache_structure(self):
  384. """Migrate legacy nested cache structure to flattened structure for Redis
  385. Redis already stores data in a flattened way, but we need to check for
  386. legacy keys that might contain nested JSON structures and migrate them.
  387. Early exit if any flattened key is found (indicating migration already done).
  388. """
  389. from lightrag.utils import generate_cache_key
  390. async with self._get_redis_connection() as redis:
  391. # Get all keys for this namespace
  392. keys = await redis.keys(f"{self.final_namespace}:*")
  393. if not keys:
  394. return
  395. # Check if we have any flattened keys already - if so, skip migration
  396. has_flattened_keys = False
  397. keys_to_migrate = []
  398. for key in keys:
  399. # Extract the ID part (after namespace:)
  400. key_id = key.split(":", 1)[1]
  401. # Check if already in flattened format (contains exactly 2 colons for mode:cache_type:hash)
  402. if ":" in key_id and len(key_id.split(":")) == 3:
  403. has_flattened_keys = True
  404. break # Early exit - migration already done
  405. # Get the data to check if it's a legacy nested structure
  406. data = await redis.get(key)
  407. if data:
  408. try:
  409. parsed_data = json.loads(data)
  410. # Check if this looks like a legacy cache mode with nested structure
  411. if isinstance(parsed_data, dict) and all(
  412. isinstance(v, dict) and "return" in v
  413. for v in parsed_data.values()
  414. ):
  415. keys_to_migrate.append((key, key_id, parsed_data))
  416. except json.JSONDecodeError:
  417. continue
  418. # If we found any flattened keys, assume migration is already done
  419. if has_flattened_keys:
  420. logger.debug(
  421. f"[{self.workspace}] Found flattened cache keys in {self.namespace}, skipping migration"
  422. )
  423. return
  424. if not keys_to_migrate:
  425. return
  426. # Perform migration
  427. pipe = redis.pipeline()
  428. migration_count = 0
  429. for old_key, mode, nested_data in keys_to_migrate:
  430. # Delete the old key
  431. pipe.delete(old_key)
  432. # Create new flattened keys
  433. for cache_hash, cache_entry in nested_data.items():
  434. cache_type = cache_entry.get("cache_type", "extract")
  435. flattened_key = generate_cache_key(mode, cache_type, cache_hash)
  436. full_key = f"{self.final_namespace}:{flattened_key}"
  437. pipe.set(full_key, json.dumps(cache_entry))
  438. migration_count += 1
  439. await pipe.execute()
  440. if migration_count > 0:
  441. logger.info(
  442. f"[{self.workspace}] Migrated {migration_count} legacy cache entries to flattened structure in Redis"
  443. )
  444. @final
  445. @dataclass
  446. class RedisDocStatusStorage(DocStatusStorage):
  447. """Redis implementation of document status storage"""
  448. def __post_init__(self):
  449. # Check for REDIS_WORKSPACE environment variable first (higher priority)
  450. # This allows administrators to force a specific workspace for all Redis storage instances
  451. redis_workspace = os.environ.get("REDIS_WORKSPACE")
  452. if redis_workspace and redis_workspace.strip():
  453. # Use environment variable value, overriding the passed workspace parameter
  454. effective_workspace = redis_workspace.strip()
  455. logger.info(
  456. f"Using REDIS_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')"
  457. )
  458. else:
  459. # Use the workspace parameter passed during initialization
  460. effective_workspace = self.workspace
  461. if effective_workspace:
  462. logger.debug(
  463. f"Using passed workspace parameter: '{effective_workspace}'"
  464. )
  465. # Build final_namespace with workspace prefix for data isolation
  466. # Keep original namespace unchanged for type detection logic
  467. if effective_workspace:
  468. self.final_namespace = f"{effective_workspace}_{self.namespace}"
  469. logger.debug(
  470. f"[{self.workspace}] Final namespace with workspace prefix: '{self.namespace}'"
  471. )
  472. else:
  473. # When workspace is empty, final_namespace equals original namespace
  474. self.final_namespace = self.namespace
  475. self.workspace = "_"
  476. logger.debug(
  477. f"[{self.workspace}] Final namespace (no workspace): '{self.namespace}'"
  478. )
  479. self._redis_url = os.environ.get(
  480. "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
  481. )
  482. self._pool = None
  483. self._redis = None
  484. self._initialized = False
  485. try:
  486. # Use shared connection pool
  487. self._pool = RedisConnectionManager.get_pool(self._redis_url)
  488. self._redis = Redis(connection_pool=self._pool)
  489. logger.info(
  490. f"[{self.workspace}] Initialized Redis doc status storage for {self.namespace} using shared connection pool"
  491. )
  492. except Exception as e:
  493. # Clean up on initialization failure
  494. if self._redis_url:
  495. RedisConnectionManager.release_pool(self._redis_url)
  496. logger.error(
  497. f"[{self.workspace}] Failed to initialize Redis doc status storage: {e}"
  498. )
  499. raise
  500. async def initialize(self):
  501. """Initialize Redis connection"""
  502. async with get_data_init_lock():
  503. if self._initialized:
  504. return
  505. try:
  506. async with self._get_redis_connection() as redis:
  507. await redis.ping()
  508. logger.info(
  509. f"[{self.workspace}] Connected to Redis for doc status namespace {self.namespace}"
  510. )
  511. self._initialized = True
  512. except Exception as e:
  513. logger.error(
  514. f"[{self.workspace}] Failed to connect to Redis for doc status: {e}"
  515. )
  516. # Clean up on connection failure
  517. await self.close()
  518. raise
  519. @asynccontextmanager
  520. async def _get_redis_connection(self):
  521. """Safe context manager for Redis operations."""
  522. if not self._redis:
  523. raise ConnectionError("Redis connection not initialized")
  524. try:
  525. # Use the existing Redis instance with shared pool
  526. yield self._redis
  527. except ConnectionError as e:
  528. logger.error(
  529. f"[{self.workspace}] Redis connection error in doc status {self.namespace}: {e}"
  530. )
  531. raise
  532. except RedisError as e:
  533. logger.error(
  534. f"[{self.workspace}] Redis operation error in doc status {self.namespace}: {e}"
  535. )
  536. raise
  537. except Exception as e:
  538. logger.error(
  539. f"[{self.workspace}] Unexpected error in Redis doc status operation for {self.namespace}: {e}"
  540. )
  541. raise
  542. async def close(self):
  543. """Close the Redis connection and release pool reference to prevent resource leaks."""
  544. if hasattr(self, "_redis") and self._redis:
  545. try:
  546. await self._redis.close()
  547. logger.debug(
  548. f"[{self.workspace}] Closed Redis connection for doc status {self.namespace}"
  549. )
  550. except Exception as e:
  551. logger.error(f"[{self.workspace}] Error closing Redis connection: {e}")
  552. finally:
  553. self._redis = None
  554. # Release the pool reference (will auto-close pool if no more references)
  555. if hasattr(self, "_redis_url") and self._redis_url:
  556. RedisConnectionManager.release_pool(self._redis_url)
  557. self._pool = None
  558. logger.debug(
  559. f"[{self.workspace}] Released Redis connection pool reference for doc status {self.namespace}"
  560. )
  561. async def __aenter__(self):
  562. """Support for async context manager."""
  563. return self
  564. async def __aexit__(self, exc_type, exc_val, exc_tb):
  565. """Ensure Redis resources are cleaned up when exiting context."""
  566. await self.close()
  567. async def filter_keys(self, keys: set[str]) -> set[str]:
  568. """Return keys that should be processed (not in storage or not successfully processed)"""
  569. async with self._get_redis_connection() as redis:
  570. pipe = redis.pipeline()
  571. keys_list = list(keys)
  572. for key in keys_list:
  573. pipe.exists(f"{self.final_namespace}:{key}")
  574. results = await pipe.execute()
  575. existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
  576. return set(keys) - existing_ids
  577. async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
  578. ordered_results: list[dict[str, Any] | None] = []
  579. async with self._get_redis_connection() as redis:
  580. try:
  581. pipe = redis.pipeline()
  582. for id in ids:
  583. pipe.get(f"{self.final_namespace}:{id}")
  584. results = await pipe.execute()
  585. for result_data in results:
  586. if result_data:
  587. try:
  588. ordered_results.append(json.loads(result_data))
  589. except json.JSONDecodeError as e:
  590. logger.error(
  591. f"[{self.workspace}] JSON decode error in get_by_ids: {e}"
  592. )
  593. raise
  594. else:
  595. ordered_results.append(None)
  596. except Exception as e:
  597. logger.error(f"[{self.workspace}] Error in get_by_ids: {e}")
  598. raise
  599. return ordered_results
  600. async def get_status_counts(self) -> dict[str, int]:
  601. """Get counts of documents in each status"""
  602. counts = {status.value: 0 for status in DocStatus}
  603. async with self._get_redis_connection() as redis:
  604. try:
  605. # Use SCAN to iterate through all keys in the namespace
  606. cursor = 0
  607. while True:
  608. cursor, keys = await redis.scan(
  609. cursor, match=f"{self.final_namespace}:*", count=1000
  610. )
  611. if keys:
  612. # Get all values in batch
  613. pipe = redis.pipeline()
  614. for key in keys:
  615. pipe.get(key)
  616. values = await pipe.execute()
  617. # Count statuses
  618. for value in values:
  619. if value:
  620. try:
  621. doc_data = json.loads(value)
  622. status = doc_data.get("status")
  623. if status in counts:
  624. counts[status] += 1
  625. except json.JSONDecodeError:
  626. continue
  627. if cursor == 0:
  628. break
  629. except Exception as e:
  630. logger.error(f"[{self.workspace}] Error getting status counts: {e}")
  631. return counts
  632. async def get_docs_by_status(
  633. self, status: DocStatus
  634. ) -> dict[str, DocProcessingStatus]:
  635. """Get all documents with a specific status"""
  636. return await self.get_docs_by_statuses([status])
  637. async def get_docs_by_statuses(
  638. self, statuses: list[DocStatus]
  639. ) -> dict[str, DocProcessingStatus]:
  640. """Get all documents matching any of the given statuses in a single SCAN pass.
  641. Redis has no server-side multi-value filter, so documents must be fetched
  642. and filtered in Python. This override performs a single SCAN + pipeline
  643. GET over the keyspace, filtering against a set of status values. The
  644. previous pattern of N separate get_docs_by_status() calls would do N full
  645. SCANs (one per status), so this reduces keyspace traversal from N passes to one.
  646. """
  647. if not statuses:
  648. return {}
  649. status_values = {s.value for s in statuses}
  650. result = {}
  651. async with self._get_redis_connection() as redis:
  652. try:
  653. cursor = 0
  654. while True:
  655. cursor, keys = await redis.scan(
  656. cursor, match=f"{self.final_namespace}:*", count=1000
  657. )
  658. if keys:
  659. pipe = redis.pipeline()
  660. for key in keys:
  661. pipe.get(key)
  662. values = await pipe.execute()
  663. for key, value in zip(keys, values):
  664. if not value:
  665. continue
  666. try:
  667. doc_data = json.loads(value)
  668. if doc_data.get("status") not in status_values:
  669. continue
  670. doc_id = key.split(":", 1)[1]
  671. data = doc_data.copy()
  672. data.pop("content", None)
  673. if "file_path" not in data:
  674. data["file_path"] = "no-file-path"
  675. if "metadata" not in data:
  676. data["metadata"] = {}
  677. if "error_msg" not in data:
  678. data["error_msg"] = None
  679. result[doc_id] = DocProcessingStatus(**data)
  680. except (json.JSONDecodeError, KeyError) as e:
  681. logger.error(
  682. f"[{self.workspace}] Error processing document {key}: {e}"
  683. )
  684. continue
  685. if cursor == 0:
  686. break
  687. except Exception as e:
  688. logger.error(
  689. f"[{self.workspace}] SCAN interrupted while fetching docs by statuses "
  690. f"— result is incomplete ({len(result)} documents collected): {e!r}"
  691. )
  692. raise
  693. return result
  694. async def get_docs_by_track_id(
  695. self, track_id: str
  696. ) -> dict[str, DocProcessingStatus]:
  697. """Get all documents with a specific track_id"""
  698. result = {}
  699. async with self._get_redis_connection() as redis:
  700. try:
  701. # Use SCAN to iterate through all keys in the namespace
  702. cursor = 0
  703. while True:
  704. cursor, keys = await redis.scan(
  705. cursor, match=f"{self.final_namespace}:*", count=1000
  706. )
  707. if keys:
  708. # Get all values in batch
  709. pipe = redis.pipeline()
  710. for key in keys:
  711. pipe.get(key)
  712. values = await pipe.execute()
  713. # Filter by track_id and create DocProcessingStatus objects
  714. for key, value in zip(keys, values):
  715. if value:
  716. try:
  717. doc_data = json.loads(value)
  718. if doc_data.get("track_id") == track_id:
  719. # Extract document ID from key
  720. doc_id = key.split(":", 1)[1]
  721. # Make a copy of the data to avoid modifying the original
  722. data = doc_data.copy()
  723. # Remove deprecated content field if it exists
  724. data.pop("content", None)
  725. # If file_path is not in data, use document id as file path
  726. if "file_path" not in data:
  727. data["file_path"] = "no-file-path"
  728. # Ensure new fields exist with default values
  729. if "metadata" not in data:
  730. data["metadata"] = {}
  731. if "error_msg" not in data:
  732. data["error_msg"] = None
  733. result[doc_id] = DocProcessingStatus(**data)
  734. except (json.JSONDecodeError, KeyError) as e:
  735. logger.error(
  736. f"[{self.workspace}] Error processing document {key}: {e}"
  737. )
  738. continue
  739. if cursor == 0:
  740. break
  741. except Exception as e:
  742. logger.error(f"[{self.workspace}] Error getting docs by track_id: {e}")
  743. return result
  744. async def index_done_callback(self) -> None:
  745. """Redis handles persistence automatically"""
  746. pass
  747. async def is_empty(self) -> bool:
  748. """Check if the storage is empty for the current workspace and namespace
  749. Returns:
  750. bool: True if storage is empty, False otherwise
  751. """
  752. pattern = f"{self.final_namespace}:*"
  753. try:
  754. async with self._get_redis_connection() as redis:
  755. # Use scan to check if any keys exist
  756. async for key in redis.scan_iter(match=pattern, count=1):
  757. return False # Found at least one key
  758. return True # No keys found
  759. except Exception as e:
  760. logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}")
  761. return True
  762. @redis_retry
  763. async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
  764. """Insert or update document status data"""
  765. if not data:
  766. return
  767. logger.debug(
  768. f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}"
  769. )
  770. async with self._get_redis_connection() as redis:
  771. try:
  772. # Ensure chunks_list field exists for new documents
  773. for i, (doc_id, doc_data) in enumerate(data.items(), start=1):
  774. if "chunks_list" not in doc_data:
  775. doc_data["chunks_list"] = []
  776. await _cooperative_yield(i)
  777. pipe = redis.pipeline()
  778. for i, (k, v) in enumerate(data.items(), start=1):
  779. pipe.set(f"{self.final_namespace}:{k}", json.dumps(v))
  780. await _cooperative_yield(i)
  781. await pipe.execute()
  782. except json.JSONDecodeError as e:
  783. logger.error(f"[{self.workspace}] JSON decode error during upsert: {e}")
  784. raise
  785. @redis_retry
  786. async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
  787. async with self._get_redis_connection() as redis:
  788. try:
  789. data = await redis.get(f"{self.final_namespace}:{id}")
  790. return json.loads(data) if data else None
  791. except json.JSONDecodeError as e:
  792. logger.error(f"[{self.workspace}] JSON decode error for id {id}: {e}")
  793. raise
  794. async def delete(self, doc_ids: list[str]) -> None:
  795. """Delete specific records from storage by their IDs"""
  796. if not doc_ids:
  797. return
  798. async with self._get_redis_connection() as redis:
  799. pipe = redis.pipeline()
  800. for doc_id in doc_ids:
  801. pipe.delete(f"{self.final_namespace}:{doc_id}")
  802. results = await pipe.execute()
  803. deleted_count = sum(results)
  804. logger.info(
  805. f"[{self.workspace}] Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}"
  806. )
  807. async def get_docs_paginated(
  808. self,
  809. status_filter: DocStatus | None = None,
  810. status_filters: list[DocStatus] | None = None,
  811. page: int = 1,
  812. page_size: int = 50,
  813. sort_field: str = "updated_at",
  814. sort_direction: str = "desc",
  815. ) -> tuple[list[tuple[str, DocProcessingStatus]], int]:
  816. """Get documents with pagination support
  817. Args:
  818. status_filter: Filter by document status, None for all statuses
  819. page: Page number (1-based)
  820. page_size: Number of documents per page (10-200)
  821. sort_field: Field to sort by ('created_at', 'updated_at', 'id')
  822. sort_direction: Sort direction ('asc' or 'desc')
  823. Returns:
  824. Tuple of (list of (doc_id, DocProcessingStatus) tuples, total_count)
  825. """
  826. status_filter_values = self.resolve_status_filter_values(
  827. status_filter=status_filter,
  828. status_filters=status_filters,
  829. )
  830. # Validate parameters
  831. if page < 1:
  832. page = 1
  833. if page_size < 10:
  834. page_size = 10
  835. elif page_size > 200:
  836. page_size = 200
  837. if sort_field not in ["created_at", "updated_at", "id", "file_path"]:
  838. sort_field = "updated_at"
  839. if sort_direction.lower() not in ["asc", "desc"]:
  840. sort_direction = "desc"
  841. # For Redis, we need to load all data and sort/filter in memory
  842. all_docs = []
  843. total_count = 0
  844. async with self._get_redis_connection() as redis:
  845. try:
  846. # Use SCAN to iterate through all keys in the namespace
  847. cursor = 0
  848. while True:
  849. cursor, keys = await redis.scan(
  850. cursor, match=f"{self.final_namespace}:*", count=1000
  851. )
  852. if keys:
  853. # Get all values in batch
  854. pipe = redis.pipeline()
  855. for key in keys:
  856. pipe.get(key)
  857. values = await pipe.execute()
  858. # Process documents
  859. for key, value in zip(keys, values):
  860. if value:
  861. try:
  862. doc_data = json.loads(value)
  863. # Apply status filter
  864. if (
  865. status_filter_values is not None
  866. and doc_data.get("status")
  867. not in status_filter_values
  868. ):
  869. continue
  870. # Extract document ID from key
  871. doc_id = key.split(":", 1)[1]
  872. # Prepare document data
  873. data = doc_data.copy()
  874. data.pop("content", None)
  875. if "file_path" not in data:
  876. data["file_path"] = "no-file-path"
  877. if "metadata" not in data:
  878. data["metadata"] = {}
  879. if "error_msg" not in data:
  880. data["error_msg"] = None
  881. # Calculate sort key for sorting (but don't add to data)
  882. if sort_field == "id":
  883. sort_key = doc_id
  884. elif sort_field == "file_path":
  885. # Use pinyin sorting for file_path field to support Chinese characters
  886. file_path_value = data.get(sort_field, "")
  887. sort_key = get_pinyin_sort_key(file_path_value)
  888. else:
  889. sort_key = data.get(sort_field, "")
  890. doc_status = DocProcessingStatus(**data)
  891. all_docs.append((doc_id, doc_status, sort_key))
  892. except (json.JSONDecodeError, KeyError) as e:
  893. logger.error(
  894. f"[{self.workspace}] Error processing document {key}: {e}"
  895. )
  896. continue
  897. if cursor == 0:
  898. break
  899. except Exception as e:
  900. logger.error(f"[{self.workspace}] Error getting paginated docs: {e}")
  901. return [], 0
  902. # Sort documents using the separate sort key
  903. reverse_sort = sort_direction.lower() == "desc"
  904. all_docs.sort(key=lambda x: x[2], reverse=reverse_sort)
  905. # Remove sort key from tuples and keep only (doc_id, doc_status)
  906. all_docs = [(doc_id, doc_status) for doc_id, doc_status, _ in all_docs]
  907. total_count = len(all_docs)
  908. # Apply pagination
  909. start_idx = (page - 1) * page_size
  910. end_idx = start_idx + page_size
  911. paginated_docs = all_docs[start_idx:end_idx]
  912. return paginated_docs, total_count
  913. async def get_all_status_counts(self) -> dict[str, int]:
  914. """Get counts of documents in each status for all documents
  915. Returns:
  916. Dictionary mapping status names to counts, including 'all' field
  917. """
  918. counts = await self.get_status_counts()
  919. # Add 'all' field with total count
  920. total_count = sum(counts.values())
  921. counts["all"] = total_count
  922. return counts
  923. async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]:
  924. """Get document by file path
  925. Args:
  926. file_path: The file path to search for
  927. Returns:
  928. Union[dict[str, Any], None]: Document data if found, None otherwise
  929. Returns the same format as get_by_id method
  930. """
  931. async with self._get_redis_connection() as redis:
  932. try:
  933. # Use SCAN to iterate through all keys in the namespace
  934. cursor = 0
  935. while True:
  936. cursor, keys = await redis.scan(
  937. cursor, match=f"{self.final_namespace}:*", count=1000
  938. )
  939. if keys:
  940. # Get all values in batch
  941. pipe = redis.pipeline()
  942. for key in keys:
  943. pipe.get(key)
  944. values = await pipe.execute()
  945. # Check each document for matching file_path
  946. for value in values:
  947. if value:
  948. try:
  949. doc_data = json.loads(value)
  950. if doc_data.get("file_path") == file_path:
  951. return doc_data
  952. except json.JSONDecodeError as e:
  953. logger.error(
  954. f"[{self.workspace}] JSON decode error in get_doc_by_file_path: {e}"
  955. )
  956. continue
  957. if cursor == 0:
  958. break
  959. return None
  960. except Exception as e:
  961. logger.error(f"[{self.workspace}] Error in get_doc_by_file_path: {e}")
  962. return None
  963. async def get_doc_by_file_basename(
  964. self, basename: str
  965. ) -> Union[tuple[str, dict[str, Any]], None]:
  966. """Find an existing record whose canonical basename matches.
  967. The caller is responsible for passing an already-canonical basename.
  968. Stored ``file_path`` values are canonicalized by the business layer, so
  969. this lookup intentionally performs an exact match only.
  970. """
  971. if not basename:
  972. return None
  973. if basename == "unknown_source":
  974. return None
  975. async with self._get_redis_connection() as redis:
  976. try:
  977. cursor = 0
  978. while True:
  979. cursor, keys = await redis.scan(
  980. cursor, match=f"{self.final_namespace}:*", count=1000
  981. )
  982. if keys:
  983. pipe = redis.pipeline()
  984. for key in keys:
  985. pipe.get(key)
  986. values = await pipe.execute()
  987. for key, value in zip(keys, values):
  988. if not value:
  989. continue
  990. try:
  991. doc_data = json.loads(value)
  992. except json.JSONDecodeError as e:
  993. logger.error(
  994. f"[{self.workspace}] JSON decode error in get_doc_by_file_basename: {e}"
  995. )
  996. continue
  997. if doc_data.get("file_path") == basename:
  998. doc_id = key.split(":", 1)[1]
  999. return doc_id, doc_data
  1000. if cursor == 0:
  1001. break
  1002. return None
  1003. except Exception as e:
  1004. logger.error(
  1005. f"[{self.workspace}] Error in get_doc_by_file_basename: {e}"
  1006. )
  1007. return None
  1008. async def get_doc_by_content_hash(
  1009. self, content_hash: str
  1010. ) -> Union[tuple[str, dict[str, Any]], None]:
  1011. """Find an existing record whose content_hash field matches."""
  1012. if not content_hash:
  1013. return None
  1014. async with self._get_redis_connection() as redis:
  1015. try:
  1016. cursor = 0
  1017. while True:
  1018. cursor, keys = await redis.scan(
  1019. cursor, match=f"{self.final_namespace}:*", count=1000
  1020. )
  1021. if keys:
  1022. pipe = redis.pipeline()
  1023. for key in keys:
  1024. pipe.get(key)
  1025. values = await pipe.execute()
  1026. for key, value in zip(keys, values):
  1027. if not value:
  1028. continue
  1029. try:
  1030. doc_data = json.loads(value)
  1031. except json.JSONDecodeError as e:
  1032. logger.error(
  1033. f"[{self.workspace}] JSON decode error in get_doc_by_content_hash: {e}"
  1034. )
  1035. continue
  1036. if doc_data.get("content_hash") == content_hash:
  1037. doc_id = key.split(":", 1)[1]
  1038. return doc_id, doc_data
  1039. if cursor == 0:
  1040. break
  1041. return None
  1042. except Exception as e:
  1043. logger.error(
  1044. f"[{self.workspace}] Error in get_doc_by_content_hash: {e}"
  1045. )
  1046. return None
  1047. async def drop(self) -> dict[str, str]:
  1048. """Drop all document status data from storage and clean up resources"""
  1049. try:
  1050. async with self._get_redis_connection() as redis:
  1051. # Use SCAN to find all keys with the namespace prefix
  1052. pattern = f"{self.final_namespace}:*"
  1053. cursor = 0
  1054. deleted_count = 0
  1055. while True:
  1056. cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
  1057. if keys:
  1058. # Delete keys in batches
  1059. pipe = redis.pipeline()
  1060. for key in keys:
  1061. pipe.delete(key)
  1062. results = await pipe.execute()
  1063. deleted_count += sum(results)
  1064. if cursor == 0:
  1065. break
  1066. logger.info(
  1067. f"[{self.workspace}] Dropped {deleted_count} doc status keys from {self.namespace}"
  1068. )
  1069. return {"status": "success", "message": "data dropped"}
  1070. except Exception as e:
  1071. logger.error(
  1072. f"[{self.workspace}] Error dropping doc status {self.namespace}: {e}"
  1073. )
  1074. return {"status": "error", "message": str(e)}