storage_migrations.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. """Storage data migration helpers for :class:`LightRAG`.
  2. Mixed into LightRAG and runs once at startup (``initialize_storages`` →
  3. ``check_and_migrate_data``) to upgrade legacy data layouts:
  4. - Backfill ``full_entities`` / ``full_relations`` from the graph + doc_status
  5. history when those KV stores are empty (entity-relation migration).
  6. - Rebuild ``entity_chunks`` / ``relation_chunks`` indexes by walking nodes/
  7. edges in the graph storage when they are empty
  8. (chunk-tracking migration).
  9. """
  10. from __future__ import annotations
  11. from lightrag.base import DocStatus
  12. from lightrag.constants import GRAPH_FIELD_SEP
  13. from lightrag.kg.shared_storage import get_data_init_lock
  14. from lightrag.utils import logger, make_relation_chunk_key
  15. class _StorageMigrationMixin:
  16. """Mixin that owns one-shot data migrations on :class:`LightRAG`.
  17. Mixed into LightRAG only. Relies on attributes that the main class
  18. initializes in ``__post_init__`` (``doc_status``, ``full_entities``,
  19. ``full_relations``, ``chunk_entity_relation_graph``, ``entity_chunks``,
  20. ``relation_chunks``).
  21. """
  22. async def check_and_migrate_data(self):
  23. """Check if data migration is needed and perform migration if necessary"""
  24. async with get_data_init_lock():
  25. try:
  26. # Check if migration is needed:
  27. # 1. chunk_entity_relation_graph has entities and relations (count > 0)
  28. # 2. full_entities and full_relations are empty
  29. # Get all entity labels from graph
  30. all_entity_labels = (
  31. await self.chunk_entity_relation_graph.get_all_labels()
  32. )
  33. if not all_entity_labels:
  34. logger.debug("No entities found in graph, skipping migration check")
  35. return
  36. try:
  37. # Initialize chunk tracking storage after migration
  38. await self._migrate_chunk_tracking_storage()
  39. except Exception as e:
  40. logger.error(f"Error during chunk_tracking migration: {e}")
  41. raise e
  42. # Check if full_entities and full_relations are empty
  43. # Get all processed documents to check their entity/relation data
  44. try:
  45. processed_docs = await self.doc_status.get_docs_by_status(
  46. DocStatus.PROCESSED
  47. )
  48. if not processed_docs:
  49. logger.debug("No processed documents found, skipping migration")
  50. return
  51. # Check first few documents to see if they have full_entities/full_relations data
  52. migration_needed = True
  53. checked_count = 0
  54. max_check = min(5, len(processed_docs)) # Check up to 5 documents
  55. for doc_id in list(processed_docs.keys())[:max_check]:
  56. checked_count += 1
  57. entity_data = await self.full_entities.get_by_id(doc_id)
  58. relation_data = await self.full_relations.get_by_id(doc_id)
  59. if entity_data or relation_data:
  60. migration_needed = False
  61. break
  62. if not migration_needed:
  63. logger.debug(
  64. "Full entities/relations data already exists, no migration needed"
  65. )
  66. return
  67. logger.info(
  68. f"Data migration needed: found {len(all_entity_labels)} entities in graph but no full_entities/full_relations data"
  69. )
  70. # Perform migration
  71. await self._migrate_entity_relation_data(processed_docs)
  72. except Exception as e:
  73. logger.error(f"Error during migration check: {e}")
  74. raise e
  75. except Exception as e:
  76. logger.error(f"Error in data migration check: {e}")
  77. raise e
  78. async def _migrate_entity_relation_data(self, processed_docs: dict):
  79. """Migrate existing entity and relation data to full_entities and full_relations storage"""
  80. logger.info(f"Starting data migration for {len(processed_docs)} documents")
  81. # Create mapping from chunk_id to doc_id
  82. chunk_to_doc = {}
  83. for doc_id, doc_status in processed_docs.items():
  84. chunk_ids = (
  85. doc_status.chunks_list
  86. if hasattr(doc_status, "chunks_list") and doc_status.chunks_list
  87. else []
  88. )
  89. for chunk_id in chunk_ids:
  90. chunk_to_doc[chunk_id] = doc_id
  91. # Initialize document entity and relation mappings
  92. doc_entities = {} # doc_id -> set of entity_names
  93. doc_relations = {} # doc_id -> set of relation_pairs (as tuples)
  94. # Get all nodes and edges from graph
  95. all_nodes = await self.chunk_entity_relation_graph.get_all_nodes()
  96. all_edges = await self.chunk_entity_relation_graph.get_all_edges()
  97. # Process all nodes once
  98. for node in all_nodes:
  99. if "source_id" in node:
  100. entity_id = node.get("entity_id") or node.get("id")
  101. if not entity_id:
  102. continue
  103. # Get chunk IDs from source_id
  104. source_ids = node["source_id"].split(GRAPH_FIELD_SEP)
  105. # Find which documents this entity belongs to
  106. for chunk_id in source_ids:
  107. doc_id = chunk_to_doc.get(chunk_id)
  108. if doc_id:
  109. if doc_id not in doc_entities:
  110. doc_entities[doc_id] = set()
  111. doc_entities[doc_id].add(entity_id)
  112. # Process all edges once
  113. for edge in all_edges:
  114. if "source_id" in edge:
  115. src = edge.get("source")
  116. tgt = edge.get("target")
  117. if not src or not tgt:
  118. continue
  119. # Get chunk IDs from source_id
  120. source_ids = edge["source_id"].split(GRAPH_FIELD_SEP)
  121. # Find which documents this relation belongs to
  122. for chunk_id in source_ids:
  123. doc_id = chunk_to_doc.get(chunk_id)
  124. if doc_id:
  125. if doc_id not in doc_relations:
  126. doc_relations[doc_id] = set()
  127. # Use tuple for set operations, convert to list later
  128. doc_relations[doc_id].add(tuple(sorted((src, tgt))))
  129. # Store the results in full_entities and full_relations
  130. migration_count = 0
  131. # Store entities
  132. if doc_entities:
  133. entities_data = {}
  134. for doc_id, entity_set in doc_entities.items():
  135. entities_data[doc_id] = {
  136. "entity_names": list(entity_set),
  137. "count": len(entity_set),
  138. }
  139. await self.full_entities.upsert(entities_data)
  140. # Store relations
  141. if doc_relations:
  142. relations_data = {}
  143. for doc_id, relation_set in doc_relations.items():
  144. # Convert tuples back to lists
  145. relations_data[doc_id] = {
  146. "relation_pairs": [list(pair) for pair in relation_set],
  147. "count": len(relation_set),
  148. }
  149. await self.full_relations.upsert(relations_data)
  150. migration_count = len(
  151. set(list(doc_entities.keys()) + list(doc_relations.keys()))
  152. )
  153. # Persist the migrated data
  154. await self.full_entities.index_done_callback()
  155. await self.full_relations.index_done_callback()
  156. logger.info(
  157. f"Data migration completed: migrated {migration_count} documents with entities/relations"
  158. )
  159. async def _migrate_chunk_tracking_storage(self) -> None:
  160. """Ensure entity/relation chunk tracking KV stores exist and are seeded."""
  161. if not self.entity_chunks or not self.relation_chunks:
  162. return
  163. need_entity_migration = False
  164. need_relation_migration = False
  165. try:
  166. need_entity_migration = await self.entity_chunks.is_empty()
  167. except Exception as exc: # pragma: no cover - defensive logging
  168. logger.error(f"Failed to check entity chunks storage: {exc}")
  169. raise exc
  170. try:
  171. need_relation_migration = await self.relation_chunks.is_empty()
  172. except Exception as exc: # pragma: no cover - defensive logging
  173. logger.error(f"Failed to check relation chunks storage: {exc}")
  174. raise exc
  175. if not need_entity_migration and not need_relation_migration:
  176. return
  177. BATCH_SIZE = 500 # Process 500 records per batch
  178. if need_entity_migration:
  179. try:
  180. nodes = await self.chunk_entity_relation_graph.get_all_nodes()
  181. except Exception as exc:
  182. logger.error(f"Failed to fetch nodes for chunk migration: {exc}")
  183. nodes = []
  184. logger.info(f"Starting chunk_tracking data migration: {len(nodes)} nodes")
  185. # Process nodes in batches
  186. total_nodes = len(nodes)
  187. total_batches = (total_nodes + BATCH_SIZE - 1) // BATCH_SIZE
  188. total_migrated = 0
  189. for batch_idx in range(total_batches):
  190. start_idx = batch_idx * BATCH_SIZE
  191. end_idx = min((batch_idx + 1) * BATCH_SIZE, total_nodes)
  192. batch_nodes = nodes[start_idx:end_idx]
  193. upsert_payload: dict[str, dict[str, object]] = {}
  194. for node in batch_nodes:
  195. entity_id = node.get("entity_id") or node.get("id")
  196. if not entity_id:
  197. continue
  198. raw_source = node.get("source_id") or ""
  199. chunk_ids = [
  200. chunk_id
  201. for chunk_id in raw_source.split(GRAPH_FIELD_SEP)
  202. if chunk_id
  203. ]
  204. if not chunk_ids:
  205. continue
  206. upsert_payload[entity_id] = {
  207. "chunk_ids": chunk_ids,
  208. "count": len(chunk_ids),
  209. }
  210. if upsert_payload:
  211. await self.entity_chunks.upsert(upsert_payload)
  212. total_migrated += len(upsert_payload)
  213. logger.info(
  214. f"Processed entity batch {batch_idx + 1}/{total_batches}: {len(upsert_payload)} records (total: {total_migrated}/{total_nodes})"
  215. )
  216. if total_migrated > 0:
  217. # Persist entity_chunks data to disk
  218. await self.entity_chunks.index_done_callback()
  219. logger.info(
  220. f"Entity chunk_tracking migration completed: {total_migrated} records persisted"
  221. )
  222. if need_relation_migration:
  223. try:
  224. edges = await self.chunk_entity_relation_graph.get_all_edges()
  225. except Exception as exc:
  226. logger.error(f"Failed to fetch edges for chunk migration: {exc}")
  227. edges = []
  228. logger.info(f"Starting chunk_tracking data migration: {len(edges)} edges")
  229. # Process edges in batches
  230. total_edges = len(edges)
  231. total_batches = (total_edges + BATCH_SIZE - 1) // BATCH_SIZE
  232. total_migrated = 0
  233. for batch_idx in range(total_batches):
  234. start_idx = batch_idx * BATCH_SIZE
  235. end_idx = min((batch_idx + 1) * BATCH_SIZE, total_edges)
  236. batch_edges = edges[start_idx:end_idx]
  237. upsert_payload: dict[str, dict[str, object]] = {}
  238. for edge in batch_edges:
  239. src = edge.get("source") or edge.get("src_id") or edge.get("src")
  240. tgt = edge.get("target") or edge.get("tgt_id") or edge.get("tgt")
  241. if not src or not tgt:
  242. continue
  243. raw_source = edge.get("source_id") or ""
  244. chunk_ids = [
  245. chunk_id
  246. for chunk_id in raw_source.split(GRAPH_FIELD_SEP)
  247. if chunk_id
  248. ]
  249. if not chunk_ids:
  250. continue
  251. storage_key = make_relation_chunk_key(src, tgt)
  252. upsert_payload[storage_key] = {
  253. "chunk_ids": chunk_ids,
  254. "count": len(chunk_ids),
  255. }
  256. if upsert_payload:
  257. await self.relation_chunks.upsert(upsert_payload)
  258. total_migrated += len(upsert_payload)
  259. logger.info(
  260. f"Processed relation batch {batch_idx + 1}/{total_batches}: {len(upsert_payload)} records (total: {total_migrated}/{total_edges})"
  261. )
  262. if total_migrated > 0:
  263. # Persist relation_chunks data to disk
  264. await self.relation_chunks.index_done_callback()
  265. logger.info(
  266. f"Relation chunk_tracking migration completed: {total_migrated} records persisted"
  267. )