mongo_impl.py 113 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928
  1. import os
  2. import re
  3. import time
  4. from dataclasses import dataclass, field
  5. import numpy as np
  6. import configparser
  7. import asyncio
  8. from typing import Any, Union, final
  9. from ..base import (
  10. BaseGraphStorage,
  11. BaseKVStorage,
  12. BaseVectorStorage,
  13. DocProcessingStatus,
  14. DocStatus,
  15. DocStatusStorage,
  16. )
  17. from ..utils import logger, compute_mdhash_id, _cooperative_yield
  18. from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
  19. from ..constants import GRAPH_FIELD_SEP
  20. from .._version import __version__
  21. from ..kg.shared_storage import get_data_init_lock, get_namespace_lock
  22. import pipmaster as pm
  23. if not pm.is_installed("pymongo"):
  24. pm.install("pymongo")
  25. from pymongo import AsyncMongoClient # type: ignore
  26. from pymongo import UpdateOne, DeleteOne # type: ignore
  27. from pymongo.asynchronous.database import AsyncDatabase # type: ignore
  28. from pymongo.asynchronous.collection import AsyncCollection # type: ignore
  29. from pymongo.operations import SearchIndexModel # type: ignore
  30. from pymongo.driver_info import DriverInfo # type: ignore
  31. from pymongo.errors import PyMongoError # type: ignore
  32. config = configparser.ConfigParser()
  33. config.read("config.ini", "utf-8")
  34. GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional")
  35. class ClientManager:
  36. _instances = {"db": None, "ref_count": 0}
  37. _lock = asyncio.Lock()
  38. @classmethod
  39. async def get_client(cls) -> AsyncMongoClient:
  40. async with cls._lock:
  41. if cls._instances["db"] is None:
  42. uri = os.environ.get(
  43. "MONGO_URI",
  44. config.get(
  45. "mongodb",
  46. "uri",
  47. fallback="mongodb://root:root@localhost:27017/",
  48. ),
  49. )
  50. database_name = os.environ.get(
  51. "MONGO_DATABASE",
  52. config.get("mongodb", "database", fallback="LightRAG"),
  53. )
  54. client = AsyncMongoClient(
  55. uri,
  56. driver=DriverInfo(name="LightRAG", version=__version__),
  57. )
  58. db = client.get_database(database_name)
  59. cls._instances["db"] = db
  60. cls._instances["ref_count"] = 0
  61. cls._instances["ref_count"] += 1
  62. return cls._instances["db"]
  63. @classmethod
  64. async def release_client(cls, db: AsyncDatabase):
  65. async with cls._lock:
  66. if db is not None:
  67. if db is cls._instances["db"]:
  68. cls._instances["ref_count"] -= 1
  69. if cls._instances["ref_count"] == 0:
  70. cls._instances["db"] = None
  71. @final
  72. @dataclass
  73. class MongoKVStorage(BaseKVStorage):
  74. db: AsyncDatabase = field(default=None)
  75. _data: AsyncCollection = field(default=None)
  76. def __init__(self, namespace, global_config, embedding_func, workspace=None):
  77. super().__init__(
  78. namespace=namespace,
  79. workspace=workspace or "",
  80. global_config=global_config,
  81. embedding_func=embedding_func,
  82. )
  83. self.__post_init__()
  84. def __post_init__(self):
  85. # Check for MONGODB_WORKSPACE environment variable first (higher priority)
  86. # This allows administrators to force a specific workspace for all MongoDB storage instances
  87. mongodb_workspace = os.environ.get("MONGODB_WORKSPACE")
  88. if mongodb_workspace and mongodb_workspace.strip():
  89. # Use environment variable value, overriding the passed workspace parameter
  90. effective_workspace = mongodb_workspace.strip()
  91. logger.info(
  92. f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')"
  93. )
  94. else:
  95. # Use the workspace parameter passed during initialization
  96. effective_workspace = self.workspace
  97. if effective_workspace:
  98. logger.debug(
  99. f"Using passed workspace parameter: '{effective_workspace}'"
  100. )
  101. # Build final_namespace with workspace prefix for data isolation
  102. # Keep original namespace unchanged for type detection logic
  103. if effective_workspace:
  104. self.final_namespace = f"{effective_workspace}_{self.namespace}"
  105. self.workspace = effective_workspace
  106. logger.debug(
  107. f"Final namespace with workspace prefix: '{self.final_namespace}'"
  108. )
  109. else:
  110. # When workspace is empty, final_namespace equals original namespace
  111. self.final_namespace = self.namespace
  112. self.workspace = ""
  113. logger.debug(
  114. f"[{self.workspace}] Final namespace (no workspace): '{self.namespace}'"
  115. )
  116. self._collection_name = self.final_namespace
  117. async def initialize(self):
  118. async with get_data_init_lock():
  119. if self.db is None:
  120. self.db = await ClientManager.get_client()
  121. self._data = await get_or_create_collection(self.db, self._collection_name)
  122. logger.debug(
  123. f"[{self.workspace}] Use MongoDB as KV {self._collection_name}"
  124. )
  125. async def finalize(self):
  126. if self.db is not None:
  127. await ClientManager.release_client(self.db)
  128. self.db = None
  129. self._data = None
  130. async def get_by_id(self, id: str) -> dict[str, Any] | None:
  131. # Unified handling for flattened keys
  132. doc = await self._data.find_one({"_id": id})
  133. if doc:
  134. # Ensure time fields are present, provide default values for old data
  135. doc.setdefault("create_time", 0)
  136. doc.setdefault("update_time", 0)
  137. return doc
  138. async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
  139. cursor = self._data.find({"_id": {"$in": ids}})
  140. docs = await cursor.to_list(length=None)
  141. doc_map: dict[str, dict[str, Any]] = {}
  142. for doc in docs:
  143. if not doc:
  144. continue
  145. doc.setdefault("create_time", 0)
  146. doc.setdefault("update_time", 0)
  147. doc_map[str(doc.get("_id"))] = doc
  148. ordered_results: list[dict[str, Any] | None] = []
  149. for id_value in ids:
  150. ordered_results.append(doc_map.get(str(id_value)))
  151. return ordered_results
  152. async def filter_keys(self, keys: set[str]) -> set[str]:
  153. cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1})
  154. existing_ids = {str(x["_id"]) async for x in cursor}
  155. return keys - existing_ids
  156. async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
  157. logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
  158. if not data:
  159. return
  160. # Unified handling for all namespaces with flattened keys
  161. # Use bulk_write for better performance
  162. operations = []
  163. current_time = int(time.time()) # Get current Unix timestamp
  164. for i, (k, v) in enumerate(data.items(), start=1):
  165. # For text_chunks namespace, ensure llm_cache_list field exists
  166. if self.namespace.endswith("text_chunks"):
  167. if "llm_cache_list" not in v:
  168. v["llm_cache_list"] = []
  169. # Create a copy of v for $set operation, excluding create_time to avoid conflicts
  170. v_for_set = v.copy()
  171. v_for_set["_id"] = k # Use flattened key as _id
  172. v_for_set["update_time"] = current_time # Always update update_time
  173. # Remove create_time from $set to avoid conflict with $setOnInsert
  174. v_for_set.pop("create_time", None)
  175. operations.append(
  176. UpdateOne(
  177. {"_id": k},
  178. {
  179. "$set": v_for_set, # Update all fields except create_time
  180. "$setOnInsert": {
  181. "create_time": current_time
  182. }, # Set create_time only on insert
  183. },
  184. upsert=True,
  185. )
  186. )
  187. await _cooperative_yield(i)
  188. if operations:
  189. await self._data.bulk_write(operations)
  190. async def index_done_callback(self) -> None:
  191. # Mongo handles persistence automatically
  192. pass
  193. async def is_empty(self) -> bool:
  194. """Check if the storage is empty for the current workspace and namespace
  195. Returns:
  196. bool: True if storage is empty, False otherwise
  197. """
  198. try:
  199. # Use count_documents with limit 1 for efficiency
  200. count = await self._data.count_documents({}, limit=1)
  201. return count == 0
  202. except PyMongoError as e:
  203. logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}")
  204. return True
  205. async def delete(self, ids: list[str]) -> None:
  206. """Delete documents with specified IDs
  207. Args:
  208. ids: List of document IDs to be deleted
  209. """
  210. if not ids:
  211. return
  212. # Convert to list if it's a set (MongoDB BSON cannot encode sets)
  213. if isinstance(ids, set):
  214. ids = list(ids)
  215. try:
  216. result = await self._data.delete_many({"_id": {"$in": ids}})
  217. logger.info(
  218. f"[{self.workspace}] Deleted {result.deleted_count} documents from {self.namespace}"
  219. )
  220. except PyMongoError as e:
  221. logger.error(
  222. f"[{self.workspace}] Error deleting documents from {self.namespace}: {e}"
  223. )
  224. async def drop(self) -> dict[str, str]:
  225. """Drop the storage by removing all documents in the collection.
  226. Returns:
  227. dict[str, str]: Status of the operation with keys 'status' and 'message'
  228. """
  229. try:
  230. result = await self._data.delete_many({})
  231. deleted_count = result.deleted_count
  232. logger.info(
  233. f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}"
  234. )
  235. return {
  236. "status": "success",
  237. "message": f"{deleted_count} documents dropped",
  238. }
  239. except PyMongoError as e:
  240. logger.error(
  241. f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}"
  242. )
  243. return {"status": "error", "message": str(e)}
  244. @final
  245. @dataclass
  246. class MongoDocStatusStorage(DocStatusStorage):
  247. db: AsyncDatabase = field(default=None)
  248. _data: AsyncCollection = field(default=None)
  249. def _prepare_doc_status_data(self, doc: dict[str, Any]) -> dict[str, Any]:
  250. """Normalize and migrate a raw Mongo document to DocProcessingStatus-compatible dict."""
  251. # Make a copy of the data to avoid modifying the original
  252. data = doc.copy()
  253. # Remove deprecated content field if it exists
  254. data.pop("content", None)
  255. # Remove MongoDB _id field if it exists
  256. data.pop("_id", None)
  257. # If file_path is not in data, use document id as file path
  258. if "file_path" not in data:
  259. data["file_path"] = "no-file-path"
  260. # Ensure new fields exist with default values
  261. if "metadata" not in data:
  262. data["metadata"] = {}
  263. if "error_msg" not in data:
  264. data["error_msg"] = None
  265. # Backward compatibility: migrate legacy 'error' field to 'error_msg'
  266. if "error" in data:
  267. if "error_msg" not in data or data["error_msg"] in (None, ""):
  268. data["error_msg"] = data.pop("error")
  269. else:
  270. data.pop("error", None)
  271. return data
  272. def __init__(self, namespace, global_config, embedding_func, workspace=None):
  273. super().__init__(
  274. namespace=namespace,
  275. workspace=workspace or "",
  276. global_config=global_config,
  277. embedding_func=embedding_func,
  278. )
  279. self.__post_init__()
  280. def __post_init__(self):
  281. # Check for MONGODB_WORKSPACE environment variable first (higher priority)
  282. # This allows administrators to force a specific workspace for all MongoDB storage instances
  283. mongodb_workspace = os.environ.get("MONGODB_WORKSPACE")
  284. if mongodb_workspace and mongodb_workspace.strip():
  285. # Use environment variable value, overriding the passed workspace parameter
  286. effective_workspace = mongodb_workspace.strip()
  287. logger.info(
  288. f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')"
  289. )
  290. else:
  291. # Use the workspace parameter passed during initialization
  292. effective_workspace = self.workspace
  293. if effective_workspace:
  294. logger.debug(
  295. f"Using passed workspace parameter: '{effective_workspace}'"
  296. )
  297. # Build final_namespace with workspace prefix for data isolation
  298. # Keep original namespace unchanged for type detection logic
  299. if effective_workspace:
  300. self.final_namespace = f"{effective_workspace}_{self.namespace}"
  301. self.workspace = effective_workspace
  302. logger.debug(
  303. f"Final namespace with workspace prefix: '{self.final_namespace}'"
  304. )
  305. else:
  306. # When workspace is empty, final_namespace equals original namespace
  307. self.final_namespace = self.namespace
  308. self.workspace = ""
  309. logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
  310. self._collection_name = self.final_namespace
  311. async def initialize(self):
  312. async with get_data_init_lock():
  313. if self.db is None:
  314. self.db = await ClientManager.get_client()
  315. self._data = await get_or_create_collection(self.db, self._collection_name)
  316. # Create and migrate all indexes including Chinese collation for file_path
  317. await self.create_and_migrate_indexes_if_not_exists()
  318. logger.debug(
  319. f"[{self.workspace}] Use MongoDB as DocStatus {self._collection_name}"
  320. )
  321. async def finalize(self):
  322. if self.db is not None:
  323. await ClientManager.release_client(self.db)
  324. self.db = None
  325. self._data = None
  326. async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
  327. return await self._data.find_one({"_id": id})
  328. async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
  329. cursor = self._data.find({"_id": {"$in": ids}})
  330. docs = await cursor.to_list(length=None)
  331. doc_map: dict[str, dict[str, Any]] = {}
  332. for doc in docs:
  333. if not doc:
  334. continue
  335. doc_map[str(doc.get("_id"))] = doc
  336. ordered_results: list[dict[str, Any] | None] = []
  337. for id_value in ids:
  338. ordered_results.append(doc_map.get(str(id_value)))
  339. return ordered_results
  340. async def filter_keys(self, data: set[str]) -> set[str]:
  341. cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
  342. existing_ids = {str(x["_id"]) async for x in cursor}
  343. return data - existing_ids
  344. async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
  345. logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
  346. if not data:
  347. return
  348. update_tasks: list[Any] = []
  349. for i, (k, v) in enumerate(data.items(), start=1):
  350. # Ensure chunks_list field exists and is an array
  351. if "chunks_list" not in v:
  352. v["chunks_list"] = []
  353. data[k]["_id"] = k
  354. update_tasks.append(
  355. self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
  356. )
  357. await _cooperative_yield(i)
  358. await asyncio.gather(*update_tasks)
  359. async def get_status_counts(self) -> dict[str, int]:
  360. """Get counts of documents in each status"""
  361. pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
  362. cursor = await self._data.aggregate(pipeline, allowDiskUse=True)
  363. result = await cursor.to_list()
  364. counts = {}
  365. for doc in result:
  366. counts[doc["_id"]] = doc["count"]
  367. return counts
  368. async def get_docs_by_status(
  369. self, status: DocStatus
  370. ) -> dict[str, DocProcessingStatus]:
  371. """Get all documents with a specific status"""
  372. return await self.get_docs_by_statuses([status])
  373. async def get_docs_by_statuses(
  374. self, statuses: list[DocStatus]
  375. ) -> dict[str, DocProcessingStatus]:
  376. """Get all documents matching any of the given statuses in a single query.
  377. Uses MongoDB's $in operator to fetch all matching statuses in one
  378. round-trip instead of one find() call per status.
  379. """
  380. if not statuses:
  381. return {}
  382. status_values = [s.value for s in statuses]
  383. cursor = self._data.find({"status": {"$in": status_values}})
  384. docs = await cursor.to_list(length=None)
  385. result = {}
  386. for doc in docs:
  387. try:
  388. data = self._prepare_doc_status_data(doc)
  389. result[doc["_id"]] = DocProcessingStatus(**data)
  390. except KeyError as e:
  391. logger.error(
  392. f"[{self.workspace}] Missing required field for document {doc['_id']}: {e}"
  393. )
  394. continue
  395. return result
  396. async def get_docs_by_track_id(
  397. self, track_id: str
  398. ) -> dict[str, DocProcessingStatus]:
  399. """Get all documents with a specific track_id"""
  400. cursor = self._data.find({"track_id": track_id})
  401. result = await cursor.to_list()
  402. processed_result = {}
  403. for doc in result:
  404. try:
  405. data = self._prepare_doc_status_data(doc)
  406. processed_result[doc["_id"]] = DocProcessingStatus(**data)
  407. except KeyError as e:
  408. logger.error(
  409. f"[{self.workspace}] Missing required field for document {doc['_id']}: {e}"
  410. )
  411. continue
  412. return processed_result
  413. async def index_done_callback(self) -> None:
  414. # Mongo handles persistence automatically
  415. pass
  416. async def is_empty(self) -> bool:
  417. """Check if the storage is empty for the current workspace and namespace
  418. Returns:
  419. bool: True if storage is empty, False otherwise
  420. """
  421. try:
  422. # Use count_documents with limit 1 for efficiency
  423. count = await self._data.count_documents({}, limit=1)
  424. return count == 0
  425. except PyMongoError as e:
  426. logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}")
  427. return True
  428. async def drop(self) -> dict[str, str]:
  429. """Drop the storage by removing all documents in the collection.
  430. Returns:
  431. dict[str, str]: Status of the operation with keys 'status' and 'message'
  432. """
  433. try:
  434. result = await self._data.delete_many({})
  435. deleted_count = result.deleted_count
  436. logger.info(
  437. f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}"
  438. )
  439. return {
  440. "status": "success",
  441. "message": f"{deleted_count} documents dropped",
  442. }
  443. except PyMongoError as e:
  444. logger.error(
  445. f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}"
  446. )
  447. return {"status": "error", "message": str(e)}
  448. async def delete(self, ids: list[str]) -> None:
  449. await self._data.delete_many({"_id": {"$in": ids}})
  450. async def create_and_migrate_indexes_if_not_exists(self):
  451. """Create indexes to optimize pagination queries and migrate file_path indexes for Chinese collation"""
  452. try:
  453. # Get indexes for the current collection only
  454. indexes_cursor = await self._data.list_indexes()
  455. existing_indexes = await indexes_cursor.to_list(length=None)
  456. existing_index_names = {idx.get("name", "") for idx in existing_indexes}
  457. # Define collation configuration for Chinese pinyin sorting
  458. collation_config = {"locale": "zh", "numericOrdering": True}
  459. # Use workspace-specific index names to avoid cross-workspace conflicts
  460. workspace_prefix = f"{self.workspace}_" if self.workspace != "" else ""
  461. # 1. Define all indexes needed with workspace-specific names
  462. all_indexes = [
  463. # Original pagination indexes
  464. {
  465. "name": f"{workspace_prefix}status_updated_at",
  466. "keys": [("status", 1), ("updated_at", -1)],
  467. },
  468. {
  469. "name": f"{workspace_prefix}status_created_at",
  470. "keys": [("status", 1), ("created_at", -1)],
  471. },
  472. {"name": f"{workspace_prefix}updated_at", "keys": [("updated_at", -1)]},
  473. {"name": f"{workspace_prefix}created_at", "keys": [("created_at", -1)]},
  474. {"name": f"{workspace_prefix}id", "keys": [("_id", 1)]},
  475. {"name": f"{workspace_prefix}track_id", "keys": [("track_id", 1)]},
  476. # New file_path indexes with Chinese collation and workspace-specific names
  477. {
  478. "name": f"{workspace_prefix}file_path_zh_collation",
  479. "keys": [("file_path", 1)],
  480. "collation": collation_config,
  481. },
  482. {
  483. "name": f"{workspace_prefix}status_file_path_zh_collation",
  484. "keys": [("status", 1), ("file_path", 1)],
  485. "collation": collation_config,
  486. },
  487. # Partial index on content_hash for content-based dedup lookups.
  488. # Mirrors the PG partial index: skip legacy/empty values so the
  489. # index stays small and a content_hash="" query is a guaranteed miss.
  490. {
  491. "name": f"{workspace_prefix}content_hash",
  492. "keys": [("content_hash", 1)],
  493. "partialFilterExpression": {
  494. "content_hash": {"$exists": True, "$type": "string", "$gt": ""}
  495. },
  496. },
  497. ]
  498. # 2. Handle legacy index cleanup: only drop old indexes that exist in THIS collection
  499. legacy_index_names = [
  500. "file_path_zh_collation",
  501. "status_file_path_zh_collation",
  502. "status_updated_at",
  503. "status_created_at",
  504. "updated_at",
  505. "created_at",
  506. "id",
  507. "track_id",
  508. "content_hash",
  509. ]
  510. for legacy_name in legacy_index_names:
  511. if (
  512. legacy_name in existing_index_names
  513. and legacy_name
  514. != f"{workspace_prefix}{legacy_name.replace(workspace_prefix, '')}"
  515. ):
  516. try:
  517. await self._data.drop_index(legacy_name)
  518. logger.debug(
  519. f"[{self.workspace}] Migrated: dropped legacy index '{legacy_name}' from collection {self._collection_name}"
  520. )
  521. existing_index_names.discard(legacy_name)
  522. except PyMongoError as drop_error:
  523. logger.warning(
  524. f"[{self.workspace}] Failed to drop legacy index '{legacy_name}' from collection {self._collection_name}: {drop_error}"
  525. )
  526. # 3. Create all needed indexes with workspace-specific names
  527. for index_info in all_indexes:
  528. index_name = index_info["name"]
  529. if index_name not in existing_index_names:
  530. create_kwargs = {"name": index_name}
  531. if "collation" in index_info:
  532. create_kwargs["collation"] = index_info["collation"]
  533. if "partialFilterExpression" in index_info:
  534. create_kwargs["partialFilterExpression"] = index_info[
  535. "partialFilterExpression"
  536. ]
  537. try:
  538. await self._data.create_index(
  539. index_info["keys"], **create_kwargs
  540. )
  541. logger.debug(
  542. f"[{self.workspace}] Created index '{index_name}' for collection {self._collection_name}"
  543. )
  544. except PyMongoError as create_error:
  545. # If creation still fails, log the error but continue with other indexes
  546. logger.error(
  547. f"[{self.workspace}] Failed to create index '{index_name}' for collection {self._collection_name}: {create_error}"
  548. )
  549. else:
  550. logger.debug(
  551. f"[{self.workspace}] Index '{index_name}' already exists for collection {self._collection_name}"
  552. )
  553. except PyMongoError as e:
  554. logger.error(
  555. f"[{self.workspace}] Error creating/migrating indexes for {self._collection_name}: {e}"
  556. )
  557. async def get_docs_paginated(
  558. self,
  559. status_filter: DocStatus | None = None,
  560. status_filters: list[DocStatus] | None = None,
  561. page: int = 1,
  562. page_size: int = 50,
  563. sort_field: str = "updated_at",
  564. sort_direction: str = "desc",
  565. ) -> tuple[list[tuple[str, DocProcessingStatus]], int]:
  566. """Get documents with pagination support
  567. Args:
  568. status_filter: Filter by document status, None for all statuses
  569. page: Page number (1-based)
  570. page_size: Number of documents per page (10-200)
  571. sort_field: Field to sort by ('created_at', 'updated_at', '_id')
  572. sort_direction: Sort direction ('asc' or 'desc')
  573. Returns:
  574. Tuple of (list of (doc_id, DocProcessingStatus) tuples, total_count)
  575. """
  576. status_filter_values = self.resolve_status_filter_values(
  577. status_filter=status_filter,
  578. status_filters=status_filters,
  579. )
  580. # Validate parameters
  581. if page < 1:
  582. page = 1
  583. if page_size < 10:
  584. page_size = 10
  585. elif page_size > 200:
  586. page_size = 200
  587. if sort_field not in ["created_at", "updated_at", "_id", "file_path"]:
  588. sort_field = "updated_at"
  589. if sort_direction.lower() not in ["asc", "desc"]:
  590. sort_direction = "desc"
  591. # Build query filter
  592. query_filter = {}
  593. if status_filter_values is not None:
  594. query_filter["status"] = {"$in": sorted(status_filter_values)}
  595. # Get total count
  596. total_count = await self._data.count_documents(query_filter)
  597. # Calculate skip value
  598. skip = (page - 1) * page_size
  599. # Build sort criteria
  600. sort_direction_value = 1 if sort_direction.lower() == "asc" else -1
  601. sort_criteria = [(sort_field, sort_direction_value)]
  602. # Query for paginated data with Chinese collation for file_path sorting
  603. if sort_field == "file_path":
  604. # Use Chinese collation for pinyin sorting
  605. cursor = (
  606. self._data.find(query_filter)
  607. .sort(sort_criteria)
  608. .collation({"locale": "zh", "numericOrdering": True})
  609. .skip(skip)
  610. .limit(page_size)
  611. )
  612. else:
  613. # Use default sorting for other fields
  614. cursor = (
  615. self._data.find(query_filter)
  616. .sort(sort_criteria)
  617. .skip(skip)
  618. .limit(page_size)
  619. )
  620. result = await cursor.to_list(length=page_size)
  621. # Convert to (doc_id, DocProcessingStatus) tuples
  622. documents = []
  623. for doc in result:
  624. try:
  625. doc_id = doc["_id"]
  626. data = self._prepare_doc_status_data(doc)
  627. doc_status = DocProcessingStatus(**data)
  628. documents.append((doc_id, doc_status))
  629. except KeyError as e:
  630. logger.error(
  631. f"[{self.workspace}] Missing required field for document {doc['_id']}: {e}"
  632. )
  633. continue
  634. return documents, total_count
  635. async def get_all_status_counts(self) -> dict[str, int]:
  636. """Get counts of documents in each status for all documents
  637. Returns:
  638. Dictionary mapping status names to counts, including 'all' field
  639. """
  640. pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
  641. cursor = await self._data.aggregate(pipeline, allowDiskUse=True)
  642. result = await cursor.to_list()
  643. counts = {}
  644. total_count = 0
  645. for doc in result:
  646. counts[doc["_id"]] = doc["count"]
  647. total_count += doc["count"]
  648. # Add 'all' field with total count
  649. counts["all"] = total_count
  650. return counts
  651. async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]:
  652. """Get document by file path
  653. Args:
  654. file_path: The file path to search for
  655. Returns:
  656. Union[dict[str, Any], None]: Document data if found, None otherwise
  657. Returns the same format as get_by_id method
  658. """
  659. return await self._data.find_one({"file_path": file_path})
  660. async def get_doc_by_file_basename(
  661. self, basename: str
  662. ) -> Union[tuple[str, dict[str, Any]], None]:
  663. """Mongo-native override of basename-based document lookup.
  664. The caller is responsible for passing an already-canonical basename;
  665. stored ``file_path`` values are canonicalized by the business layer, so
  666. this lookup performs an exact match only and relies on the file_path
  667. index created by ``create_and_migrate_indexes_if_not_exists``.
  668. """
  669. if not basename:
  670. return None
  671. if basename == "unknown_source":
  672. return None
  673. try:
  674. doc = await self._data.find_one({"file_path": basename})
  675. except PyMongoError as e:
  676. logger.error(f"[{self.workspace}] Error in get_doc_by_file_basename: {e}")
  677. return None
  678. if not doc:
  679. return None
  680. doc_id = doc.get("_id")
  681. if doc_id is None:
  682. return None
  683. return str(doc_id), doc
  684. async def get_doc_by_content_hash(
  685. self, content_hash: str
  686. ) -> Union[tuple[str, dict[str, Any]], None]:
  687. """Mongo-native override of content-hash document lookup.
  688. Uses the partial ``content_hash`` index. Empty strings are treated as a
  689. miss to align with the partial-index predicate; legacy rows missing the
  690. field cannot match a non-empty query because ``find_one`` requires an
  691. exact value.
  692. """
  693. if not content_hash:
  694. return None
  695. try:
  696. doc = await self._data.find_one({"content_hash": content_hash})
  697. except PyMongoError as e:
  698. logger.error(f"[{self.workspace}] Error in get_doc_by_content_hash: {e}")
  699. return None
  700. if not doc:
  701. return None
  702. doc_id = doc.get("_id")
  703. if doc_id is None:
  704. return None
  705. return str(doc_id), doc
  706. @final
  707. @dataclass
  708. class MongoGraphStorage(BaseGraphStorage):
  709. """
  710. A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries.
  711. """
  712. db: AsyncDatabase = field(default=None)
  713. # node collection storing node_id, node_properties
  714. collection: AsyncCollection = field(default=None)
  715. # edge collection storing source_node_id, target_node_id, and edge_properties
  716. edgeCollection: AsyncCollection = field(default=None)
  717. def __init__(self, namespace, global_config, embedding_func, workspace=None):
  718. super().__init__(
  719. namespace=namespace,
  720. workspace=workspace or "",
  721. global_config=global_config,
  722. embedding_func=embedding_func,
  723. )
  724. # Check for MONGODB_WORKSPACE environment variable first (higher priority)
  725. # This allows administrators to force a specific workspace for all MongoDB storage instances
  726. mongodb_workspace = os.environ.get("MONGODB_WORKSPACE")
  727. if mongodb_workspace and mongodb_workspace.strip():
  728. # Use environment variable value, overriding the passed workspace parameter
  729. effective_workspace = mongodb_workspace.strip()
  730. logger.info(
  731. f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')"
  732. )
  733. else:
  734. # Use the workspace parameter passed during initialization
  735. effective_workspace = self.workspace
  736. if effective_workspace:
  737. logger.debug(
  738. f"Using passed workspace parameter: '{effective_workspace}'"
  739. )
  740. # Build final_namespace with workspace prefix for data isolation
  741. # Keep original namespace unchanged for type detection logic
  742. if effective_workspace:
  743. self.final_namespace = f"{effective_workspace}_{self.namespace}"
  744. self.workspace = effective_workspace
  745. logger.debug(
  746. f"Final namespace with workspace prefix: '{self.final_namespace}'"
  747. )
  748. else:
  749. # When workspace is empty, final_namespace equals original namespace
  750. self.final_namespace = self.namespace
  751. self.workspace = ""
  752. logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
  753. self._collection_name = self.final_namespace
  754. self._edge_collection_name = f"{self._collection_name}_edges"
  755. async def initialize(self):
  756. async with get_data_init_lock():
  757. if self.db is None:
  758. self.db = await ClientManager.get_client()
  759. self.collection = await get_or_create_collection(
  760. self.db, self._collection_name
  761. )
  762. self.edge_collection = await get_or_create_collection(
  763. self.db, self._edge_collection_name
  764. )
  765. # Create Atlas Search index for better search performance if possible
  766. await self.create_search_index_if_not_exists()
  767. logger.debug(
  768. f"[{self.workspace}] Use MongoDB as KG {self._collection_name}"
  769. )
  770. async def finalize(self):
  771. if self.db is not None:
  772. await ClientManager.release_client(self.db)
  773. self.db = None
  774. self.collection = None
  775. self.edge_collection = None
  776. # Sample entity document
  777. # "source_ids" is Array representation of "source_id" split by GRAPH_FIELD_SEP
  778. # {
  779. # "_id" : "CompanyA",
  780. # "entity_id" : "CompanyA",
  781. # "entity_type" : "Organization",
  782. # "description" : "A major technology company",
  783. # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec",
  784. # "source_ids": ["chunk-eeec0036b909839e8ec4fa150c939eec"],
  785. # "file_path" : "custom_kg",
  786. # "created_at" : 1749904575
  787. # }
  788. # Sample relation document
  789. # {
  790. # "_id" : ObjectId("6856ac6e7c6bad9b5470b678"), // MongoDB build-in ObjectId
  791. # "description" : "CompanyA develops ProductX",
  792. # "source_node_id" : "CompanyA",
  793. # "target_node_id" : "ProductX",
  794. # "relationship": "Develops", // To distinguish multiple same-target relations
  795. # "weight" : Double("1"),
  796. # "keywords" : "develop, produce",
  797. # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec",
  798. # "source_ids": ["chunk-eeec0036b909839e8ec4fa150c939eec"],
  799. # "file_path" : "custom_kg",
  800. # "created_at" : 1749904575
  801. # }
  802. #
  803. # -------------------------------------------------------------------------
  804. # BASIC QUERIES
  805. # -------------------------------------------------------------------------
  806. #
  807. async def has_node(self, node_id: str) -> bool:
  808. """
  809. Check if node_id is present in the collection by looking up its doc.
  810. No real need for $graphLookup here, but let's keep it direct.
  811. """
  812. doc = await self.collection.find_one({"_id": node_id}, {"_id": 1})
  813. return doc is not None
  814. async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
  815. """
  816. Check if there's a direct single-hop edge between source_node_id and target_node_id.
  817. """
  818. doc = await self.edge_collection.find_one(
  819. {
  820. "$or": [
  821. {
  822. "source_node_id": source_node_id,
  823. "target_node_id": target_node_id,
  824. },
  825. {
  826. "source_node_id": target_node_id,
  827. "target_node_id": source_node_id,
  828. },
  829. ]
  830. },
  831. {"_id": 1},
  832. )
  833. return doc is not None
  834. #
  835. # -------------------------------------------------------------------------
  836. # DEGREES
  837. # -------------------------------------------------------------------------
  838. #
  839. async def node_degree(self, node_id: str) -> int:
  840. """
  841. Returns the total number of edges connected to node_id (both inbound and outbound).
  842. """
  843. return await self.edge_collection.count_documents(
  844. {"$or": [{"source_node_id": node_id}, {"target_node_id": node_id}]}
  845. )
  846. async def edge_degree(self, src_id: str, tgt_id: str) -> int:
  847. """Get the total degree (sum of relationships) of two nodes.
  848. Args:
  849. src_id: Label of the source node
  850. tgt_id: Label of the target node
  851. Returns:
  852. int: Sum of the degrees of both nodes
  853. """
  854. src_degree = await self.node_degree(src_id)
  855. trg_degree = await self.node_degree(tgt_id)
  856. return src_degree + trg_degree
  857. #
  858. # -------------------------------------------------------------------------
  859. # GETTERS
  860. # -------------------------------------------------------------------------
  861. #
  862. async def get_node(self, node_id: str) -> dict[str, str] | None:
  863. """
  864. Return the full node document, or None if missing.
  865. """
  866. return await self.collection.find_one({"_id": node_id})
  867. async def get_edge(
  868. self, source_node_id: str, target_node_id: str
  869. ) -> dict[str, str] | None:
  870. return await self.edge_collection.find_one(
  871. {
  872. "$or": [
  873. {
  874. "source_node_id": source_node_id,
  875. "target_node_id": target_node_id,
  876. },
  877. {
  878. "source_node_id": target_node_id,
  879. "target_node_id": source_node_id,
  880. },
  881. ]
  882. }
  883. )
  884. async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
  885. """
  886. Retrieves all edges (relationships) for a particular node identified by its label.
  887. Args:
  888. source_node_id: Label of the node to get edges for
  889. Returns:
  890. list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
  891. None: If no edges found
  892. """
  893. cursor = self.edge_collection.find(
  894. {
  895. "$or": [
  896. {"source_node_id": source_node_id},
  897. {"target_node_id": source_node_id},
  898. ]
  899. },
  900. {"source_node_id": 1, "target_node_id": 1},
  901. )
  902. return [
  903. (e.get("source_node_id"), e.get("target_node_id")) async for e in cursor
  904. ]
  905. async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
  906. result = {}
  907. async for doc in self.collection.find({"_id": {"$in": node_ids}}):
  908. result[doc.get("_id")] = doc
  909. return result
  910. async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
  911. # merge the outbound and inbound results with the same "_id" and sum the "degree"
  912. merged_results = {}
  913. # Outbound degrees
  914. outbound_pipeline = [
  915. {"$match": {"source_node_id": {"$in": node_ids}}},
  916. {"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}},
  917. ]
  918. cursor = await self.edge_collection.aggregate(
  919. outbound_pipeline, allowDiskUse=True
  920. )
  921. async for doc in cursor:
  922. merged_results[doc.get("_id")] = doc.get("degree")
  923. # Inbound degrees
  924. inbound_pipeline = [
  925. {"$match": {"target_node_id": {"$in": node_ids}}},
  926. {"$group": {"_id": "$target_node_id", "degree": {"$sum": 1}}},
  927. ]
  928. cursor = await self.edge_collection.aggregate(
  929. inbound_pipeline, allowDiskUse=True
  930. )
  931. async for doc in cursor:
  932. merged_results[doc.get("_id")] = merged_results.get(
  933. doc.get("_id"), 0
  934. ) + doc.get("degree")
  935. return merged_results
  936. async def get_nodes_edges_batch(
  937. self, node_ids: list[str]
  938. ) -> dict[str, list[tuple[str, str]]]:
  939. """
  940. Batch retrieve edges for multiple nodes.
  941. For each node, returns both outgoing and incoming edges to properly represent
  942. the undirected graph nature.
  943. Args:
  944. node_ids: List of node IDs (entity_id) for which to retrieve edges.
  945. Returns:
  946. A dictionary mapping each node ID to its list of edge tuples (source, target).
  947. For each node, the list includes both:
  948. - Outgoing edges: (queried_node, connected_node)
  949. - Incoming edges: (connected_node, queried_node)
  950. """
  951. result = {node_id: [] for node_id in node_ids}
  952. # Query outgoing edges (where node is the source)
  953. outgoing_cursor = self.edge_collection.find(
  954. {"source_node_id": {"$in": node_ids}},
  955. {"source_node_id": 1, "target_node_id": 1},
  956. )
  957. async for edge in outgoing_cursor:
  958. source = edge["source_node_id"]
  959. target = edge["target_node_id"]
  960. result[source].append((source, target))
  961. # Query incoming edges (where node is the target)
  962. incoming_cursor = self.edge_collection.find(
  963. {"target_node_id": {"$in": node_ids}},
  964. {"source_node_id": 1, "target_node_id": 1},
  965. )
  966. async for edge in incoming_cursor:
  967. source = edge["source_node_id"]
  968. target = edge["target_node_id"]
  969. result[target].append((source, target))
  970. return result
  971. #
  972. # -------------------------------------------------------------------------
  973. # UPSERTS
  974. # -------------------------------------------------------------------------
  975. #
  976. async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
  977. """
  978. Insert or update a node document.
  979. """
  980. update_doc = {"$set": {**node_data}}
  981. if node_data.get("source_id", ""):
  982. update_doc["$set"]["source_ids"] = node_data["source_id"].split(
  983. GRAPH_FIELD_SEP
  984. )
  985. await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
  986. async def upsert_edge(
  987. self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
  988. ) -> None:
  989. """
  990. Upsert an edge between source_node_id and target_node_id with optional 'relation'.
  991. If an edge with the same target exists, we remove it and re-insert with updated data.
  992. """
  993. # Ensure source node exists
  994. await self.upsert_node(source_node_id, {})
  995. update_doc = {"$set": edge_data}
  996. if edge_data.get("source_id", ""):
  997. update_doc["$set"]["source_ids"] = edge_data["source_id"].split(
  998. GRAPH_FIELD_SEP
  999. )
  1000. edge_data["source_node_id"] = source_node_id
  1001. edge_data["target_node_id"] = target_node_id
  1002. await self.edge_collection.update_one(
  1003. {
  1004. "$or": [
  1005. {
  1006. "source_node_id": source_node_id,
  1007. "target_node_id": target_node_id,
  1008. },
  1009. {
  1010. "source_node_id": target_node_id,
  1011. "target_node_id": source_node_id,
  1012. },
  1013. ]
  1014. },
  1015. update_doc,
  1016. upsert=True,
  1017. )
  1018. async def upsert_nodes_batch(self, nodes: list[tuple[str, dict[str, str]]]) -> None:
  1019. """Batch insert/update multiple nodes using a single bulk_write() call.
  1020. Args:
  1021. nodes: List of (node_id, node_data) tuples.
  1022. """
  1023. if not nodes:
  1024. return
  1025. ops = []
  1026. for node_id, node_data in nodes:
  1027. update_doc: dict = {"$set": {**node_data}}
  1028. if node_data.get("source_id", ""):
  1029. update_doc["$set"]["source_ids"] = node_data["source_id"].split(
  1030. GRAPH_FIELD_SEP
  1031. )
  1032. ops.append(UpdateOne({"_id": node_id}, update_doc, upsert=True))
  1033. await self.collection.bulk_write(ops, ordered=True)
  1034. async def has_nodes_batch(self, node_ids: list[str]) -> set[str]:
  1035. """Check existence of multiple nodes using a single $in query.
  1036. Args:
  1037. node_ids: List of node IDs to check.
  1038. Returns:
  1039. Set of node_ids that exist in the graph.
  1040. """
  1041. if not node_ids:
  1042. return set()
  1043. cursor = self.collection.find({"_id": {"$in": node_ids}}, {"_id": 1})
  1044. return {doc["_id"] async for doc in cursor}
  1045. async def upsert_edges_batch(
  1046. self, edges: list[tuple[str, str, dict[str, str]]]
  1047. ) -> None:
  1048. """Batch insert/update multiple edges using a single bulk_write() call.
  1049. Also ensures source nodes exist (matching upsert_edge() behaviour) via a
  1050. separate bulk_write on the node collection for any source nodes that need
  1051. to be created as empty placeholders.
  1052. Args:
  1053. edges: List of (source_node_id, target_node_id, edge_data) tuples.
  1054. """
  1055. if not edges:
  1056. return
  1057. # Ensure all source nodes exist (mirrors upsert_edge's upsert_node call)
  1058. source_node_ids = list(dict.fromkeys(src for src, _tgt, _data in edges))
  1059. node_ops = [
  1060. UpdateOne({"_id": src}, {"$setOnInsert": {"_id": src}}, upsert=True)
  1061. for src in source_node_ids
  1062. ]
  1063. await self.collection.bulk_write(node_ops, ordered=False)
  1064. edge_ops = []
  1065. for source_node_id, target_node_id, edge_data in edges:
  1066. update_doc: dict = {"$set": {**edge_data}}
  1067. if edge_data.get("source_id", ""):
  1068. update_doc["$set"]["source_ids"] = edge_data["source_id"].split(
  1069. GRAPH_FIELD_SEP
  1070. )
  1071. update_doc["$set"]["source_node_id"] = source_node_id
  1072. update_doc["$set"]["target_node_id"] = target_node_id
  1073. edge_ops.append(
  1074. UpdateOne(
  1075. {
  1076. "$or": [
  1077. {
  1078. "source_node_id": source_node_id,
  1079. "target_node_id": target_node_id,
  1080. },
  1081. {
  1082. "source_node_id": target_node_id,
  1083. "target_node_id": source_node_id,
  1084. },
  1085. ]
  1086. },
  1087. update_doc,
  1088. upsert=True,
  1089. )
  1090. )
  1091. await self.edge_collection.bulk_write(edge_ops, ordered=True)
  1092. #
  1093. # -------------------------------------------------------------------------
  1094. # DELETION
  1095. # -------------------------------------------------------------------------
  1096. #
  1097. async def delete_node(self, node_id: str) -> None:
  1098. """
  1099. 1) Remove node's doc entirely.
  1100. 2) Remove inbound & outbound edges from any doc that references node_id.
  1101. """
  1102. # Remove all edges
  1103. await self.edge_collection.delete_many(
  1104. {"$or": [{"source_node_id": node_id}, {"target_node_id": node_id}]}
  1105. )
  1106. # Remove the node doc
  1107. await self.collection.delete_one({"_id": node_id})
  1108. #
  1109. # -------------------------------------------------------------------------
  1110. # QUERY
  1111. # -------------------------------------------------------------------------
  1112. #
  1113. async def get_all_labels(self) -> list[str]:
  1114. """
  1115. Get all existing node _ids(entity names) in the database
  1116. Returns:
  1117. [id1, id2, ...] # Alphabetically sorted id list
  1118. """
  1119. # Use aggregation with allowDiskUse for large datasets
  1120. pipeline = [{"$project": {"_id": 1}}, {"$sort": {"_id": 1}}]
  1121. cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
  1122. labels = []
  1123. async for doc in cursor:
  1124. labels.append(doc["_id"])
  1125. return labels
  1126. def _construct_graph_node(
  1127. self, node_id, node_data: dict[str, str]
  1128. ) -> KnowledgeGraphNode:
  1129. return KnowledgeGraphNode(
  1130. id=node_id,
  1131. labels=[node_id],
  1132. properties={
  1133. k: v
  1134. for k, v in node_data.items()
  1135. if k
  1136. not in [
  1137. "_id",
  1138. "connected_edges",
  1139. "source_ids",
  1140. "edge_count",
  1141. ]
  1142. },
  1143. )
  1144. def _construct_graph_edge(self, edge_id: str, edge: dict[str, str]):
  1145. return KnowledgeGraphEdge(
  1146. id=edge_id,
  1147. type=edge.get("relationship", ""),
  1148. source=edge["source_node_id"],
  1149. target=edge["target_node_id"],
  1150. properties={
  1151. k: v
  1152. for k, v in edge.items()
  1153. if k
  1154. not in [
  1155. "_id",
  1156. "source_node_id",
  1157. "target_node_id",
  1158. "relationship",
  1159. "source_ids",
  1160. ]
  1161. },
  1162. )
  1163. async def _fetch_nodes_by_ids(
  1164. self, node_ids: list[str], projection: dict[str, int] | None = None
  1165. ) -> list[dict[str, Any]]:
  1166. """Fetch nodes by ID while preserving the requested order."""
  1167. if not node_ids:
  1168. return []
  1169. cursor = self.collection.find({"_id": {"$in": node_ids}}, projection)
  1170. docs_by_id = {}
  1171. async for doc in cursor:
  1172. docs_by_id[str(doc["_id"])] = doc
  1173. return [docs_by_id[node_id] for node_id in node_ids if node_id in docs_by_id]
  1174. async def get_knowledge_graph_all_by_degree(
  1175. self, max_depth: int, max_nodes: int
  1176. ) -> KnowledgeGraph:
  1177. """
  1178. It's possible that the node with one or multiple relationships is retrieved,
  1179. while its neighbor is not. Then this node might seem like disconnected in UI.
  1180. """
  1181. total_node_count = await self.collection.count_documents({})
  1182. result = KnowledgeGraph()
  1183. seen_edges = set()
  1184. result.is_truncated = total_node_count > max_nodes
  1185. if result.is_truncated:
  1186. # Get all node_ids ranked by degree if max_nodes exceeds total node count
  1187. pipeline = [
  1188. {"$project": {"source_node_id": 1, "_id": 0}},
  1189. {"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}},
  1190. {
  1191. "$unionWith": {
  1192. "coll": self._edge_collection_name,
  1193. "pipeline": [
  1194. {"$project": {"target_node_id": 1, "_id": 0}},
  1195. {
  1196. "$group": {
  1197. "_id": "$target_node_id",
  1198. "degree": {"$sum": 1},
  1199. }
  1200. },
  1201. ],
  1202. }
  1203. },
  1204. {"$group": {"_id": "$_id", "degree": {"$sum": "$degree"}}},
  1205. {"$sort": {"degree": -1}},
  1206. {"$limit": max_nodes},
  1207. ]
  1208. cursor = await self.edge_collection.aggregate(pipeline, allowDiskUse=True)
  1209. node_ids = []
  1210. async for doc in cursor:
  1211. node_id = str(doc["_id"])
  1212. node_ids.append(node_id)
  1213. if len(node_ids) < max_nodes:
  1214. remaining = max_nodes - len(node_ids)
  1215. cursor = self.collection.find(
  1216. {"_id": {"$nin": node_ids}},
  1217. {"source_ids": 0},
  1218. ).limit(remaining)
  1219. async for doc in cursor:
  1220. node_ids.append(str(doc["_id"]))
  1221. docs = await self._fetch_nodes_by_ids(node_ids, {"source_ids": 0})
  1222. for doc in docs:
  1223. result.nodes.append(self._construct_graph_node(doc["_id"], doc))
  1224. # As node count reaches the limit, only need to fetch the edges that directly connect to these nodes
  1225. edge_cursor = self.edge_collection.find(
  1226. {
  1227. "$and": [
  1228. {"source_node_id": {"$in": node_ids}},
  1229. {"target_node_id": {"$in": node_ids}},
  1230. ]
  1231. }
  1232. )
  1233. else:
  1234. # All nodes and edges are needed
  1235. cursor = self.collection.find({}, {"source_ids": 0})
  1236. async for doc in cursor:
  1237. node_id = str(doc["_id"])
  1238. result.nodes.append(self._construct_graph_node(doc["_id"], doc))
  1239. edge_cursor = self.edge_collection.find({})
  1240. async for edge in edge_cursor:
  1241. edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
  1242. if edge_id not in seen_edges:
  1243. seen_edges.add(edge_id)
  1244. result.edges.append(self._construct_graph_edge(edge_id, edge))
  1245. return result
  1246. async def _bidirectional_bfs_nodes(
  1247. self,
  1248. node_labels: list[str],
  1249. seen_nodes: set[str],
  1250. result: KnowledgeGraph,
  1251. depth: int,
  1252. max_depth: int,
  1253. max_nodes: int,
  1254. ) -> KnowledgeGraph:
  1255. if depth > max_depth or len(result.nodes) > max_nodes:
  1256. return result
  1257. cursor = self.collection.find({"_id": {"$in": node_labels}})
  1258. async for node in cursor:
  1259. node_id = node["_id"]
  1260. if node_id not in seen_nodes:
  1261. seen_nodes.add(node_id)
  1262. result.nodes.append(self._construct_graph_node(node_id, node))
  1263. if len(result.nodes) > max_nodes:
  1264. return result
  1265. # Collect neighbors
  1266. # Get both inbound and outbound one hop nodes
  1267. cursor = self.edge_collection.find(
  1268. {
  1269. "$or": [
  1270. {"source_node_id": {"$in": node_labels}},
  1271. {"target_node_id": {"$in": node_labels}},
  1272. ]
  1273. }
  1274. )
  1275. neighbor_nodes = []
  1276. async for edge in cursor:
  1277. if edge["source_node_id"] not in seen_nodes:
  1278. neighbor_nodes.append(edge["source_node_id"])
  1279. if edge["target_node_id"] not in seen_nodes:
  1280. neighbor_nodes.append(edge["target_node_id"])
  1281. if neighbor_nodes:
  1282. result = await self._bidirectional_bfs_nodes(
  1283. neighbor_nodes, seen_nodes, result, depth + 1, max_depth, max_nodes
  1284. )
  1285. return result
  1286. async def get_knowledge_subgraph_bidirectional_bfs(
  1287. self,
  1288. node_label: str,
  1289. depth: int,
  1290. max_depth: int,
  1291. max_nodes: int,
  1292. ) -> KnowledgeGraph:
  1293. seen_nodes = set()
  1294. seen_edges = set()
  1295. result = KnowledgeGraph()
  1296. result = await self._bidirectional_bfs_nodes(
  1297. [node_label], seen_nodes, result, depth, max_depth, max_nodes
  1298. )
  1299. # Get all edges from seen_nodes
  1300. all_node_ids = list(seen_nodes)
  1301. cursor = self.edge_collection.find(
  1302. {
  1303. "$and": [
  1304. {"source_node_id": {"$in": all_node_ids}},
  1305. {"target_node_id": {"$in": all_node_ids}},
  1306. ]
  1307. }
  1308. )
  1309. async for edge in cursor:
  1310. edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
  1311. if edge_id not in seen_edges:
  1312. result.edges.append(self._construct_graph_edge(edge_id, edge))
  1313. seen_edges.add(edge_id)
  1314. return result
  1315. async def get_knowledge_subgraph_in_out_bound_bfs(
  1316. self, node_label: str, max_depth: int, max_nodes: int
  1317. ) -> KnowledgeGraph:
  1318. seen_nodes = set()
  1319. seen_edges = set()
  1320. result = KnowledgeGraph()
  1321. project_doc = {
  1322. "source_ids": 0,
  1323. "created_at": 0,
  1324. "entity_type": 0,
  1325. "file_path": 0,
  1326. }
  1327. # Verify if starting node exists
  1328. start_node = await self.collection.find_one({"_id": node_label})
  1329. if not start_node:
  1330. logger.warning(
  1331. f"[{self.workspace}] Starting node with label {node_label} does not exist!"
  1332. )
  1333. return result
  1334. seen_nodes.add(node_label)
  1335. result.nodes.append(self._construct_graph_node(node_label, start_node))
  1336. if max_depth == 0:
  1337. return result
  1338. # In MongoDB, depth = 0 means one-hop
  1339. max_depth = max_depth - 1
  1340. pipeline = [
  1341. {"$match": {"_id": node_label}},
  1342. {"$project": project_doc},
  1343. {
  1344. "$graphLookup": {
  1345. "from": self._edge_collection_name,
  1346. "startWith": "$_id",
  1347. "connectFromField": "target_node_id",
  1348. "connectToField": "source_node_id",
  1349. "maxDepth": max_depth,
  1350. "depthField": "depth",
  1351. "as": "connected_edges",
  1352. },
  1353. },
  1354. {
  1355. "$unionWith": {
  1356. "coll": self._collection_name,
  1357. "pipeline": [
  1358. {"$match": {"_id": node_label}},
  1359. {"$project": project_doc},
  1360. {
  1361. "$graphLookup": {
  1362. "from": self._edge_collection_name,
  1363. "startWith": "$_id",
  1364. "connectFromField": "source_node_id",
  1365. "connectToField": "target_node_id",
  1366. "maxDepth": max_depth,
  1367. "depthField": "depth",
  1368. "as": "connected_edges",
  1369. }
  1370. },
  1371. ],
  1372. }
  1373. },
  1374. ]
  1375. cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
  1376. node_edges = []
  1377. # Two records for node_label are returned capturing outbound and inbound connected_edges
  1378. async for doc in cursor:
  1379. if doc.get("connected_edges", []):
  1380. node_edges.extend(doc.get("connected_edges"))
  1381. # Sort the connected edges by depth ascending and weight descending
  1382. # And stores the source_node_id and target_node_id in sequence to retrieve the neighbouring nodes
  1383. node_edges = sorted(
  1384. node_edges,
  1385. key=lambda x: (x["depth"], -x["weight"]),
  1386. )
  1387. # As order matters, we need to use another list to store the node_id
  1388. # And only take the first max_nodes ones
  1389. node_ids = []
  1390. for edge in node_edges:
  1391. if len(node_ids) < max_nodes and edge["source_node_id"] not in seen_nodes:
  1392. node_ids.append(edge["source_node_id"])
  1393. seen_nodes.add(edge["source_node_id"])
  1394. if len(node_ids) < max_nodes and edge["target_node_id"] not in seen_nodes:
  1395. node_ids.append(edge["target_node_id"])
  1396. seen_nodes.add(edge["target_node_id"])
  1397. # Filter out all the node whose id is same as node_label so that we do not check existence next step
  1398. cursor = self.collection.find({"_id": {"$in": node_ids}})
  1399. async for doc in cursor:
  1400. result.nodes.append(self._construct_graph_node(str(doc["_id"]), doc))
  1401. for edge in node_edges:
  1402. if (
  1403. edge["source_node_id"] not in seen_nodes
  1404. or edge["target_node_id"] not in seen_nodes
  1405. ):
  1406. continue
  1407. edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
  1408. if edge_id not in seen_edges:
  1409. result.edges.append(self._construct_graph_edge(edge_id, edge))
  1410. seen_edges.add(edge_id)
  1411. return result
  1412. async def get_knowledge_graph(
  1413. self,
  1414. node_label: str,
  1415. max_depth: int = 3,
  1416. max_nodes: int = None,
  1417. ) -> KnowledgeGraph:
  1418. """
  1419. Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
  1420. Args:
  1421. node_label: Label of the starting node, * means all nodes
  1422. max_depth: Maximum depth of the subgraph, Defaults to 3
  1423. max_nodes: Maximum nodes to return, Defaults to global_config max_graph_nodes
  1424. Returns:
  1425. KnowledgeGraph object containing nodes and edges, with an is_truncated flag
  1426. indicating whether the graph was truncated due to max_nodes limit
  1427. If a graph is like this and starting from B:
  1428. A → B ← C ← F, B -> E, C → D
  1429. Outbound BFS:
  1430. B → E
  1431. Inbound BFS:
  1432. A → B
  1433. C → B
  1434. F → C
  1435. Bidirectional BFS:
  1436. A → B
  1437. B → E
  1438. F → C
  1439. C → B
  1440. C → D
  1441. """
  1442. # Use global_config max_graph_nodes as default if max_nodes is None
  1443. if max_nodes is None:
  1444. max_nodes = self.global_config.get("max_graph_nodes", 1000)
  1445. else:
  1446. # Limit max_nodes to not exceed global_config max_graph_nodes
  1447. max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
  1448. result = KnowledgeGraph()
  1449. start = time.perf_counter()
  1450. try:
  1451. # Optimize pipeline to avoid memory issues with large datasets
  1452. if node_label == "*":
  1453. result = await self.get_knowledge_graph_all_by_degree(
  1454. max_depth, max_nodes
  1455. )
  1456. elif GRAPH_BFS_MODE == "in_out_bound":
  1457. result = await self.get_knowledge_subgraph_in_out_bound_bfs(
  1458. node_label, max_depth, max_nodes
  1459. )
  1460. else:
  1461. result = await self.get_knowledge_subgraph_bidirectional_bfs(
  1462. node_label, 0, max_depth, max_nodes
  1463. )
  1464. duration = time.perf_counter() - start
  1465. logger.info(
  1466. f"[{self.workspace}] Subgraph query successful in {duration:.4f} seconds | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
  1467. )
  1468. except PyMongoError as e:
  1469. # Handle memory limit errors specifically
  1470. if "memory limit" in str(e).lower() or "sort exceeded" in str(e).lower():
  1471. logger.warning(
  1472. f"[{self.workspace}] MongoDB memory limit exceeded, falling back to simple query: {str(e)}"
  1473. )
  1474. # Fallback to a simple query without complex aggregation
  1475. try:
  1476. simple_cursor = self.collection.find({}).limit(max_nodes)
  1477. async for doc in simple_cursor:
  1478. result.nodes.append(
  1479. self._construct_graph_node(str(doc["_id"]), doc)
  1480. )
  1481. result.is_truncated = True
  1482. logger.info(
  1483. f"[{self.workspace}] Fallback query completed | Node count: {len(result.nodes)}"
  1484. )
  1485. except PyMongoError as fallback_error:
  1486. logger.error(
  1487. f"[{self.workspace}] Fallback query also failed: {str(fallback_error)}"
  1488. )
  1489. else:
  1490. logger.error(f"[{self.workspace}] MongoDB query failed: {str(e)}")
  1491. return result
  1492. async def index_done_callback(self) -> None:
  1493. # Mongo handles persistence automatically
  1494. pass
  1495. async def remove_nodes(self, nodes: list[str]) -> None:
  1496. """Delete multiple nodes
  1497. Args:
  1498. nodes: List of node IDs to be deleted
  1499. """
  1500. logger.info(f"[{self.workspace}] Deleting {len(nodes)} nodes")
  1501. if not nodes:
  1502. return
  1503. # 1. Remove all edges referencing these nodes
  1504. await self.edge_collection.delete_many(
  1505. {
  1506. "$or": [
  1507. {"source_node_id": {"$in": nodes}},
  1508. {"target_node_id": {"$in": nodes}},
  1509. ]
  1510. }
  1511. )
  1512. # 2. Delete the node documents
  1513. await self.collection.delete_many({"_id": {"$in": nodes}})
  1514. logger.debug(f"[{self.workspace}] Successfully deleted nodes: {nodes}")
  1515. async def remove_edges(self, edges: list[tuple[str, str]]) -> None:
  1516. """Delete multiple edges
  1517. Args:
  1518. edges: List of edges to be deleted, each edge is a (source, target) tuple
  1519. """
  1520. logger.info(f"[{self.workspace}] Deleting {len(edges)} edges")
  1521. if not edges:
  1522. return
  1523. all_edge_pairs = []
  1524. for source_id, target_id in edges:
  1525. all_edge_pairs.append(
  1526. {"source_node_id": source_id, "target_node_id": target_id}
  1527. )
  1528. all_edge_pairs.append(
  1529. {"source_node_id": target_id, "target_node_id": source_id}
  1530. )
  1531. await self.edge_collection.delete_many({"$or": all_edge_pairs})
  1532. logger.debug(f"[{self.workspace}] Successfully deleted edges: {edges}")
  1533. async def get_all_nodes(self) -> list[dict]:
  1534. """Get all nodes in the graph.
  1535. Returns:
  1536. A list of all nodes, where each node is a dictionary of its properties
  1537. """
  1538. cursor = self.collection.find({})
  1539. nodes = []
  1540. async for node in cursor:
  1541. node_dict = dict(node)
  1542. # Add node id (entity_id) to the dictionary for easier access
  1543. node_dict["id"] = node_dict.get("_id")
  1544. nodes.append(node_dict)
  1545. return nodes
  1546. async def get_all_edges(self) -> list[dict]:
  1547. """Get all edges in the graph.
  1548. Returns:
  1549. A list of all edges, where each edge is a dictionary of its properties
  1550. """
  1551. cursor = self.edge_collection.find({})
  1552. edges = []
  1553. async for edge in cursor:
  1554. edge_dict = dict(edge)
  1555. edge_dict["source"] = edge_dict.get("source_node_id")
  1556. edge_dict["target"] = edge_dict.get("target_node_id")
  1557. edges.append(edge_dict)
  1558. return edges
  1559. async def get_popular_labels(self, limit: int = 300) -> list[str]:
  1560. """Get popular labels(entity names) by node degree (most connected entities)
  1561. Args:
  1562. limit: Maximum number of labels to return
  1563. Returns:
  1564. List of labels(entity names) sorted by degree (highest first)
  1565. """
  1566. try:
  1567. # Use aggregation pipeline to count edges per node and sort by degree
  1568. pipeline = [
  1569. # Count outbound edges
  1570. {"$group": {"_id": "$source_node_id", "out_degree": {"$sum": 1}}},
  1571. # Union with inbound edges count
  1572. {
  1573. "$unionWith": {
  1574. "coll": self._edge_collection_name,
  1575. "pipeline": [
  1576. {
  1577. "$group": {
  1578. "_id": "$target_node_id",
  1579. "in_degree": {"$sum": 1},
  1580. }
  1581. }
  1582. ],
  1583. }
  1584. },
  1585. # Group by node_id and sum degrees
  1586. {
  1587. "$group": {
  1588. "_id": "$_id",
  1589. "total_degree": {
  1590. "$sum": {
  1591. "$add": [
  1592. {"$ifNull": ["$out_degree", 0]},
  1593. {"$ifNull": ["$in_degree", 0]},
  1594. ]
  1595. }
  1596. },
  1597. }
  1598. },
  1599. # Sort by degree descending, then by label ascending
  1600. {"$sort": {"total_degree": -1, "_id": 1}},
  1601. # Limit results
  1602. {"$limit": limit},
  1603. # Project only the label
  1604. {"$project": {"_id": 1}},
  1605. ]
  1606. cursor = await self.edge_collection.aggregate(pipeline, allowDiskUse=True)
  1607. labels = []
  1608. async for doc in cursor:
  1609. if doc.get("_id"):
  1610. labels.append(doc["_id"])
  1611. logger.debug(
  1612. f"[{self.workspace}] Retrieved {len(labels)} popular labels (limit: {limit})"
  1613. )
  1614. return labels
  1615. except Exception as e:
  1616. logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}")
  1617. return []
  1618. async def _try_atlas_text_search(self, query_strip: str, limit: int) -> list[str]:
  1619. """Try Atlas Search using simple text search."""
  1620. try:
  1621. pipeline = [
  1622. {
  1623. "$search": {
  1624. "index": "entity_id_search_idx",
  1625. "text": {"query": query_strip, "path": "_id"},
  1626. }
  1627. },
  1628. {"$project": {"_id": 1, "score": {"$meta": "searchScore"}}},
  1629. {"$limit": limit},
  1630. ]
  1631. cursor = await self.collection.aggregate(pipeline)
  1632. labels = [doc["_id"] async for doc in cursor if doc.get("_id")]
  1633. if labels:
  1634. logger.debug(
  1635. f"[{self.workspace}] Atlas text search returned {len(labels)} results"
  1636. )
  1637. return labels
  1638. return []
  1639. except PyMongoError as e:
  1640. logger.debug(f"[{self.workspace}] Atlas text search failed: {e}")
  1641. return []
  1642. async def _try_atlas_autocomplete_search(
  1643. self, query_strip: str, limit: int
  1644. ) -> list[str]:
  1645. """Try Atlas Search using autocomplete for prefix matching."""
  1646. try:
  1647. pipeline = [
  1648. {
  1649. "$search": {
  1650. "index": "entity_id_search_idx",
  1651. "autocomplete": {
  1652. "query": query_strip,
  1653. "path": "_id",
  1654. "fuzzy": {"maxEdits": 1, "prefixLength": 1},
  1655. },
  1656. }
  1657. },
  1658. {"$project": {"_id": 1, "score": {"$meta": "searchScore"}}},
  1659. {"$limit": limit},
  1660. ]
  1661. cursor = await self.collection.aggregate(pipeline)
  1662. labels = [doc["_id"] async for doc in cursor if doc.get("_id")]
  1663. if labels:
  1664. logger.debug(
  1665. f"[{self.workspace}] Atlas autocomplete search returned {len(labels)} results"
  1666. )
  1667. return labels
  1668. return []
  1669. except PyMongoError as e:
  1670. logger.debug(f"[{self.workspace}] Atlas autocomplete search failed: {e}")
  1671. return []
  1672. async def _try_atlas_compound_search(
  1673. self, query_strip: str, limit: int
  1674. ) -> list[str]:
  1675. """Try Atlas Search using compound query for comprehensive matching."""
  1676. try:
  1677. pipeline = [
  1678. {
  1679. "$search": {
  1680. "index": "entity_id_search_idx",
  1681. "compound": {
  1682. "should": [
  1683. {
  1684. "text": {
  1685. "query": query_strip,
  1686. "path": "_id",
  1687. "score": {"boost": {"value": 10}},
  1688. }
  1689. },
  1690. {
  1691. "autocomplete": {
  1692. "query": query_strip,
  1693. "path": "_id",
  1694. "score": {"boost": {"value": 5}},
  1695. "fuzzy": {"maxEdits": 1, "prefixLength": 1},
  1696. }
  1697. },
  1698. {
  1699. "wildcard": {
  1700. "query": f"*{query_strip}*",
  1701. "path": "_id",
  1702. "score": {"boost": {"value": 2}},
  1703. }
  1704. },
  1705. ],
  1706. "minimumShouldMatch": 1,
  1707. },
  1708. }
  1709. },
  1710. {"$project": {"_id": 1, "score": {"$meta": "searchScore"}}},
  1711. {"$sort": {"score": {"$meta": "searchScore"}}},
  1712. {"$limit": limit},
  1713. ]
  1714. cursor = await self.collection.aggregate(pipeline)
  1715. labels = [doc["_id"] async for doc in cursor if doc.get("_id")]
  1716. if labels:
  1717. logger.debug(
  1718. f"[{self.workspace}] Atlas compound search returned {len(labels)} results"
  1719. )
  1720. return labels
  1721. return []
  1722. except PyMongoError as e:
  1723. logger.debug(f"[{self.workspace}] Atlas compound search failed: {e}")
  1724. return []
  1725. async def _fallback_regex_search(self, query_strip: str, limit: int) -> list[str]:
  1726. """Fallback to regex-based search when Atlas Search fails."""
  1727. try:
  1728. logger.debug(
  1729. f"[{self.workspace}] Using regex fallback search for: '{query_strip}'"
  1730. )
  1731. escaped_query = re.escape(query_strip)
  1732. regex_condition = {"_id": {"$regex": escaped_query, "$options": "i"}}
  1733. cursor = self.collection.find(regex_condition, {"_id": 1}).limit(limit * 2)
  1734. docs = await cursor.to_list(length=limit * 2)
  1735. # Extract labels
  1736. labels = []
  1737. for doc in docs:
  1738. doc_id = doc.get("_id")
  1739. if doc_id:
  1740. labels.append(doc_id)
  1741. # Sort results to prioritize exact matches and starts-with matches
  1742. def sort_key(label):
  1743. label_lower = label.lower()
  1744. query_lower_strip = query_strip.lower()
  1745. if label_lower == query_lower_strip:
  1746. return (0, label_lower) # Exact match - highest priority
  1747. elif label_lower.startswith(query_lower_strip):
  1748. return (1, label_lower) # Starts with - medium priority
  1749. else:
  1750. return (2, label_lower) # Contains - lowest priority
  1751. labels.sort(key=sort_key)
  1752. labels = labels[:limit] # Apply final limit after sorting
  1753. logger.debug(
  1754. f"[{self.workspace}] Regex fallback search returned {len(labels)} results (limit: {limit})"
  1755. )
  1756. return labels
  1757. except Exception as e:
  1758. logger.error(f"[{self.workspace}] Regex fallback search failed: {e}")
  1759. import traceback
  1760. logger.error(f"[{self.workspace}] Traceback: {traceback.format_exc()}")
  1761. return []
  1762. async def search_labels(self, query: str, limit: int = 50) -> list[str]:
  1763. """
  1764. Search labels(entity names) with progressive fallback strategy:
  1765. 1. Atlas text search (simple and fast)
  1766. 2. Atlas autocomplete search (prefix matching with fuzzy)
  1767. 3. Atlas compound search (comprehensive matching)
  1768. 4. Regex fallback (when Atlas Search is unavailable)
  1769. """
  1770. query_strip = query.strip()
  1771. if not query_strip:
  1772. return []
  1773. # First check if we have any nodes at all
  1774. try:
  1775. node_count = await self.collection.count_documents({})
  1776. if node_count == 0:
  1777. logger.debug(
  1778. f"[{self.workspace}] No nodes found in collection {self._collection_name}"
  1779. )
  1780. return []
  1781. except PyMongoError as e:
  1782. logger.error(f"[{self.workspace}] Error counting nodes: {e}")
  1783. return []
  1784. # Progressive search strategy
  1785. search_methods = [
  1786. ("text", self._try_atlas_text_search),
  1787. ("autocomplete", self._try_atlas_autocomplete_search),
  1788. ("compound", self._try_atlas_compound_search),
  1789. ]
  1790. # Try Atlas Search methods in order
  1791. for method_name, search_method in search_methods:
  1792. try:
  1793. labels = await search_method(query_strip, limit)
  1794. if labels:
  1795. logger.debug(
  1796. f"[{self.workspace}] Search successful using {method_name} method: {len(labels)} results"
  1797. )
  1798. return labels
  1799. else:
  1800. logger.debug(
  1801. f"[{self.workspace}] {method_name} search returned no results, trying next method"
  1802. )
  1803. except Exception as e:
  1804. logger.debug(
  1805. f"[{self.workspace}] {method_name} search failed: {e}, trying next method"
  1806. )
  1807. continue
  1808. # If all Atlas Search methods fail, use regex fallback
  1809. logger.info(
  1810. f"[{self.workspace}] All Atlas Search methods failed, using regex fallback search for: '{query_strip}'"
  1811. )
  1812. return await self._fallback_regex_search(query_strip, limit)
  1813. async def _check_if_index_needs_rebuild(
  1814. self, indexes: list, index_name: str
  1815. ) -> bool:
  1816. """Check if the existing index needs to be rebuilt due to configuration issues."""
  1817. for index in indexes:
  1818. if index["name"] == index_name:
  1819. # Check if the index has the old problematic configuration
  1820. definition = index.get("latestDefinition", {})
  1821. mappings = definition.get("mappings", {})
  1822. fields = mappings.get("fields", {})
  1823. id_field = fields.get("_id", {})
  1824. # If it's the old single-type autocomplete configuration, rebuild
  1825. if (
  1826. isinstance(id_field, dict)
  1827. and id_field.get("type") == "autocomplete"
  1828. ):
  1829. logger.info(
  1830. f"[{self.workspace}] Found old index configuration for '{index_name}', will rebuild"
  1831. )
  1832. return True
  1833. # If it's not a list (multi-type configuration), rebuild
  1834. if not isinstance(id_field, list):
  1835. logger.info(
  1836. f"[{self.workspace}] Index '{index_name}' needs upgrade to multi-type configuration"
  1837. )
  1838. return True
  1839. logger.info(
  1840. f"[{self.workspace}] Index '{index_name}' has correct configuration"
  1841. )
  1842. return False
  1843. return True # Index doesn't exist, needs creation
  1844. async def _safely_drop_old_index(self, index_name: str):
  1845. """Safely drop the old search index."""
  1846. try:
  1847. await self.collection.drop_search_index(index_name)
  1848. logger.info(
  1849. f"[{self.workspace}] Successfully dropped old search index '{index_name}'"
  1850. )
  1851. except PyMongoError as e:
  1852. logger.warning(
  1853. f"[{self.workspace}] Could not drop old index '{index_name}': {e}"
  1854. )
  1855. async def _create_improved_search_index(self, index_name: str):
  1856. """Create an improved search index with multiple field types."""
  1857. search_index_model = SearchIndexModel(
  1858. definition={
  1859. "mappings": {
  1860. "dynamic": False,
  1861. "fields": {
  1862. "_id": [
  1863. {
  1864. "type": "string",
  1865. },
  1866. {
  1867. "type": "token",
  1868. },
  1869. {
  1870. "type": "autocomplete",
  1871. "maxGrams": 15,
  1872. "minGrams": 2,
  1873. },
  1874. ]
  1875. },
  1876. },
  1877. "analyzer": "lucene.standard", # Index-level analyzer for text processing
  1878. },
  1879. name=index_name,
  1880. type="search",
  1881. )
  1882. await self.collection.create_search_index(search_index_model)
  1883. logger.info(
  1884. f"[{self.workspace}] Created improved Atlas Search index '{index_name}' for collection {self._collection_name}. "
  1885. )
  1886. logger.info(
  1887. f"[{self.workspace}] Index will be built asynchronously, using regex fallback until ready."
  1888. )
  1889. async def create_search_index_if_not_exists(self):
  1890. """Creates an improved Atlas Search index for entity search, rebuilding if necessary."""
  1891. index_name = "entity_id_search_idx"
  1892. try:
  1893. # Check if we're using MongoDB Atlas (has search index capabilities)
  1894. indexes_cursor = await self.collection.list_search_indexes()
  1895. indexes = await indexes_cursor.to_list(length=None)
  1896. # Check if we need to rebuild the index
  1897. needs_rebuild = await self._check_if_index_needs_rebuild(
  1898. indexes, index_name
  1899. )
  1900. if needs_rebuild:
  1901. # Check if index exists and drop it
  1902. index_exists = any(idx["name"] == index_name for idx in indexes)
  1903. if index_exists:
  1904. await self._safely_drop_old_index(index_name)
  1905. # Create the improved search index (async, no waiting)
  1906. await self._create_improved_search_index(index_name)
  1907. else:
  1908. logger.info(
  1909. f"[{self.workspace}] Atlas Search index '{index_name}' already exists with correct configuration"
  1910. )
  1911. except PyMongoError as e:
  1912. # This is expected if not using MongoDB Atlas or if search indexes are not supported
  1913. logger.info(
  1914. f"[{self.workspace}] Could not create Atlas Search index for {self._collection_name}: {e}. "
  1915. "This is normal if not using MongoDB Atlas - search will use regex fallback."
  1916. )
  1917. except Exception as e:
  1918. logger.warning(
  1919. f"[{self.workspace}] Unexpected error creating Atlas Search index for {self._collection_name}: {e}"
  1920. )
  1921. async def drop(self) -> dict[str, str]:
  1922. """Drop the storage by removing all documents in the collection.
  1923. Returns:
  1924. dict[str, str]: Status of the operation with keys 'status' and 'message'
  1925. """
  1926. try:
  1927. result = await self.collection.delete_many({})
  1928. deleted_count = result.deleted_count
  1929. logger.info(
  1930. f"[{self.workspace}] Dropped {deleted_count} documents from graph {self._collection_name}"
  1931. )
  1932. result = await self.edge_collection.delete_many({})
  1933. edge_count = result.deleted_count
  1934. logger.info(
  1935. f"[{self.workspace}] Dropped {edge_count} edges from graph {self._edge_collection_name}"
  1936. )
  1937. return {
  1938. "status": "success",
  1939. "message": f"{deleted_count} documents and {edge_count} edges dropped",
  1940. }
  1941. except PyMongoError as e:
  1942. logger.error(
  1943. f"[{self.workspace}] Error dropping graph {self._collection_name}: {e}"
  1944. )
  1945. return {"status": "error", "message": str(e)}
  1946. @dataclass
  1947. class _PendingVectorDoc:
  1948. """Buffered vector upsert waiting for embedding and/or bulk flush."""
  1949. source: dict[str, Any]
  1950. content: str
  1951. vector: list[float] | None = None
  1952. @final
  1953. @dataclass
  1954. class MongoVectorDBStorage(BaseVectorStorage):
  1955. db: AsyncDatabase | None = field(default=None)
  1956. _data: AsyncCollection | None = field(default=None)
  1957. _index_name: str = field(default="", init=False)
  1958. def __init__(
  1959. self, namespace, global_config, embedding_func, workspace=None, meta_fields=None
  1960. ):
  1961. super().__init__(
  1962. namespace=namespace,
  1963. workspace=workspace or "",
  1964. global_config=global_config,
  1965. embedding_func=embedding_func,
  1966. meta_fields=meta_fields or set(),
  1967. )
  1968. self.__post_init__()
  1969. def __post_init__(self):
  1970. self._validate_embedding_func()
  1971. # Check for MONGODB_WORKSPACE environment variable first (higher priority)
  1972. # This allows administrators to force a specific workspace for all MongoDB storage instances
  1973. mongodb_workspace = os.environ.get("MONGODB_WORKSPACE")
  1974. if mongodb_workspace and mongodb_workspace.strip():
  1975. # Use environment variable value, overriding the passed workspace parameter
  1976. effective_workspace = mongodb_workspace.strip()
  1977. logger.info(
  1978. f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')"
  1979. )
  1980. else:
  1981. # Use the workspace parameter passed during initialization
  1982. effective_workspace = self.workspace
  1983. if effective_workspace:
  1984. logger.debug(
  1985. f"Using passed workspace parameter: '{effective_workspace}'"
  1986. )
  1987. # Build final_namespace with workspace prefix for data isolation
  1988. # Keep original namespace unchanged for type detection logic
  1989. if effective_workspace:
  1990. self.final_namespace = f"{effective_workspace}_{self.namespace}"
  1991. self.workspace = effective_workspace
  1992. logger.debug(
  1993. f"Final namespace with workspace prefix: '{self.final_namespace}'"
  1994. )
  1995. else:
  1996. # When workspace is empty, final_namespace equals original namespace
  1997. self.final_namespace = self.namespace
  1998. self.workspace = ""
  1999. logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
  2000. # Set index name based on workspace for backward compatibility
  2001. if effective_workspace:
  2002. # Use collection-specific index name for workspaced collections to avoid conflicts
  2003. self._index_name = f"vector_knn_index_{self.final_namespace}"
  2004. else:
  2005. # Keep original index name for backward compatibility with existing deployments
  2006. self._index_name = "vector_knn_index"
  2007. kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
  2008. cosine_threshold = kwargs.get("cosine_better_than_threshold")
  2009. if cosine_threshold is None:
  2010. raise ValueError(
  2011. "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
  2012. )
  2013. self.cosine_better_than_threshold = cosine_threshold
  2014. self._collection_name = self.final_namespace
  2015. self._max_batch_size = self.global_config["embedding_batch_num"]
  2016. # Deferred-embedding buffers and the per-namespace flush lock.
  2017. # Constructed in initialize() once shared-storage primitives are
  2018. # available; keyed on final_namespace so two instances pointing at
  2019. # the same MongoDB collection (e.g. with the MONGODB_WORKSPACE env
  2020. # override) share a single writer lock.
  2021. self._pending_vector_docs: dict[str, _PendingVectorDoc] = {}
  2022. self._pending_vector_deletes: set[str] = set()
  2023. self._flush_lock = None
  2024. async def initialize(self):
  2025. async with get_data_init_lock():
  2026. if self.db is None:
  2027. self.db = await ClientManager.get_client()
  2028. self._data = await get_or_create_collection(self.db, self._collection_name)
  2029. # Ensure vector index exists
  2030. await self.create_vector_index_if_not_exists()
  2031. logger.debug(
  2032. f"[{self.workspace}] Use MongoDB as VDB {self._collection_name}"
  2033. )
  2034. if self._flush_lock is None:
  2035. self._flush_lock = get_namespace_lock(
  2036. namespace=self.final_namespace, workspace=""
  2037. )
  2038. async def finalize(self):
  2039. """Flush pending vector ops, release the Mongo client, surface unflushed data."""
  2040. flush_error: Exception | None = None
  2041. try:
  2042. await self._flush_pending_vector_ops()
  2043. except Exception as e:
  2044. flush_error = e
  2045. if self.db is not None:
  2046. await ClientManager.release_client(self.db)
  2047. self.db = None
  2048. self._data = None
  2049. pending_docs = len(self._pending_vector_docs)
  2050. pending_deletes = len(self._pending_vector_deletes)
  2051. if flush_error is not None:
  2052. raise RuntimeError(
  2053. f"[{self.workspace}] MongoVectorDBStorage.finalize() flush raised; "
  2054. f"{pending_docs} pending upserts and {pending_deletes} pending "
  2055. f"deletes were left buffered (client released, data lost)"
  2056. ) from flush_error
  2057. if pending_docs or pending_deletes:
  2058. raise RuntimeError(
  2059. f"[{self.workspace}] MongoVectorDBStorage.finalize() left "
  2060. f"{pending_docs} pending upserts and {pending_deletes} pending "
  2061. f"deletes buffered after final flush attempt (these writes have been lost)"
  2062. )
  2063. async def create_vector_index_if_not_exists(self):
  2064. """Creates an Atlas Vector Search index."""
  2065. try:
  2066. indexes_cursor = await self._data.list_search_indexes()
  2067. indexes = await indexes_cursor.to_list(length=None)
  2068. for index in indexes:
  2069. if index["name"] == self._index_name:
  2070. # Check if the existing index has matching vector dimensions
  2071. existing_dim = None
  2072. definition = index.get("latestDefinition", {})
  2073. fields = definition.get("fields", [])
  2074. for field in fields:
  2075. if (
  2076. field.get("type") == "vector"
  2077. and field.get("path") == "vector"
  2078. ):
  2079. existing_dim = field.get("numDimensions")
  2080. break
  2081. expected_dim = self.embedding_func.embedding_dim
  2082. if existing_dim is not None and existing_dim != expected_dim:
  2083. error_msg = (
  2084. f"Vector dimension mismatch! Index '{self._index_name}' has "
  2085. f"dimension {existing_dim}, but current embedding model expects "
  2086. f"dimension {expected_dim}. Please drop the existing index or "
  2087. f"use an embedding model with matching dimensions."
  2088. )
  2089. logger.error(f"[{self.workspace}] {error_msg}")
  2090. raise ValueError(error_msg)
  2091. logger.info(
  2092. f"[{self.workspace}] vector index {self._index_name} already exists with matching dimensions ({expected_dim})"
  2093. )
  2094. return
  2095. search_index_model = SearchIndexModel(
  2096. definition={
  2097. "fields": [
  2098. {
  2099. "type": "vector",
  2100. "numDimensions": self.embedding_func.embedding_dim, # Ensure correct dimensions
  2101. "path": "vector",
  2102. "similarity": "cosine", # Options: euclidean, cosine, dotProduct
  2103. }
  2104. ]
  2105. },
  2106. name=self._index_name,
  2107. type="vectorSearch",
  2108. )
  2109. await self._data.create_search_index(search_index_model)
  2110. logger.info(
  2111. f"[{self.workspace}] Vector index {self._index_name} created successfully."
  2112. )
  2113. except PyMongoError as e:
  2114. error_msg = f"[{self.workspace}] Error creating vector index {self._index_name}: {e}"
  2115. logger.error(error_msg)
  2116. raise SystemExit(
  2117. f"Failed to create MongoDB vector index. Program cannot continue. {error_msg}"
  2118. )
  2119. async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
  2120. """Buffer vector docs for embedding and batched flush.
  2121. Embedding deliberately does NOT happen here: repeated upserts of
  2122. the same id, or many small batches, collapse into a single
  2123. flush-time embedding pass. Reads observe pending docs via the
  2124. same lock for read-your-writes.
  2125. """
  2126. if not data:
  2127. return
  2128. current_time = int(time.time())
  2129. pending_docs: list[tuple[str, _PendingVectorDoc]] = []
  2130. for i, (k, v) in enumerate(data.items(), start=1):
  2131. source = {
  2132. "_id": k,
  2133. "created_at": current_time,
  2134. **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
  2135. }
  2136. pending_docs.append(
  2137. (
  2138. k,
  2139. _PendingVectorDoc(source=source, content=v["content"]),
  2140. )
  2141. )
  2142. await _cooperative_yield(i)
  2143. # Installing a fresh _PendingVectorDoc invalidates any vector
  2144. # cached by a prior get_vectors_by_ids() call on a stale revision.
  2145. async with self._flush_lock:
  2146. for doc_id, pdoc in pending_docs:
  2147. self._pending_vector_deletes.discard(doc_id)
  2148. self._pending_vector_docs[doc_id] = pdoc
  2149. async def query(
  2150. self, query: str, top_k: int, query_embedding: list[float] = None
  2151. ) -> list[dict[str, Any]]:
  2152. """Queries the vector database using Atlas Vector Search.
  2153. Reads from the server-side index only; buffered upserts and deletes
  2154. are NOT visible until ``index_done_callback`` / ``finalize`` flushes
  2155. them. Callers that need read-your-writes for a freshly upserted id
  2156. should use ``get_by_id`` / ``get_by_ids`` (which consult the buffer)
  2157. or flush first. Matches the deferred-embedding contract used by
  2158. OpenSearch / FAISS / Nano.
  2159. """
  2160. if query_embedding is not None:
  2161. # Convert numpy array to list if needed for MongoDB compatibility
  2162. if hasattr(query_embedding, "tolist"):
  2163. query_vector = query_embedding.tolist()
  2164. else:
  2165. query_vector = list(query_embedding)
  2166. else:
  2167. # Generate the embedding
  2168. embedding = await self.embedding_func(
  2169. [query], context="query", _priority=5
  2170. ) # higher priority for query
  2171. # Convert numpy array to a list to ensure compatibility with MongoDB
  2172. query_vector = embedding[0].tolist()
  2173. # Define the aggregation pipeline with the converted query vector
  2174. pipeline = [
  2175. {
  2176. "$vectorSearch": {
  2177. "index": self._index_name, # Use stored index name for consistency
  2178. "path": "vector",
  2179. "queryVector": query_vector,
  2180. "numCandidates": 100, # Adjust for performance
  2181. "limit": top_k,
  2182. }
  2183. },
  2184. {"$addFields": {"score": {"$meta": "vectorSearchScore"}}},
  2185. {"$match": {"score": {"$gte": self.cosine_better_than_threshold}}},
  2186. {"$project": {"vector": 0}},
  2187. ]
  2188. # Execute the aggregation pipeline
  2189. cursor = await self._data.aggregate(pipeline, allowDiskUse=True)
  2190. results = await cursor.to_list(length=None)
  2191. # Format and return the results with created_at field
  2192. return [
  2193. {
  2194. **doc,
  2195. "id": doc["_id"],
  2196. "distance": doc.get("score", None),
  2197. "created_at": doc.get("created_at"), # Include created_at field
  2198. }
  2199. for doc in results
  2200. ]
  2201. async def index_done_callback(self) -> None:
  2202. """Flush buffered vector ops; Mongo persists automatically once written."""
  2203. await self._flush_pending_vector_ops()
  2204. async def _flush_pending_vector_ops(self) -> None:
  2205. """Flush buffered vector upserts and deletes via a single bulk_write.
  2206. Embedding runs *inside* this lock (not in `upsert` or lock-free):
  2207. it makes deferred embedding and the bulk write atomic against
  2208. concurrent upserts and destructive mutations. Any failure (embed
  2209. or server write) raises and leaves both buffers intact; the next
  2210. `index_done_callback` retries automatically.
  2211. Concurrency invariant: ``_flush_lock`` is a non-reentrant asyncio
  2212. lock. Callers MUST NOT hold it when invoking this method --
  2213. re-entry would deadlock. The only in-tree callers are
  2214. ``index_done_callback`` and ``finalize``, both lock-free.
  2215. """
  2216. async with self._flush_lock:
  2217. if not self._pending_vector_docs and not self._pending_vector_deletes:
  2218. return
  2219. if self._data is None:
  2220. return
  2221. pending_docs = self._pending_vector_docs
  2222. pending_deletes = self._pending_vector_deletes
  2223. docs_to_embed: list[tuple[str, _PendingVectorDoc]] = [
  2224. (doc_id, pdoc)
  2225. for doc_id, pdoc in pending_docs.items()
  2226. if pdoc.vector is None
  2227. ]
  2228. if docs_to_embed:
  2229. contents = [pdoc.content for _, pdoc in docs_to_embed]
  2230. batches = [
  2231. contents[i : i + self._max_batch_size]
  2232. for i in range(0, len(contents), self._max_batch_size)
  2233. ]
  2234. logger.info(
  2235. f"[{self.workspace}] {self.namespace} flush: embedding "
  2236. f"{len(docs_to_embed)} vectors in {len(batches)} batch(es) "
  2237. f"(batch_num={self._max_batch_size})"
  2238. )
  2239. try:
  2240. embeddings_list = await asyncio.gather(
  2241. *[
  2242. self.embedding_func(batch, context="document")
  2243. for batch in batches
  2244. ]
  2245. )
  2246. except Exception as e:
  2247. logger.error(
  2248. f"[{self.workspace}] Error embedding pending vector ops "
  2249. f"(upserts={len(docs_to_embed)}): {e}"
  2250. )
  2251. raise
  2252. embeddings = np.concatenate(embeddings_list)
  2253. if len(embeddings) != len(docs_to_embed):
  2254. raise RuntimeError(
  2255. f"[{self.workspace}] Embedding count mismatch: expected "
  2256. f"{len(docs_to_embed)}, got {len(embeddings)}"
  2257. )
  2258. for i, ((_, pdoc), embedding) in enumerate(
  2259. zip(docs_to_embed, embeddings), start=1
  2260. ):
  2261. pdoc.vector = np.array(embedding, dtype=np.float32).tolist()
  2262. await _cooperative_yield(i)
  2263. # Build the bulk_write op list.
  2264. ops: list[Any] = []
  2265. committed_ids: list[str] = []
  2266. for doc_id, pdoc in pending_docs.items():
  2267. if pdoc.vector is None:
  2268. continue
  2269. committed_ids.append(doc_id)
  2270. full_doc = {**pdoc.source, "vector": pdoc.vector}
  2271. ops.append(UpdateOne({"_id": doc_id}, {"$set": full_doc}, upsert=True))
  2272. for doc_id in pending_deletes:
  2273. ops.append(DeleteOne({"_id": doc_id}))
  2274. if not ops:
  2275. return
  2276. try:
  2277. await self._data.bulk_write(ops, ordered=False)
  2278. except Exception as e:
  2279. logger.error(
  2280. f"[{self.workspace}] Error flushing vector ops "
  2281. f"(upserts={len(pending_docs)}, "
  2282. f"deletes={len(pending_deletes)}): {e}"
  2283. )
  2284. raise
  2285. # On success, clear the buffers in-place so external references
  2286. # (e.g. drop()) see the cleared state.
  2287. for doc_id in committed_ids:
  2288. pending_docs.pop(doc_id, None)
  2289. pending_deletes.clear()
  2290. async def delete(self, ids: list[str]) -> None:
  2291. """Buffer vector deletes for batched flush."""
  2292. if not ids:
  2293. return
  2294. if isinstance(ids, set):
  2295. ids = list(ids)
  2296. async with self._flush_lock:
  2297. for doc_id in ids:
  2298. self._pending_vector_docs.pop(doc_id, None)
  2299. self._pending_vector_deletes.add(doc_id)
  2300. logger.debug(
  2301. f"[{self.workspace}] Buffered delete for {len(ids)} vectors in {self.namespace}"
  2302. )
  2303. async def delete_entity(self, entity_name: str) -> None:
  2304. """Buffer an entity vector delete by computing its hash ID."""
  2305. entity_id = compute_mdhash_id(entity_name, prefix="ent-")
  2306. async with self._flush_lock:
  2307. self._pending_vector_docs.pop(entity_id, None)
  2308. self._pending_vector_deletes.add(entity_id)
  2309. logger.debug(
  2310. f"[{self.workspace}] Buffered delete for entity {entity_name} (id={entity_id})"
  2311. )
  2312. async def delete_entity_relation(self, entity_name: str) -> None:
  2313. """Delete all relation vectors where entity appears as src or tgt.
  2314. The whole method runs under ``_flush_lock`` so the server-side find
  2315. + delete cannot interleave with an in-flight bulk write. Server-side
  2316. failures are re-raised (no log-and-swallow): the caller decides
  2317. whether to retry.
  2318. Buffer semantics — post-prune with caller short-circuit contract:
  2319. Matching pending upserts in ``_pending_vector_docs`` are
  2320. pruned **only after** the server-side ``delete_many``
  2321. succeeds. On failure the pending buffer stays intact and
  2322. the exception propagates so the caller (``adelete_by_entity``
  2323. in ``utils_graph.py``) can short-circuit before
  2324. ``_persist_graph_updates`` flushes a half-cleaned buffer.
  2325. """
  2326. def _prune_pending() -> None:
  2327. for doc_id in [
  2328. k
  2329. for k, v in self._pending_vector_docs.items()
  2330. if v.source.get("src_id") == entity_name
  2331. or v.source.get("tgt_id") == entity_name
  2332. ]:
  2333. self._pending_vector_docs.pop(doc_id, None)
  2334. async with self._flush_lock:
  2335. if self._data is None:
  2336. # No server state to mutate; buffer prune is the only
  2337. # delete intent we can record.
  2338. _prune_pending()
  2339. return
  2340. # _id is the only field we need from the find; project to keep
  2341. # the cursor light.
  2342. relations_cursor = self._data.find(
  2343. {"$or": [{"src_id": entity_name}, {"tgt_id": entity_name}]},
  2344. {"_id": 1},
  2345. )
  2346. relations = await relations_cursor.to_list(length=None)
  2347. if not relations:
  2348. # No server rows to delete — still safe to prune any
  2349. # pending upserts so they can't re-create the relation.
  2350. _prune_pending()
  2351. logger.debug(
  2352. f"[{self.workspace}] No relations found for entity {entity_name}"
  2353. )
  2354. return
  2355. relation_ids = [relation["_id"] for relation in relations]
  2356. await self._data.delete_many({"_id": {"$in": relation_ids}})
  2357. # Server-side delete succeeded — safe to prune the pending
  2358. # buffer so subsequent flushes don't re-upsert the deleted
  2359. # relations.
  2360. _prune_pending()
  2361. logger.debug(
  2362. f"[{self.workspace}] Deleted {len(relation_ids)} relations for {entity_name}"
  2363. )
  2364. async def get_by_id(self, id: str) -> dict[str, Any] | None:
  2365. """Get vector data by its ID, with read-your-writes against the buffer.
  2366. Pending buffer hits never include the `vector` field; server-side
  2367. fallback projects it out for parity.
  2368. """
  2369. async with self._flush_lock:
  2370. if id in self._pending_vector_deletes:
  2371. return None
  2372. pending = self._pending_vector_docs.get(id)
  2373. if pending is not None:
  2374. doc = dict(pending.source)
  2375. # Surface both _id (Mongo native) and id (API expectation).
  2376. doc.setdefault("_id", id)
  2377. doc["id"] = id
  2378. return doc
  2379. try:
  2380. result = await self._data.find_one({"_id": id}, {"vector": 0})
  2381. if result:
  2382. result_dict = dict(result)
  2383. if "_id" in result_dict and "id" not in result_dict:
  2384. result_dict["id"] = result_dict["_id"]
  2385. return result_dict
  2386. return None
  2387. except Exception as e:
  2388. logger.error(
  2389. f"[{self.workspace}] Error retrieving vector data for ID {id}: {e}"
  2390. )
  2391. return None
  2392. async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
  2393. """Get multiple vector data by their IDs (read-your-writes), preserving order."""
  2394. if not ids:
  2395. return []
  2396. buffered: dict[str, dict[str, Any] | None] = {}
  2397. remaining: list[str] = []
  2398. async with self._flush_lock:
  2399. for doc_id in ids:
  2400. if doc_id in self._pending_vector_deletes:
  2401. buffered[doc_id] = None
  2402. continue
  2403. pending = self._pending_vector_docs.get(doc_id)
  2404. if pending is not None:
  2405. doc = dict(pending.source)
  2406. doc.setdefault("_id", doc_id)
  2407. doc["id"] = doc_id
  2408. buffered[doc_id] = doc
  2409. continue
  2410. remaining.append(doc_id)
  2411. formatted_map: dict[str, dict[str, Any]] = {}
  2412. if remaining:
  2413. try:
  2414. cursor = self._data.find({"_id": {"$in": remaining}}, {"vector": 0})
  2415. results = await cursor.to_list(length=None)
  2416. for result in results:
  2417. result_dict = dict(result)
  2418. if "_id" in result_dict and "id" not in result_dict:
  2419. result_dict["id"] = result_dict["_id"]
  2420. key = str(result_dict.get("id", result_dict.get("_id")))
  2421. formatted_map[key] = result_dict
  2422. except Exception as e:
  2423. logger.error(
  2424. f"[{self.workspace}] Error retrieving vector data for IDs {remaining}: {e}"
  2425. )
  2426. return []
  2427. return [
  2428. buffered[doc_id] if doc_id in buffered else formatted_map.get(str(doc_id))
  2429. for doc_id in ids
  2430. ]
  2431. async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
  2432. """Get vector embeddings for given IDs, with read-your-writes.
  2433. Pending docs whose vector hasn't been embedded yet are embedded
  2434. lazily inside the lock; the resulting vector is cached on the
  2435. buffered `_PendingVectorDoc` so the next flush won't re-embed.
  2436. Visibility caveat for ids not in the buffer: the server-side
  2437. ``find`` fallback runs *outside* ``_flush_lock``. A concurrent
  2438. ``delete()`` that lands between lock release and the cursor
  2439. read only buffers the delete -- the old vector is still on disk
  2440. until the next flush, so this method may return a stale vector
  2441. for an id that has been buffered for deletion. This is
  2442. best-effort read-after-uncommitted-delete and matches the
  2443. ``query()`` contract: callers needing strict consistency must
  2444. ``index_done_callback()`` first.
  2445. """
  2446. if not ids:
  2447. return {}
  2448. result: dict[str, list[float]] = {}
  2449. remaining: list[str] = []
  2450. async with self._flush_lock:
  2451. docs_to_embed: list[tuple[str, _PendingVectorDoc]] = []
  2452. for doc_id in ids:
  2453. if doc_id in self._pending_vector_deletes:
  2454. continue
  2455. pending = self._pending_vector_docs.get(doc_id)
  2456. if pending is not None:
  2457. if pending.vector is None:
  2458. docs_to_embed.append((doc_id, pending))
  2459. else:
  2460. result[doc_id] = pending.vector
  2461. continue
  2462. remaining.append(doc_id)
  2463. if docs_to_embed:
  2464. contents = [pdoc.content for _, pdoc in docs_to_embed]
  2465. batches = [
  2466. contents[i : i + self._max_batch_size]
  2467. for i in range(0, len(contents), self._max_batch_size)
  2468. ]
  2469. try:
  2470. embeddings_list = await asyncio.gather(
  2471. *[
  2472. self.embedding_func(batch, context="document")
  2473. for batch in batches
  2474. ]
  2475. )
  2476. except Exception as e:
  2477. logger.error(
  2478. f"[{self.workspace}] Error lazily embedding pending vectors "
  2479. f"(upserts={len(docs_to_embed)}): {e}"
  2480. )
  2481. raise
  2482. embeddings = np.concatenate(embeddings_list)
  2483. if len(embeddings) != len(docs_to_embed):
  2484. raise RuntimeError(
  2485. f"[{self.workspace}] Embedding count mismatch: expected "
  2486. f"{len(docs_to_embed)}, got {len(embeddings)}"
  2487. )
  2488. for i, ((doc_id, pdoc), embedding) in enumerate(
  2489. zip(docs_to_embed, embeddings), start=1
  2490. ):
  2491. pdoc.vector = np.array(embedding, dtype=np.float32).tolist()
  2492. result[doc_id] = pdoc.vector
  2493. await _cooperative_yield(i)
  2494. if not remaining:
  2495. return result
  2496. try:
  2497. cursor = self._data.find(
  2498. {"_id": {"$in": remaining}}, {"_id": 1, "vector": 1}
  2499. )
  2500. results = await cursor.to_list(length=None)
  2501. for row in results:
  2502. if row and "vector" in row and "_id" in row:
  2503. result[row["_id"]] = row["vector"]
  2504. return result
  2505. except PyMongoError as e:
  2506. logger.error(f"[{self.workspace}] Error getting vectors: {e}")
  2507. return result
  2508. async def drop(self) -> dict[str, str]:
  2509. """Drop all documents and recreate the vector index. Destructive.
  2510. MUST only be called when ``pipeline_status`` is idle (see the
  2511. Pipeline concurrency contract in ``AGENTS.md``); the only
  2512. in-tree caller ``clear_documents`` enforces this.
  2513. Caveat — only this instance's buffers are cleared. Other
  2514. ``MongoVectorDBStorage`` instances aliased onto the same
  2515. ``final_namespace`` (multi-worker processes, or distinct
  2516. workspaces collapsed by ``MONGODB_WORKSPACE``) keep their own
  2517. buffers; a sibling whose prior flush failed and left buffers
  2518. intact will, on its next flush, bulk-write those stale rows into
  2519. the freshly recreated collection. Direct callers bypassing the
  2520. idle precondition MUST flush every aliased instance first.
  2521. Returns:
  2522. dict[str, str]: ``{"status": "success"|"error", "message": str}``
  2523. """
  2524. try:
  2525. async with self._flush_lock:
  2526. # Discard any buffered writes before the collection is wiped;
  2527. # a concurrent flush would otherwise resurrect them.
  2528. self._pending_vector_docs.clear()
  2529. self._pending_vector_deletes.clear()
  2530. # Delete all documents
  2531. result = await self._data.delete_many({})
  2532. deleted_count = result.deleted_count
  2533. # Recreate vector index
  2534. await self.create_vector_index_if_not_exists()
  2535. logger.info(
  2536. f"[{self.workspace}] Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index"
  2537. )
  2538. return {
  2539. "status": "success",
  2540. "message": f"{deleted_count} documents dropped and vector index recreated",
  2541. }
  2542. except PyMongoError as e:
  2543. logger.error(
  2544. f"[{self.workspace}] Error dropping vector storage {self._collection_name}: {e}"
  2545. )
  2546. return {"status": "error", "message": str(e)}
  2547. async def get_or_create_collection(db: AsyncDatabase, collection_name: str):
  2548. collection_names = await db.list_collection_names()
  2549. if collection_name not in collection_names:
  2550. collection = await db.create_collection(collection_name)
  2551. logger.info(f"Created collection: {collection_name}")
  2552. return collection
  2553. else:
  2554. logger.debug(f"Collection '{collection_name}' already exists.")
  2555. return db.get_collection(collection_name)