| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030 |
- import asyncio
- import os
- from typing import Any, final, Optional, Dict
- from dataclasses import dataclass, fields
- import numpy as np
- from lightrag.utils import logger, compute_mdhash_id, _cooperative_yield
- from ..base import BaseVectorStorage
- from ..constants import DEFAULT_MAX_FILE_PATH_LENGTH
- from ..kg.shared_storage import get_data_init_lock, get_namespace_lock
- import pipmaster as pm
- if not pm.is_installed("pymilvus"):
- pm.install("pymilvus>=2.6.2")
- import configparser
- from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema # type: ignore
- from packaging import version
- config = configparser.ConfigParser()
- config.read("config.ini", "utf-8")
- @dataclass
- class _PendingVectorDoc:
- """Buffered vector upsert waiting for embedding and/or bulk flush."""
- source: dict[str, Any]
- content: str
- vector: list[float] | None = None
- # Supported index types
- SUPPORTED_INDEX_TYPES = {
- "AUTOINDEX",
- "HNSW",
- "HNSW_SQ",
- "HNSW_PQ",
- "HNSW_PRQ",
- "IVF_FLAT",
- "IVF_SQ8",
- "IVF_PQ",
- "DISKANN",
- "SCANN",
- }
- # Supported metric types
- SUPPORTED_METRIC_TYPES = {"COSINE", "L2", "IP"}
- # HNSW_SQ quantization types
- SUPPORTED_SQ_TYPES = {"SQ4U", "SQ6", "SQ8", "BF16", "FP16"}
- SUPPORTED_REFINE_TYPES = {"SQ6", "SQ8", "BF16", "FP16", "FP32"}
- # Index type version requirements
- # Important: HNSW_SQ was first introduced in Milvus 2.6.8 (not 2.5)
- INDEX_VERSION_REQUIREMENTS = {
- "HNSW_SQ": "2.6.8", # HNSW_SQ requires Milvus 2.6.8+ (supports sq_types such as SQ4U, SQ6, SQ8, BF16, FP16)
- }
- def _get_env_bool(key: str, default: bool = False) -> bool:
- """Parse environment variable as boolean"""
- val = os.environ.get(key, "").lower()
- if val in ("true", "1", "yes", "on"):
- return True
- elif val in ("false", "0", "no", "off"):
- return False
- return default
- def _get_env_int(key: str, default: int) -> int:
- """Parse environment variable as integer"""
- val = os.environ.get(key, "")
- if val:
- try:
- return int(val)
- except ValueError:
- logger.warning(
- f"Invalid integer value for {key}: {val}, using default {default}"
- )
- return default
- @dataclass
- class MilvusIndexConfig:
- """
- Milvus vector index configuration class
- Supports configuration via environment variables or initialization parameters.
- Initialization parameters take precedence over environment variables.
- """
- # Base configuration
- index_type: Optional[str] = None
- metric_type: Optional[str] = None
- # HNSW series parameters
- hnsw_m: Optional[int] = None
- hnsw_ef_construction: Optional[int] = None
- hnsw_ef: Optional[int] = None
- # HNSW_SQ specific parameters
- sq_type: Optional[str] = None
- sq_refine: Optional[bool] = None
- sq_refine_type: Optional[str] = None
- sq_refine_k: Optional[int] = None
- # IVF series parameters
- ivf_nlist: Optional[int] = None
- ivf_nprobe: Optional[int] = None
- def __post_init__(self):
- """Load configuration from environment variables (init parameters take precedence)"""
- # Index type
- self.index_type = (
- self.index_type or os.environ.get("MILVUS_INDEX_TYPE", "AUTOINDEX")
- ).upper()
- # Metric type
- self.metric_type = (
- self.metric_type or os.environ.get("MILVUS_METRIC_TYPE", "COSINE")
- ).upper()
- # HNSW parameters
- # Defaults aligned with Milvus 2.4+ official documentation
- if self.hnsw_m is None:
- self.hnsw_m = _get_env_int("MILVUS_HNSW_M", 16)
- if self.hnsw_ef_construction is None:
- self.hnsw_ef_construction = _get_env_int("MILVUS_HNSW_EF_CONSTRUCTION", 360)
- if self.hnsw_ef is None:
- self.hnsw_ef = _get_env_int("MILVUS_HNSW_EF", 200)
- # HNSW_SQ parameters
- if self.sq_type is None:
- self.sq_type = os.environ.get("MILVUS_HNSW_SQ_TYPE", "SQ8").upper()
- if self.sq_refine is None:
- self.sq_refine = _get_env_bool("MILVUS_HNSW_SQ_REFINE", False)
- if self.sq_refine_type is None:
- self.sq_refine_type = os.environ.get(
- "MILVUS_HNSW_SQ_REFINE_TYPE", "FP32"
- ).upper()
- if self.sq_refine_k is None:
- self.sq_refine_k = _get_env_int("MILVUS_HNSW_SQ_REFINE_K", 10)
- # IVF parameters
- if self.ivf_nlist is None:
- self.ivf_nlist = _get_env_int("MILVUS_IVF_NLIST", 1024)
- if self.ivf_nprobe is None:
- self.ivf_nprobe = _get_env_int("MILVUS_IVF_NPROBE", 16)
- # Validate configuration
- self._validate()
- def _validate(self):
- """Validate configuration validity"""
- if self.index_type not in SUPPORTED_INDEX_TYPES:
- raise ValueError(
- f"Unsupported index type: {self.index_type}. "
- f"Supported: {SUPPORTED_INDEX_TYPES}"
- )
- if self.metric_type not in SUPPORTED_METRIC_TYPES:
- raise ValueError(
- f"Unsupported metric type: {self.metric_type}. "
- f"Supported: {SUPPORTED_METRIC_TYPES}"
- )
- if self.index_type == "HNSW_SQ":
- if self.sq_type not in SUPPORTED_SQ_TYPES:
- raise ValueError(
- f"Unsupported sq_type: {self.sq_type}. "
- f"Supported: {SUPPORTED_SQ_TYPES}"
- )
- if self.sq_refine and self.sq_refine_type not in SUPPORTED_REFINE_TYPES:
- raise ValueError(
- f"Unsupported refine_type: {self.sq_refine_type}. "
- f"Supported: {SUPPORTED_REFINE_TYPES}"
- )
- # Parameter range validation
- if not (2 <= self.hnsw_m <= 2048):
- raise ValueError(f"hnsw_m must be in [2, 2048], got {self.hnsw_m}")
- if self.hnsw_ef_construction < 1:
- raise ValueError(
- f"hnsw_ef_construction must be >= 1, got {self.hnsw_ef_construction}"
- )
- if self.ivf_nlist < 1 or self.ivf_nlist > 65536:
- raise ValueError(f"ivf_nlist must be in [1, 65536], got {self.ivf_nlist}")
- def validate_milvus_version(self, server_version: str) -> None:
- """
- Validate Milvus server version supports the configured index type
- Args:
- server_version: Milvus server version string (e.g., "2.6.9")
- Raises:
- ValueError: Version does not meet index type requirements
- """
- current_ver = version.parse(
- server_version.split("-")[0]
- ) # Handle "2.6.9-dev" format
- # Check HNSW_SQ index type version requirements (requires 2.6.8+)
- if self.index_type == "HNSW_SQ":
- required = INDEX_VERSION_REQUIREMENTS["HNSW_SQ"]
- if current_ver < version.parse(required):
- raise ValueError(
- f"HNSW_SQ requires Milvus {required}+, "
- f"current version: {server_version}"
- )
- logger.info(
- f"Milvus version {server_version} validated for index type "
- f"{self.index_type}"
- + (f" with sq_type {self.sq_type}" if self.index_type == "HNSW_SQ" else "")
- )
- def build_index_params(self, index_params, field_name: str = "vector"):
- """
- Build pymilvus index parameters
- Args:
- index_params: IndexParams instance (from compatibility helper or client.prepare_index_params())
- field_name: Vector field name
- Returns:
- IndexParams object, or a dict fallback when direct API creation is needed.
- """
- if index_params is None:
- if self.index_type == "AUTOINDEX":
- logger.info(
- "Using AUTOINDEX with direct API fallback because IndexParams is unavailable"
- )
- return {
- "field_name": field_name,
- "index_type": self.index_type,
- "metric_type": self.metric_type,
- "params": {},
- }
- raise RuntimeError(
- f"IndexParams not available but required for index type "
- f"'{self.index_type}'. Ensure pymilvus is installed correctly."
- )
- params: Dict[str, Any] = {}
- # HNSW series indexes
- if self.index_type in ("HNSW", "HNSW_SQ", "HNSW_PQ", "HNSW_PRQ"):
- params["M"] = self.hnsw_m
- params["efConstruction"] = self.hnsw_ef_construction
- # HNSW_SQ specific parameters
- if self.index_type == "HNSW_SQ":
- params["sq_type"] = self.sq_type
- if self.sq_refine:
- params["refine"] = True
- params["refine_type"] = self.sq_refine_type
- # IVF series indexes
- elif self.index_type in ("IVF_FLAT", "IVF_SQ8", "IVF_PQ"):
- params["nlist"] = self.ivf_nlist
- # DISKANN / SCANN have no additional params
- index_params.add_index(
- field_name=field_name,
- index_type=self.index_type,
- metric_type=self.metric_type,
- params=params,
- )
- logger.info(
- f"Milvus index configured: type={self.index_type}, "
- f"metric={self.metric_type}, params={params}"
- )
- return index_params
- def build_search_params(self) -> Dict[str, Any]:
- """
- Build search parameters
- Returns:
- Search parameters dictionary
- """
- search_params: Dict[str, Any] = {}
- if self.index_type in ("HNSW", "HNSW_SQ", "HNSW_PQ", "HNSW_PRQ"):
- search_params["ef"] = self.hnsw_ef
- if self.index_type == "HNSW_SQ" and self.sq_refine:
- search_params["refine_k"] = self.sq_refine_k
- elif self.index_type in ("IVF_FLAT", "IVF_SQ8", "IVF_PQ"):
- search_params["nprobe"] = self.ivf_nprobe
- return {"params": search_params} if search_params else {}
- @classmethod
- def get_config_field_names(cls) -> set:
- """Get all configuration field names from the dataclass.
- This method provides a single source of truth for configuration parameter names,
- eliminating the need to maintain duplicate hardcoded lists elsewhere.
- Returns:
- Set of field names that can be used to extract configuration from kwargs
- """
- return {f.name for f in fields(cls)}
- def to_dict(self) -> Dict[str, Any]:
- """Export configuration as dictionary (for logging/debugging)"""
- return {
- "index_type": self.index_type,
- "metric_type": self.metric_type,
- "hnsw_m": self.hnsw_m,
- "hnsw_ef_construction": self.hnsw_ef_construction,
- "hnsw_ef": self.hnsw_ef,
- "sq_type": self.sq_type if self.index_type == "HNSW_SQ" else None,
- "sq_refine": self.sq_refine if self.index_type == "HNSW_SQ" else None,
- "sq_refine_type": (
- self.sq_refine_type
- if self.index_type == "HNSW_SQ" and self.sq_refine
- else None
- ),
- "sq_refine_k": (
- self.sq_refine_k
- if self.index_type == "HNSW_SQ" and self.sq_refine
- else None
- ),
- "ivf_nlist": (
- self.ivf_nlist if self.index_type.startswith("IVF") else None
- ),
- "ivf_nprobe": (
- self.ivf_nprobe if self.index_type.startswith("IVF") else None
- ),
- }
- @final
- @dataclass
- class MilvusVectorDBStorage(BaseVectorStorage):
- def _get_milvus_connection_kwargs(self, include_db_name: bool = True) -> dict:
- """Build Milvus connection kwargs from env/config."""
- connection_kwargs = {
- "uri": os.environ.get(
- "MILVUS_URI",
- config.get(
- "milvus",
- "uri",
- fallback=os.path.join(
- self.global_config["working_dir"], "milvus_lite.db"
- ),
- ),
- ),
- "user": os.environ.get(
- "MILVUS_USER", config.get("milvus", "user", fallback=None)
- ),
- "password": os.environ.get(
- "MILVUS_PASSWORD",
- config.get("milvus", "password", fallback=None),
- ),
- "token": os.environ.get(
- "MILVUS_TOKEN", config.get("milvus", "token", fallback=None)
- ),
- }
- db_name = os.environ.get(
- "MILVUS_DB_NAME",
- config.get("milvus", "db_name", fallback=None),
- )
- if include_db_name and db_name:
- connection_kwargs["db_name"] = db_name
- return connection_kwargs
- def _get_milvus_db_name(self) -> Optional[str]:
- """Return the configured Milvus database name, if any."""
- db_name = self._get_milvus_connection_kwargs(include_db_name=True).get(
- "db_name"
- )
- if db_name is None:
- return None
- normalized_name = str(db_name).strip()
- return normalized_name or None
- def _create_milvus_client(self) -> MilvusClient:
- """Create a Milvus client and ensure the configured database exists."""
- client = MilvusClient(
- **self._get_milvus_connection_kwargs(include_db_name=False)
- )
- db_name = self._get_milvus_db_name()
- if not db_name:
- return client
- existing_databases = set(client.list_databases())
- if db_name not in existing_databases:
- logger.warning(
- f"[{self.workspace}] Milvus database '{db_name}' not found, creating it"
- )
- client.create_database(db_name)
- use_database = getattr(client, "use_database", None) or getattr(
- client, "using_database", None
- )
- if callable(use_database):
- use_database(db_name)
- logger.debug(
- f"[{self.workspace}] Using Milvus database '{db_name}' for namespace '{self.namespace}'"
- )
- return client
- return MilvusClient(**self._get_milvus_connection_kwargs(include_db_name=True))
- def _create_schema_for_namespace(self) -> CollectionSchema:
- """Create schema based on the current instance's namespace"""
- # Get vector dimension from embedding_func
- dimension = self.embedding_func.embedding_dim
- # Base fields (common to all collections)
- base_fields = [
- FieldSchema(
- name="id", dtype=DataType.VARCHAR, max_length=64, is_primary=True
- ),
- FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
- FieldSchema(name="created_at", dtype=DataType.INT64),
- ]
- # Determine specific fields based on namespace
- if self.namespace.endswith("entities"):
- specific_fields = [
- FieldSchema(
- name="entity_name",
- dtype=DataType.VARCHAR,
- max_length=512,
- nullable=True,
- ),
- FieldSchema(
- name="file_path",
- dtype=DataType.VARCHAR,
- max_length=DEFAULT_MAX_FILE_PATH_LENGTH,
- nullable=True,
- ),
- ]
- description = "LightRAG entities vector storage"
- elif self.namespace.endswith("relationships"):
- specific_fields = [
- FieldSchema(
- name="src_id", dtype=DataType.VARCHAR, max_length=512, nullable=True
- ),
- FieldSchema(
- name="tgt_id", dtype=DataType.VARCHAR, max_length=512, nullable=True
- ),
- FieldSchema(
- name="file_path",
- dtype=DataType.VARCHAR,
- max_length=DEFAULT_MAX_FILE_PATH_LENGTH,
- nullable=True,
- ),
- ]
- description = "LightRAG relationships vector storage"
- elif self.namespace.endswith("chunks"):
- specific_fields = [
- FieldSchema(
- name="full_doc_id",
- dtype=DataType.VARCHAR,
- max_length=64,
- nullable=True,
- ),
- FieldSchema(
- name="file_path",
- dtype=DataType.VARCHAR,
- max_length=DEFAULT_MAX_FILE_PATH_LENGTH,
- nullable=True,
- ),
- ]
- description = "LightRAG chunks vector storage"
- else:
- # Default generic schema (backward compatibility)
- specific_fields = [
- FieldSchema(
- name="file_path",
- dtype=DataType.VARCHAR,
- max_length=DEFAULT_MAX_FILE_PATH_LENGTH,
- nullable=True,
- ),
- ]
- description = "LightRAG generic vector storage"
- # Merge all fields
- all_fields = base_fields + specific_fields
- return CollectionSchema(
- fields=all_fields,
- description=description,
- enable_dynamic_field=True, # Support dynamic fields
- )
- def _get_index_params(self):
- """Get IndexParams in a version-compatible way"""
- try:
- # Try to use client's prepare_index_params method (most common)
- if hasattr(self._client, "prepare_index_params"):
- return self._client.prepare_index_params()
- except Exception:
- pass
- try:
- # Try to import IndexParams from different possible locations
- from pymilvus.client.prepare import IndexParams # type: ignore
- return IndexParams()
- except ImportError:
- pass
- try:
- from pymilvus.client.types import IndexParams # type: ignore
- return IndexParams()
- except ImportError:
- pass
- try:
- from pymilvus import IndexParams # type: ignore
- return IndexParams()
- except ImportError:
- pass
- # If all else fails, return None to use fallback method
- return None
- def _create_scalar_index_fallback(self, field_name: str, index_type: str):
- """Fallback method to create scalar index using direct API"""
- # Skip unsupported index types
- if index_type == "SORTED":
- logger.info(
- f"[{self.workspace}] Skipping SORTED index for {field_name} (not supported in this Milvus version)"
- )
- return
- try:
- self._client.create_index(
- collection_name=self.final_namespace,
- field_name=field_name,
- index_params={"index_type": index_type},
- )
- logger.debug(
- f"[{self.workspace}] Created {field_name} index using fallback method"
- )
- except Exception as e:
- logger.info(
- f"[{self.workspace}] Could not create {field_name} index using fallback method: {e}"
- )
- def _create_indexes_after_collection(self):
- """Create indexes after collection is created"""
- # Build vector index using index configuration
- # Use compatibility helper to get IndexParams
- index_params_for_vector = self._get_index_params()
- vector_index_params = self.index_config.build_index_params(
- index_params_for_vector, field_name="vector"
- )
- # Re-raise exceptions to surface vector index creation failures
- if isinstance(vector_index_params, dict):
- self._client.create_index(
- collection_name=self.final_namespace,
- field_name=vector_index_params["field_name"],
- index_params={
- "index_type": vector_index_params["index_type"],
- "metric_type": vector_index_params["metric_type"],
- "params": vector_index_params["params"],
- },
- )
- else:
- self._client.create_index(
- collection_name=self.final_namespace,
- index_params=vector_index_params,
- )
- logger.debug(
- f"[{self.workspace}] Created vector index with config: {self.index_config.to_dict()}"
- )
- # Create scalar indexes based on namespace
- # Wrap scalar index creation in try-except to allow graceful degradation
- try:
- # Try to get IndexParams in a version-compatible way
- scalar_index_params = self._get_index_params()
- if scalar_index_params is not None:
- # Create scalar indexes based on namespace
- if self.namespace.endswith("entities"):
- # Create indexes for entity fields
- try:
- entity_name_index = self._get_index_params()
- entity_name_index.add_index(
- field_name="entity_name", index_type="INVERTED"
- )
- self._client.create_index(
- collection_name=self.final_namespace,
- index_params=entity_name_index,
- )
- except Exception as e:
- logger.debug(
- f"[{self.workspace}] IndexParams method failed for entity_name: {e}"
- )
- self._create_scalar_index_fallback("entity_name", "INVERTED")
- elif self.namespace.endswith("relationships"):
- # Create indexes for relationship fields
- try:
- src_id_index = self._get_index_params()
- src_id_index.add_index(
- field_name="src_id", index_type="INVERTED"
- )
- self._client.create_index(
- collection_name=self.final_namespace,
- index_params=src_id_index,
- )
- except Exception as e:
- logger.debug(
- f"[{self.workspace}] IndexParams method failed for src_id: {e}"
- )
- self._create_scalar_index_fallback("src_id", "INVERTED")
- try:
- tgt_id_index = self._get_index_params()
- tgt_id_index.add_index(
- field_name="tgt_id", index_type="INVERTED"
- )
- self._client.create_index(
- collection_name=self.final_namespace,
- index_params=tgt_id_index,
- )
- except Exception as e:
- logger.debug(
- f"[{self.workspace}] IndexParams method failed for tgt_id: {e}"
- )
- self._create_scalar_index_fallback("tgt_id", "INVERTED")
- elif self.namespace.endswith("chunks"):
- # Create indexes for chunk fields
- try:
- doc_id_index = self._get_index_params()
- doc_id_index.add_index(
- field_name="full_doc_id", index_type="INVERTED"
- )
- self._client.create_index(
- collection_name=self.final_namespace,
- index_params=doc_id_index,
- )
- except Exception as e:
- logger.debug(
- f"[{self.workspace}] IndexParams method failed for full_doc_id: {e}"
- )
- self._create_scalar_index_fallback("full_doc_id", "INVERTED")
- else:
- # Fallback to direct API calls if IndexParams is not available
- logger.info(
- f"[{self.workspace}] IndexParams not available, using fallback methods for {self.namespace}"
- )
- # Create scalar indexes using fallback
- if self.namespace.endswith("entities"):
- self._create_scalar_index_fallback("entity_name", "INVERTED")
- elif self.namespace.endswith("relationships"):
- self._create_scalar_index_fallback("src_id", "INVERTED")
- self._create_scalar_index_fallback("tgt_id", "INVERTED")
- elif self.namespace.endswith("chunks"):
- self._create_scalar_index_fallback("full_doc_id", "INVERTED")
- logger.info(
- f"[{self.workspace}] Created indexes for collection: {self.namespace}"
- )
- except Exception as e:
- # Scalar index failures are logged as warnings (not critical)
- logger.warning(
- f"[{self.workspace}] Failed to create some scalar indexes for {self.namespace}: {e}"
- )
- def _get_required_fields_for_namespace(self) -> dict:
- """Get required core field definitions for current namespace"""
- # Base fields (common to all types)
- base_fields = {
- "id": {"type": "VarChar", "is_primary": True},
- "vector": {"type": "FloatVector"},
- "created_at": {"type": "Int64"},
- }
- # Add specific fields based on namespace
- if self.namespace.endswith("entities"):
- specific_fields = {
- "entity_name": {"type": "VarChar"},
- "file_path": {"type": "VarChar"},
- }
- elif self.namespace.endswith("relationships"):
- specific_fields = {
- "src_id": {"type": "VarChar"},
- "tgt_id": {"type": "VarChar"},
- "file_path": {"type": "VarChar"},
- }
- elif self.namespace.endswith("chunks"):
- specific_fields = {
- "full_doc_id": {"type": "VarChar"},
- "file_path": {"type": "VarChar"},
- }
- else:
- specific_fields = {
- "file_path": {"type": "VarChar"},
- }
- return {**base_fields, **specific_fields}
- def _is_field_compatible(self, existing_field: dict, expected_config: dict) -> bool:
- """Check compatibility of a single field"""
- field_name = existing_field.get("name", "unknown")
- existing_type = existing_field.get("type")
- expected_type = expected_config.get("type")
- logger.debug(
- f"[{self.workspace}] Checking field '{field_name}': existing_type={existing_type} (type={type(existing_type)}), expected_type={expected_type}"
- )
- # Convert DataType enum values to string names if needed
- original_existing_type = existing_type
- if hasattr(existing_type, "name"):
- existing_type = existing_type.name
- logger.debug(
- f"[{self.workspace}] Converted enum to name: {original_existing_type} -> {existing_type}"
- )
- elif isinstance(existing_type, int):
- # Map common Milvus internal type codes to type names for backward compatibility
- type_mapping = {
- 21: "VarChar",
- 101: "FloatVector",
- 5: "Int64",
- 9: "Double",
- }
- mapped_type = type_mapping.get(existing_type, str(existing_type))
- logger.debug(
- f"[{self.workspace}] Mapped numeric type: {existing_type} -> {mapped_type}"
- )
- existing_type = mapped_type
- # Normalize type names for comparison
- type_aliases = {
- "VARCHAR": "VarChar",
- "String": "VarChar",
- "FLOAT_VECTOR": "FloatVector",
- "INT64": "Int64",
- "BigInt": "Int64",
- "DOUBLE": "Double",
- "Float": "Double",
- }
- original_existing = existing_type
- original_expected = expected_type
- existing_type = type_aliases.get(existing_type, existing_type)
- expected_type = type_aliases.get(expected_type, expected_type)
- if original_existing != existing_type or original_expected != expected_type:
- logger.debug(
- f"[{self.workspace}] Applied aliases: {original_existing} -> {existing_type}, {original_expected} -> {expected_type}"
- )
- # Basic type compatibility check
- type_compatible = existing_type == expected_type
- logger.debug(
- f"[{self.workspace}] Type compatibility for '{field_name}': {existing_type} == {expected_type} -> {type_compatible}"
- )
- if not type_compatible:
- logger.warning(
- f"[{self.workspace}] Type mismatch for field '{field_name}': expected {expected_type}, got {existing_type}"
- )
- return False
- # Primary key check - be more flexible about primary key detection
- if expected_config.get("is_primary"):
- # Check multiple possible field names for primary key status
- is_primary = (
- existing_field.get("is_primary_key", False)
- or existing_field.get("is_primary", False)
- or existing_field.get("primary_key", False)
- )
- logger.debug(
- f"[{self.workspace}] Primary key check for '{field_name}': expected=True, actual={is_primary}"
- )
- logger.debug(
- f"[{self.workspace}] Raw field data for '{field_name}': {existing_field}"
- )
- # For ID field, be more lenient - if it's the ID field, assume it should be primary
- if field_name == "id" and not is_primary:
- logger.info(
- f"[{self.workspace}] ID field '{field_name}' not marked as primary in existing collection, but treating as compatible"
- )
- # Don't fail for ID field primary key mismatch
- elif not is_primary:
- logger.warning(
- f"[{self.workspace}] Primary key mismatch for field '{field_name}': expected primary key, but field is not primary"
- )
- return False
- logger.debug(f"[{self.workspace}] Field '{field_name}' is compatible")
- return True
- def _check_vector_dimension(self, collection_info: dict):
- """Check vector dimension compatibility"""
- current_dimension = self.embedding_func.embedding_dim
- # Find vector field dimension
- for field in collection_info.get("fields", []):
- if field.get("name") == "vector":
- field_type = field.get("type")
- # Extract type name from DataType enum or string
- type_name = None
- if hasattr(field_type, "name"):
- type_name = field_type.name
- elif isinstance(field_type, str):
- type_name = field_type
- else:
- type_name = str(field_type)
- # Check if it's a vector type (supports multiple formats)
- if type_name in ["FloatVector", "FLOAT_VECTOR"]:
- existing_dimension = field.get("params", {}).get("dim")
- # Convert both to int for comparison to handle type mismatches
- # (Milvus API may return string "1024" vs int 1024)
- try:
- existing_dim_int = (
- int(existing_dimension)
- if existing_dimension is not None
- else None
- )
- current_dim_int = (
- int(current_dimension)
- if current_dimension is not None
- else None
- )
- except (TypeError, ValueError) as e:
- logger.error(
- f"[{self.workspace}] Failed to parse dimensions: existing={existing_dimension} (type={type(existing_dimension)}), "
- f"current={current_dimension} (type={type(current_dimension)}), error={e}"
- )
- raise ValueError(
- f"Invalid dimension values for collection '{self.final_namespace}': "
- f"existing={existing_dimension}, current={current_dimension}"
- ) from e
- if existing_dim_int != current_dim_int:
- raise ValueError(
- f"Vector dimension mismatch for collection '{self.final_namespace}': "
- f"existing={existing_dim_int}, current={current_dim_int}"
- )
- logger.debug(
- f"[{self.workspace}] Vector dimension check passed: {current_dim_int}"
- )
- return
- # If no vector field found, this might be an old collection created with simple schema
- logger.warning(
- f"[{self.workspace}] Vector field not found in collection '{self.namespace}'. This might be an old collection created with simple schema."
- )
- logger.warning(
- f"[{self.workspace}] Consider recreating the collection for optimal performance."
- )
- return
- def _check_file_path_length_restriction(self, collection_info: dict) -> bool:
- """Check if collection has file_path length restrictions that need migration
- Returns:
- bool: True if migration is needed, False otherwise
- """
- existing_fields = {
- field["name"]: field for field in collection_info.get("fields", [])
- }
- # Check if file_path field exists and has length restrictions
- if "file_path" in existing_fields:
- file_path_field = existing_fields["file_path"]
- # Get max_length from field params
- max_length = file_path_field.get("params", {}).get("max_length")
- if max_length and max_length < DEFAULT_MAX_FILE_PATH_LENGTH:
- logger.info(
- f"[{self.workspace}] Collection {self.namespace} has file_path max_length={max_length}, "
- f"needs migration to {DEFAULT_MAX_FILE_PATH_LENGTH}"
- )
- return True
- return False
- def _check_schema_compatibility(self, collection_info: dict):
- """Check schema field compatibility and detect migration needs"""
- existing_fields = {
- field["name"]: field for field in collection_info.get("fields", [])
- }
- # Check if this is an old collection created with simple schema
- has_vector_field = any(
- field.get("name") == "vector" for field in collection_info.get("fields", [])
- )
- if not has_vector_field:
- logger.warning(
- f"[{self.workspace}] Collection {self.namespace} appears to be created with old simple schema (no vector field)"
- )
- logger.warning(
- f"[{self.workspace}] This collection will work but may have suboptimal performance"
- )
- logger.warning(
- f"[{self.workspace}] Consider recreating the collection for optimal performance"
- )
- return
- # Check if migration is needed for file_path length restrictions
- if self._check_file_path_length_restriction(collection_info):
- logger.info(
- f"[{self.workspace}] Starting automatic migration for collection {self.namespace}"
- )
- self._migrate_collection_schema()
- return
- # For collections with vector field, check basic compatibility
- # Only check for critical incompatibilities, not missing optional fields
- critical_fields = {"id": {"type": "VarChar", "is_primary": True}}
- incompatible_fields = []
- for field_name, expected_config in critical_fields.items():
- if field_name in existing_fields:
- existing_field = existing_fields[field_name]
- if not self._is_field_compatible(existing_field, expected_config):
- incompatible_fields.append(
- f"{field_name}: expected {expected_config['type']}, "
- f"got {existing_field.get('type')}"
- )
- if incompatible_fields:
- raise ValueError(
- f"Critical schema incompatibility in collection '{self.final_namespace}': {incompatible_fields}"
- )
- # Get all expected fields for informational purposes
- expected_fields = self._get_required_fields_for_namespace()
- missing_fields = [
- field for field in expected_fields if field not in existing_fields
- ]
- if missing_fields:
- logger.info(
- f"[{self.workspace}] Collection {self.namespace} missing optional fields: {missing_fields}"
- )
- logger.info(
- "These fields would be available in a newly created collection for better performance"
- )
- logger.debug(
- f"[{self.workspace}] Schema compatibility check passed for {self.namespace}"
- )
- def _migrate_collection_schema(self):
- """Migrate collection schema using query_iterator - completely solves query window limitations"""
- original_collection_name = self.final_namespace
- temp_collection_name = f"{self.final_namespace}_temp"
- iterator = None
- try:
- logger.info(
- f"[{self.workspace}] Starting iterator-based schema migration for {self.namespace}"
- )
- # Step 1: Create temporary collection with new schema
- logger.info(
- f"[{self.workspace}] Step 1: Creating temporary collection: {temp_collection_name}"
- )
- # Temporarily update final_namespace for index creation
- self.final_namespace = temp_collection_name
- new_schema = self._create_schema_for_namespace()
- self._client.create_collection(
- collection_name=temp_collection_name, schema=new_schema
- )
- try:
- self._create_indexes_after_collection()
- except Exception as index_error:
- logger.warning(
- f"[{self.workspace}] Failed to create indexes for new collection: {index_error}"
- )
- # Continue with migration even if index creation fails
- # Load the new collection
- self._client.load_collection(temp_collection_name)
- # Step 2: Copy data using query_iterator (solves query window limitation)
- logger.info(
- f"[{self.workspace}] Step 2: Copying data using query_iterator from: {original_collection_name}"
- )
- # Create query iterator
- try:
- iterator = self._client.query_iterator(
- collection_name=original_collection_name,
- batch_size=2000, # Adjustable batch size for optimal performance
- output_fields=["*"], # Get all fields
- )
- logger.debug(f"[{self.workspace}] Query iterator created successfully")
- except Exception as iterator_error:
- logger.error(
- f"[{self.workspace}] Failed to create query iterator: {iterator_error}"
- )
- raise
- # Iterate through all data
- total_migrated = 0
- batch_number = 1
- while True:
- try:
- batch_data = iterator.next()
- if not batch_data:
- # No more data available
- break
- # Insert batch data to new collection
- try:
- self._client.insert(
- collection_name=temp_collection_name, data=batch_data
- )
- total_migrated += len(batch_data)
- logger.info(
- f"[{self.workspace}] Iterator batch {batch_number}: "
- f"processed {len(batch_data)} records, total migrated: {total_migrated}"
- )
- batch_number += 1
- except Exception as batch_error:
- logger.error(
- f"[{self.workspace}] Failed to insert iterator batch {batch_number}: {batch_error}"
- )
- raise
- except Exception as next_error:
- logger.error(
- f"[{self.workspace}] Iterator next() failed at batch {batch_number}: {next_error}"
- )
- raise
- if total_migrated > 0:
- logger.info(
- f"[{self.workspace}] Successfully migrated {total_migrated} records using iterator"
- )
- else:
- logger.info(
- f"[{self.workspace}] No data found in original collection, migration completed"
- )
- # Step 3: Rename origin collection (keep for safety)
- logger.info(
- f"[{self.workspace}] Step 3: Rename origin collection to {original_collection_name}_old"
- )
- try:
- self._client.rename_collection(
- original_collection_name, f"{original_collection_name}_old"
- )
- except Exception as rename_error:
- try:
- logger.warning(
- f"[{self.workspace}] Try to drop origin collection instead"
- )
- self._client.drop_collection(original_collection_name)
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Rename operation failed: {rename_error}"
- )
- raise e
- # Step 4: Rename temporary collection to original name
- logger.info(
- f"[{self.workspace}] Step 4: Renaming collection {temp_collection_name} -> {original_collection_name}"
- )
- try:
- self._client.rename_collection(
- temp_collection_name, original_collection_name
- )
- logger.info(f"[{self.workspace}] Rename operation completed")
- except Exception as rename_error:
- logger.error(
- f"[{self.workspace}] Rename operation failed: {rename_error}"
- )
- raise RuntimeError(
- f"Failed to rename collection: {rename_error}"
- ) from rename_error
- # Restore final_namespace
- self.final_namespace = original_collection_name
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Iterator-based migration failed for {self.namespace}: {e}"
- )
- # Attempt cleanup of temporary collection if it exists
- try:
- if self._client and self._client.has_collection(temp_collection_name):
- logger.info(
- f"[{self.workspace}] Cleaning up failed migration temporary collection"
- )
- self._client.drop_collection(temp_collection_name)
- except Exception as cleanup_error:
- logger.warning(
- f"[{self.workspace}] Failed to cleanup temporary collection: {cleanup_error}"
- )
- # Re-raise the original error
- raise RuntimeError(
- f"Iterator-based migration failed for collection {self.namespace}: {e}"
- ) from e
- finally:
- # Ensure iterator is properly closed
- if iterator:
- try:
- iterator.close()
- logger.debug(
- f"[{self.workspace}] Query iterator closed successfully"
- )
- except Exception as close_error:
- logger.warning(
- f"[{self.workspace}] Failed to close query iterator: {close_error}"
- )
- def _validate_collection_compatibility(self):
- """Validate existing collection's dimension and schema compatibility"""
- try:
- collection_info = self._client.describe_collection(self.final_namespace)
- # 1. Check vector dimension
- self._check_vector_dimension(collection_info)
- # 2. Check schema compatibility
- self._check_schema_compatibility(collection_info)
- logger.info(
- f"[{self.workspace}] VectorDB Collection '{self.namespace}' compatibility validation passed"
- )
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Collection compatibility validation failed for {self.namespace}: {e}"
- )
- raise
- @staticmethod
- def _is_missing_vector_index_error(error: Exception) -> bool:
- """Return True when the error indicates the collection lacks a vector index."""
- error_message = str(error).lower()
- return (
- "no vector index" in error_message
- or "please create index firstly" in error_message
- )
- def _repair_missing_vector_index(self):
- """Create indexes for an existing collection that is missing its vector index."""
- logger.warning(
- f"[{self.workspace}] Collection '{self.namespace}' is missing a vector index, attempting repair"
- )
- self._create_indexes_after_collection()
- def _ensure_collection_loaded(self):
- """Ensure the collection is loaded into memory for search operations"""
- try:
- # Check if collection exists first
- if not self._client.has_collection(self.final_namespace):
- logger.error(
- f"[{self.workspace}] Collection {self.namespace} does not exist"
- )
- raise ValueError(f"Collection {self.final_namespace} does not exist")
- # Load the collection if it's not already loaded
- # In Milvus, collections need to be loaded before they can be searched
- self._client.load_collection(self.final_namespace)
- # logger.debug(f"[{self.workspace}] Collection {self.namespace} loaded successfully")
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Failed to load collection {self.namespace}: {e}"
- )
- raise
- def _create_collection_if_not_exist(self):
- """Create collection if not exists and check existing collection compatibility"""
- try:
- # Check if our specific collection exists
- collection_exists = self._client.has_collection(self.final_namespace)
- logger.info(
- f"[{self.workspace}] VectorDB collection '{self.namespace}' exists check: {collection_exists}"
- )
- if collection_exists:
- # Double-check by trying to describe the collection
- try:
- self._client.describe_collection(self.final_namespace)
- self._validate_collection_compatibility()
- try:
- # Ensure the collection is loaded after validation
- self._ensure_collection_loaded()
- return
- except Exception as load_error:
- if not self._is_missing_vector_index_error(load_error):
- raise
- try:
- self._repair_missing_vector_index()
- self._ensure_collection_loaded()
- logger.info(
- f"[{self.workspace}] Repaired missing vector index for existing collection '{self.namespace}'"
- )
- return
- except Exception as repair_error:
- raise RuntimeError(
- f"Index repair failed for collection '{self.final_namespace}'. "
- f"Original error: {repair_error}"
- ) from repair_error
- except Exception as validation_error:
- # CRITICAL: Collection exists but validation failed
- # This indicates potential data migration failure or incompatible schema
- # Stop execution to prevent data loss and require manual intervention
- logger.error(
- f"[{self.workspace}] CRITICAL ERROR: Collection '{self.namespace}' exists but validation failed!"
- )
- logger.error(
- f"[{self.workspace}] This indicates potential data migration failure or schema incompatibility."
- )
- logger.error(
- f"[{self.workspace}] Validation error: {validation_error}"
- )
- logger.error(f"[{self.workspace}] MANUAL INTERVENTION REQUIRED:")
- logger.error(
- f"[{self.workspace}] 1. Check the existing collection schema and data integrity"
- )
- logger.error(
- f"[{self.workspace}] 2. Backup existing data if needed"
- )
- logger.error(
- f"[{self.workspace}] 3. Manually resolve schema compatibility issues"
- )
- logger.error(
- f"[{self.workspace}] 4. Consider dropping and recreating the collection if data is not critical"
- )
- logger.error(
- f"[{self.workspace}] Program execution stopped to prevent potential data loss."
- )
- # Raise a specific exception to stop execution
- raise RuntimeError(
- f"Collection validation failed for '{self.final_namespace}'. "
- f"Data migration failure detected. Manual intervention required to prevent data loss. "
- f"Original error: {validation_error}"
- )
- # Collection doesn't exist, create new collection
- logger.info(f"[{self.workspace}] Creating new collection: {self.namespace}")
- schema = self._create_schema_for_namespace()
- # Create collection with schema only first
- self._client.create_collection(
- collection_name=self.final_namespace, schema=schema
- )
- # Then create indexes
- self._create_indexes_after_collection()
- # Load the newly created collection
- self._ensure_collection_loaded()
- logger.info(
- f"[{self.workspace}] Successfully created Milvus collection: {self.namespace}"
- )
- except RuntimeError:
- # Re-raise RuntimeError (validation failures) without modification
- # These are critical errors that should stop execution
- raise
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error in _create_collection_if_not_exist for {self.namespace}: {e}"
- )
- # If there's any error (other than validation failure), try to force create the collection
- logger.info(
- f"[{self.workspace}] Attempting to force create collection {self.namespace}..."
- )
- try:
- # Try to drop the collection first if it exists in a bad state
- try:
- if self._client.has_collection(self.final_namespace):
- logger.info(
- f"[{self.workspace}] Dropping potentially corrupted collection {self.namespace}"
- )
- self._client.drop_collection(self.final_namespace)
- except Exception as drop_error:
- logger.warning(
- f"[{self.workspace}] Could not drop collection {self.namespace}: {drop_error}"
- )
- # Create fresh collection
- schema = self._create_schema_for_namespace()
- self._client.create_collection(
- collection_name=self.final_namespace, schema=schema
- )
- self._create_indexes_after_collection()
- # Load the newly created collection
- self._ensure_collection_loaded()
- logger.info(
- f"[{self.workspace}] Successfully force-created collection {self.namespace}"
- )
- except Exception as create_error:
- logger.error(
- f"[{self.workspace}] Failed to force-create collection {self.namespace}: {create_error}"
- )
- raise
- def __post_init__(self):
- self._validate_embedding_func()
- # Extract MilvusIndexConfig parameters from vector_db_storage_cls_kwargs
- #
- # IMPORTANT: This approach allows Milvus index configuration via vector_db_storage_cls_kwargs,
- # which is the RECOMMENDED method for framework integration (e.g., RAGAnything).
- #
- # All 11 index configuration parameters can be passed through vector_db_storage_cls_kwargs:
- # - index_type, metric_type
- # - hnsw_m, hnsw_ef_construction, hnsw_ef
- # - sq_type, sq_refine, sq_refine_type, sq_refine_k
- # - ivf_nlist, ivf_nprobe
- #
- # Example:
- # LightRAG(
- # vector_storage="MilvusVectorDBStorage",
- # vector_db_storage_cls_kwargs={
- # "cosine_better_than_threshold": 0.2,
- # "index_type": "HNSW",
- # "metric_type": "COSINE",
- # "hnsw_m": 32,
- # "hnsw_ef_construction": 256,
- # }
- # )
- #
- # Use MilvusIndexConfig.get_config_field_names() to dynamically extract valid parameters.
- # This ensures we always stay in sync with the MilvusIndexConfig dataclass definition.
- kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
- index_config_keys = MilvusIndexConfig.get_config_field_names()
- index_config_params = {
- k: v for k, v in kwargs.items() if k in index_config_keys
- }
- # Initialize index configuration (if not already set)
- # Configuration priority: init params from kwargs > environment variables > defaults
- if not hasattr(self, "index_config") or self.index_config is None:
- self.index_config = MilvusIndexConfig(**index_config_params)
- # Check for MILVUS_WORKSPACE environment variable first (higher priority)
- # This allows administrators to force a specific workspace for all Milvus storage instances
- milvus_workspace = os.environ.get("MILVUS_WORKSPACE")
- if milvus_workspace and milvus_workspace.strip():
- # Use environment variable value, overriding the passed workspace parameter
- effective_workspace = milvus_workspace.strip()
- logger.info(
- f"Using MILVUS_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')"
- )
- else:
- # Use the workspace parameter passed during initialization
- effective_workspace = self.workspace
- if effective_workspace:
- logger.debug(
- f"Using passed workspace parameter: '{effective_workspace}'"
- )
- # Build final_namespace with workspace prefix for data isolation
- # Keep original namespace unchanged for type detection logic
- if effective_workspace:
- self.final_namespace = f"{effective_workspace}_{self.namespace}"
- logger.debug(
- f"Final namespace with workspace prefix: '{self.final_namespace}'"
- )
- else:
- # When workspace is empty, final_namespace equals original namespace
- self.final_namespace = self.namespace
- self.workspace = ""
- logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
- cosine_threshold = kwargs.get("cosine_better_than_threshold")
- if cosine_threshold is None:
- raise ValueError(
- "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
- )
- self.cosine_better_than_threshold = cosine_threshold
- # Ensure created_at is in meta_fields
- if "created_at" not in self.meta_fields:
- self.meta_fields.add("created_at")
- # Initialize client as None - will be created in initialize() method
- self._client = None
- self._max_batch_size = self.global_config["embedding_batch_num"]
- self._initialized = False
- # Deferred-embedding buffers and the per-namespace flush lock.
- # The lock keys on final_namespace so two instances pointing at the
- # same Milvus collection (e.g. when MILVUS_WORKSPACE env override is
- # used) share a single writer lock. We construct it here in
- # __post_init__ — not in initialize() — so any code path that
- # touches the buffer before initialize() still has a valid lock.
- self._pending_vector_docs: dict[str, _PendingVectorDoc] = {}
- self._pending_vector_deletes: set[str] = set()
- self._flush_lock = get_namespace_lock(
- namespace=self.final_namespace, workspace=""
- )
- async def initialize(self):
- """Initialize Milvus collection"""
- async with get_data_init_lock():
- if self._initialized:
- return
- try:
- # Create MilvusClient if not already created
- if self._client is None:
- self._client = self._create_milvus_client()
- logger.debug(
- f"[{self.workspace}] MilvusClient created successfully"
- )
- # Validate Milvus version compatibility with configured index
- if self.index_config.index_type in INDEX_VERSION_REQUIREMENTS:
- try:
- server_version = self._client.get_server_version()
- self.index_config.validate_milvus_version(server_version)
- except Exception as version_error:
- logger.error(
- f"[{self.workspace}] Milvus version validation failed: {version_error}"
- )
- raise
- # Create collection and check compatibility
- self._create_collection_if_not_exist()
- self._initialized = True
- logger.info(
- f"[{self.workspace}] Milvus collection '{self.namespace}' initialized successfully"
- )
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Failed to initialize Milvus collection '{self.namespace}': {e}"
- )
- raise
- async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
- """Buffer vector docs for embedding and batched flush.
- Embedding deliberately does NOT happen here: repeated upserts of the
- same id, or many small batches, collapse into a single flush-time
- embedding pass. Reads (`get_by_id`/`get_by_ids`/`get_vectors_by_ids`)
- observe pending docs via the same lock for read-your-writes.
- """
- if not data:
- return
- import time
- current_time = int(time.time())
- pending_docs: list[tuple[str, _PendingVectorDoc]] = []
- for i, (k, v) in enumerate(data.items(), start=1):
- source = {
- "id": k,
- "created_at": current_time,
- **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
- }
- pending_docs.append(
- (
- k,
- _PendingVectorDoc(source=source, content=v["content"]),
- )
- )
- await _cooperative_yield(i)
- # An upsert overrides any pending delete on the same id; installing
- # a fresh _PendingVectorDoc instance invalidates any vector cached
- # by a prior get_vectors_by_ids() call on a stale revision.
- async with self._flush_lock:
- for doc_id, pdoc in pending_docs:
- self._pending_vector_deletes.discard(doc_id)
- self._pending_vector_docs[doc_id] = pdoc
- async def query(
- self, query: str, top_k: int, query_embedding: list[float] = None
- ) -> list[dict[str, Any]]:
- """Similarity search against the persisted Milvus collection.
- Note: buffered-but-unflushed upserts are NOT visible to this method —
- they exist only in `_pending_vector_docs` until `index_done_callback()`
- embeds and writes them. Callers that need read-after-write visibility
- for similarity search must run an explicit flush first.
- """
- # Ensure collection is loaded before querying
- self._ensure_collection_loaded()
- # Use provided embedding or compute it
- if query_embedding is not None:
- embedding = [query_embedding] # Milvus expects a list of embeddings
- else:
- embedding = await self.embedding_func(
- [query], context="query", _priority=5
- ) # higher priority for query
- # Include all meta_fields (created_at is now always included)
- output_fields = list(self.meta_fields)
- # Build search params from index config
- search_params_base = self.index_config.build_search_params()
- # Merge with metric type and radius threshold
- search_params = {
- "metric_type": self.index_config.metric_type,
- "params": {
- **search_params_base.get("params", {}),
- "radius": self.cosine_better_than_threshold,
- },
- }
- results = self._client.search(
- collection_name=self.final_namespace,
- data=embedding,
- limit=top_k,
- output_fields=output_fields,
- search_params=search_params,
- )
- return [
- {
- **dp["entity"],
- "id": dp["id"],
- "distance": dp["distance"],
- "created_at": dp.get("created_at"),
- }
- for dp in results[0]
- ]
- async def index_done_callback(self) -> None:
- """Flush all buffered vector ops to Milvus before returning.
- Contract: on a successful return, every previously buffered upsert
- has been embedded and committed to the collection, and every buffered
- delete has been issued — i.e. all pending vectors are durable in
- Milvus (which persists automatically once written). On any embed-
- or server-side failure this method raises and leaves both buffers
- intact for the next callback to retry; the caller MUST NOT assume
- clean persistence in that case.
- """
- await self._flush_pending_vector_ops()
- async def _flush_pending_vector_ops(self) -> None:
- """Flush buffered vector upserts and deletes to Milvus.
- Embedding runs *inside* this lock (not in `upsert` or lock-free):
- it makes deferred embedding and bulk indexing atomic against
- concurrent upserts and destructive mutations. Any failure (embed
- or server write) raises and leaves both buffers intact; the next
- `index_done_callback` retries automatically.
- """
- async with self._flush_lock:
- if not self._pending_vector_docs and not self._pending_vector_deletes:
- return
- if self._client is None:
- return
- # Milvus requires the collection to be loaded before upsert/delete.
- self._ensure_collection_loaded()
- pending_docs = self._pending_vector_docs
- pending_deletes = self._pending_vector_deletes
- docs_to_embed: list[tuple[str, _PendingVectorDoc]] = [
- (doc_id, pdoc)
- for doc_id, pdoc in pending_docs.items()
- if pdoc.vector is None
- ]
- if docs_to_embed:
- contents = [pdoc.content for _, pdoc in docs_to_embed]
- batches = [
- contents[i : i + self._max_batch_size]
- for i in range(0, len(contents), self._max_batch_size)
- ]
- logger.info(
- f"[{self.workspace}] {self.namespace} flush: embedding "
- f"{len(docs_to_embed)} vectors in {len(batches)} batch(es) "
- f"(batch_num={self._max_batch_size})"
- )
- try:
- embeddings_list = await asyncio.gather(
- *[
- self.embedding_func(batch, context="document")
- for batch in batches
- ]
- )
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error embedding pending vector ops "
- f"(upserts={len(docs_to_embed)}): {e}"
- )
- raise
- embeddings = np.concatenate(embeddings_list)
- if len(embeddings) != len(docs_to_embed):
- raise RuntimeError(
- f"[{self.workspace}] Embedding count mismatch: expected "
- f"{len(docs_to_embed)}, got {len(embeddings)}"
- )
- for i, ((_, pdoc), embedding) in enumerate(
- zip(docs_to_embed, embeddings), start=1
- ):
- pdoc.vector = embedding.tolist()
- await _cooperative_yield(i)
- # Assemble final upsert payload. After the embed loop above every
- # pending doc has a non-None vector (count-mismatch was checked),
- # so we can iterate without re-guarding.
- committed_ids: list[str] = list(pending_docs.keys())
- list_data: list[dict[str, Any]] = [
- {**pending_docs[doc_id].source, "vector": pending_docs[doc_id].vector}
- for doc_id in committed_ids
- ]
- try:
- if list_data:
- self._client.upsert(
- collection_name=self.final_namespace, data=list_data
- )
- if pending_deletes:
- self._client.delete(
- collection_name=self.final_namespace,
- pks=list(pending_deletes),
- )
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error flushing vector ops "
- f"(upserts={len(pending_docs)}, "
- f"deletes={len(pending_deletes)}): {e}"
- )
- raise
- # On success, clear the buffers in-place so external references
- # (e.g. drop()) see the cleared state.
- for doc_id in committed_ids:
- pending_docs.pop(doc_id, None)
- pending_deletes.clear()
- async def delete_entity(self, entity_name: str) -> None:
- """Buffer an entity vector delete by computing its hash ID."""
- entity_id = compute_mdhash_id(entity_name, prefix="ent-")
- async with self._flush_lock:
- self._pending_vector_docs.pop(entity_id, None)
- self._pending_vector_deletes.add(entity_id)
- logger.debug(
- f"[{self.workspace}] Buffered delete for entity {entity_name} (id={entity_id})"
- )
- async def delete_entity_relation(self, entity_name: str) -> None:
- """Delete all relation vectors where entity appears as src or tgt.
- The whole method runs under ``_flush_lock`` so the server-side query
- + delete cannot interleave with an in-flight bulk upsert.
- Server-side failures are re-raised (no log-and-swallow): the caller
- decides whether to retry.
- Buffer semantics — post-prune with caller short-circuit contract:
- Matching pending upserts in ``_pending_vector_docs`` are
- pruned **only after** the server-side query + delete
- succeeds. On failure the pending buffer stays intact and
- the exception propagates so the caller (``adelete_by_entity``
- in ``utils_graph.py``) can short-circuit before
- ``_persist_graph_updates`` flushes a half-cleaned buffer.
- Semantic note (deferred-buffer ↔ persisted divergence): pruning only
- consults the *current* buffered ``src_id`` / ``tgt_id`` view; we do
- not re-read the persisted row a buffered upsert is about to
- overwrite. So if a pending upsert is rewriting an already-persisted
- ``rel-X-Y`` so that its new ``src_id`` / ``tgt_id`` matches
- ``entity_name`` while the persisted row's do not (or vice versa),
- the persisted row will not be deleted by the server-side filter and
- the pending overwrite is dropped — i.e. the final state can diverge
- from the eager-flush ordering (upsert → flush → delete). Callers
- that require eager-equivalent semantics should call
- ``index_done_callback()`` before ``delete_entity_relation``.
- """
- def _prune_pending() -> None:
- for doc_id in [
- k
- for k, v in self._pending_vector_docs.items()
- if v.source.get("src_id") == entity_name
- or v.source.get("tgt_id") == entity_name
- ]:
- self._pending_vector_docs.pop(doc_id, None)
- async with self._flush_lock:
- if self._client is None:
- # No server state to mutate; buffer prune is the only
- # delete intent we can record.
- _prune_pending()
- return
- self._ensure_collection_loaded()
- expr = f'src_id == "{entity_name}" or tgt_id == "{entity_name}"'
- results = self._client.query(
- collection_name=self.final_namespace,
- filter=expr,
- output_fields=["id"],
- )
- if not results:
- # No server rows to delete — still safe to prune any
- # pending upserts so they can't re-create the relation.
- _prune_pending()
- logger.debug(
- f"[{self.workspace}] No relations found for entity {entity_name}"
- )
- return
- relation_ids = [item["id"] for item in results]
- self._client.delete(collection_name=self.final_namespace, pks=relation_ids)
- # Server-side delete succeeded — safe to prune the pending
- # buffer so subsequent flushes don't re-upsert the deleted
- # relations.
- _prune_pending()
- logger.debug(
- f"[{self.workspace}] Deleted {len(relation_ids)} relations for {entity_name}"
- )
- async def delete(self, ids: list[str]) -> None:
- """Buffer vector deletes for batched flush."""
- if not ids:
- return
- if isinstance(ids, set):
- ids = list(ids)
- async with self._flush_lock:
- for doc_id in ids:
- self._pending_vector_docs.pop(doc_id, None)
- self._pending_vector_deletes.add(doc_id)
- logger.debug(
- f"[{self.workspace}] Buffered delete for {len(ids)} vectors in {self.namespace}"
- )
- async def get_by_id(self, id: str) -> dict[str, Any] | None:
- """Get vector data by its ID, with read-your-writes against the buffer."""
- async with self._flush_lock:
- if id in self._pending_vector_deletes:
- return None
- pending = self._pending_vector_docs.get(id)
- if pending is not None:
- doc = dict(pending.source)
- doc["id"] = id
- return doc
- try:
- # Ensure collection is loaded before querying
- self._ensure_collection_loaded()
- # Include all meta_fields (created_at is now always included) plus id
- output_fields = list(self.meta_fields) + ["id"]
- result = self._client.query(
- collection_name=self.final_namespace,
- filter=f'id == "{id}"',
- output_fields=output_fields,
- )
- if not result or len(result) == 0:
- return None
- return result[0]
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error retrieving vector data for ID {id}: {e}"
- )
- return None
- async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
- """Get multiple vector data by their IDs (read-your-writes), preserving order."""
- if not ids:
- return []
- buffered: dict[str, dict[str, Any] | None] = {}
- remaining: list[str] = []
- async with self._flush_lock:
- for doc_id in ids:
- if doc_id in self._pending_vector_deletes:
- buffered[doc_id] = None
- continue
- pending = self._pending_vector_docs.get(doc_id)
- if pending is not None:
- doc = dict(pending.source)
- doc["id"] = doc_id
- buffered[doc_id] = doc
- continue
- remaining.append(doc_id)
- result_map: dict[str, dict[str, Any]] = {}
- if remaining:
- try:
- # Ensure collection is loaded before querying
- self._ensure_collection_loaded()
- # Include all meta_fields (created_at is now always included) plus id
- output_fields = list(self.meta_fields) + ["id"]
- id_list = '", "'.join(remaining)
- filter_expr = f'id in ["{id_list}"]'
- result = self._client.query(
- collection_name=self.final_namespace,
- filter=filter_expr,
- output_fields=output_fields,
- )
- if result:
- for row in result:
- if not row:
- continue
- row_id = row.get("id")
- if row_id is not None:
- result_map[str(row_id)] = row
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error retrieving vector data for IDs {remaining}: {e}"
- )
- return []
- return [
- buffered[doc_id] if doc_id in buffered else result_map.get(str(doc_id))
- for doc_id in ids
- ]
- async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
- """Get vector embeddings for given IDs, with read-your-writes.
- Pending docs with `vector is None` trigger a lazy embed inside the
- lock; the resulting vector is cached on the buffered `_PendingVectorDoc`
- so the next flush won't re-embed the same content.
- """
- if not ids:
- return {}
- result: dict[str, list[float]] = {}
- remaining: list[str] = []
- async with self._flush_lock:
- docs_to_embed: list[tuple[str, _PendingVectorDoc]] = []
- for doc_id in ids:
- if doc_id in self._pending_vector_deletes:
- continue
- pending = self._pending_vector_docs.get(doc_id)
- if pending is not None:
- if pending.vector is None:
- docs_to_embed.append((doc_id, pending))
- else:
- result[doc_id] = pending.vector
- continue
- remaining.append(doc_id)
- if docs_to_embed:
- contents = [pdoc.content for _, pdoc in docs_to_embed]
- batches = [
- contents[i : i + self._max_batch_size]
- for i in range(0, len(contents), self._max_batch_size)
- ]
- try:
- embeddings_list = await asyncio.gather(
- *[
- self.embedding_func(batch, context="document")
- for batch in batches
- ]
- )
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error lazily embedding pending vectors "
- f"(upserts={len(docs_to_embed)}): {e}"
- )
- raise
- embeddings = np.concatenate(embeddings_list)
- if len(embeddings) != len(docs_to_embed):
- raise RuntimeError(
- f"[{self.workspace}] Embedding count mismatch: expected "
- f"{len(docs_to_embed)}, got {len(embeddings)}"
- )
- for i, ((doc_id, pdoc), embedding) in enumerate(
- zip(docs_to_embed, embeddings), start=1
- ):
- pdoc.vector = embedding.tolist()
- result[doc_id] = pdoc.vector
- await _cooperative_yield(i)
- if not remaining:
- return result
- try:
- self._ensure_collection_loaded()
- id_list = '", "'.join(remaining)
- filter_expr = f'id in ["{id_list}"]'
- rows = self._client.query(
- collection_name=self.final_namespace,
- filter=filter_expr,
- output_fields=["id", "vector"],
- )
- for item in rows or []:
- if item and "vector" in item and "id" in item:
- vector_data = item["vector"]
- if isinstance(vector_data, np.ndarray):
- vector_data = vector_data.tolist()
- # Match get_by_ids: stringify the server-returned id so
- # callers can index the dict by the original requested id.
- result[str(item["id"])] = vector_data
- return result
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error retrieving vectors by IDs from {self.namespace}: {e}"
- )
- return result
- async def finalize(self):
- """Flush pending vector ops; surface unflushed data as RuntimeError.
- Milvus has no client connection to release (the MilvusClient is
- stateless from the storage layer's perspective), but we still need
- to fail loudly when a transient bulk error left writes buffered —
- the caller must not believe storage finalized cleanly.
- """
- flush_error: Exception | None = None
- try:
- await self._flush_pending_vector_ops()
- except Exception as e:
- flush_error = e
- # Read the residual buffer sizes under the flush lock so the
- # snapshot is consistent with any racing late-arriving mutator
- # (cancellation paths can land an upsert/delete between the flush
- # above and the post-mortem check below).
- async with self._flush_lock:
- pending_docs = len(self._pending_vector_docs)
- pending_deletes = len(self._pending_vector_deletes)
- if flush_error is not None:
- raise RuntimeError(
- f"[{self.workspace}] MilvusVectorDBStorage.finalize() flush raised; "
- f"{pending_docs} pending upserts and {pending_deletes} pending "
- f"deletes were left buffered (data lost)"
- ) from flush_error
- if pending_docs or pending_deletes:
- raise RuntimeError(
- f"[{self.workspace}] MilvusVectorDBStorage.finalize() left "
- f"{pending_docs} pending upserts and {pending_deletes} pending "
- f"deletes buffered after final flush attempt (these writes have been lost)"
- )
- async def drop(self) -> dict[str, str]:
- """Drop all data from the Milvus collection. Destructive.
- MUST only be called when ``pipeline_status`` is idle (see the
- Pipeline concurrency contract in ``AGENTS.md``); the only
- in-tree caller ``clear_documents`` enforces this.
- Caveat — only this instance's buffers are cleared. Other
- ``MilvusVectorDBStorage`` instances aliased onto the same
- ``final_namespace`` (multi-worker processes, or distinct
- workspaces collapsed by ``MILVUS_WORKSPACE``) keep their own
- buffers; a sibling whose prior flush failed and left buffers
- intact will, on its next flush, upsert those stale rows into
- the freshly recreated collection. Direct callers bypassing the
- idle precondition MUST flush every aliased instance first.
- Returns:
- dict[str, str]: ``{"status": "success"|"error", "message": str}``
- """
- try:
- async with self._flush_lock:
- # Discard any buffered writes before the collection is gone;
- # a concurrent flush would otherwise resurrect them.
- self._pending_vector_docs.clear()
- self._pending_vector_deletes.clear()
- # Drop the collection and recreate it
- if self._client.has_collection(self.final_namespace):
- self._client.drop_collection(self.final_namespace)
- # Recreate the collection
- self._create_collection_if_not_exist()
- logger.info(
- f"[{self.workspace}] Process {os.getpid()} drop Milvus collection {self.namespace}"
- )
- return {"status": "success", "message": "data dropped"}
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error dropping Milvus collection {self.namespace}: {e}"
- )
- return {"status": "error", "message": str(e)}
|