opensearch_impl.py 158 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917
  1. """
  2. OpenSearch Storage Implementation for LightRAG
  3. This module provides OpenSearch-based storage backends for LightRAG,
  4. including KV storage, document status storage, graph storage, and vector storage.
  5. Requirements:
  6. - opensearch-py >= 3.0.0
  7. - OpenSearch 3.x or higher with k-NN plugin enabled
  8. """
  9. import os
  10. import re
  11. import ssl as ssl_module
  12. import time
  13. import asyncio
  14. from dataclasses import dataclass, field
  15. from typing import Any, AsyncIterator, Union, final
  16. import numpy as np
  17. import configparser
  18. from ..base import (
  19. BaseGraphStorage,
  20. BaseKVStorage,
  21. BaseVectorStorage,
  22. DocProcessingStatus,
  23. DocStatus,
  24. DocStatusStorage,
  25. )
  26. from ..utils import logger, compute_mdhash_id, _cooperative_yield
  27. from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
  28. from ..constants import GRAPH_FIELD_SEP
  29. from ..kg.shared_storage import get_data_init_lock, get_namespace_lock
  30. import pipmaster as pm
  31. if not pm.is_installed("opensearch-py"):
  32. pm.install("opensearch-py")
  33. from opensearchpy import AsyncOpenSearch, helpers # type: ignore
  34. from opensearchpy.exceptions import OpenSearchException, NotFoundError, RequestError # type: ignore
  35. config = configparser.ConfigParser()
  36. config.read("config.ini", "utf-8")
  37. def _get_opensearch_env(key, fallback):
  38. cfg_key = key.replace("OPENSEARCH_", "").lower()
  39. return os.environ.get(key, config.get("opensearch", cfg_key, fallback=fallback))
  40. def _get_index_number_of_shards() -> int:
  41. return int(_get_opensearch_env("OPENSEARCH_NUMBER_OF_SHARDS", "1"))
  42. def _get_index_number_of_replicas() -> int:
  43. return int(_get_opensearch_env("OPENSEARCH_NUMBER_OF_REPLICAS", "0"))
  44. def _sanitize_index_name(name: str) -> str:
  45. """Sanitize a string to be a valid OpenSearch index name."""
  46. sanitized = re.sub(r"[^a-z0-9_-]", "_", name.lower())
  47. if sanitized and sanitized[0] in "-_+":
  48. sanitized = "x" + sanitized
  49. return sanitized
  50. # HTTP statuses that indicate a transient failure where retrying makes sense:
  51. # request timeout, rate limit, and the standard 5xx server-error range.
  52. # A missing status (None) typically means a network or parse error before the
  53. # server responded, which is also retriable.
  54. _RETRYABLE_BULK_STATUSES: frozenset[int] = frozenset({408, 429, 500, 502, 503, 504})
  55. # Cap the length of error summaries dumped to logs so a multi-MB mapping
  56. # explanation can't flood the log file.
  57. _BULK_ERROR_SUMMARY_MAX_LEN = 200
  58. @dataclass(frozen=True)
  59. class _FailedBulkOp:
  60. """Structured representation of a non-retryable per-action bulk failure."""
  61. op: str
  62. doc_id: str
  63. status: int | None
  64. error: str
  65. @dataclass
  66. class _PendingVectorDoc:
  67. """Buffered vector upsert waiting for embedding and/or bulk flush."""
  68. source: dict[str, Any]
  69. content: str
  70. vector: list[float] | None = None
  71. def _summarize_bulk_error(error: Any) -> str:
  72. """Turn an opensearch-py per-action ``error`` payload into a short string.
  73. The field may be a string, dict (``{"type": ..., "reason": ...}``) or
  74. something else entirely. We prefer ``reason`` / ``type`` from dicts to
  75. keep the log readable.
  76. """
  77. if error is None:
  78. return ""
  79. if isinstance(error, str):
  80. summary = error
  81. elif isinstance(error, dict):
  82. reason = error.get("reason") or error.get("type")
  83. summary = reason if isinstance(reason, str) else repr(error)
  84. else:
  85. summary = repr(error)
  86. if len(summary) > _BULK_ERROR_SUMMARY_MAX_LEN:
  87. summary = summary[: _BULK_ERROR_SUMMARY_MAX_LEN - 3] + "..."
  88. return summary
  89. def _extract_bulk_failed_ids(
  90. failed: list[Any] | None,
  91. ) -> tuple[set[str], list[_FailedBulkOp]]:
  92. """Split an opensearch-py bulk ``failed`` list into retryable / dead ops.
  93. ``async_bulk(raise_on_error=False)`` returns ``(success, failed)`` where
  94. ``failed`` is a list of per-action error dicts shaped like::
  95. {"index": {"_id": "...", "status": 500, "error": {...}}}
  96. {"delete": {"_id": "...", "status": 404, ...}}
  97. {"create": {"_id": "...", "status": 409, ...}}
  98. Returns ``(retryable, non_retryable)``:
  99. * ``retryable`` — ``set[str]`` of ids that should be retried on
  100. the next flush (408 / 429 / 5xx, plus a missing status which
  101. usually means a network-level failure before the server responded).
  102. * ``non_retryable`` — ``list[_FailedBulkOp]`` of permanent failures
  103. (most 4xx, mapping errors, etc.) carrying op-name, id, status and
  104. a short ``error`` summary so callers can log meaningful context.
  105. ``404`` on a delete is treated as success-equivalent and dropped
  106. from both sets.
  107. Unrecognised or malformed entries are skipped so a stray dict shape
  108. never crashes the flush path.
  109. """
  110. retryable: set[str] = set()
  111. non_retryable: list[_FailedBulkOp] = []
  112. if not failed:
  113. return retryable, non_retryable
  114. for entry in failed:
  115. if not isinstance(entry, dict):
  116. continue
  117. for op_name, op_payload in entry.items():
  118. if not isinstance(op_payload, dict):
  119. continue
  120. doc_id = op_payload.get("_id")
  121. if not isinstance(doc_id, str):
  122. continue
  123. status = op_payload.get("status")
  124. # Deleting a missing doc is not a real failure -- the row is
  125. # already gone, so we don't carry it forward on every flush.
  126. if op_name == "delete" and status == 404:
  127. continue
  128. if status is None or status in _RETRYABLE_BULK_STATUSES:
  129. retryable.add(doc_id)
  130. else:
  131. non_retryable.append(
  132. _FailedBulkOp(
  133. op=op_name,
  134. doc_id=doc_id,
  135. status=status if isinstance(status, int) else None,
  136. error=_summarize_bulk_error(op_payload.get("error")),
  137. )
  138. )
  139. return retryable, non_retryable
  140. # Detected at first connection; True when OpenSearch >= 3.3.0.
  141. _shard_doc_supported: bool | None = None
  142. def _pit_sort_with_field(field: str) -> list[dict]:
  143. """Return PIT sort clause with a unique field as primary sort.
  144. Used purely as a pagination tiebreaker — order is fixed to asc since the
  145. business sort (when present) is applied separately by the caller.
  146. >= 3.3.0: _shard_doc only (most efficient, already unique within PIT).
  147. < 3.3.0: field + _doc (field is unique, _doc for efficiency).
  148. """
  149. if _shard_doc_supported:
  150. return [{"_shard_doc": "asc"}]
  151. return [{field: {"order": "asc"}}, {"_doc": "asc"}]
  152. def _pit_sort_with_composite_key(*fields: str) -> list[dict]:
  153. """Return PIT sort clause with multiple fields forming a composite unique key.
  154. >= 3.3.0: _shard_doc (most efficient, ignores the fields).
  155. < 3.3.0: field1 + field2 + ... + _doc (composite is unique, _doc for efficiency).
  156. """
  157. if _shard_doc_supported:
  158. return [{"_shard_doc": "asc"}]
  159. return [{f: {"order": "asc"}} for f in fields] + [{"_doc": "asc"}]
  160. async def _detect_shard_doc_support(client: AsyncOpenSearch) -> bool:
  161. """Check if the cluster supports _shard_doc (OpenSearch >= 3.3.0)."""
  162. try:
  163. info = await client.info()
  164. version_str = info.get("version", {}).get("number", "0.0.0")
  165. # Strip pre-release suffixes (e.g. "3.3.0-SNAPSHOT" → "3", "3", "0")
  166. parts = [p.split("-")[0] for p in version_str.split(".")]
  167. major = int(parts[0]) if parts[0].isdigit() else 0
  168. minor = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0
  169. supported = (major > 3) or (major == 3 and minor >= 3)
  170. logger.info(
  171. f"OpenSearch version {version_str}: "
  172. f"_shard_doc {'supported' if supported else 'not supported, using field+_doc fallback'}"
  173. )
  174. return supported
  175. except Exception as e:
  176. logger.warning(
  177. f"Failed to detect OpenSearch version, assuming _shard_doc not supported: {e}"
  178. )
  179. return False
  180. class ClientManager:
  181. """Singleton manager for OpenSearch client connections."""
  182. _instances = {"client": None, "ref_count": 0}
  183. _lock = asyncio.Lock()
  184. @classmethod
  185. async def get_client(cls) -> AsyncOpenSearch:
  186. """Get or create a shared AsyncOpenSearch client with reference counting."""
  187. global _shard_doc_supported
  188. async with cls._lock:
  189. if cls._instances["client"] is None:
  190. hosts_str = _get_opensearch_env("OPENSEARCH_HOSTS", "localhost:9200")
  191. hosts = [h.strip() for h in hosts_str.split(",") if h.strip()]
  192. username = _get_opensearch_env("OPENSEARCH_USER", "admin")
  193. password = _get_opensearch_env("OPENSEARCH_PASSWORD", "admin")
  194. use_ssl = _get_opensearch_env("OPENSEARCH_USE_SSL", "true").lower() in (
  195. "true",
  196. "1",
  197. "yes",
  198. )
  199. verify_certs = _get_opensearch_env(
  200. "OPENSEARCH_VERIFY_CERTS", "false"
  201. ).lower() in ("true", "1", "yes")
  202. timeout = int(_get_opensearch_env("OPENSEARCH_TIMEOUT", "30"))
  203. max_retries = int(_get_opensearch_env("OPENSEARCH_MAX_RETRIES", "3"))
  204. ssl_context = None
  205. if use_ssl and not verify_certs:
  206. ssl_context = ssl_module.create_default_context()
  207. ssl_context.check_hostname = False
  208. ssl_context.verify_mode = ssl_module.CERT_NONE
  209. client = AsyncOpenSearch(
  210. hosts=hosts,
  211. http_auth=(username, password) if username else None,
  212. use_ssl=use_ssl,
  213. verify_certs=verify_certs,
  214. ssl_context=ssl_context,
  215. ssl_show_warn=False,
  216. timeout=timeout,
  217. max_retries=max_retries,
  218. retry_on_timeout=True,
  219. )
  220. cls._instances["client"] = client
  221. cls._instances["ref_count"] = 0
  222. _shard_doc_supported = await _detect_shard_doc_support(client)
  223. logger.info(f"OpenSearch client connected to {hosts}")
  224. cls._instances["ref_count"] += 1
  225. return cls._instances["client"]
  226. @classmethod
  227. async def release_client(cls, client: AsyncOpenSearch):
  228. """Release a client reference. Closes the connection when ref count reaches 0."""
  229. global _shard_doc_supported
  230. async with cls._lock:
  231. if client is not None and client is cls._instances["client"]:
  232. cls._instances["ref_count"] -= 1
  233. if cls._instances["ref_count"] <= 0:
  234. try:
  235. await cls._instances["client"].close()
  236. except Exception:
  237. pass
  238. cls._instances["client"] = None
  239. cls._instances["ref_count"] = 0
  240. _shard_doc_supported = None
  241. logger.info("OpenSearch client connection closed")
  242. def _resolve_workspace(workspace: str, namespace: str):
  243. """Resolve effective workspace from env or parameter."""
  244. opensearch_workspace = os.environ.get("OPENSEARCH_WORKSPACE")
  245. if opensearch_workspace and opensearch_workspace.strip():
  246. effective = opensearch_workspace.strip()
  247. logger.info(
  248. f"Using OPENSEARCH_WORKSPACE: '{effective}' (overriding '{workspace}/{namespace}')"
  249. )
  250. return effective
  251. return workspace
  252. def _build_index_name(workspace: str, namespace: str) -> tuple[str, str, str]:
  253. """Build index name and return (effective_workspace, final_namespace, index_name)."""
  254. effective = _resolve_workspace(workspace, namespace)
  255. if effective:
  256. final_ns = f"{effective}_{namespace}"
  257. else:
  258. final_ns = namespace
  259. effective = ""
  260. index_name = _sanitize_index_name(final_ns)
  261. return effective, final_ns, index_name
  262. async def _mget_optional_doc(
  263. client: AsyncOpenSearch,
  264. index_name: str,
  265. doc_id: str,
  266. source_excludes: list[str] | None = None,
  267. ) -> dict[str, Any] | None:
  268. """Fetch a single document via mget and return None when it is absent.
  269. ``source_excludes`` is forwarded to OpenSearch's ``_source_excludes`` so
  270. callers can ask the server to omit specific fields (e.g. ``["vector"]``)
  271. and save network bandwidth.
  272. """
  273. kwargs: dict[str, Any] = {"index": index_name, "body": {"ids": [doc_id]}}
  274. if source_excludes:
  275. kwargs["_source_excludes"] = source_excludes
  276. response = await client.mget(**kwargs)
  277. docs = response.get("docs", [])
  278. if not docs:
  279. return None
  280. doc = docs[0]
  281. if not doc.get("found"):
  282. return None
  283. return doc
  284. def _is_missing_index_error(exc: Exception) -> bool:
  285. """Return True when an OpenSearch exception means the target index is missing."""
  286. return "index_not_found_exception" in str(exc)
  287. async def _verify_mirrored_id_mapping(client: AsyncOpenSearch, index_name: str) -> None:
  288. """Fail-fast when an existing index lacks the __mirrored_id keyword mapping.
  289. Only enforced on OpenSearch < 3.3.0, where __mirrored_id serves as the
  290. cross-shard pagination tiebreaker. Indices created by older LightRAG
  291. releases will be missing this mapping; sorting by a missing field on a
  292. multi-shard index can drop or duplicate documents during PIT pagination.
  293. """
  294. if _shard_doc_supported:
  295. return
  296. try:
  297. mapping = await client.indices.get_mapping(index=index_name)
  298. except OpenSearchException:
  299. return
  300. props = mapping.get(index_name, {}).get("mappings", {}).get("properties", {})
  301. if "__mirrored_id" not in props:
  302. raise RuntimeError(
  303. f"Index '{index_name}' lacks the '__mirrored_id' keyword mapping "
  304. f"required for stable PIT pagination on OpenSearch < 3.3.0. "
  305. f"This index was likely created by an older LightRAG release. "
  306. f"Please reindex the data, or upgrade the cluster to OpenSearch >= 3.3.0."
  307. )
  308. @final
  309. @dataclass
  310. class OpenSearchKVStorage(BaseKVStorage):
  311. """Key-Value storage using OpenSearch. Uses dynamic mapping to support varied schemas."""
  312. client: AsyncOpenSearch = field(default=None)
  313. _index_name: str = field(default="", init=False)
  314. _index_ready: bool = field(default=False, init=False)
  315. def __init__(self, namespace, global_config, embedding_func, workspace=None):
  316. super().__init__(
  317. namespace=namespace,
  318. workspace=workspace or "",
  319. global_config=global_config,
  320. embedding_func=embedding_func,
  321. )
  322. self.__post_init__()
  323. def __post_init__(self):
  324. self.workspace, self.final_namespace, self._index_name = _build_index_name(
  325. self.workspace, self.namespace
  326. )
  327. # Pending writes are flushed via _flush_pending_kv_ops() during
  328. # index_done_callback() / finalize(). Buffering many small upsert()
  329. # invocations into a single async_bulk roundtrip avoids the per-call
  330. # HTTP overhead profiled in issue #2785; the lock-everywhere model
  331. # mirrors what #3043 introduced for OpenSearchVectorDBStorage.
  332. self._pending_upserts: dict[str, dict[str, Any]] = {}
  333. self._pending_kv_deletes: set[str] = set()
  334. # Namespace-keyed lock (multi-process aware) is assigned in
  335. # initialize(). All buffer reads / writes and the flush itself
  336. # acquire this lock so an in-flight flush cannot interleave with
  337. # concurrent get_by_id / upsert / delete on the same workspace.
  338. self._flush_lock = None
  339. async def initialize(self):
  340. """Initialize client connection and create index if needed."""
  341. async with get_data_init_lock():
  342. if self.client is None:
  343. self.client = await ClientManager.get_client()
  344. await self._create_index_if_not_exists()
  345. self._index_ready = True
  346. logger.debug(
  347. f"[{self.workspace}] OpenSearch KV storage initialized: {self._index_name}"
  348. )
  349. if self._flush_lock is None:
  350. self._flush_lock = get_namespace_lock(
  351. self.namespace, workspace=self.workspace
  352. )
  353. async def _ensure_index_ready(self):
  354. """Recreate the KV index after drop before the next write."""
  355. if self._index_ready:
  356. return
  357. async with get_data_init_lock():
  358. if self.client is None:
  359. self.client = await ClientManager.get_client()
  360. if not self._index_ready:
  361. await self._create_index_if_not_exists()
  362. self._index_ready = True
  363. def _mark_index_missing(self):
  364. """Mark the KV index as unavailable for subsequent read short-circuiting."""
  365. self._index_ready = False
  366. async def _create_index_if_not_exists(self):
  367. try:
  368. if not await self.client.indices.exists(index=self._index_name):
  369. # Use dynamic mapping so any namespace schema works
  370. body = {
  371. "mappings": {
  372. "dynamic": True,
  373. "properties": {
  374. "__mirrored_id": {"type": "keyword"},
  375. },
  376. },
  377. "settings": {
  378. "index": {
  379. "number_of_shards": _get_index_number_of_shards(),
  380. "number_of_replicas": _get_index_number_of_replicas(),
  381. },
  382. },
  383. }
  384. await self.client.indices.create(index=self._index_name, body=body)
  385. logger.info(f"[{self.workspace}] Created index: {self._index_name}")
  386. else:
  387. await _verify_mirrored_id_mapping(self.client, self._index_name)
  388. except RequestError as e:
  389. if "resource_already_exists_exception" not in str(e):
  390. raise
  391. except OpenSearchException as e:
  392. logger.error(f"[{self.workspace}] Error creating index: {e}")
  393. raise
  394. async def finalize(self):
  395. """Flush pending writes and release the OpenSearch client connection.
  396. Regular flush failures (any ``Exception``) are captured so they
  397. can be re-surfaced as a ``RuntimeError`` that names the unflushed
  398. buffer counts -- otherwise ``LightRAG.finalize_storages()`` would
  399. log the storage as successfully finalized while writes silently
  400. failed to reach OpenSearch.
  401. ``BaseException`` subclasses other than ``Exception`` (notably
  402. ``asyncio.CancelledError`` / ``KeyboardInterrupt`` / ``SystemExit``)
  403. are NOT caught: they propagate through the ``finally`` block so
  404. shutdown cancellation is honoured and not silently swallowed.
  405. The client is released in ``finally`` so it does not leak whether
  406. the flush succeeded, failed, or was cancelled.
  407. """
  408. flush_error: Exception | None = None
  409. try:
  410. try:
  411. await self._flush_pending_kv_ops()
  412. except Exception as e:
  413. # _flush_pending_kv_ops leaves the buffers intact on raise.
  414. flush_error = e
  415. finally:
  416. if self.client is not None:
  417. await ClientManager.release_client(self.client)
  418. self.client = None
  419. # Reached only when no BaseException propagated through the
  420. # finally above. Snapshot remaining buffer state to report
  421. # concrete counts.
  422. pending_upserts = len(self._pending_upserts)
  423. pending_deletes = len(self._pending_kv_deletes)
  424. if flush_error is not None:
  425. raise RuntimeError(
  426. f"[{self.workspace}] OpenSearchKVStorage.finalize() flush "
  427. f"raised; {pending_upserts} pending upserts and "
  428. f"{pending_deletes} pending deletes were left buffered "
  429. f"(client released, data lost)"
  430. ) from flush_error
  431. if pending_upserts or pending_deletes:
  432. raise RuntimeError(
  433. f"[{self.workspace}] OpenSearchKVStorage.finalize() left "
  434. f"{pending_upserts} pending upserts and {pending_deletes} "
  435. f"pending deletes buffered after final flush attempt "
  436. f"(transient bulk failure); these writes have been lost"
  437. )
  438. async def _iter_raw_docs(
  439. self, batch_size: int = 1000
  440. ) -> AsyncIterator[list[dict[str, Any]]]:
  441. """Yield raw OpenSearch hits using PIT + search_after pagination."""
  442. if not self._index_ready:
  443. return
  444. try:
  445. pit = await self.client.create_pit(
  446. index=self._index_name, params={"keep_alive": "1m"}
  447. )
  448. pit_id = pit["pit_id"]
  449. try:
  450. search_after = None
  451. while True:
  452. body = {
  453. "query": {"match_all": {}},
  454. "size": batch_size,
  455. "pit": {"id": pit_id, "keep_alive": "1m"},
  456. "sort": _pit_sort_with_field("__mirrored_id"),
  457. }
  458. if search_after:
  459. body["search_after"] = search_after
  460. response = await self.client.search(body=body)
  461. hits = response["hits"]["hits"]
  462. if not hits:
  463. break
  464. yield hits
  465. search_after = hits[-1]["sort"]
  466. if len(hits) < batch_size:
  467. break
  468. finally:
  469. try:
  470. await self.client.delete_pit(body={"pit_id": [pit_id]})
  471. except Exception:
  472. pass
  473. except OpenSearchException as e:
  474. if _is_missing_index_error(e):
  475. self._mark_index_missing()
  476. return
  477. logger.error(f"[{self.workspace}] Error scanning documents: {e}")
  478. raise
  479. def _materialize_pending_kv_doc(
  480. self, doc_id: str, source: dict[str, Any]
  481. ) -> dict[str, Any]:
  482. """Return a get_by_id-shaped view of a buffered upsert.
  483. Mirrors the post-processing applied to mget hits: drops the
  484. ``__mirrored_id`` PIT sort key, attaches the ``_id`` field and
  485. ensures ``create_time`` / ``update_time`` defaults are populated.
  486. The buffer entry itself is not mutated.
  487. """
  488. doc = {k: v for k, v in source.items() if k != "__mirrored_id"}
  489. doc["_id"] = doc_id
  490. doc.setdefault("create_time", 0)
  491. doc.setdefault("update_time", 0)
  492. return doc
  493. async def get_by_id(self, id: str) -> dict[str, Any] | None:
  494. """Get a document by its ID, with read-your-writes against the buffer.
  495. Priority: pending delete (tombstone) → pending upsert (buffered
  496. write) → OpenSearch via mget. The buffered path strips
  497. ``__mirrored_id`` so the returned dict has the same shape as the
  498. mget path.
  499. """
  500. async with self._flush_lock:
  501. if id in self._pending_kv_deletes:
  502. return None
  503. pending = self._pending_upserts.get(id)
  504. if pending is not None:
  505. return self._materialize_pending_kv_doc(id, pending)
  506. if not self._index_ready:
  507. return None
  508. try:
  509. response = await _mget_optional_doc(self.client, self._index_name, id)
  510. if response is None:
  511. return None
  512. doc = response["_source"]
  513. doc.pop("__mirrored_id", None)
  514. doc["_id"] = response["_id"]
  515. doc.setdefault("create_time", 0)
  516. doc.setdefault("update_time", 0)
  517. return doc
  518. except OpenSearchException as e:
  519. if _is_missing_index_error(e):
  520. self._mark_index_missing()
  521. return None
  522. logger.error(f"[{self.workspace}] Error getting document {id}: {e}")
  523. return None
  524. async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
  525. """Get multiple documents by IDs (read-your-writes), preserving order.
  526. Buffer is consulted under the lock with the same three-tier
  527. priority as ``get_by_id``; remaining ids fall through to mget
  528. outside the lock so the network call does not stall the flush.
  529. """
  530. if not ids:
  531. return []
  532. buffered: dict[str, dict[str, Any] | None] = {}
  533. remaining: list[str] = []
  534. async with self._flush_lock:
  535. for doc_id in ids:
  536. if doc_id in self._pending_kv_deletes:
  537. buffered[doc_id] = None
  538. continue
  539. pending = self._pending_upserts.get(doc_id)
  540. if pending is not None:
  541. buffered[doc_id] = self._materialize_pending_kv_doc(doc_id, pending)
  542. continue
  543. remaining.append(doc_id)
  544. index_ready = self._index_ready
  545. doc_map: dict[str, dict[str, Any] | None] = {}
  546. if remaining and index_ready:
  547. try:
  548. response = await self.client.mget(
  549. index=self._index_name, body={"ids": remaining}
  550. )
  551. for doc in response["docs"]:
  552. if doc.get("found"):
  553. data = doc["_source"]
  554. data.pop("__mirrored_id", None)
  555. data["_id"] = doc["_id"]
  556. data.setdefault("create_time", 0)
  557. data.setdefault("update_time", 0)
  558. doc_map[doc["_id"]] = data
  559. except OpenSearchException as e:
  560. if _is_missing_index_error(e):
  561. self._mark_index_missing()
  562. else:
  563. logger.error(f"[{self.workspace}] Error getting documents: {e}")
  564. return [
  565. buffered[doc_id] if doc_id in buffered else doc_map.get(doc_id)
  566. for doc_id in ids
  567. ]
  568. async def filter_keys(self, keys: set[str]) -> set[str]:
  569. """Return the subset of keys that do not exist in storage.
  570. Buffer-aware: buffered upserts count as "exists" (and so are
  571. removed from the missing set), buffered deletes count as
  572. "missing" and are NOT queried via mget (a persisted-but-pending-
  573. delete row would otherwise be misclassified as existing).
  574. """
  575. async with self._flush_lock:
  576. pending_upserts = set(self._pending_upserts)
  577. pending_deletes = set(self._pending_kv_deletes)
  578. index_ready = self._index_ready
  579. # Buffered upserts shadow OpenSearch -- they will exist after flush.
  580. to_check = keys - pending_upserts - pending_deletes
  581. if not to_check:
  582. # All keys are accounted for by the buffer alone.
  583. return keys - pending_upserts
  584. if not index_ready:
  585. return keys - pending_upserts
  586. try:
  587. response = await self.client.mget(
  588. index=self._index_name,
  589. body={"ids": list(to_check)},
  590. _source=False,
  591. )
  592. existing_on_server = {
  593. doc["_id"] for doc in response["docs"] if doc.get("found")
  594. }
  595. return (keys - pending_upserts) - existing_on_server
  596. except OpenSearchException as e:
  597. if _is_missing_index_error(e):
  598. self._mark_index_missing()
  599. return keys - pending_upserts
  600. logger.error(f"[{self.workspace}] Error filtering keys: {e}")
  601. return keys - pending_upserts
  602. async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
  603. """Buffer documents for batched flush.
  604. Time-stamping and ``__mirrored_id`` injection happen eagerly so the
  605. persisted shape matches what reads expect; the actual ``async_bulk``
  606. call is deferred to ``_flush_pending_kv_ops()`` invoked from
  607. ``index_done_callback`` / ``finalize``.
  608. Multi-worker note: the buffer is process-local. Other workers will
  609. not see these writes until ``index_done_callback()`` flushes them.
  610. """
  611. if not data:
  612. return
  613. await self._ensure_index_ready()
  614. logger.debug(
  615. f"[{self.workspace}] Buffering {len(data)} documents for {self.namespace}"
  616. )
  617. current_time = int(time.time())
  618. # Construct sources outside the lock (no IO; just dict shuffling)
  619. # so we hold the lock only for the buffer-swap step.
  620. prepared: list[tuple[str, dict[str, Any]]] = []
  621. for i, (doc_id, doc_data) in enumerate(data.items(), start=1):
  622. doc_data["update_time"] = current_time
  623. doc_data.setdefault("create_time", current_time)
  624. source = {k: v for k, v in doc_data.items() if k != "_id"}
  625. source["__mirrored_id"] = doc_id
  626. prepared.append((doc_id, source))
  627. await _cooperative_yield(i)
  628. # Buffer: an upsert cancels any pending delete on the same id.
  629. async with self._flush_lock:
  630. for doc_id, source in prepared:
  631. self._pending_kv_deletes.discard(doc_id)
  632. self._pending_upserts[doc_id] = source
  633. async def delete(self, ids: list[str]) -> None:
  634. """Buffer document deletes for batched flush.
  635. A delete cancels any pending upsert on the same id; the actual
  636. bulk delete is performed by ``_flush_pending_kv_ops`` during the
  637. next ``index_done_callback`` / ``finalize`` call.
  638. ``_index_ready`` is intentionally NOT checked here: even if the
  639. index has been marked missing, the buffered upsert (if any) must
  640. still be invalidated, otherwise a subsequent flush would resurrect
  641. a logically-deleted key.
  642. """
  643. if not ids:
  644. return
  645. if isinstance(ids, set):
  646. ids = list(ids)
  647. async with self._flush_lock:
  648. for doc_id in ids:
  649. self._pending_upserts.pop(doc_id, None)
  650. self._pending_kv_deletes.add(doc_id)
  651. logger.debug(
  652. f"[{self.workspace}] Buffered delete for {len(ids)} documents in {self.namespace}"
  653. )
  654. async def _flush_pending_kv_ops(self) -> None:
  655. """Flush buffered upserts + deletes via a single async_bulk call.
  656. Concurrency contract: the entire flush runs under ``_flush_lock``;
  657. ``upsert`` / ``delete`` / reads / ``drop`` all acquire the same lock
  658. so an in-flight flush cannot interleave with concurrent buffer
  659. mutations.
  660. Failure handling mirrors the Vector-side helper:
  661. * If ``_ensure_index_ready`` raises, the buffers are left intact
  662. and the next flush retries.
  663. * If ``async_bulk`` raises, the buffers are left intact.
  664. * Per-doc retryable failures (408 / 429 / 5xx) stay in the buffer.
  665. * Per-doc non-retryable failures (most 4xx) are cleared and a
  666. sample is logged at WARNING with op / id / status / error.
  667. """
  668. async with self._flush_lock:
  669. if not self._pending_upserts and not self._pending_kv_deletes:
  670. return
  671. if self.client is None:
  672. return
  673. await self._ensure_index_ready()
  674. pending_upserts = self._pending_upserts
  675. pending_deletes = self._pending_kv_deletes
  676. # Deletes come first so that a delete followed (in time) by an
  677. # upsert on the same id is still observable as an index after
  678. # the bulk completes; we also dedupe via the set/dict already.
  679. actions: list[dict[str, Any]] = []
  680. for doc_id in pending_deletes:
  681. actions.append(
  682. {
  683. "_op_type": "delete",
  684. "_index": self._index_name,
  685. "_id": doc_id,
  686. }
  687. )
  688. for doc_id, source in pending_upserts.items():
  689. actions.append(
  690. {
  691. "_op_type": "index",
  692. "_index": self._index_name,
  693. "_id": doc_id,
  694. "_source": source,
  695. }
  696. )
  697. try:
  698. success, failed = await helpers.async_bulk(
  699. self.client, actions, raise_on_error=False
  700. )
  701. except OpenSearchException as e:
  702. logger.error(
  703. f"[{self.workspace}] Error flushing KV ops "
  704. f"(upserts={len(pending_upserts)}, "
  705. f"deletes={len(pending_deletes)}): {e}"
  706. )
  707. raise
  708. retryable_ids, non_retryable_ops = _extract_bulk_failed_ids(failed)
  709. non_retryable_ids = {op.doc_id for op in non_retryable_ops}
  710. # Clear successful + non-retryable entries; keep retryable ones.
  711. for doc_id in list(pending_upserts.keys()):
  712. if doc_id not in retryable_ids:
  713. pending_upserts.pop(doc_id, None)
  714. new_deletes: set[str] = set()
  715. for doc_id in pending_deletes:
  716. if doc_id in retryable_ids:
  717. new_deletes.add(doc_id)
  718. pending_deletes.clear()
  719. pending_deletes.update(new_deletes)
  720. if retryable_ids:
  721. logger.warning(
  722. f"[{self.workspace}] {len(retryable_ids)} KV ops will "
  723. f"retry on the next flush (transient failure)"
  724. )
  725. if non_retryable_ops:
  726. sample = non_retryable_ops[:5]
  727. sample_text = ", ".join(
  728. f"{op.op}/{op.doc_id}/status={op.status}/{op.error}"
  729. for op in sample
  730. )
  731. logger.warning(
  732. f"[{self.workspace}] {len(non_retryable_ops)} KV ops "
  733. f"failed permanently and were dropped (non-retryable status). "
  734. f"Sample: {sample_text}"
  735. )
  736. if len(non_retryable_ops) > len(sample):
  737. logger.debug(
  738. f"[{self.workspace}] Remaining permanent failures: "
  739. + ", ".join(
  740. f"{op.op}/{op.doc_id}/status={op.status}/{op.error}"
  741. for op in non_retryable_ops[len(sample) :]
  742. )
  743. )
  744. logger.debug(
  745. f"[{self.workspace}] Flushed KV ops: {success} ok, "
  746. f"retry={len(retryable_ids)}, dropped={len(non_retryable_ids)}"
  747. )
  748. async def index_done_callback(self) -> None:
  749. """Flush pending KV ops and refresh the index for search visibility.
  750. Flush runs first so a previously-missing index gets recreated by
  751. ``_flush_pending_kv_ops`` (via ``_ensure_index_ready``) before any
  752. buffered writes are abandoned. The refresh step is skipped only
  753. when the index is still not ready after the flush attempt.
  754. """
  755. await self._flush_pending_kv_ops()
  756. if not self._index_ready:
  757. return
  758. try:
  759. await self.client.indices.refresh(index=self._index_name)
  760. except OpenSearchException as e:
  761. if _is_missing_index_error(e):
  762. self._mark_index_missing()
  763. return
  764. except Exception:
  765. pass
  766. async def is_empty(self) -> bool:
  767. """Return True if the index (plus pending buffer) contains no docs.
  768. Buffer-aware: a pending upsert makes is_empty False immediately,
  769. avoiding the counterintuitive "I just upserted but is_empty
  770. returned True" case. Pending deletes alone are not enough to flip
  771. the answer because we cannot tell whether other persisted rows
  772. survive without flushing.
  773. """
  774. async with self._flush_lock:
  775. if self._pending_upserts:
  776. return False
  777. index_ready = self._index_ready
  778. if not index_ready:
  779. return True
  780. try:
  781. response = await self.client.count(index=self._index_name)
  782. return response["count"] == 0
  783. except OpenSearchException as e:
  784. if _is_missing_index_error(e):
  785. self._mark_index_missing()
  786. return True
  787. async def drop(self) -> dict[str, str]:
  788. """Delete the entire index, discarding pending buffers.
  789. Runs entirely under ``_flush_lock`` so a concurrent flush / upsert
  790. cannot land writes against an index that is being deleted.
  791. """
  792. async with self._flush_lock:
  793. # Pending writes are meaningless once the index is dropped.
  794. self._pending_upserts.clear()
  795. self._pending_kv_deletes.clear()
  796. try:
  797. try:
  798. await self.client.indices.delete(index=self._index_name)
  799. logger.info(f"[{self.workspace}] Dropped index: {self._index_name}")
  800. except NotFoundError:
  801. logger.info(
  802. f"[{self.workspace}] Index already missing during drop: {self._index_name}"
  803. )
  804. self._mark_index_missing()
  805. return {
  806. "status": "success",
  807. "message": f"Index {self._index_name} dropped",
  808. }
  809. except OpenSearchException as e:
  810. self._mark_index_missing()
  811. logger.error(f"[{self.workspace}] Error dropping index: {e}")
  812. return {"status": "error", "message": str(e)}
  813. except Exception as e:
  814. self._mark_index_missing()
  815. logger.error(f"[{self.workspace}] Unexpected error dropping index: {e}")
  816. return {"status": "error", "message": str(e)}
  817. @final
  818. @dataclass
  819. class OpenSearchDocStatusStorage(DocStatusStorage):
  820. """Document status storage using OpenSearch."""
  821. client: AsyncOpenSearch = field(default=None)
  822. _index_name: str = field(default="", init=False)
  823. _index_ready: bool = field(default=False, init=False)
  824. def __init__(self, namespace, global_config, embedding_func, workspace=None):
  825. super().__init__(
  826. namespace=namespace,
  827. workspace=workspace or "",
  828. global_config=global_config,
  829. embedding_func=embedding_func,
  830. )
  831. self.__post_init__()
  832. def __post_init__(self):
  833. self.workspace, self.final_namespace, self._index_name = _build_index_name(
  834. self.workspace, self.namespace
  835. )
  836. def _prepare_doc_status_data(self, doc: dict[str, Any]) -> dict[str, Any]:
  837. """Normalize a raw OpenSearch document to DocProcessingStatus-compatible dict."""
  838. data = doc.copy()
  839. data.pop("_id", None)
  840. data.pop("__mirrored_id", None)
  841. if "file_path" not in data:
  842. data["file_path"] = "no-file-path"
  843. data.setdefault("metadata", {})
  844. data.setdefault("error_msg", None)
  845. if "error" in data:
  846. if not data.get("error_msg"):
  847. data["error_msg"] = data.pop("error")
  848. else:
  849. data.pop("error", None)
  850. return data
  851. async def initialize(self):
  852. """Initialize client connection and create doc status index."""
  853. async with get_data_init_lock():
  854. if self.client is None:
  855. self.client = await ClientManager.get_client()
  856. await self._create_index_if_not_exists()
  857. self._index_ready = True
  858. logger.debug(
  859. f"[{self.workspace}] OpenSearch DocStatus storage initialized: {self._index_name}"
  860. )
  861. async def _ensure_index_ready(self):
  862. """Recreate the doc status index after drop before the next write."""
  863. if self._index_ready:
  864. return
  865. async with get_data_init_lock():
  866. if self.client is None:
  867. self.client = await ClientManager.get_client()
  868. if not self._index_ready:
  869. await self._create_index_if_not_exists()
  870. self._index_ready = True
  871. def _mark_index_missing(self):
  872. """Mark the doc status index as unavailable for subsequent read short-circuiting."""
  873. self._index_ready = False
  874. async def _create_index_if_not_exists(self):
  875. try:
  876. if not await self.client.indices.exists(index=self._index_name):
  877. body = {
  878. "mappings": {
  879. "dynamic": True,
  880. "properties": {
  881. "__mirrored_id": {"type": "keyword"},
  882. "status": {"type": "keyword"},
  883. "file_path": {"type": "keyword"},
  884. "track_id": {"type": "keyword"},
  885. "content_hash": {"type": "keyword"},
  886. "created_at": {"type": "date"},
  887. "updated_at": {"type": "date"},
  888. },
  889. },
  890. "settings": {
  891. "index": {
  892. "number_of_shards": _get_index_number_of_shards(),
  893. "number_of_replicas": _get_index_number_of_replicas(),
  894. },
  895. },
  896. }
  897. await self.client.indices.create(index=self._index_name, body=body)
  898. logger.info(
  899. f"[{self.workspace}] Created doc status index: {self._index_name}"
  900. )
  901. else:
  902. await _verify_mirrored_id_mapping(self.client, self._index_name)
  903. await self._ensure_content_hash_mapping()
  904. except RequestError as e:
  905. if "resource_already_exists_exception" not in str(e):
  906. raise
  907. except OpenSearchException as e:
  908. logger.error(f"[{self.workspace}] Error creating doc status index: {e}")
  909. raise
  910. async def _ensure_content_hash_mapping(self) -> None:
  911. """Add the content_hash keyword mapping to a pre-existing doc status index.
  912. Indices created by older LightRAG releases lack content_hash entirely.
  913. put_mapping is idempotent for new fields, so this is safe to call every
  914. startup; we only fail loudly when the cluster reports a mapping conflict
  915. (which would indicate dynamic mapping already coerced content_hash to a
  916. different type).
  917. """
  918. try:
  919. mapping = await self.client.indices.get_mapping(index=self._index_name)
  920. except OpenSearchException:
  921. return
  922. props = (
  923. mapping.get(self._index_name, {}).get("mappings", {}).get("properties", {})
  924. )
  925. if "content_hash" in props:
  926. return
  927. try:
  928. await self.client.indices.put_mapping(
  929. index=self._index_name,
  930. body={"properties": {"content_hash": {"type": "keyword"}}},
  931. )
  932. logger.info(
  933. f"[{self.workspace}] Added content_hash keyword mapping to {self._index_name}"
  934. )
  935. except OpenSearchException as e:
  936. logger.warning(
  937. f"[{self.workspace}] Failed to add content_hash mapping to "
  938. f"{self._index_name}: {e}"
  939. )
  940. async def finalize(self):
  941. """Release the OpenSearch client connection."""
  942. if self.client is not None:
  943. await ClientManager.release_client(self.client)
  944. self.client = None
  945. async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
  946. """Get a document status record by ID."""
  947. if not self._index_ready:
  948. return None
  949. try:
  950. response = await _mget_optional_doc(self.client, self._index_name, id)
  951. if response is None:
  952. return None
  953. doc = response["_source"]
  954. doc["_id"] = response["_id"]
  955. return doc
  956. except OpenSearchException as e:
  957. if _is_missing_index_error(e):
  958. self._mark_index_missing()
  959. return None
  960. logger.error(f"[{self.workspace}] Error getting doc status {id}: {e}")
  961. return None
  962. async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
  963. """Get multiple document status records by IDs."""
  964. if not self._index_ready:
  965. return [None] * len(ids)
  966. try:
  967. response = await self.client.mget(index=self._index_name, body={"ids": ids})
  968. doc_map = {}
  969. for doc in response["docs"]:
  970. if doc.get("found"):
  971. data = doc["_source"]
  972. data["_id"] = doc["_id"]
  973. doc_map[doc["_id"]] = data
  974. return [doc_map.get(id) for id in ids]
  975. except OpenSearchException as e:
  976. if _is_missing_index_error(e):
  977. self._mark_index_missing()
  978. return [None] * len(ids)
  979. logger.error(f"[{self.workspace}] Error getting doc statuses: {e}")
  980. return [None] * len(ids)
  981. async def filter_keys(self, keys: set[str]) -> set[str]:
  982. """Return the subset of keys that do not exist in storage."""
  983. if not self._index_ready:
  984. return keys
  985. try:
  986. response = await self.client.mget(
  987. index=self._index_name, body={"ids": list(keys)}, _source=False
  988. )
  989. existing_ids = {doc["_id"] for doc in response["docs"] if doc.get("found")}
  990. return keys - existing_ids
  991. except OpenSearchException as e:
  992. if _is_missing_index_error(e):
  993. self._mark_index_missing()
  994. return keys
  995. logger.error(f"[{self.workspace}] Error filtering keys: {e}")
  996. return keys
  997. async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
  998. """Insert or update document status records."""
  999. if not data:
  1000. return
  1001. await self._ensure_index_ready()
  1002. logger.debug(f"[{self.workspace}] Upserting {len(data)} doc statuses")
  1003. actions = []
  1004. for i, (k, v) in enumerate(data.items(), start=1):
  1005. v.setdefault("chunks_list", [])
  1006. source = {fk: fv for fk, fv in v.items() if fk != "_id"}
  1007. source["__mirrored_id"] = k
  1008. actions.append(
  1009. {
  1010. "_op_type": "index",
  1011. "_index": self._index_name,
  1012. "_id": k,
  1013. "_source": source,
  1014. }
  1015. )
  1016. await _cooperative_yield(i)
  1017. try:
  1018. # DocStatus needs refresh="wait_for" because get_docs_by_status
  1019. # (search-based) is called immediately after enqueue upserts.
  1020. await helpers.async_bulk(
  1021. self.client, actions, raise_on_error=False, refresh="wait_for"
  1022. )
  1023. except OpenSearchException as e:
  1024. logger.error(f"[{self.workspace}] Error upserting doc statuses: {e}")
  1025. async def get_status_counts(self) -> dict[str, int]:
  1026. """Get document counts grouped by status."""
  1027. if not self._index_ready:
  1028. return {}
  1029. try:
  1030. body = {
  1031. "size": 0,
  1032. "aggs": {"status_counts": {"terms": {"field": "status", "size": 100}}},
  1033. }
  1034. response = await self.client.search(index=self._index_name, body=body)
  1035. return {
  1036. bucket["key"]: bucket["doc_count"]
  1037. for bucket in response["aggregations"]["status_counts"]["buckets"]
  1038. }
  1039. except OpenSearchException as e:
  1040. if _is_missing_index_error(e):
  1041. self._mark_index_missing()
  1042. return {}
  1043. logger.error(f"[{self.workspace}] Error getting status counts: {e}")
  1044. return {}
  1045. async def _search_all_docs(self, query: dict) -> dict[str, DocProcessingStatus]:
  1046. """Fetch all documents matching a query using PIT + search_after."""
  1047. if not self._index_ready:
  1048. return {}
  1049. result = {}
  1050. batch_size = 10000
  1051. try:
  1052. pit = await self.client.create_pit(
  1053. index=self._index_name, params={"keep_alive": "1m"}
  1054. )
  1055. pit_id = pit["pit_id"]
  1056. try:
  1057. search_after = None
  1058. while True:
  1059. body = {
  1060. "query": query,
  1061. "size": batch_size,
  1062. "pit": {"id": pit_id, "keep_alive": "1m"},
  1063. "sort": _pit_sort_with_field("__mirrored_id"),
  1064. }
  1065. if search_after:
  1066. body["search_after"] = search_after
  1067. response = await self.client.search(body=body)
  1068. hits = response["hits"]["hits"]
  1069. if not hits:
  1070. break
  1071. for hit in hits:
  1072. try:
  1073. data = self._prepare_doc_status_data(hit["_source"])
  1074. result[hit["_id"]] = DocProcessingStatus(**data)
  1075. except (KeyError, TypeError) as e:
  1076. logger.error(
  1077. f"[{self.workspace}] Error parsing doc {hit['_id']}: {e}"
  1078. )
  1079. search_after = hits[-1]["sort"]
  1080. if len(hits) < batch_size:
  1081. break
  1082. finally:
  1083. try:
  1084. await self.client.delete_pit(body={"pit_id": [pit_id]})
  1085. except Exception:
  1086. pass
  1087. except OpenSearchException as e:
  1088. if _is_missing_index_error(e):
  1089. self._mark_index_missing()
  1090. return {}
  1091. logger.error(f"[{self.workspace}] Error fetching docs: {e}")
  1092. return result
  1093. async def get_docs_by_status(
  1094. self, status: DocStatus
  1095. ) -> dict[str, DocProcessingStatus]:
  1096. """Get all documents matching a specific processing status."""
  1097. return await self.get_docs_by_statuses([status])
  1098. async def get_docs_by_statuses(
  1099. self, statuses: list[DocStatus]
  1100. ) -> dict[str, DocProcessingStatus]:
  1101. """Get all documents matching any of the given statuses in a single query.
  1102. Uses OpenSearch's terms query (multi-value equivalent of term) to fetch
  1103. all matching statuses in one PIT + search_after pass instead of one
  1104. full scan per status.
  1105. """
  1106. if not statuses:
  1107. return {}
  1108. status_values = [s.value for s in statuses]
  1109. return await self._search_all_docs({"terms": {"status": status_values}})
  1110. async def get_docs_by_track_id(
  1111. self, track_id: str
  1112. ) -> dict[str, DocProcessingStatus]:
  1113. """Get all documents matching a specific track ID."""
  1114. return await self._search_all_docs({"term": {"track_id": track_id}})
  1115. async def get_docs_paginated(
  1116. self,
  1117. status_filter: DocStatus | None = None,
  1118. status_filters: list[DocStatus] | None = None,
  1119. page: int = 1,
  1120. page_size: int = 50,
  1121. sort_field: str = "updated_at",
  1122. sort_direction: str = "desc",
  1123. ) -> tuple[list[tuple[str, DocProcessingStatus]], int]:
  1124. """Get documents with pagination using PIT + search_after."""
  1125. if not self._index_ready:
  1126. return [], 0
  1127. status_filter_values = self.resolve_status_filter_values(
  1128. status_filter=status_filter,
  1129. status_filters=status_filters,
  1130. )
  1131. page = max(1, page)
  1132. page_size = max(10, min(200, page_size))
  1133. if sort_field == "id":
  1134. sort_field = "_id"
  1135. if sort_field not in ("created_at", "updated_at", "_id", "file_path"):
  1136. sort_field = "updated_at"
  1137. sort_order = "asc" if sort_direction.lower() == "asc" else "desc"
  1138. query = {"match_all": {}}
  1139. if status_filter_values is not None:
  1140. if len(status_filter_values) == 1:
  1141. query = {"term": {"status": next(iter(status_filter_values))}}
  1142. else:
  1143. query = {"terms": {"status": sorted(status_filter_values)}}
  1144. skip_count = (page - 1) * page_size
  1145. try:
  1146. count_resp = await self.client.count(
  1147. index=self._index_name, body={"query": query}
  1148. )
  1149. total_count = count_resp.get("count", 0)
  1150. if total_count == 0 or skip_count >= total_count:
  1151. return [], total_count
  1152. sort_clause = [{sort_field: {"order": sort_order}}] + _pit_sort_with_field(
  1153. "__mirrored_id"
  1154. )
  1155. pit = await self.client.create_pit(
  1156. index=self._index_name, params={"keep_alive": "1m"}
  1157. )
  1158. pit_id = pit["pit_id"]
  1159. try:
  1160. search_after = None
  1161. skipped = 0
  1162. while skipped < skip_count:
  1163. batch = min(page_size, skip_count - skipped)
  1164. body = {
  1165. "query": query,
  1166. "sort": sort_clause,
  1167. "size": batch,
  1168. "pit": {"id": pit_id, "keep_alive": "1m"},
  1169. }
  1170. if search_after:
  1171. body["search_after"] = search_after
  1172. resp = await self.client.search(body=body)
  1173. hits = resp["hits"]["hits"]
  1174. if not hits:
  1175. return [], total_count
  1176. search_after = hits[-1]["sort"]
  1177. skipped += len(hits)
  1178. body = {
  1179. "query": query,
  1180. "sort": sort_clause,
  1181. "size": page_size,
  1182. "pit": {"id": pit_id, "keep_alive": "1m"},
  1183. }
  1184. if search_after:
  1185. body["search_after"] = search_after
  1186. response = await self.client.search(body=body)
  1187. finally:
  1188. try:
  1189. await self.client.delete_pit(body={"pit_id": [pit_id]})
  1190. except Exception:
  1191. pass
  1192. documents = []
  1193. for hit in response["hits"]["hits"]:
  1194. try:
  1195. data = self._prepare_doc_status_data(hit["_source"])
  1196. documents.append((hit["_id"], DocProcessingStatus(**data)))
  1197. except (KeyError, TypeError) as e:
  1198. logger.error(
  1199. f"[{self.workspace}] Error parsing doc {hit['_id']}: {e}"
  1200. )
  1201. return documents, total_count
  1202. except OpenSearchException as e:
  1203. if _is_missing_index_error(e):
  1204. self._mark_index_missing()
  1205. return [], 0
  1206. logger.error(f"[{self.workspace}] Error in paginated query: {e}")
  1207. return [], 0
  1208. async def get_all_status_counts(self) -> dict[str, int]:
  1209. """Get document counts for all statuses including an 'all' total."""
  1210. if not self._index_ready:
  1211. return {}
  1212. try:
  1213. body = {
  1214. "size": 0,
  1215. "aggs": {"status_counts": {"terms": {"field": "status", "size": 100}}},
  1216. }
  1217. response = await self.client.search(index=self._index_name, body=body)
  1218. counts = {}
  1219. total = 0
  1220. for bucket in response["aggregations"]["status_counts"]["buckets"]:
  1221. counts[bucket["key"]] = bucket["doc_count"]
  1222. total += bucket["doc_count"]
  1223. counts["all"] = total
  1224. return counts
  1225. except OpenSearchException as e:
  1226. if _is_missing_index_error(e):
  1227. self._mark_index_missing()
  1228. return {}
  1229. logger.error(f"[{self.workspace}] Error getting all status counts: {e}")
  1230. return {}
  1231. async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]:
  1232. """Find a document status record by its file_path field."""
  1233. if not self._index_ready:
  1234. return None
  1235. try:
  1236. body = {"query": {"term": {"file_path": file_path}}, "size": 1}
  1237. response = await self.client.search(index=self._index_name, body=body)
  1238. hits = response["hits"]["hits"]
  1239. if hits:
  1240. doc = hits[0]["_source"]
  1241. doc["_id"] = hits[0]["_id"]
  1242. return doc
  1243. return None
  1244. except OpenSearchException as e:
  1245. if _is_missing_index_error(e):
  1246. self._mark_index_missing()
  1247. return None
  1248. logger.error(f"[{self.workspace}] Error getting doc by file_path: {e}")
  1249. return None
  1250. async def get_doc_by_file_basename(
  1251. self, basename: str
  1252. ) -> Union[tuple[str, dict[str, Any]], None]:
  1253. """Find an existing record whose canonical basename matches.
  1254. The caller is responsible for passing an already-canonical basename;
  1255. stored ``file_path`` values are canonicalized by the business layer, so
  1256. this lookup performs an exact term query against the file_path keyword
  1257. field.
  1258. """
  1259. if not basename:
  1260. return None
  1261. if basename == "unknown_source":
  1262. return None
  1263. if not self._index_ready:
  1264. return None
  1265. try:
  1266. body = {"query": {"term": {"file_path": basename}}, "size": 1}
  1267. response = await self.client.search(index=self._index_name, body=body)
  1268. hits = response["hits"]["hits"]
  1269. if not hits:
  1270. return None
  1271. hit = hits[0]
  1272. doc = hit["_source"]
  1273. return hit["_id"], doc
  1274. except OpenSearchException as e:
  1275. if _is_missing_index_error(e):
  1276. self._mark_index_missing()
  1277. return None
  1278. logger.error(f"[{self.workspace}] Error getting doc by file_basename: {e}")
  1279. return None
  1280. async def get_doc_by_content_hash(
  1281. self, content_hash: str
  1282. ) -> Union[tuple[str, dict[str, Any]], None]:
  1283. """Find an existing record whose content_hash field matches.
  1284. Uses the content_hash keyword mapping created by
  1285. ``_create_index_if_not_exists`` / ``_ensure_content_hash_mapping``.
  1286. Empty values short-circuit so legacy rows without the field cannot
  1287. accidentally match via type coercion.
  1288. """
  1289. if not content_hash:
  1290. return None
  1291. if not self._index_ready:
  1292. return None
  1293. try:
  1294. body = {"query": {"term": {"content_hash": content_hash}}, "size": 1}
  1295. response = await self.client.search(index=self._index_name, body=body)
  1296. hits = response["hits"]["hits"]
  1297. if not hits:
  1298. return None
  1299. hit = hits[0]
  1300. doc = hit["_source"]
  1301. return hit["_id"], doc
  1302. except OpenSearchException as e:
  1303. if _is_missing_index_error(e):
  1304. self._mark_index_missing()
  1305. return None
  1306. logger.error(f"[{self.workspace}] Error getting doc by content_hash: {e}")
  1307. return None
  1308. async def index_done_callback(self) -> None:
  1309. """Refresh index to make recently indexed documents searchable."""
  1310. if not self._index_ready:
  1311. return
  1312. try:
  1313. await self.client.indices.refresh(index=self._index_name)
  1314. except OpenSearchException as e:
  1315. if _is_missing_index_error(e):
  1316. self._mark_index_missing()
  1317. return
  1318. except Exception:
  1319. pass
  1320. async def is_empty(self) -> bool:
  1321. """Return True if the index contains no documents."""
  1322. if not self._index_ready:
  1323. return True
  1324. try:
  1325. response = await self.client.count(index=self._index_name)
  1326. return response["count"] == 0
  1327. except OpenSearchException as e:
  1328. if _is_missing_index_error(e):
  1329. self._mark_index_missing()
  1330. return True
  1331. async def delete(self, ids: list[str]) -> None:
  1332. """Delete document status records by IDs."""
  1333. if not ids:
  1334. return
  1335. if not self._index_ready:
  1336. return
  1337. if isinstance(ids, set):
  1338. ids = list(ids)
  1339. try:
  1340. # DocStatus needs refresh="wait_for" because downstream readers
  1341. # (get_docs_by_status, get_docs_paginated, etc.) are search-based
  1342. # and callers like _validate_and_fix_document_consistency() may
  1343. # query immediately after deletion without index_done_callback().
  1344. actions = [
  1345. {"_op_type": "delete", "_index": self._index_name, "_id": doc_id}
  1346. for doc_id in ids
  1347. ]
  1348. await helpers.async_bulk(
  1349. self.client, actions, raise_on_error=False, refresh="wait_for"
  1350. )
  1351. except OpenSearchException as e:
  1352. if _is_missing_index_error(e):
  1353. self._mark_index_missing()
  1354. return
  1355. logger.error(f"[{self.workspace}] Error deleting doc statuses: {e}")
  1356. async def drop(self) -> dict[str, str]:
  1357. """Delete the entire doc status index."""
  1358. try:
  1359. try:
  1360. await self.client.indices.delete(index=self._index_name)
  1361. logger.info(
  1362. f"[{self.workspace}] Dropped doc status index: {self._index_name}"
  1363. )
  1364. except NotFoundError:
  1365. logger.info(
  1366. f"[{self.workspace}] Doc status index already missing during drop: {self._index_name}"
  1367. )
  1368. self._mark_index_missing()
  1369. return {"status": "success", "message": f"Index {self._index_name} dropped"}
  1370. except OpenSearchException as e:
  1371. self._mark_index_missing()
  1372. logger.error(f"[{self.workspace}] Error dropping doc status index: {e}")
  1373. return {"status": "error", "message": str(e)}
  1374. except Exception as e:
  1375. self._mark_index_missing()
  1376. logger.error(
  1377. f"[{self.workspace}] Unexpected error dropping doc status index: {e}"
  1378. )
  1379. return {"status": "error", "message": str(e)}
  1380. @final
  1381. @dataclass
  1382. class OpenSearchGraphStorage(BaseGraphStorage):
  1383. """Graph storage using OpenSearch with separate nodes and edges indices.
  1384. Supports two BFS traversal strategies:
  1385. - PPL graphlookup (server-side BFS, requires OpenSearch SQL plugin with Calcite engine)
  1386. - Application-level batched BFS (fallback, works on any OpenSearch 3.x+)
  1387. The strategy is auto-detected during initialize() and can be overridden via
  1388. the OPENSEARCH_USE_PPL_GRAPHLOOKUP environment variable (true/false).
  1389. """
  1390. client: AsyncOpenSearch = field(default=None)
  1391. _nodes_index: str = field(default="", init=False)
  1392. _edges_index: str = field(default="", init=False)
  1393. _indices_ready: bool = field(default=False, init=False)
  1394. _nodes_dirty: bool = field(default=False, init=False)
  1395. _edges_dirty: bool = field(default=False, init=False)
  1396. _ppl_graphlookup_available: bool = field(default=False, init=False)
  1397. def __init__(self, namespace, global_config, embedding_func, workspace=None):
  1398. super().__init__(
  1399. namespace=namespace,
  1400. workspace=workspace or "",
  1401. global_config=global_config,
  1402. embedding_func=embedding_func,
  1403. )
  1404. self.__post_init__()
  1405. def __post_init__(self):
  1406. self.workspace, self.final_namespace, base_name = _build_index_name(
  1407. self.workspace, self.namespace
  1408. )
  1409. self._nodes_index = f"{base_name}-nodes"
  1410. self._edges_index = f"{base_name}-edges"
  1411. async def initialize(self):
  1412. """Initialize client, create indices, and detect PPL graphlookup support."""
  1413. async with get_data_init_lock():
  1414. if self.client is None:
  1415. self.client = await ClientManager.get_client()
  1416. await self._create_indices_if_not_exist()
  1417. self._indices_ready = True
  1418. self._nodes_dirty = False
  1419. self._edges_dirty = False
  1420. await self._detect_ppl_graphlookup()
  1421. logger.debug(
  1422. f"[{self.workspace}] OpenSearch Graph storage initialized: "
  1423. f"{self._nodes_index}, {self._edges_index} "
  1424. f"(PPL graphlookup: {self._ppl_graphlookup_available})"
  1425. )
  1426. async def _ensure_indices_ready(self):
  1427. """Recreate graph indices after drop before the next write."""
  1428. if self._indices_ready:
  1429. return
  1430. async with get_data_init_lock():
  1431. if self.client is None:
  1432. self.client = await ClientManager.get_client()
  1433. if not self._indices_ready:
  1434. await self._create_indices_if_not_exist()
  1435. self._indices_ready = True
  1436. def _mark_indices_missing(self):
  1437. """Mark graph indices as unavailable for subsequent read short-circuiting."""
  1438. self._indices_ready = False
  1439. self._nodes_dirty = False
  1440. self._edges_dirty = False
  1441. async def _refresh_graph_indices_if_dirty(
  1442. self, *, refresh_nodes: bool = False, refresh_edges: bool = False
  1443. ) -> None:
  1444. """Refresh graph indices only when prior writes made search views stale."""
  1445. if not self._indices_ready:
  1446. return
  1447. if not (
  1448. (refresh_nodes and self._nodes_dirty)
  1449. or (refresh_edges and self._edges_dirty)
  1450. ):
  1451. return
  1452. try:
  1453. async with get_data_init_lock():
  1454. if refresh_nodes and self._nodes_dirty:
  1455. await self.client.indices.refresh(index=self._nodes_index)
  1456. self._nodes_dirty = False
  1457. if refresh_edges and self._edges_dirty:
  1458. await self.client.indices.refresh(index=self._edges_index)
  1459. self._edges_dirty = False
  1460. except OpenSearchException as e:
  1461. if _is_missing_index_error(e):
  1462. self._mark_indices_missing()
  1463. return
  1464. raise
  1465. async def _detect_ppl_graphlookup(self):
  1466. """Detect whether PPL graphlookup command is available on this cluster."""
  1467. env_override = os.environ.get("OPENSEARCH_USE_PPL_GRAPHLOOKUP", "").lower()
  1468. if env_override == "true":
  1469. self._ppl_graphlookup_available = True
  1470. return
  1471. if env_override == "false":
  1472. self._ppl_graphlookup_available = False
  1473. return
  1474. # Auto-detect by sending a minimal PPL query
  1475. try:
  1476. await self.client.transport.perform_request(
  1477. "POST",
  1478. "/_plugins/_ppl",
  1479. body={"query": f"source = {self._edges_index} | head 0"},
  1480. )
  1481. # PPL endpoint works; now test graphlookup syntax with a no-op query
  1482. await self.client.transport.perform_request(
  1483. "POST",
  1484. "/_plugins/_ppl",
  1485. body={
  1486. "query": (
  1487. f"source = {self._edges_index} | head 1 "
  1488. f"| graphLookup {self._edges_index} "
  1489. f"start=source_node_id edge=target_node_id-->source_node_id "
  1490. f"maxDepth=0 as _gl_probe"
  1491. )
  1492. },
  1493. )
  1494. self._ppl_graphlookup_available = True
  1495. logger.info(
  1496. f"[{self.workspace}] PPL graphlookup is available, using server-side BFS"
  1497. )
  1498. except Exception:
  1499. self._ppl_graphlookup_available = False
  1500. logger.info(
  1501. f"[{self.workspace}] PPL graphlookup not available, using client-side BFS"
  1502. )
  1503. async def _create_indices_if_not_exist(self):
  1504. try:
  1505. if not await self.client.indices.exists(index=self._nodes_index):
  1506. body = {
  1507. "mappings": {
  1508. "dynamic": True,
  1509. "properties": {
  1510. "entity_id": {"type": "keyword"},
  1511. "entity_type": {"type": "keyword"},
  1512. "description": {"type": "text"},
  1513. "source_id": {"type": "text"},
  1514. "source_ids": {"type": "keyword"},
  1515. "file_path": {"type": "keyword"},
  1516. "created_at": {"type": "long"},
  1517. },
  1518. },
  1519. "settings": {
  1520. "index": {
  1521. "number_of_shards": _get_index_number_of_shards(),
  1522. "number_of_replicas": _get_index_number_of_replicas(),
  1523. }
  1524. },
  1525. }
  1526. await self.client.indices.create(index=self._nodes_index, body=body)
  1527. logger.info(
  1528. f"[{self.workspace}] Created nodes index: {self._nodes_index}"
  1529. )
  1530. except RequestError as e:
  1531. if "resource_already_exists_exception" not in str(e):
  1532. raise
  1533. try:
  1534. if not await self.client.indices.exists(index=self._edges_index):
  1535. body = {
  1536. "mappings": {
  1537. "dynamic": True,
  1538. "properties": {
  1539. "source_node_id": {"type": "keyword"},
  1540. "target_node_id": {"type": "keyword"},
  1541. "relationship": {"type": "keyword"},
  1542. "description": {"type": "text"},
  1543. "weight": {"type": "float"},
  1544. "keywords": {"type": "text"},
  1545. "source_id": {"type": "text"},
  1546. "source_ids": {"type": "keyword"},
  1547. "file_path": {"type": "keyword"},
  1548. "created_at": {"type": "long"},
  1549. },
  1550. },
  1551. "settings": {
  1552. "index": {
  1553. "number_of_shards": _get_index_number_of_shards(),
  1554. "number_of_replicas": _get_index_number_of_replicas(),
  1555. }
  1556. },
  1557. }
  1558. await self.client.indices.create(index=self._edges_index, body=body)
  1559. logger.info(
  1560. f"[{self.workspace}] Created edges index: {self._edges_index}"
  1561. )
  1562. except RequestError as e:
  1563. if "resource_already_exists_exception" not in str(e):
  1564. raise
  1565. async def finalize(self):
  1566. """Release the OpenSearch client connection."""
  1567. if self.client is not None:
  1568. await ClientManager.release_client(self.client)
  1569. self.client = None
  1570. # --- Basic queries ---
  1571. async def has_node(self, node_id: str) -> bool:
  1572. """Check whether a node exists in the graph."""
  1573. if not self._indices_ready:
  1574. return False
  1575. try:
  1576. return await self.client.exists(index=self._nodes_index, id=node_id)
  1577. except OpenSearchException as e:
  1578. if _is_missing_index_error(e):
  1579. self._mark_indices_missing()
  1580. return False
  1581. async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
  1582. """Check whether an edge exists between two nodes (bidirectional).
  1583. Uses mget with the two candidate edge IDs so the check is real-time
  1584. (translog-backed), consistent with has_node() and independent of the
  1585. index refresh cycle.
  1586. """
  1587. if not self._indices_ready:
  1588. return False
  1589. try:
  1590. forward_id = compute_mdhash_id(
  1591. f"{source_node_id}-{target_node_id}", prefix="edge-"
  1592. )
  1593. reverse_id = compute_mdhash_id(
  1594. f"{target_node_id}-{source_node_id}", prefix="edge-"
  1595. )
  1596. response = await self.client.mget(
  1597. index=self._edges_index, body={"ids": [forward_id, reverse_id]}
  1598. )
  1599. return any(doc.get("found") for doc in response.get("docs", []))
  1600. except OpenSearchException as e:
  1601. if _is_missing_index_error(e):
  1602. self._mark_indices_missing()
  1603. return False
  1604. async def node_degree(self, node_id: str) -> int:
  1605. """Count the number of edges connected to a node."""
  1606. if not self._indices_ready:
  1607. return 0
  1608. try:
  1609. await self._refresh_graph_indices_if_dirty(refresh_edges=True)
  1610. response = await self.client.count(
  1611. index=self._edges_index,
  1612. body={
  1613. "query": {
  1614. "bool": {
  1615. "should": [
  1616. {"term": {"source_node_id": node_id}},
  1617. {"term": {"target_node_id": node_id}},
  1618. ]
  1619. }
  1620. }
  1621. },
  1622. )
  1623. return response.get("count", 0)
  1624. except OpenSearchException as e:
  1625. if _is_missing_index_error(e):
  1626. self._mark_indices_missing()
  1627. return 0
  1628. async def edge_degree(self, src_id: str, tgt_id: str) -> int:
  1629. """Sum of degrees of both endpoint nodes."""
  1630. src_degree = await self.node_degree(src_id)
  1631. tgt_degree = await self.node_degree(tgt_id)
  1632. return src_degree + tgt_degree
  1633. async def get_node(self, node_id: str) -> dict[str, str] | None:
  1634. """Get a node document by ID, or None if not found."""
  1635. if not self._indices_ready:
  1636. return None
  1637. try:
  1638. response = await _mget_optional_doc(self.client, self._nodes_index, node_id)
  1639. if response is None:
  1640. return None
  1641. doc = response["_source"]
  1642. doc["_id"] = response["_id"]
  1643. return doc
  1644. except OpenSearchException as e:
  1645. if _is_missing_index_error(e):
  1646. self._mark_indices_missing()
  1647. return None
  1648. async def get_edge(
  1649. self, source_node_id: str, target_node_id: str
  1650. ) -> dict[str, str] | None:
  1651. """Get an edge between two nodes (bidirectional), or None.
  1652. Uses mget with the two candidate edge IDs so the read is real-time
  1653. (translog-backed), consistent with get_node() and independent of the
  1654. index refresh cycle.
  1655. """
  1656. if not self._indices_ready:
  1657. return None
  1658. try:
  1659. forward_id = compute_mdhash_id(
  1660. f"{source_node_id}-{target_node_id}", prefix="edge-"
  1661. )
  1662. reverse_id = compute_mdhash_id(
  1663. f"{target_node_id}-{source_node_id}", prefix="edge-"
  1664. )
  1665. response = await self.client.mget(
  1666. index=self._edges_index, body={"ids": [forward_id, reverse_id]}
  1667. )
  1668. for doc in response.get("docs", []):
  1669. if doc.get("found"):
  1670. result = doc["_source"]
  1671. result["_id"] = doc["_id"]
  1672. return result
  1673. return None
  1674. except OpenSearchException as e:
  1675. if _is_missing_index_error(e):
  1676. self._mark_indices_missing()
  1677. return None
  1678. async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
  1679. """Get all (source, target) edge tuples connected to a node."""
  1680. if not self._indices_ready:
  1681. return None
  1682. try:
  1683. await self._refresh_graph_indices_if_dirty(refresh_edges=True)
  1684. query = {
  1685. "bool": {
  1686. "should": [
  1687. {"term": {"source_node_id": source_node_id}},
  1688. {"term": {"target_node_id": source_node_id}},
  1689. ]
  1690. }
  1691. }
  1692. edges = []
  1693. pit = await self.client.create_pit(
  1694. index=self._edges_index, params={"keep_alive": "1m"}
  1695. )
  1696. pit_id = pit["pit_id"]
  1697. try:
  1698. search_after = None
  1699. while True:
  1700. body = {
  1701. "query": query,
  1702. "_source": ["source_node_id", "target_node_id"],
  1703. "size": 10000,
  1704. "pit": {"id": pit_id, "keep_alive": "1m"},
  1705. "sort": _pit_sort_with_composite_key(
  1706. "source_node_id", "target_node_id"
  1707. ),
  1708. }
  1709. if search_after:
  1710. body["search_after"] = search_after
  1711. response = await self.client.search(body=body)
  1712. hits = response["hits"]["hits"]
  1713. if not hits:
  1714. break
  1715. for hit in hits:
  1716. edges.append(
  1717. (
  1718. hit["_source"]["source_node_id"],
  1719. hit["_source"]["target_node_id"],
  1720. )
  1721. )
  1722. search_after = hits[-1]["sort"]
  1723. if len(hits) < 10000:
  1724. break
  1725. finally:
  1726. try:
  1727. await self.client.delete_pit(body={"pit_id": [pit_id]})
  1728. except Exception:
  1729. pass
  1730. return edges
  1731. except OpenSearchException as e:
  1732. if _is_missing_index_error(e):
  1733. self._mark_indices_missing()
  1734. return None
  1735. # --- Batch operations ---
  1736. async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
  1737. """Batch-fetch multiple nodes by ID."""
  1738. if not self._indices_ready:
  1739. return {}
  1740. try:
  1741. response = await self.client.mget(
  1742. index=self._nodes_index, body={"ids": node_ids}
  1743. )
  1744. result = {}
  1745. for doc in response["docs"]:
  1746. if doc.get("found"):
  1747. data = doc["_source"]
  1748. data["_id"] = doc["_id"]
  1749. result[doc["_id"]] = data
  1750. return result
  1751. except OpenSearchException as e:
  1752. if _is_missing_index_error(e):
  1753. self._mark_indices_missing()
  1754. return {}
  1755. async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
  1756. """Batch-fetch edge counts for multiple nodes using aggregations."""
  1757. if not node_ids:
  1758. return {}
  1759. if not self._indices_ready:
  1760. return {}
  1761. try:
  1762. await self._refresh_graph_indices_if_dirty(refresh_edges=True)
  1763. # Use a single query with aggregations for both source and target
  1764. body = {
  1765. "size": 0,
  1766. "query": {
  1767. "bool": {
  1768. "should": [
  1769. {"terms": {"source_node_id": node_ids}},
  1770. {"terms": {"target_node_id": node_ids}},
  1771. ]
  1772. }
  1773. },
  1774. "aggs": {
  1775. "source_degrees": {
  1776. "terms": {
  1777. "field": "source_node_id",
  1778. "size": len(node_ids) * 2,
  1779. }
  1780. },
  1781. "target_degrees": {
  1782. "terms": {
  1783. "field": "target_node_id",
  1784. "size": len(node_ids) * 2,
  1785. }
  1786. },
  1787. },
  1788. }
  1789. response = await self.client.search(index=self._edges_index, body=body)
  1790. result = {}
  1791. for bucket in response["aggregations"]["source_degrees"]["buckets"]:
  1792. if bucket["key"] in node_ids:
  1793. result[bucket["key"]] = (
  1794. result.get(bucket["key"], 0) + bucket["doc_count"]
  1795. )
  1796. for bucket in response["aggregations"]["target_degrees"]["buckets"]:
  1797. if bucket["key"] in node_ids:
  1798. result[bucket["key"]] = (
  1799. result.get(bucket["key"], 0) + bucket["doc_count"]
  1800. )
  1801. return result
  1802. except OpenSearchException as e:
  1803. if _is_missing_index_error(e):
  1804. self._mark_indices_missing()
  1805. return {}
  1806. async def get_nodes_edges_batch(
  1807. self, node_ids: list[str]
  1808. ) -> dict[str, list[tuple[str, str]]]:
  1809. """Batch-fetch edge tuples for multiple nodes."""
  1810. result = {nid: [] for nid in node_ids}
  1811. if not self._indices_ready:
  1812. return result
  1813. try:
  1814. await self._refresh_graph_indices_if_dirty(refresh_edges=True)
  1815. query = {
  1816. "bool": {
  1817. "should": [
  1818. {"terms": {"source_node_id": node_ids}},
  1819. {"terms": {"target_node_id": node_ids}},
  1820. ]
  1821. }
  1822. }
  1823. pit = await self.client.create_pit(
  1824. index=self._edges_index, params={"keep_alive": "1m"}
  1825. )
  1826. pit_id = pit["pit_id"]
  1827. try:
  1828. search_after = None
  1829. while True:
  1830. body = {
  1831. "query": query,
  1832. "_source": ["source_node_id", "target_node_id"],
  1833. "size": 10000,
  1834. "pit": {"id": pit_id, "keep_alive": "1m"},
  1835. "sort": _pit_sort_with_composite_key(
  1836. "source_node_id", "target_node_id"
  1837. ),
  1838. }
  1839. if search_after:
  1840. body["search_after"] = search_after
  1841. response = await self.client.search(body=body)
  1842. hits = response["hits"]["hits"]
  1843. if not hits:
  1844. break
  1845. for hit in hits:
  1846. src = hit["_source"]["source_node_id"]
  1847. tgt = hit["_source"]["target_node_id"]
  1848. if src in result:
  1849. result[src].append((src, tgt))
  1850. if tgt in result:
  1851. result[tgt].append((src, tgt))
  1852. search_after = hits[-1]["sort"]
  1853. if len(hits) < 10000:
  1854. break
  1855. finally:
  1856. try:
  1857. await self.client.delete_pit(body={"pit_id": [pit_id]})
  1858. except Exception:
  1859. pass
  1860. except OpenSearchException as e:
  1861. if _is_missing_index_error(e):
  1862. self._mark_indices_missing()
  1863. pass
  1864. return result
  1865. # --- Upsert operations ---
  1866. async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
  1867. """Insert or update a node. Adds entity_id for PPL compatibility."""
  1868. try:
  1869. await self._ensure_indices_ready()
  1870. doc = {k: v for k, v in node_data.items() if k != "_id"}
  1871. doc["entity_id"] = node_id
  1872. if node_data.get("source_id", ""):
  1873. doc["source_ids"] = node_data["source_id"].split(GRAPH_FIELD_SEP)
  1874. # No per-operation refresh: node reads use ID-based mget/exists
  1875. # (translog, real-time). Search visibility after index_done_callback().
  1876. await self.client.index(index=self._nodes_index, id=node_id, body=doc)
  1877. self._nodes_dirty = True
  1878. except OpenSearchException as e:
  1879. logger.error(f"[{self.workspace}] Error upserting node {node_id}: {e}")
  1880. async def upsert_edge(
  1881. self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
  1882. ) -> None:
  1883. """Insert or update an edge with deterministic ID for bidirectional handling."""
  1884. try:
  1885. await self._ensure_indices_ready()
  1886. # Ensure source node exists (don't overwrite if it already has data)
  1887. if not await self.has_node(source_node_id):
  1888. await self.upsert_node(source_node_id, {})
  1889. doc = {k: v for k, v in edge_data.items() if k != "_id"}
  1890. doc["source_node_id"] = source_node_id
  1891. doc["target_node_id"] = target_node_id
  1892. if edge_data.get("source_id", ""):
  1893. doc["source_ids"] = edge_data["source_id"].split(GRAPH_FIELD_SEP)
  1894. # Use a deterministic ID for the edge so upserts work
  1895. edge_id = compute_mdhash_id(
  1896. f"{source_node_id}-{target_node_id}", prefix="edge-"
  1897. )
  1898. # Check if reverse edge exists
  1899. reverse_id = compute_mdhash_id(
  1900. f"{target_node_id}-{source_node_id}", prefix="edge-"
  1901. )
  1902. try:
  1903. if await self.client.exists(index=self._edges_index, id=reverse_id):
  1904. edge_id = reverse_id
  1905. except OpenSearchException:
  1906. pass
  1907. await self.client.index(index=self._edges_index, id=edge_id, body=doc)
  1908. self._edges_dirty = True
  1909. except OpenSearchException as e:
  1910. logger.error(
  1911. f"[{self.workspace}] Error upserting edge {source_node_id}->{target_node_id}: {e}"
  1912. )
  1913. async def upsert_nodes_batch(self, nodes: list[tuple[str, dict[str, str]]]) -> None:
  1914. """Batch insert/update multiple nodes using the OpenSearch bulk API.
  1915. Args:
  1916. nodes: List of (node_id, node_data) tuples.
  1917. """
  1918. if not nodes:
  1919. return
  1920. try:
  1921. await self._ensure_indices_ready()
  1922. actions = []
  1923. for node_id, node_data in nodes:
  1924. doc = {k: v for k, v in node_data.items() if k != "_id"}
  1925. doc["entity_id"] = node_id
  1926. if node_data.get("source_id", ""):
  1927. doc["source_ids"] = node_data["source_id"].split(GRAPH_FIELD_SEP)
  1928. actions.append(
  1929. {
  1930. "_op_type": "index",
  1931. "_index": self._nodes_index,
  1932. "_id": node_id,
  1933. "_source": doc,
  1934. }
  1935. )
  1936. await helpers.async_bulk(self.client, actions)
  1937. self._nodes_dirty = True
  1938. except OpenSearchException as e:
  1939. logger.error(f"[{self.workspace}] Error during batch node upsert: {e}")
  1940. async def has_nodes_batch(self, node_ids: list[str]) -> set[str]:
  1941. """Check existence of multiple nodes using a single mget request.
  1942. Args:
  1943. node_ids: List of node IDs to check.
  1944. Returns:
  1945. Set of node_ids that exist in the graph.
  1946. """
  1947. if not node_ids:
  1948. return set()
  1949. if not self._indices_ready:
  1950. return set()
  1951. try:
  1952. response = await self.client.mget(
  1953. index=self._nodes_index, body={"ids": node_ids}
  1954. )
  1955. return {doc["_id"] for doc in response.get("docs", []) if doc.get("found")}
  1956. except OpenSearchException as e:
  1957. if _is_missing_index_error(e):
  1958. self._mark_indices_missing()
  1959. return set()
  1960. async def upsert_edges_batch(
  1961. self, edges: list[tuple[str, str, dict[str, str]]]
  1962. ) -> None:
  1963. """Batch insert/update multiple edges using the OpenSearch bulk API.
  1964. Replicates the bidirectional edge-ID logic of upsert_edge(): a canonical
  1965. forward ID is used unless a reverse-direction document already exists, in
  1966. which case the reverse ID is used so the update lands on the existing doc.
  1967. The reverse-ID look-up is done in a single mget call before the bulk write.
  1968. Args:
  1969. edges: List of (source_node_id, target_node_id, edge_data) tuples.
  1970. """
  1971. if not edges:
  1972. return
  1973. try:
  1974. await self._ensure_indices_ready()
  1975. # Ensure all source nodes exist (mirrors upsert_edge behaviour)
  1976. source_ids = list({src for src, _tgt, _data in edges})
  1977. existing_sources = await self.has_nodes_batch(source_ids)
  1978. missing_sources = [
  1979. (nid, {}) for nid in source_ids if nid not in existing_sources
  1980. ]
  1981. if missing_sources:
  1982. await self.upsert_nodes_batch(missing_sources)
  1983. # Compute forward and reverse edge IDs, then batch-check which
  1984. # reverse-direction docs already exist (one mget instead of N exists).
  1985. forward_ids = [
  1986. compute_mdhash_id(f"{src}-{tgt}", prefix="edge-")
  1987. for src, tgt, _ in edges
  1988. ]
  1989. reverse_ids = [
  1990. compute_mdhash_id(f"{tgt}-{src}", prefix="edge-")
  1991. for src, tgt, _ in edges
  1992. ]
  1993. try:
  1994. rev_response = await self.client.mget(
  1995. index=self._edges_index, body={"ids": reverse_ids}
  1996. )
  1997. existing_reverse = {
  1998. doc["_id"]
  1999. for doc in rev_response.get("docs", [])
  2000. if doc.get("found")
  2001. }
  2002. except OpenSearchException:
  2003. existing_reverse = set()
  2004. actions = []
  2005. reserved_edge_ids = set(existing_reverse)
  2006. for (src, tgt, edge_data), fwd_id, rev_id in zip(
  2007. edges, forward_ids, reverse_ids
  2008. ):
  2009. edge_id = rev_id if rev_id in reserved_edge_ids else fwd_id
  2010. reserved_edge_ids.add(edge_id)
  2011. doc = {k: v for k, v in edge_data.items() if k != "_id"}
  2012. doc["source_node_id"] = src
  2013. doc["target_node_id"] = tgt
  2014. if edge_data.get("source_id", ""):
  2015. doc["source_ids"] = edge_data["source_id"].split(GRAPH_FIELD_SEP)
  2016. actions.append(
  2017. {
  2018. "_op_type": "index",
  2019. "_index": self._edges_index,
  2020. "_id": edge_id,
  2021. "_source": doc,
  2022. }
  2023. )
  2024. await helpers.async_bulk(self.client, actions)
  2025. self._edges_dirty = True
  2026. except OpenSearchException as e:
  2027. logger.error(f"[{self.workspace}] Error during batch edge upsert: {e}")
  2028. # --- Delete operations ---
  2029. async def delete_node(self, node_id: str) -> None:
  2030. """Delete a node and all its connected edges.
  2031. Marks node and edge search views dirty so refresh happens lazily on the
  2032. next search/count-based graph read. Uses conflicts="proceed" to
  2033. tolerate already-deleted matches.
  2034. """
  2035. try:
  2036. # Refresh edge search view so delete_by_query sees all un-flushed writes.
  2037. await self._refresh_graph_indices_if_dirty(refresh_edges=True)
  2038. # Delete all edges referencing this node
  2039. body = {
  2040. "query": {
  2041. "bool": {
  2042. "should": [
  2043. {"term": {"source_node_id": node_id}},
  2044. {"term": {"target_node_id": node_id}},
  2045. ]
  2046. }
  2047. }
  2048. }
  2049. await self.client.delete_by_query(
  2050. index=self._edges_index,
  2051. body=body,
  2052. params={"conflicts": "proceed"},
  2053. )
  2054. # Delete the node
  2055. try:
  2056. await self.client.delete(index=self._nodes_index, id=node_id)
  2057. except NotFoundError:
  2058. pass
  2059. self._nodes_dirty = True
  2060. self._edges_dirty = True
  2061. except OpenSearchException as e:
  2062. logger.error(f"[{self.workspace}] Error deleting node {node_id}: {e}")
  2063. async def remove_nodes(self, nodes: list[str]) -> None:
  2064. """Batch-delete multiple nodes and their connected edges.
  2065. Marks node and edge search views dirty so refresh happens lazily on the
  2066. next search/count-based graph read. Uses conflicts="proceed" to
  2067. tolerate already-deleted matches.
  2068. """
  2069. if not nodes:
  2070. return
  2071. logger.info(f"[{self.workspace}] Deleting {len(nodes)} nodes")
  2072. try:
  2073. # Refresh edge search view so delete_by_query sees all un-flushed writes.
  2074. await self._refresh_graph_indices_if_dirty(refresh_edges=True)
  2075. # Delete edges
  2076. body = {
  2077. "query": {
  2078. "bool": {
  2079. "should": [
  2080. {"terms": {"source_node_id": nodes}},
  2081. {"terms": {"target_node_id": nodes}},
  2082. ]
  2083. }
  2084. }
  2085. }
  2086. await self.client.delete_by_query(
  2087. index=self._edges_index,
  2088. body=body,
  2089. params={"conflicts": "proceed"},
  2090. )
  2091. # Delete nodes
  2092. actions = [
  2093. {"_op_type": "delete", "_index": self._nodes_index, "_id": nid}
  2094. for nid in nodes
  2095. ]
  2096. await helpers.async_bulk(self.client, actions, raise_on_error=False)
  2097. self._nodes_dirty = True
  2098. self._edges_dirty = True
  2099. except OpenSearchException as e:
  2100. logger.error(f"[{self.workspace}] Error removing nodes: {e}")
  2101. async def remove_edges(self, edges: list[tuple[str, str]]) -> None:
  2102. """Batch-delete multiple edges by deterministic ID (real-time).
  2103. Each edge is stored under one of two candidate IDs:
  2104. forward = compute_mdhash_id("src-tgt", prefix="edge-")
  2105. reverse = compute_mdhash_id("tgt-src", prefix="edge-")
  2106. We delete both candidates for every requested edge so the deletion
  2107. is effective regardless of which direction was stored.
  2108. Marks edge search views dirty so refresh happens lazily on the next
  2109. search/count-based graph read.
  2110. """
  2111. if not edges:
  2112. return
  2113. logger.info(f"[{self.workspace}] Deleting {len(edges)} edges")
  2114. try:
  2115. operations = []
  2116. for src, tgt in edges:
  2117. for edge_id in (
  2118. compute_mdhash_id(f"{src}-{tgt}", prefix="edge-"),
  2119. compute_mdhash_id(f"{tgt}-{src}", prefix="edge-"),
  2120. ):
  2121. operations.append(
  2122. {
  2123. "delete": {
  2124. "_index": self._edges_index,
  2125. "_id": edge_id,
  2126. }
  2127. }
  2128. )
  2129. await self.client.bulk(body=operations)
  2130. self._edges_dirty = True
  2131. except OpenSearchException as e:
  2132. logger.error(f"[{self.workspace}] Error removing edges: {e}")
  2133. # --- Query operations ---
  2134. async def get_all_labels(self) -> list[str]:
  2135. """Get all node IDs (entity names) sorted alphabetically."""
  2136. if not self._indices_ready:
  2137. return []
  2138. try:
  2139. await self._refresh_graph_indices_if_dirty(refresh_nodes=True)
  2140. labels = []
  2141. pit = await self.client.create_pit(
  2142. index=self._nodes_index, params={"keep_alive": "1m"}
  2143. )
  2144. pit_id = pit["pit_id"]
  2145. try:
  2146. search_after = None
  2147. while True:
  2148. body = {
  2149. "query": {"match_all": {}},
  2150. "_source": False,
  2151. "size": 10000,
  2152. "pit": {"id": pit_id, "keep_alive": "1m"},
  2153. "sort": _pit_sort_with_field("entity_id"),
  2154. }
  2155. if search_after:
  2156. body["search_after"] = search_after
  2157. response = await self.client.search(body=body)
  2158. hits = response["hits"]["hits"]
  2159. if not hits:
  2160. break
  2161. for hit in hits:
  2162. labels.append(hit["_id"])
  2163. search_after = hits[-1]["sort"]
  2164. if len(hits) < 10000:
  2165. break
  2166. finally:
  2167. try:
  2168. await self.client.delete_pit(body={"pit_id": [pit_id]})
  2169. except Exception:
  2170. pass
  2171. labels.sort()
  2172. return labels
  2173. except OpenSearchException as e:
  2174. if _is_missing_index_error(e):
  2175. self._mark_indices_missing()
  2176. return []
  2177. async def _collect_node_ids(
  2178. self, limit: int, exclude_ids: set[str] | None = None
  2179. ) -> list[str]:
  2180. """Collect up to `limit` node IDs, optionally skipping known IDs."""
  2181. if limit <= 0:
  2182. return []
  2183. excluded = exclude_ids or set()
  2184. if not excluded and limit <= 10000:
  2185. body = {
  2186. "query": {"match_all": {}},
  2187. "_source": False,
  2188. "size": limit,
  2189. }
  2190. resp = await self.client.search(index=self._nodes_index, body=body)
  2191. return [hit["_id"] for hit in resp["hits"]["hits"]]
  2192. node_ids: list[str] = []
  2193. pit = await self.client.create_pit(
  2194. index=self._nodes_index, params={"keep_alive": "1m"}
  2195. )
  2196. pit_id = pit["pit_id"]
  2197. try:
  2198. search_after = None
  2199. while len(node_ids) < limit:
  2200. body = {
  2201. "query": {"match_all": {}},
  2202. "_source": False,
  2203. "size": 10000,
  2204. "pit": {"id": pit_id, "keep_alive": "1m"},
  2205. "sort": _pit_sort_with_field("entity_id"),
  2206. }
  2207. if search_after:
  2208. body["search_after"] = search_after
  2209. resp = await self.client.search(body=body)
  2210. hits = resp["hits"]["hits"]
  2211. if not hits:
  2212. break
  2213. for hit in hits:
  2214. node_id = hit["_id"]
  2215. if node_id in excluded:
  2216. continue
  2217. node_ids.append(node_id)
  2218. if len(node_ids) >= limit:
  2219. break
  2220. search_after = hits[-1].get("sort")
  2221. if len(hits) < 10000:
  2222. break
  2223. finally:
  2224. try:
  2225. await self.client.delete_pit(body={"pit_id": [pit_id]})
  2226. except Exception:
  2227. pass
  2228. return node_ids
  2229. @staticmethod
  2230. def _edge_rank_key(edge: dict[str, Any]) -> tuple[int, float]:
  2231. """Rank traversal edges by shallower depth first, then higher weight."""
  2232. depth = edge.get("_depth", edge.get("depth", 0))
  2233. try:
  2234. depth_value = int(depth)
  2235. except (TypeError, ValueError):
  2236. depth_value = 0
  2237. weight = edge.get("weight", 0)
  2238. try:
  2239. weight_value = float(weight)
  2240. except (TypeError, ValueError):
  2241. weight_value = 0.0
  2242. return (depth_value, -weight_value)
  2243. async def _append_edges_between_nodes(
  2244. self, node_ids: list[str], result: KnowledgeGraph
  2245. ) -> None:
  2246. """Append all edges whose source and target are both in `node_ids`."""
  2247. if not node_ids:
  2248. return
  2249. edge_query = {
  2250. "bool": {
  2251. "must": [
  2252. {"terms": {"source_node_id": node_ids}},
  2253. {"terms": {"target_node_id": node_ids}},
  2254. ]
  2255. }
  2256. }
  2257. seen_edges = set()
  2258. pit = await self.client.create_pit(
  2259. index=self._edges_index, params={"keep_alive": "1m"}
  2260. )
  2261. pit_id = pit["pit_id"]
  2262. try:
  2263. search_after = None
  2264. while True:
  2265. edge_body = {
  2266. "query": edge_query,
  2267. "size": 10000,
  2268. "pit": {"id": pit_id, "keep_alive": "1m"},
  2269. "sort": _pit_sort_with_composite_key(
  2270. "source_node_id", "target_node_id"
  2271. ),
  2272. }
  2273. if search_after:
  2274. edge_body["search_after"] = search_after
  2275. edge_resp = await self.client.search(body=edge_body)
  2276. hits = edge_resp["hits"]["hits"]
  2277. if not hits:
  2278. break
  2279. for hit in hits:
  2280. e = hit["_source"]
  2281. eid = f"{e['source_node_id']}-{e['target_node_id']}"
  2282. if eid not in seen_edges:
  2283. seen_edges.add(eid)
  2284. result.edges.append(self._construct_graph_edge(eid, e))
  2285. search_after = hits[-1].get("sort")
  2286. if len(hits) < 10000:
  2287. break
  2288. finally:
  2289. try:
  2290. await self.client.delete_pit(body={"pit_id": [pit_id]})
  2291. except Exception:
  2292. pass
  2293. def _construct_graph_node(self, node_id, node_data: dict) -> KnowledgeGraphNode:
  2294. return KnowledgeGraphNode(
  2295. id=node_id,
  2296. labels=[node_id],
  2297. properties={
  2298. k: v
  2299. for k, v in node_data.items()
  2300. if k
  2301. not in (
  2302. "_id",
  2303. "entity_id",
  2304. "source_ids",
  2305. "connected_edges",
  2306. "edge_count",
  2307. )
  2308. },
  2309. )
  2310. def _construct_graph_edge(self, edge_id: str, edge: dict) -> KnowledgeGraphEdge:
  2311. return KnowledgeGraphEdge(
  2312. id=edge_id,
  2313. type=edge.get("relationship", ""),
  2314. source=edge["source_node_id"],
  2315. target=edge["target_node_id"],
  2316. properties={
  2317. k: v
  2318. for k, v in edge.items()
  2319. if k
  2320. not in (
  2321. "_id",
  2322. "source_node_id",
  2323. "target_node_id",
  2324. "relationship",
  2325. "source_ids",
  2326. )
  2327. },
  2328. )
  2329. async def get_knowledge_graph(
  2330. self,
  2331. node_label: str,
  2332. max_depth: int = 3,
  2333. max_nodes: int = None,
  2334. ) -> KnowledgeGraph:
  2335. """Retrieve a subgraph via PPL graphlookup (if available) or client-side BFS."""
  2336. if not self._indices_ready:
  2337. return KnowledgeGraph()
  2338. if max_nodes is None:
  2339. max_nodes = self.global_config.get("max_graph_nodes", 1000)
  2340. else:
  2341. max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
  2342. result = KnowledgeGraph()
  2343. start = time.perf_counter()
  2344. try:
  2345. await self._refresh_graph_indices_if_dirty(
  2346. refresh_nodes=True, refresh_edges=True
  2347. )
  2348. if node_label == "*":
  2349. result = await self._get_knowledge_graph_all(max_nodes)
  2350. elif self._ppl_graphlookup_available:
  2351. result = await self._bfs_subgraph_ppl(node_label, max_depth, max_nodes)
  2352. else:
  2353. result = await self._bfs_subgraph(node_label, max_depth, max_nodes)
  2354. duration = time.perf_counter() - start
  2355. logger.info(
  2356. f"[{self.workspace}] Subgraph query in {duration:.4f}s | "
  2357. f"Nodes: {len(result.nodes)} | Edges: {len(result.edges)} | Truncated: {result.is_truncated}"
  2358. )
  2359. except OpenSearchException as e:
  2360. if _is_missing_index_error(e):
  2361. self._mark_indices_missing()
  2362. return KnowledgeGraph()
  2363. logger.error(f"[{self.workspace}] Graph query failed: {e}")
  2364. return result
  2365. async def _get_knowledge_graph_all(self, max_nodes: int) -> KnowledgeGraph:
  2366. """Get all nodes (up to max_nodes, ranked by degree) and their interconnecting edges."""
  2367. result = KnowledgeGraph()
  2368. if not self._indices_ready:
  2369. return result
  2370. try:
  2371. total = (await self.client.count(index=self._nodes_index))["count"]
  2372. result.is_truncated = total > max_nodes
  2373. if result.is_truncated:
  2374. # Get top nodes by degree
  2375. body = {
  2376. "size": 0,
  2377. "aggs": {
  2378. "src": {
  2379. "terms": {
  2380. "field": "source_node_id",
  2381. "size": max_nodes,
  2382. }
  2383. },
  2384. "tgt": {
  2385. "terms": {
  2386. "field": "target_node_id",
  2387. "size": max_nodes,
  2388. }
  2389. },
  2390. },
  2391. }
  2392. resp = await self.client.search(index=self._edges_index, body=body)
  2393. degree_map = {}
  2394. for bucket in resp["aggregations"]["src"]["buckets"]:
  2395. degree_map[bucket["key"]] = (
  2396. degree_map.get(bucket["key"], 0) + bucket["doc_count"]
  2397. )
  2398. for bucket in resp["aggregations"]["tgt"]["buckets"]:
  2399. degree_map[bucket["key"]] = (
  2400. degree_map.get(bucket["key"], 0) + bucket["doc_count"]
  2401. )
  2402. top_ids = sorted(degree_map, key=degree_map.get, reverse=True)[
  2403. :max_nodes
  2404. ]
  2405. if len(top_ids) < max_nodes:
  2406. top_ids.extend(
  2407. await self._collect_node_ids(
  2408. max_nodes - len(top_ids), exclude_ids=set(top_ids)
  2409. )
  2410. )
  2411. else:
  2412. top_ids = await self._collect_node_ids(max_nodes)
  2413. # Fetch node data
  2414. if top_ids:
  2415. node_resp = await self.client.mget(
  2416. index=self._nodes_index, body={"ids": top_ids}
  2417. )
  2418. found_node_ids = []
  2419. for doc in node_resp["docs"]:
  2420. if doc.get("found"):
  2421. found_node_ids.append(doc["_id"])
  2422. result.nodes.append(
  2423. self._construct_graph_node(doc["_id"], doc["_source"])
  2424. )
  2425. await self._append_edges_between_nodes(found_node_ids, result)
  2426. except OpenSearchException as e:
  2427. if _is_missing_index_error(e):
  2428. self._mark_indices_missing()
  2429. return result
  2430. logger.error(f"[{self.workspace}] Error in get_knowledge_graph_all: {e}")
  2431. return result
  2432. async def _bfs_subgraph_ppl(
  2433. self, start_label: str, max_depth: int, max_nodes: int
  2434. ) -> KnowledgeGraph:
  2435. """Server-side BFS using PPL graphlookup command.
  2436. Queries the nodes index for the start node, then uses graphLookup to traverse
  2437. the edges index with bidirectional BFS. Uses `flatten` to unnest results and
  2438. `depthField` for depth-based sorting. Falls back to client-side BFS on failure.
  2439. """
  2440. result = KnowledgeGraph()
  2441. # Verify start node exists
  2442. start_node = await self.get_node(start_label)
  2443. if not start_node:
  2444. return result
  2445. result.nodes.append(self._construct_graph_node(start_label, start_node))
  2446. if max_depth == 0:
  2447. return result
  2448. # PPL maxDepth=0 means 1 hop (direct match), so max_depth-1
  2449. ppl_depth = max(0, max_depth - 1)
  2450. escaped = self._escape_ppl(start_label)
  2451. ppl_query = (
  2452. f"source = {self._nodes_index}"
  2453. f" | where entity_id = '{escaped}'"
  2454. f" | graphLookup {self._edges_index}"
  2455. f" start=entity_id"
  2456. f" edge=target_node_id<->source_node_id"
  2457. f" maxDepth={ppl_depth}"
  2458. f" depthField=_depth"
  2459. f" usePIT=true"
  2460. f" as connected_edges"
  2461. )
  2462. try:
  2463. resp = await self.client.transport.perform_request(
  2464. "POST",
  2465. "/_plugins/_ppl",
  2466. body={"query": ppl_query},
  2467. )
  2468. except Exception as e:
  2469. logger.warning(
  2470. f"[{self.workspace}] PPL graphlookup failed, falling back to client BFS: {e}"
  2471. )
  2472. return await self._bfs_subgraph(start_label, max_depth, max_nodes)
  2473. # Parse PPL response — schema-driven to avoid fragile positional access
  2474. try:
  2475. datarows = resp.get("datarows", [])
  2476. schema = [col["name"] for col in resp.get("schema", [])]
  2477. ce_idx = (
  2478. schema.index("connected_edges") if "connected_edges" in schema else -1
  2479. )
  2480. # Collect all edge rows from connected_edges arrays
  2481. all_edge_rows = []
  2482. for row in datarows:
  2483. edges_arr = row[ce_idx] if ce_idx >= 0 else []
  2484. if isinstance(edges_arr, list):
  2485. all_edge_rows.extend(edges_arr)
  2486. if not all_edge_rows:
  2487. return result
  2488. if isinstance(all_edge_rows[0], dict):
  2489. sorted_edge_rows = sorted(all_edge_rows, key=self._edge_rank_key)
  2490. else:
  2491. # Positional array — column positions are unknown, fall back to client BFS
  2492. logger.warning(
  2493. f"[{self.workspace}] PPL returned positional arrays, falling back to client BFS"
  2494. )
  2495. return await self._bfs_subgraph(start_label, max_depth, max_nodes)
  2496. except (KeyError, IndexError, TypeError, ValueError) as e:
  2497. logger.warning(
  2498. f"[{self.workspace}] Error parsing PPL response, falling back: {e}"
  2499. )
  2500. return await self._bfs_subgraph(start_label, max_depth, max_nodes)
  2501. ordered_node_ids = [start_label]
  2502. discovered_nodes = {start_label}
  2503. for edge_row in sorted_edge_rows:
  2504. for node_id in (
  2505. edge_row.get("source_node_id"),
  2506. edge_row.get("target_node_id"),
  2507. ):
  2508. if not node_id or node_id in discovered_nodes:
  2509. continue
  2510. discovered_nodes.add(node_id)
  2511. if len(ordered_node_ids) < max_nodes:
  2512. ordered_node_ids.append(node_id)
  2513. result.is_truncated = len(discovered_nodes) > max_nodes
  2514. # Batch fetch node data (start node already added)
  2515. new_node_ids = [nid for nid in ordered_node_ids if nid != start_label]
  2516. if new_node_ids:
  2517. node_resp = await self.client.mget(
  2518. index=self._nodes_index, body={"ids": new_node_ids}
  2519. )
  2520. for doc in node_resp["docs"]:
  2521. if doc.get("found"):
  2522. result.nodes.append(
  2523. self._construct_graph_node(doc["_id"], doc["_source"])
  2524. )
  2525. await self._append_edges_between_nodes(ordered_node_ids, result)
  2526. return result
  2527. @staticmethod
  2528. def _escape_ppl(value: str) -> str:
  2529. """Escape a string for safe inclusion in a PPL single-quoted literal.
  2530. Escapes backslashes, single quotes, and control characters that could
  2531. interfere with PPL query parsing.
  2532. """
  2533. value = value.replace("\\", "\\\\").replace("'", "\\'")
  2534. # Strip control characters that could break the PPL string literal
  2535. value = value.replace("\n", " ").replace("\r", " ").replace("\t", " ")
  2536. return value
  2537. @staticmethod
  2538. def _escape_wildcard(value: str) -> str:
  2539. """Escape OpenSearch wildcard special characters in user input.
  2540. Escapes \\, *, and ? so they are treated as literal characters
  2541. rather than wildcard operators, preventing DoS via expensive patterns.
  2542. """
  2543. # Escape backslash first, then wildcard metacharacters
  2544. return value.replace("\\", "\\\\").replace("*", "\\*").replace("?", "\\?")
  2545. async def _bfs_subgraph(
  2546. self, start_label: str, max_depth: int, max_nodes: int
  2547. ) -> KnowledgeGraph:
  2548. """BFS traversal from a starting node, batching neighbor lookups per level."""
  2549. result = KnowledgeGraph()
  2550. seen_nodes = set()
  2551. # Verify start node exists
  2552. start_node = await self.get_node(start_label)
  2553. if not start_node:
  2554. return result
  2555. seen_nodes.add(start_label)
  2556. result.nodes.append(self._construct_graph_node(start_label, start_node))
  2557. current_level = [start_label]
  2558. for _ in range(max_depth):
  2559. if not current_level or len(seen_nodes) >= max_nodes:
  2560. break
  2561. # Batch fetch all edges for current level
  2562. body = {
  2563. "query": {
  2564. "bool": {
  2565. "should": [
  2566. {"terms": {"source_node_id": current_level}},
  2567. {"terms": {"target_node_id": current_level}},
  2568. ]
  2569. }
  2570. },
  2571. "_source": ["source_node_id", "target_node_id"],
  2572. "size": 10000,
  2573. }
  2574. try:
  2575. resp = await self.client.search(index=self._edges_index, body=body)
  2576. except OpenSearchException:
  2577. break
  2578. next_level = set()
  2579. for hit in resp["hits"]["hits"]:
  2580. src = hit["_source"]["source_node_id"]
  2581. tgt = hit["_source"]["target_node_id"]
  2582. if src not in seen_nodes:
  2583. next_level.add(src)
  2584. if tgt not in seen_nodes:
  2585. next_level.add(tgt)
  2586. # Limit to max_nodes
  2587. new_ids = []
  2588. for nid in next_level:
  2589. if len(seen_nodes) + len(new_ids) >= max_nodes:
  2590. break
  2591. new_ids.append(nid)
  2592. if new_ids:
  2593. # Batch fetch node data
  2594. node_resp = await self.client.mget(
  2595. index=self._nodes_index, body={"ids": new_ids}
  2596. )
  2597. for doc in node_resp["docs"]:
  2598. if doc.get("found"):
  2599. seen_nodes.add(doc["_id"])
  2600. result.nodes.append(
  2601. self._construct_graph_node(doc["_id"], doc["_source"])
  2602. )
  2603. current_level = new_ids
  2604. # Fetch all edges between seen nodes using PIT scrolling
  2605. all_ids = list(seen_nodes)
  2606. if all_ids:
  2607. try:
  2608. await self._append_edges_between_nodes(all_ids, result)
  2609. except OpenSearchException:
  2610. pass
  2611. result.is_truncated = len(seen_nodes) >= max_nodes
  2612. return result
  2613. async def get_all_nodes(self) -> list[dict]:
  2614. """Get all nodes with their properties."""
  2615. if not self._indices_ready:
  2616. return []
  2617. try:
  2618. await self._refresh_graph_indices_if_dirty(refresh_nodes=True)
  2619. nodes = []
  2620. pit = await self.client.create_pit(
  2621. index=self._nodes_index, params={"keep_alive": "1m"}
  2622. )
  2623. pit_id = pit["pit_id"]
  2624. try:
  2625. search_after = None
  2626. while True:
  2627. body = {
  2628. "query": {"match_all": {}},
  2629. "size": 10000,
  2630. "pit": {"id": pit_id, "keep_alive": "1m"},
  2631. "sort": _pit_sort_with_field("entity_id"),
  2632. }
  2633. if search_after:
  2634. body["search_after"] = search_after
  2635. response = await self.client.search(body=body)
  2636. hits = response["hits"]["hits"]
  2637. if not hits:
  2638. break
  2639. for hit in hits:
  2640. node = hit["_source"]
  2641. node["id"] = hit["_id"]
  2642. nodes.append(node)
  2643. search_after = hits[-1]["sort"]
  2644. if len(hits) < 10000:
  2645. break
  2646. finally:
  2647. try:
  2648. await self.client.delete_pit(body={"pit_id": [pit_id]})
  2649. except Exception:
  2650. pass
  2651. return nodes
  2652. except OpenSearchException as e:
  2653. if _is_missing_index_error(e):
  2654. self._mark_indices_missing()
  2655. return []
  2656. async def get_all_edges(self) -> list[dict]:
  2657. """Get all edges with source/target fields added."""
  2658. if not self._indices_ready:
  2659. return []
  2660. try:
  2661. await self._refresh_graph_indices_if_dirty(refresh_edges=True)
  2662. edges = []
  2663. pit = await self.client.create_pit(
  2664. index=self._edges_index, params={"keep_alive": "1m"}
  2665. )
  2666. pit_id = pit["pit_id"]
  2667. try:
  2668. search_after = None
  2669. while True:
  2670. body = {
  2671. "query": {"match_all": {}},
  2672. "size": 10000,
  2673. "pit": {"id": pit_id, "keep_alive": "1m"},
  2674. "sort": _pit_sort_with_composite_key(
  2675. "source_node_id", "target_node_id"
  2676. ),
  2677. }
  2678. if search_after:
  2679. body["search_after"] = search_after
  2680. response = await self.client.search(body=body)
  2681. hits = response["hits"]["hits"]
  2682. if not hits:
  2683. break
  2684. for hit in hits:
  2685. edge = hit["_source"]
  2686. edge["source"] = edge.get("source_node_id")
  2687. edge["target"] = edge.get("target_node_id")
  2688. edges.append(edge)
  2689. search_after = hits[-1]["sort"]
  2690. if len(hits) < 10000:
  2691. break
  2692. finally:
  2693. try:
  2694. await self.client.delete_pit(body={"pit_id": [pit_id]})
  2695. except Exception:
  2696. pass
  2697. return edges
  2698. except OpenSearchException as e:
  2699. if _is_missing_index_error(e):
  2700. self._mark_indices_missing()
  2701. return []
  2702. async def get_popular_labels(self, limit: int = 300) -> list[str]:
  2703. """Get node labels ranked by edge degree (most connected first)."""
  2704. if not self._indices_ready:
  2705. return []
  2706. try:
  2707. await self._refresh_graph_indices_if_dirty(refresh_edges=True)
  2708. body = {
  2709. "size": 0,
  2710. "aggs": {
  2711. "src": {"terms": {"field": "source_node_id", "size": limit * 2}},
  2712. "tgt": {"terms": {"field": "target_node_id", "size": limit * 2}},
  2713. },
  2714. }
  2715. response = await self.client.search(index=self._edges_index, body=body)
  2716. degree_map = {}
  2717. for bucket in response["aggregations"]["src"]["buckets"]:
  2718. degree_map[bucket["key"]] = (
  2719. degree_map.get(bucket["key"], 0) + bucket["doc_count"]
  2720. )
  2721. for bucket in response["aggregations"]["tgt"]["buckets"]:
  2722. degree_map[bucket["key"]] = (
  2723. degree_map.get(bucket["key"], 0) + bucket["doc_count"]
  2724. )
  2725. sorted_labels = sorted(degree_map, key=degree_map.get, reverse=True)[:limit]
  2726. return sorted_labels
  2727. except OpenSearchException as e:
  2728. if _is_missing_index_error(e):
  2729. self._mark_indices_missing()
  2730. return []
  2731. async def search_labels(self, query: str, limit: int = 50) -> list[str]:
  2732. """Search node labels with wildcard and prefix matching."""
  2733. query = query.strip()
  2734. if not query:
  2735. return []
  2736. if not self._indices_ready:
  2737. return []
  2738. try:
  2739. await self._refresh_graph_indices_if_dirty(refresh_nodes=True)
  2740. body = {
  2741. "query": {
  2742. "bool": {
  2743. "should": [
  2744. {"term": {"entity_id": {"value": query, "boost": 10}}},
  2745. {
  2746. "prefix": {
  2747. "entity_id": {"value": query.lower(), "boost": 5}
  2748. }
  2749. },
  2750. {
  2751. "wildcard": {
  2752. "entity_id": {
  2753. "value": f"*{self._escape_wildcard(query.lower())}*",
  2754. "case_insensitive": True,
  2755. "boost": 2,
  2756. }
  2757. }
  2758. },
  2759. ]
  2760. }
  2761. },
  2762. "_source": False,
  2763. "size": limit,
  2764. }
  2765. response = await self.client.search(index=self._nodes_index, body=body)
  2766. return [hit["_id"] for hit in response["hits"]["hits"]]
  2767. except OpenSearchException as e:
  2768. if _is_missing_index_error(e):
  2769. self._mark_indices_missing()
  2770. return []
  2771. async def index_done_callback(self) -> None:
  2772. """Refresh both node and edge indices."""
  2773. if not self._indices_ready:
  2774. return
  2775. try:
  2776. await self._refresh_graph_indices_if_dirty(
  2777. refresh_nodes=True, refresh_edges=True
  2778. )
  2779. except OpenSearchException as e:
  2780. if _is_missing_index_error(e):
  2781. self._mark_indices_missing()
  2782. return
  2783. except Exception:
  2784. pass
  2785. async def drop(self) -> dict[str, str]:
  2786. """Delete both node and edge indices."""
  2787. errors = []
  2788. for idx in (self._nodes_index, self._edges_index):
  2789. try:
  2790. await self.client.indices.delete(index=idx)
  2791. logger.info(f"[{self.workspace}] Dropped graph index: {idx}")
  2792. except NotFoundError:
  2793. logger.info(
  2794. f"[{self.workspace}] Graph index already missing during drop: {idx}"
  2795. )
  2796. except OpenSearchException as e:
  2797. errors.append(f"{idx}: {e}")
  2798. logger.error(
  2799. f"[{self.workspace}] Error dropping graph index {idx}: {e}"
  2800. )
  2801. except Exception as e:
  2802. errors.append(f"{idx}: {e}")
  2803. logger.error(
  2804. f"[{self.workspace}] Unexpected error dropping graph index {idx}: {e}"
  2805. )
  2806. self._mark_indices_missing()
  2807. if errors:
  2808. return {
  2809. "status": "error",
  2810. "message": "Failed to drop graph indices: " + "; ".join(errors),
  2811. }
  2812. try:
  2813. logger.info(f"[{self.workspace}] Dropped graph indices")
  2814. return {"status": "success", "message": "Graph indices dropped"}
  2815. except Exception as e:
  2816. logger.error(f"[{self.workspace}] Error finalizing graph drop: {e}")
  2817. return {"status": "error", "message": str(e)}
  2818. @final
  2819. @dataclass
  2820. class OpenSearchVectorDBStorage(BaseVectorStorage):
  2821. """Vector storage using OpenSearch k-NN plugin with corrected cosine score handling."""
  2822. client: AsyncOpenSearch = field(default=None)
  2823. _index_name: str = field(default="", init=False)
  2824. _index_ready: bool = field(default=False, init=False)
  2825. def __init__(
  2826. self, namespace, global_config, embedding_func, workspace=None, meta_fields=None
  2827. ):
  2828. super().__init__(
  2829. namespace=namespace,
  2830. workspace=workspace or "",
  2831. global_config=global_config,
  2832. embedding_func=embedding_func,
  2833. meta_fields=meta_fields or set(),
  2834. )
  2835. self.__post_init__()
  2836. def __post_init__(self):
  2837. self._validate_embedding_func()
  2838. self.workspace, self.final_namespace, self._index_name = _build_index_name(
  2839. self.workspace, self.namespace
  2840. )
  2841. kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
  2842. cosine_threshold = kwargs.get("cosine_better_than_threshold")
  2843. if cosine_threshold is None:
  2844. raise ValueError(
  2845. "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
  2846. )
  2847. self.cosine_better_than_threshold = cosine_threshold
  2848. self._max_batch_size = self.global_config["embedding_batch_num"]
  2849. # Pending writes are flushed via _flush_pending_vector_ops() during
  2850. # index_done_callback() / finalize(). This batches many small upsert()
  2851. # invocations into a single async_bulk roundtrip. See issue #2785.
  2852. self._pending_vector_docs: dict[str, _PendingVectorDoc] = {}
  2853. self._pending_vector_deletes: set[str] = set()
  2854. # Namespace-keyed lock (multi-process safe) is initialised in
  2855. # initialize(). All buffer reads / writes and any destructive server
  2856. # mutation (delete_by_query, drop, finalize) are serialised through
  2857. # this lock to keep in-process readers race-free during a flush and
  2858. # to order cross-worker flushes against the same OpenSearch index.
  2859. self._flush_lock = None
  2860. async def initialize(self):
  2861. """Initialize client and create k-NN vector index."""
  2862. async with get_data_init_lock():
  2863. if self.client is None:
  2864. self.client = await ClientManager.get_client()
  2865. await self._create_knn_index_if_not_exists()
  2866. self._index_ready = True
  2867. logger.debug(
  2868. f"[{self.workspace}] OpenSearch Vector storage initialized: {self._index_name}"
  2869. )
  2870. if self._flush_lock is None:
  2871. self._flush_lock = get_namespace_lock(
  2872. self.namespace, workspace=self.workspace
  2873. )
  2874. async def _ensure_index_ready(self):
  2875. """Recreate the vector index before the next write if it is missing."""
  2876. if self._index_ready:
  2877. return
  2878. async with get_data_init_lock():
  2879. if self.client is None:
  2880. self.client = await ClientManager.get_client()
  2881. if not self._index_ready:
  2882. await self._create_knn_index_if_not_exists()
  2883. self._index_ready = True
  2884. def _mark_index_missing(self):
  2885. """Mark the vector index as unavailable for subsequent read short-circuiting."""
  2886. self._index_ready = False
  2887. async def _create_knn_index_if_not_exists(self):
  2888. try:
  2889. if await self.client.indices.exists(index=self._index_name):
  2890. # Validate existing index dimension
  2891. try:
  2892. mapping = await self.client.indices.get_mapping(
  2893. index=self._index_name
  2894. )
  2895. existing_dim = (
  2896. mapping[self._index_name]["mappings"]["properties"]
  2897. .get("vector", {})
  2898. .get("dimension")
  2899. )
  2900. expected_dim = self.embedding_func.embedding_dim
  2901. if existing_dim is not None and existing_dim != expected_dim:
  2902. raise ValueError(
  2903. f"Vector dimension mismatch! Index '{self._index_name}' has "
  2904. f"dimension {existing_dim}, but current embedding model expects "
  2905. f"dimension {expected_dim}. Please drop the existing index or "
  2906. f"use an embedding model with matching dimensions."
  2907. )
  2908. except (KeyError, TypeError):
  2909. logger.warning(
  2910. f"[{self.workspace}] Could not read vector mapping for index "
  2911. f"'{self._index_name}'; skipping dimension validation"
  2912. )
  2913. return
  2914. ef_construction = int(
  2915. _get_opensearch_env("OPENSEARCH_KNN_EF_CONSTRUCTION", "200")
  2916. )
  2917. m = int(_get_opensearch_env("OPENSEARCH_KNN_M", "16"))
  2918. ef_search = int(_get_opensearch_env("OPENSEARCH_KNN_EF_SEARCH", "100"))
  2919. body = {
  2920. "settings": {
  2921. "index": {
  2922. "knn": True,
  2923. "knn.algo_param.ef_search": ef_search,
  2924. "number_of_shards": _get_index_number_of_shards(),
  2925. "number_of_replicas": _get_index_number_of_replicas(),
  2926. }
  2927. },
  2928. "mappings": {
  2929. "properties": {
  2930. "vector": {
  2931. "type": "knn_vector",
  2932. "dimension": self.embedding_func.embedding_dim,
  2933. "method": {
  2934. "name": "hnsw",
  2935. "space_type": "cosinesimil",
  2936. "engine": "lucene",
  2937. "parameters": {
  2938. "ef_construction": ef_construction,
  2939. "m": m,
  2940. },
  2941. },
  2942. },
  2943. "content": {"type": "text"},
  2944. "entity_name": {"type": "keyword"},
  2945. "src_id": {"type": "keyword"},
  2946. "tgt_id": {"type": "keyword"},
  2947. "file_path": {"type": "keyword"},
  2948. "created_at": {"type": "long"},
  2949. },
  2950. "dynamic": True,
  2951. },
  2952. }
  2953. await self.client.indices.create(index=self._index_name, body=body)
  2954. logger.info(
  2955. f"[{self.workspace}] Created k-NN index: {self._index_name} "
  2956. f"(dim={self.embedding_func.embedding_dim})"
  2957. )
  2958. except RequestError as e:
  2959. if "resource_already_exists_exception" not in str(e):
  2960. logger.error(f"[{self.workspace}] Error creating k-NN index: {e}")
  2961. raise
  2962. except OpenSearchException as e:
  2963. logger.error(f"[{self.workspace}] Error creating k-NN index: {e}")
  2964. raise
  2965. async def finalize(self):
  2966. """Flush pending writes and release the OpenSearch client connection.
  2967. Regular flush failures (any ``Exception``) are captured so they
  2968. can be re-surfaced as a ``RuntimeError`` that names the unflushed
  2969. buffer counts -- otherwise ``LightRAG.finalize_storages()`` would
  2970. log the storage as successfully finalized while writes silently
  2971. failed to reach OpenSearch.
  2972. ``BaseException`` subclasses other than ``Exception`` (notably
  2973. ``asyncio.CancelledError`` / ``KeyboardInterrupt`` / ``SystemExit``)
  2974. are NOT caught: they propagate through the ``finally`` block so
  2975. shutdown cancellation is honoured and not silently swallowed.
  2976. The client is released in ``finally`` so it does not leak whether
  2977. the flush succeeded, failed, or was cancelled.
  2978. """
  2979. flush_error: Exception | None = None
  2980. try:
  2981. try:
  2982. await self._flush_pending_vector_ops()
  2983. except Exception as e:
  2984. flush_error = e
  2985. finally:
  2986. if self.client is not None:
  2987. await ClientManager.release_client(self.client)
  2988. self.client = None
  2989. pending_docs = len(self._pending_vector_docs)
  2990. pending_deletes = len(self._pending_vector_deletes)
  2991. if flush_error is not None:
  2992. raise RuntimeError(
  2993. f"[{self.workspace}] OpenSearchVectorDBStorage.finalize() "
  2994. f"flush raised; {pending_docs} pending upserts and "
  2995. f"{pending_deletes} pending deletes were left buffered "
  2996. f"(client released, data lost)"
  2997. ) from flush_error
  2998. if pending_docs or pending_deletes:
  2999. raise RuntimeError(
  3000. f"[{self.workspace}] OpenSearchVectorDBStorage.finalize() "
  3001. f"left {pending_docs} pending upserts and {pending_deletes} "
  3002. f"pending deletes buffered after final flush attempt "
  3003. f"(transient bulk failure); these writes have been lost"
  3004. )
  3005. async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
  3006. """Buffer vector docs for embedding and batched flush.
  3007. Docs are buffered in ``self._pending_vector_docs`` and flushed in a
  3008. single ``async_bulk`` call during ``index_done_callback()`` /
  3009. ``finalize()``. This is a behavioral change relative to per-call
  3010. ``async_bulk``: writes are not durable in OpenSearch until the next
  3011. flush, which matches the contract used by other LightRAG storage
  3012. backends ("changes will be persisted during the next
  3013. index_done_callback").
  3014. Embedding is deferred to the flush path so repeated upserts of the
  3015. same id and many small upsert calls can be embedded once in a single
  3016. batch. Flush holds the namespace lock while embedding and bulk
  3017. indexing so cross-worker destructive mutations cannot interleave with
  3018. partially-flushed vector writes.
  3019. """
  3020. if not data:
  3021. return
  3022. await self._ensure_index_ready()
  3023. logger.debug(
  3024. f"[{self.workspace}] Buffering {len(data)} vectors for {self.namespace}"
  3025. )
  3026. current_time = int(time.time())
  3027. pending_docs: list[tuple[str, _PendingVectorDoc]] = []
  3028. for i, (k, v) in enumerate(data.items(), start=1):
  3029. content = v["content"]
  3030. source = {
  3031. "created_at": current_time,
  3032. **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
  3033. }
  3034. pending_docs.append(
  3035. (
  3036. k,
  3037. _PendingVectorDoc(
  3038. source=source,
  3039. content=content,
  3040. ),
  3041. )
  3042. )
  3043. await _cooperative_yield(i)
  3044. # Buffer: an upsert overrides a pending delete on the same id.
  3045. async with self._flush_lock:
  3046. for doc_id, pending_doc in pending_docs:
  3047. self._pending_vector_deletes.discard(doc_id)
  3048. self._pending_vector_docs[doc_id] = pending_doc
  3049. async def _flush_pending_vector_ops(self) -> None:
  3050. """Flush buffered vector upserts and deletes via a single async_bulk call.
  3051. Concurrency contract: the entire flush, including deferred embedding,
  3052. runs under ``_flush_lock`` (a ``get_namespace_lock`` instance), and so
  3053. do all buffer reads / writes and destructive server mutations on this
  3054. storage. That keeps the operation sequential within the process and
  3055. orders concurrent cross-worker flushes against the same OpenSearch
  3056. index.
  3057. Embedding deliberately runs *inside* this lock (not in ``upsert`` or
  3058. lock-free): it makes deferred embedding and bulk indexing atomic
  3059. against concurrent upserts and destructive mutations (``drop`` /
  3060. ``delete_entity_relation``). This is what lets
  3061. ``index_done_callback`` / ``finalize`` promise that every buffered
  3062. vector is embedded and persisted on return. Moving embedding out of
  3063. the lock to avoid blocking reads would let a destructive op
  3064. interleave between embed and bulk and resurrect or drop vectors out
  3065. of order -- do not do it.
  3066. Failure handling:
  3067. * If ``_ensure_index_ready`` raises, the buffers are left intact
  3068. and the next flush will retry.
  3069. * If embedding raises, the buffers are left intact and the next
  3070. flush will retry. Model providers already retry internally, so
  3071. this is treated like a persistence failure.
  3072. * If ``async_bulk`` itself raises (network / parse error), the
  3073. buffers are left intact and the next flush will retry. Index
  3074. ops are idempotent on ``_id`` and a re-issued delete on a
  3075. missing doc is filtered out as 404 by ``_extract_bulk_failed_ids``.
  3076. * Per-doc retryable failures (408 / 429 / 5xx) stay in the
  3077. buffer for the next flush.
  3078. * Per-doc non-retryable failures (most 4xx, mapping errors) are
  3079. cleared from the buffer and logged with a sample of
  3080. (op, id, status, error) so operators can diagnose them.
  3081. """
  3082. async with self._flush_lock:
  3083. if not self._pending_vector_docs and not self._pending_vector_deletes:
  3084. return
  3085. if self.client is None:
  3086. return
  3087. # If the index disappeared between writes (e.g. read path
  3088. # marked it missing), recreate it now. Failure leaves the
  3089. # buffers untouched and bubbles up to the caller.
  3090. await self._ensure_index_ready()
  3091. pending_docs = self._pending_vector_docs
  3092. pending_deletes = self._pending_vector_deletes
  3093. docs_to_embed = [
  3094. (doc_id, pending_doc)
  3095. for doc_id, pending_doc in pending_docs.items()
  3096. if pending_doc.vector is None
  3097. ]
  3098. if docs_to_embed:
  3099. contents = [pending_doc.content for _, pending_doc in docs_to_embed]
  3100. batches = [
  3101. contents[i : i + self._max_batch_size]
  3102. for i in range(0, len(contents), self._max_batch_size)
  3103. ]
  3104. # TEMP diagnostic (remove later): confirm deferred batching is
  3105. # actually coalescing per-id upserts. defer working -> docs >>
  3106. # batches; eager/per-id -> docs == batches == 1 every flush.
  3107. logger.info(
  3108. f"[{self.workspace}] {self.namespace} flush: embedding "
  3109. f"{len(docs_to_embed)} vectors in {len(batches)} batch(es) "
  3110. f"(batch_num={self._max_batch_size})"
  3111. )
  3112. try:
  3113. embeddings_list = await asyncio.gather(
  3114. *[
  3115. self.embedding_func(batch, context="document")
  3116. for batch in batches
  3117. ]
  3118. )
  3119. except Exception as e:
  3120. logger.error(
  3121. f"[{self.workspace}] Error embedding pending vector ops "
  3122. f"(upserts={len(docs_to_embed)}): {e}"
  3123. )
  3124. raise
  3125. embeddings = np.concatenate(embeddings_list)
  3126. # Explicit check (not assert): a count mismatch would silently
  3127. # truncate via zip() under `python -O`, mis-pairing vectors with
  3128. # docs. Raise instead so buffers stay intact for the next flush.
  3129. if len(embeddings) != len(docs_to_embed):
  3130. raise RuntimeError(
  3131. f"[{self.workspace}] Embedding count mismatch: expected "
  3132. f"{len(docs_to_embed)}, got {len(embeddings)}"
  3133. )
  3134. for i, ((_, pending_doc), embedding) in enumerate(
  3135. zip(docs_to_embed, embeddings), start=1
  3136. ):
  3137. pending_doc.vector = embedding.tolist()
  3138. await _cooperative_yield(i)
  3139. actions: list[dict[str, Any]] = []
  3140. for doc_id in pending_deletes:
  3141. actions.append(
  3142. {
  3143. "_op_type": "delete",
  3144. "_index": self._index_name,
  3145. "_id": doc_id,
  3146. }
  3147. )
  3148. committed_doc_ids: set[str] = set()
  3149. for doc_id, pending_doc in pending_docs.items():
  3150. if pending_doc.vector is None:
  3151. continue
  3152. committed_doc_ids.add(doc_id)
  3153. actions.append(
  3154. {
  3155. "_op_type": "index",
  3156. "_index": self._index_name,
  3157. "_id": doc_id,
  3158. "_source": {
  3159. **pending_doc.source,
  3160. "vector": pending_doc.vector,
  3161. },
  3162. }
  3163. )
  3164. if not actions:
  3165. return
  3166. try:
  3167. # No per-operation refresh: search visibility is established
  3168. # by the refresh in index_done_callback().
  3169. _, failed = await helpers.async_bulk(
  3170. self.client, actions, raise_on_error=False
  3171. )
  3172. except OpenSearchException as e:
  3173. logger.error(
  3174. f"[{self.workspace}] Error flushing vector ops "
  3175. f"(upserts={len(pending_docs)}, "
  3176. f"deletes={len(pending_deletes)}): {e}"
  3177. )
  3178. # Bulk did not return per-doc statuses, so keep everything
  3179. # buffered for the next flush.
  3180. raise
  3181. retryable_ids, non_retryable_ops = _extract_bulk_failed_ids(failed)
  3182. # Clear successful and non-retryable entries; keep retryable ones
  3183. # in place for the next flush.
  3184. for doc_id in committed_doc_ids:
  3185. if doc_id not in retryable_ids:
  3186. pending_docs.pop(doc_id, None)
  3187. new_deletes: set[str] = set()
  3188. for doc_id in pending_deletes:
  3189. if doc_id in retryable_ids:
  3190. new_deletes.add(doc_id)
  3191. pending_deletes.clear()
  3192. pending_deletes.update(new_deletes)
  3193. if retryable_ids:
  3194. logger.warning(
  3195. f"[{self.workspace}] {len(retryable_ids)} vector ops will "
  3196. f"retry on the next flush (transient failure)"
  3197. )
  3198. if non_retryable_ops:
  3199. sample = non_retryable_ops[:5]
  3200. sample_text = ", ".join(
  3201. f"{op.op}/{op.doc_id}/status={op.status}/{op.error}"
  3202. for op in sample
  3203. )
  3204. logger.warning(
  3205. f"[{self.workspace}] {len(non_retryable_ops)} vector ops "
  3206. f"failed permanently and were dropped (non-retryable status). "
  3207. f"Sample: {sample_text}"
  3208. )
  3209. async def query(
  3210. self, query: str, top_k: int, query_embedding: list[float] = None
  3211. ) -> list[dict[str, Any]]:
  3212. """k-NN similarity search with cosine score conversion for lucene engine."""
  3213. if not self._index_ready:
  3214. return []
  3215. if query_embedding is not None:
  3216. query_vector = (
  3217. query_embedding.tolist()
  3218. if hasattr(query_embedding, "tolist")
  3219. else list(query_embedding)
  3220. )
  3221. else:
  3222. embedding = await self.embedding_func([query], context="query", _priority=5)
  3223. query_vector = embedding[0].tolist()
  3224. search_body = {
  3225. "size": top_k,
  3226. "query": {"knn": {"vector": {"vector": query_vector, "k": top_k}}},
  3227. "_source": {"excludes": ["vector"]},
  3228. }
  3229. try:
  3230. response = await self.client.search(
  3231. index=self._index_name, body=search_body
  3232. )
  3233. results = []
  3234. for hit in response["hits"]["hits"]:
  3235. # OpenSearch k-NN with lucene engine and cosinesimil space type
  3236. # returns scores that can be used directly as similarity measure.
  3237. score = hit["_score"]
  3238. if score >= self.cosine_better_than_threshold:
  3239. doc = hit["_source"]
  3240. doc["id"] = hit["_id"]
  3241. doc["distance"] = score
  3242. results.append(doc)
  3243. logger.info(
  3244. f"[{self.workspace}] Vector query on {self._index_name}: "
  3245. f"top_k={top_k}, threshold={self.cosine_better_than_threshold}, "
  3246. f"total_hits={len(response['hits']['hits'])}, "
  3247. f"passed_filter={len(results)}, "
  3248. f"score_range=[{min((h['_score'] for h in response['hits']['hits']), default=0):.4f}, "
  3249. f"{max((h['_score'] for h in response['hits']['hits']), default=0):.4f}]"
  3250. )
  3251. return results
  3252. except OpenSearchException as e:
  3253. if _is_missing_index_error(e):
  3254. self._mark_index_missing()
  3255. return []
  3256. logger.error(f"[{self.workspace}] Error querying vectors: {e}")
  3257. return []
  3258. async def index_done_callback(self) -> None:
  3259. """Flush pending vector ops and refresh the index for k-NN visibility.
  3260. Flush runs first so that a previously-missing index gets recreated
  3261. by ``_flush_pending_vector_ops`` (via ``_ensure_index_ready``)
  3262. before any buffered writes are abandoned. The refresh step is
  3263. skipped only when the index is still not ready after the flush
  3264. attempt -- refreshing a half-built index is pointless.
  3265. Durability contract: each call embeds and bulk-indexes the *entire*
  3266. pending buffer in one shot. Deferred embedding runs inside
  3267. ``_flush_pending_vector_ops``'s ``_flush_lock`` section (not in
  3268. ``upsert``) precisely so this callback can guarantee every buffered
  3269. vector is embedded and flushed together; only transient per-doc
  3270. failures stay buffered for the next flush. Do not move embedding
  3271. out of the lock -- see ``_flush_pending_vector_ops`` for why.
  3272. """
  3273. await self._flush_pending_vector_ops()
  3274. if not self._index_ready:
  3275. return
  3276. try:
  3277. await self.client.indices.refresh(index=self._index_name)
  3278. except OpenSearchException as e:
  3279. if _is_missing_index_error(e):
  3280. self._mark_index_missing()
  3281. return
  3282. except Exception:
  3283. pass
  3284. async def get_by_id(self, id: str) -> dict[str, Any] | None:
  3285. """Get a vector document by ID, with read-your-writes against the buffer.
  3286. The ``vector`` field is stripped from the result to match every other
  3287. LightRAG vector backend (see ``NanoVectorDBStorage.get_by_id``).
  3288. Callers that need the embedding itself must use ``get_vectors_by_ids``.
  3289. """
  3290. # Buffer lookups happen under the namespace lock so an in-flight
  3291. # flush is observed as either "completely before" or "completely
  3292. # after" -- never as a snapshot-swapped intermediate state.
  3293. async with self._flush_lock:
  3294. if id in self._pending_vector_deletes:
  3295. return None
  3296. pending = self._pending_vector_docs.get(id)
  3297. if pending is not None:
  3298. # pending.source is built in upsert from created_at + meta_fields
  3299. # and never carries the embedding, so no "vector" strip is needed
  3300. # here (unlike the mget path below, which excludes it server-side).
  3301. doc = dict(pending.source)
  3302. doc["id"] = id
  3303. return doc
  3304. if not self._index_ready:
  3305. return None
  3306. # Network IO outside the lock so mget RTT doesn't block flush.
  3307. try:
  3308. response = await _mget_optional_doc(
  3309. self.client,
  3310. self._index_name,
  3311. id,
  3312. source_excludes=["vector"],
  3313. )
  3314. if response is None:
  3315. return None
  3316. doc = response["_source"]
  3317. doc.pop("vector", None) # defensive in case _source_excludes is ignored
  3318. doc["id"] = response["_id"]
  3319. return doc
  3320. except OpenSearchException as e:
  3321. if _is_missing_index_error(e):
  3322. self._mark_index_missing()
  3323. return None
  3324. logger.error(f"[{self.workspace}] Error getting vector {id}: {e}")
  3325. return None
  3326. async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
  3327. """Get multiple vector documents by IDs (read-your-writes), preserving order.
  3328. The ``vector`` field is stripped from each result; see ``get_by_id``.
  3329. """
  3330. if not ids:
  3331. return []
  3332. buffered: dict[str, dict[str, Any] | None] = {}
  3333. remaining: list[str] = []
  3334. async with self._flush_lock:
  3335. for doc_id in ids:
  3336. if doc_id in self._pending_vector_deletes:
  3337. buffered[doc_id] = None
  3338. continue
  3339. pending = self._pending_vector_docs.get(doc_id)
  3340. if pending is not None:
  3341. # pending.source never carries the embedding; see get_by_id.
  3342. doc = dict(pending.source)
  3343. doc["id"] = doc_id
  3344. buffered[doc_id] = doc
  3345. continue
  3346. remaining.append(doc_id)
  3347. index_ready = self._index_ready
  3348. doc_map: dict[str, dict[str, Any] | None] = {}
  3349. if remaining and index_ready:
  3350. try:
  3351. response = await self.client.mget(
  3352. index=self._index_name,
  3353. body={"ids": remaining},
  3354. _source_excludes=["vector"],
  3355. )
  3356. for doc in response["docs"]:
  3357. if doc.get("found"):
  3358. data = doc["_source"]
  3359. data.pop("vector", None)
  3360. data["id"] = doc["_id"]
  3361. doc_map[doc["_id"]] = data
  3362. except OpenSearchException as e:
  3363. if _is_missing_index_error(e):
  3364. self._mark_index_missing()
  3365. else:
  3366. logger.error(
  3367. f"[{self.workspace}] Error getting vectors by ids: {e}"
  3368. )
  3369. return [
  3370. buffered[doc_id] if doc_id in buffered else doc_map.get(doc_id)
  3371. for doc_id in ids
  3372. ]
  3373. async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
  3374. """Get vector embeddings for given IDs, with read-your-writes."""
  3375. if not ids:
  3376. return {}
  3377. result: dict[str, list[float]] = {}
  3378. remaining: list[str] = []
  3379. async with self._flush_lock:
  3380. docs_to_embed: list[tuple[str, _PendingVectorDoc]] = []
  3381. for doc_id in ids:
  3382. if doc_id in self._pending_vector_deletes:
  3383. continue
  3384. pending = self._pending_vector_docs.get(doc_id)
  3385. if pending is not None:
  3386. if pending.vector is None:
  3387. docs_to_embed.append((doc_id, pending))
  3388. else:
  3389. result[doc_id] = pending.vector
  3390. continue
  3391. remaining.append(doc_id)
  3392. index_ready = self._index_ready
  3393. if docs_to_embed:
  3394. contents = [pending_doc.content for _, pending_doc in docs_to_embed]
  3395. batches = [
  3396. contents[i : i + self._max_batch_size]
  3397. for i in range(0, len(contents), self._max_batch_size)
  3398. ]
  3399. try:
  3400. embeddings_list = await asyncio.gather(
  3401. *[
  3402. self.embedding_func(batch, context="document")
  3403. for batch in batches
  3404. ]
  3405. )
  3406. except Exception as e:
  3407. logger.error(
  3408. f"[{self.workspace}] Error lazily embedding pending vectors "
  3409. f"(upserts={len(docs_to_embed)}): {e}"
  3410. )
  3411. raise
  3412. embeddings = np.concatenate(embeddings_list)
  3413. # Explicit check (not assert): see _flush_pending_vector_ops.
  3414. if len(embeddings) != len(docs_to_embed):
  3415. raise RuntimeError(
  3416. f"[{self.workspace}] Embedding count mismatch: expected "
  3417. f"{len(docs_to_embed)}, got {len(embeddings)}"
  3418. )
  3419. for i, ((doc_id, pending_doc), embedding) in enumerate(
  3420. zip(docs_to_embed, embeddings), start=1
  3421. ):
  3422. pending_doc.vector = embedding.tolist()
  3423. result[doc_id] = pending_doc.vector
  3424. await _cooperative_yield(i)
  3425. if not remaining:
  3426. return result
  3427. if not index_ready:
  3428. return result
  3429. try:
  3430. response = await self.client.mget(
  3431. index=self._index_name,
  3432. body={"ids": remaining},
  3433. _source_includes=["vector"],
  3434. )
  3435. for doc in response["docs"]:
  3436. if doc.get("found") and "vector" in doc.get("_source", {}):
  3437. result[doc["_id"]] = doc["_source"]["vector"]
  3438. return result
  3439. except OpenSearchException as e:
  3440. if _is_missing_index_error(e):
  3441. self._mark_index_missing()
  3442. return result
  3443. logger.error(f"[{self.workspace}] Error getting vectors: {e}")
  3444. return result
  3445. async def delete(self, ids: list[str]) -> None:
  3446. """Buffer vector deletes for batched flush.
  3447. A delete cancels any pending upsert for the same id; the actual bulk
  3448. delete is performed by ``_flush_pending_vector_ops`` during the next
  3449. ``index_done_callback`` / ``finalize`` call.
  3450. """
  3451. if not ids:
  3452. return
  3453. if isinstance(ids, set):
  3454. ids = list(ids)
  3455. async with self._flush_lock:
  3456. for doc_id in ids:
  3457. self._pending_vector_docs.pop(doc_id, None)
  3458. self._pending_vector_deletes.add(doc_id)
  3459. logger.debug(
  3460. f"[{self.workspace}] Buffered delete for {len(ids)} vectors in {self.namespace}"
  3461. )
  3462. async def delete_entity(self, entity_name: str) -> None:
  3463. """Buffer an entity vector delete by computing its hash ID."""
  3464. entity_id = compute_mdhash_id(entity_name, prefix="ent-")
  3465. async with self._flush_lock:
  3466. self._pending_vector_docs.pop(entity_id, None)
  3467. self._pending_vector_deletes.add(entity_id)
  3468. logger.debug(f"[{self.workspace}] Buffered delete for entity {entity_name}")
  3469. async def delete_entity_relation(self, entity_name: str) -> None:
  3470. """Delete all relation vectors where entity appears as src or tgt.
  3471. The whole method runs under ``_flush_lock`` so the ``delete_by_query``
  3472. cannot interleave with an in-flight bulk indexing of a related doc.
  3473. Buffered upserts that match are pruned in-memory; persisted rows are
  3474. removed via the server-side ``delete_by_query``.
  3475. Buffer semantics — post-prune with caller short-circuit contract:
  3476. Matching pending upserts are pruned **only after** the
  3477. server-side ``delete_by_query`` succeeds (or returns the
  3478. equivalent of "index already missing"). On any other server
  3479. failure the exception is re-raised and the pending buffer
  3480. stays intact so a higher-level retry can still observe the
  3481. buffered relation vectors. Correctness relies on the caller
  3482. short-circuiting before ``index_done_callback`` can run;
  3483. ``adelete_by_entity`` in ``utils_graph.py`` honors this.
  3484. Previously this method pre-pruned the buffer and swallowed
  3485. ``OpenSearchException`` into a ``logger.error`` — that
  3486. combination silently dropped both the buffered relation
  3487. vectors and the server-side failure signal, leaving the
  3488. caller's graph + vector store permanently inconsistent.
  3489. """
  3490. def _prune_pending() -> None:
  3491. for doc_id in [
  3492. k
  3493. for k, v in self._pending_vector_docs.items()
  3494. if v.source.get("src_id") == entity_name
  3495. or v.source.get("tgt_id") == entity_name
  3496. ]:
  3497. self._pending_vector_docs.pop(doc_id, None)
  3498. async with self._flush_lock:
  3499. if not self._index_ready:
  3500. # No server state to mutate; buffer prune is the only
  3501. # delete intent we can record.
  3502. _prune_pending()
  3503. return
  3504. body = {
  3505. "query": {
  3506. "bool": {
  3507. "should": [
  3508. {"term": {"src_id": entity_name}},
  3509. {"term": {"tgt_id": entity_name}},
  3510. ]
  3511. }
  3512. }
  3513. }
  3514. try:
  3515. # conflicts="proceed" tolerates stale search view after refresh removal.
  3516. await self.client.delete_by_query(
  3517. index=self._index_name, body=body, params={"conflicts": "proceed"}
  3518. )
  3519. except OpenSearchException as e:
  3520. if _is_missing_index_error(e):
  3521. # Index gone is equivalent to "all rows already
  3522. # deleted" — safe to prune pending and treat as
  3523. # success.
  3524. self._mark_index_missing()
  3525. _prune_pending()
  3526. return
  3527. logger.error(
  3528. f"[{self.workspace}] Error deleting relations for {entity_name}: {e}"
  3529. )
  3530. raise
  3531. # Server-side delete succeeded — safe to prune the pending
  3532. # buffer so subsequent flushes don't re-upsert the deleted
  3533. # relations.
  3534. _prune_pending()
  3535. logger.debug(
  3536. f"[{self.workspace}] Deleted relations for entity {entity_name}"
  3537. )
  3538. async def drop(self) -> dict[str, str]:
  3539. """Delete and recreate the vector index, discarding pending buffers.
  3540. Runs entirely under ``_flush_lock`` so a concurrent flush / upsert
  3541. cannot land writes against an index that is being deleted and
  3542. rebuilt.
  3543. """
  3544. async with self._flush_lock:
  3545. # Pending writes are meaningless once the index is dropped.
  3546. self._pending_vector_docs.clear()
  3547. self._pending_vector_deletes.clear()
  3548. try:
  3549. try:
  3550. await self.client.indices.delete(index=self._index_name)
  3551. logger.info(
  3552. f"[{self.workspace}] Dropped vector index: {self._index_name}"
  3553. )
  3554. except NotFoundError:
  3555. logger.info(
  3556. f"[{self.workspace}] Vector index already missing during drop: {self._index_name}"
  3557. )
  3558. # Recreate the index
  3559. await self._create_knn_index_if_not_exists()
  3560. self._index_ready = True
  3561. logger.info(
  3562. f"[{self.workspace}] Dropped and recreated vector index: {self._index_name}"
  3563. )
  3564. return {
  3565. "status": "success",
  3566. "message": f"Vector index {self._index_name} dropped and recreated",
  3567. }
  3568. except OpenSearchException as e:
  3569. self._mark_index_missing()
  3570. logger.error(f"[{self.workspace}] Error dropping vector index: {e}")
  3571. return {"status": "error", "message": str(e)}
  3572. except Exception as e:
  3573. self._mark_index_missing()
  3574. logger.error(
  3575. f"[{self.workspace}] Unexpected error dropping vector index: {e}"
  3576. )
  3577. return {"status": "error", "message": str(e)}