postgres_impl.py 326 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412441344144415441644174418441944204421442244234424442544264427442844294430443144324433443444354436443744384439444044414442444344444445444644474448444944504451445244534454445544564457445844594460446144624463446444654466446744684469447044714472447344744475447644774478447944804481448244834484448544864487448844894490449144924493449444954496449744984499450045014502450345044505450645074508450945104511451245134514451545164517451845194520452145224523452445254526452745284529453045314532453345344535453645374538453945404541454245434544454545464547454845494550455145524553455445554556455745584559456045614562456345644565456645674568456945704571457245734574457545764577457845794580458145824583458445854586458745884589459045914592459345944595459645974598459946004601460246034604460546064607460846094610461146124613461446154616461746184619462046214622462346244625462646274628462946304631463246334634463546364637463846394640464146424643464446454646464746484649465046514652465346544655465646574658465946604661466246634664466546664667466846694670467146724673467446754676467746784679468046814682468346844685468646874688468946904691469246934694469546964697469846994700470147024703470447054706470747084709471047114712471347144715471647174718471947204721472247234724472547264727472847294730473147324733473447354736473747384739474047414742474347444745474647474748474947504751475247534754475547564757475847594760476147624763476447654766476747684769477047714772477347744775477647774778477947804781478247834784478547864787478847894790479147924793479447954796479747984799480048014802480348044805480648074808480948104811481248134814481548164817481848194820482148224823482448254826482748284829483048314832483348344835483648374838483948404841484248434844484548464847484848494850485148524853485448554856485748584859486048614862486348644865486648674868486948704871487248734874487548764877487848794880488148824883488448854886488748884889489048914892489348944895489648974898489949004901490249034904490549064907490849094910491149124913491449154916491749184919492049214922492349244925492649274928492949304931493249334934493549364937493849394940494149424943494449454946494749484949495049514952495349544955495649574958495949604961496249634964496549664967496849694970497149724973497449754976497749784979498049814982498349844985498649874988498949904991499249934994499549964997499849995000500150025003500450055006500750085009501050115012501350145015501650175018501950205021502250235024502550265027502850295030503150325033503450355036503750385039504050415042504350445045504650475048504950505051505250535054505550565057505850595060506150625063506450655066506750685069507050715072507350745075507650775078507950805081508250835084508550865087508850895090509150925093509450955096509750985099510051015102510351045105510651075108510951105111511251135114511551165117511851195120512151225123512451255126512751285129513051315132513351345135513651375138513951405141514251435144514551465147514851495150515151525153515451555156515751585159516051615162516351645165516651675168516951705171517251735174517551765177517851795180518151825183518451855186518751885189519051915192519351945195519651975198519952005201520252035204520552065207520852095210521152125213521452155216521752185219522052215222522352245225522652275228522952305231523252335234523552365237523852395240524152425243524452455246524752485249525052515252525352545255525652575258525952605261526252635264526552665267526852695270527152725273527452755276527752785279528052815282528352845285528652875288528952905291529252935294529552965297529852995300530153025303530453055306530753085309531053115312531353145315531653175318531953205321532253235324532553265327532853295330533153325333533453355336533753385339534053415342534353445345534653475348534953505351535253535354535553565357535853595360536153625363536453655366536753685369537053715372537353745375537653775378537953805381538253835384538553865387538853895390539153925393539453955396539753985399540054015402540354045405540654075408540954105411541254135414541554165417541854195420542154225423542454255426542754285429543054315432543354345435543654375438543954405441544254435444544554465447544854495450545154525453545454555456545754585459546054615462546354645465546654675468546954705471547254735474547554765477547854795480548154825483548454855486548754885489549054915492549354945495549654975498549955005501550255035504550555065507550855095510551155125513551455155516551755185519552055215522552355245525552655275528552955305531553255335534553555365537553855395540554155425543554455455546554755485549555055515552555355545555555655575558555955605561556255635564556555665567556855695570557155725573557455755576557755785579558055815582558355845585558655875588558955905591559255935594559555965597559855995600560156025603560456055606560756085609561056115612561356145615561656175618561956205621562256235624562556265627562856295630563156325633563456355636563756385639564056415642564356445645564656475648564956505651565256535654565556565657565856595660566156625663566456655666566756685669567056715672567356745675567656775678567956805681568256835684568556865687568856895690569156925693569456955696569756985699570057015702570357045705570657075708570957105711571257135714571557165717571857195720572157225723572457255726572757285729573057315732573357345735573657375738573957405741574257435744574557465747574857495750575157525753575457555756575757585759576057615762576357645765576657675768576957705771577257735774577557765777577857795780578157825783578457855786578757885789579057915792579357945795579657975798579958005801580258035804580558065807580858095810581158125813581458155816581758185819582058215822582358245825582658275828582958305831583258335834583558365837583858395840584158425843584458455846584758485849585058515852585358545855585658575858585958605861586258635864586558665867586858695870587158725873587458755876587758785879588058815882588358845885588658875888588958905891589258935894589558965897589858995900590159025903590459055906590759085909591059115912591359145915591659175918591959205921592259235924592559265927592859295930593159325933593459355936593759385939594059415942594359445945594659475948594959505951595259535954595559565957595859595960596159625963596459655966596759685969597059715972597359745975597659775978597959805981598259835984598559865987598859895990599159925993599459955996599759985999600060016002600360046005600660076008600960106011601260136014601560166017601860196020602160226023602460256026602760286029603060316032603360346035603660376038603960406041604260436044604560466047604860496050605160526053605460556056605760586059606060616062606360646065606660676068606960706071607260736074607560766077607860796080608160826083608460856086608760886089609060916092609360946095609660976098609961006101610261036104610561066107610861096110611161126113611461156116611761186119612061216122612361246125612661276128612961306131613261336134613561366137613861396140614161426143614461456146614761486149615061516152615361546155615661576158615961606161616261636164616561666167616861696170617161726173617461756176617761786179618061816182618361846185618661876188618961906191619261936194619561966197619861996200620162026203620462056206620762086209621062116212621362146215621662176218621962206221622262236224622562266227622862296230623162326233623462356236623762386239624062416242624362446245624662476248624962506251625262536254625562566257625862596260626162626263626462656266626762686269627062716272627362746275627662776278627962806281628262836284628562866287628862896290629162926293629462956296629762986299630063016302630363046305630663076308630963106311631263136314631563166317631863196320632163226323632463256326632763286329633063316332633363346335633663376338633963406341634263436344634563466347634863496350635163526353635463556356635763586359636063616362636363646365636663676368636963706371637263736374637563766377637863796380638163826383638463856386638763886389639063916392639363946395639663976398639964006401640264036404640564066407640864096410641164126413641464156416641764186419642064216422642364246425642664276428642964306431643264336434643564366437643864396440644164426443644464456446644764486449645064516452645364546455645664576458645964606461646264636464646564666467646864696470647164726473647464756476647764786479648064816482648364846485648664876488648964906491649264936494649564966497649864996500650165026503650465056506650765086509651065116512651365146515651665176518651965206521652265236524652565266527652865296530653165326533653465356536653765386539654065416542654365446545654665476548654965506551655265536554655565566557655865596560656165626563656465656566656765686569657065716572657365746575657665776578657965806581658265836584658565866587658865896590659165926593659465956596659765986599660066016602660366046605660666076608660966106611661266136614661566166617661866196620662166226623662466256626662766286629663066316632663366346635663666376638663966406641664266436644664566466647664866496650665166526653665466556656665766586659666066616662666366646665666666676668666966706671667266736674667566766677667866796680668166826683668466856686668766886689669066916692669366946695669666976698669967006701670267036704670567066707670867096710671167126713671467156716671767186719672067216722672367246725672667276728672967306731673267336734673567366737673867396740674167426743674467456746674767486749675067516752675367546755675667576758675967606761676267636764676567666767676867696770677167726773677467756776677767786779678067816782678367846785678667876788678967906791679267936794679567966797679867996800680168026803680468056806680768086809681068116812681368146815681668176818681968206821682268236824682568266827682868296830683168326833683468356836683768386839684068416842684368446845684668476848684968506851685268536854685568566857685868596860686168626863686468656866686768686869687068716872687368746875687668776878687968806881688268836884688568866887688868896890689168926893689468956896689768986899690069016902690369046905690669076908690969106911691269136914691569166917691869196920692169226923692469256926692769286929693069316932693369346935693669376938693969406941694269436944694569466947694869496950695169526953695469556956695769586959696069616962696369646965696669676968696969706971697269736974697569766977697869796980698169826983698469856986698769886989699069916992699369946995699669976998699970007001700270037004700570067007700870097010701170127013701470157016701770187019702070217022702370247025702670277028702970307031703270337034703570367037703870397040704170427043704470457046704770487049705070517052705370547055705670577058705970607061706270637064706570667067706870697070707170727073707470757076707770787079708070817082708370847085708670877088708970907091709270937094709570967097709870997100710171027103710471057106710771087109711071117112711371147115711671177118711971207121712271237124712571267127712871297130713171327133713471357136713771387139714071417142714371447145714671477148714971507151715271537154715571567157715871597160716171627163716471657166716771687169717071717172717371747175717671777178717971807181718271837184718571867187718871897190719171927193719471957196719771987199720072017202720372047205720672077208720972107211721272137214721572167217721872197220722172227223722472257226722772287229723072317232723372347235723672377238723972407241724272437244724572467247724872497250725172527253725472557256725772587259726072617262726372647265726672677268726972707271727272737274727572767277727872797280728172827283728472857286728772887289729072917292729372947295729672977298729973007301730273037304730573067307730873097310731173127313731473157316731773187319732073217322732373247325732673277328732973307331733273337334733573367337733873397340734173427343734473457346734773487349735073517352735373547355735673577358735973607361736273637364736573667367736873697370737173727373737473757376737773787379738073817382738373847385738673877388738973907391739273937394739573967397739873997400740174027403740474057406740774087409741074117412741374147415741674177418741974207421742274237424742574267427742874297430743174327433743474357436743774387439744074417442744374447445744674477448744974507451745274537454745574567457745874597460746174627463746474657466746774687469747074717472747374747475747674777478747974807481748274837484748574867487748874897490749174927493749474957496749774987499750075017502750375047505750675077508750975107511751275137514751575167517751875197520752175227523752475257526752775287529753075317532753375347535753675377538753975407541754275437544754575467547754875497550755175527553755475557556755775587559756075617562756375647565756675677568756975707571757275737574757575767577757875797580758175827583758475857586758775887589759075917592759375947595759675977598759976007601760276037604760576067607760876097610761176127613761476157616761776187619762076217622762376247625762676277628762976307631763276337634763576367637763876397640764176427643764476457646764776487649765076517652765376547655765676577658765976607661766276637664766576667667766876697670767176727673767476757676767776787679768076817682768376847685768676877688768976907691
  1. import asyncio
  2. import time
  3. import hashlib
  4. import json
  5. import os
  6. import re
  7. import datetime
  8. from datetime import timezone
  9. from dataclasses import dataclass, field
  10. from typing import Any, Awaitable, Callable, TypeVar, Union, final
  11. import numpy as np
  12. import configparser
  13. import ssl
  14. import itertools
  15. from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
  16. from tenacity import (
  17. AsyncRetrying,
  18. RetryCallState,
  19. retry,
  20. retry_if_exception,
  21. retry_if_exception_type,
  22. stop_after_attempt,
  23. wait_exponential,
  24. wait_fixed,
  25. )
  26. from ..base import (
  27. BaseGraphStorage,
  28. BaseKVStorage,
  29. BaseVectorStorage,
  30. DocProcessingStatus,
  31. DocStatus,
  32. DocStatusStorage,
  33. )
  34. from ..exceptions import DataMigrationError
  35. from ..namespace import NameSpace, is_namespace
  36. from ..utils import (
  37. logger,
  38. compute_mdhash_id,
  39. _cooperative_yield,
  40. performance_timing_log,
  41. )
  42. from ..kg.shared_storage import get_data_init_lock, get_namespace_lock
  43. import pipmaster as pm
  44. if not pm.is_installed("asyncpg"):
  45. pm.install("asyncpg")
  46. if not pm.is_installed("pgvector"):
  47. pm.install("pgvector")
  48. import asyncpg # type: ignore
  49. from asyncpg import Pool # type: ignore
  50. from pgvector.asyncpg import register_vector # type: ignore
  51. from dotenv import load_dotenv
  52. # use the .env that is inside the current folder
  53. # allows to use different .env file for each lightrag instance
  54. # the OS environment variables take precedence over the .env file
  55. load_dotenv(dotenv_path=".env", override=False)
  56. T = TypeVar("T")
  57. # PostgreSQL identifier length limit (in bytes)
  58. PG_MAX_IDENTIFIER_LENGTH = 63
  59. # All known vector index suffixes, used to drop conflicting indexes when switching types
  60. _VECTOR_INDEX_SUFFIXES = [
  61. "hnsw_cosine",
  62. "hnsw_halfvec_cosine",
  63. "ivfflat_cosine",
  64. "vchordrq_cosine",
  65. ]
  66. def _safe_index_name(table_name: str, index_suffix: str) -> str:
  67. """
  68. Generate a PostgreSQL-safe index name that won't be truncated.
  69. PostgreSQL silently truncates identifiers to 63 bytes. This function
  70. ensures index names stay within that limit by hashing long table names.
  71. Args:
  72. table_name: The table name (may be long with model suffix)
  73. index_suffix: The index type suffix (e.g., 'hnsw_cosine', 'id', 'workspace_id')
  74. Returns:
  75. A deterministic index name that fits within 63 bytes
  76. """
  77. # Construct the full index name
  78. full_name = f"idx_{table_name.lower()}_{index_suffix}"
  79. # If it fits within the limit, use it as-is
  80. if len(full_name.encode("utf-8")) <= PG_MAX_IDENTIFIER_LENGTH:
  81. return full_name
  82. # Otherwise, hash the table name to create a shorter unique identifier
  83. # Keep 'idx_' prefix and suffix readable, hash the middle
  84. hash_input = table_name.lower().encode("utf-8")
  85. table_hash = hashlib.md5(hash_input).hexdigest()[:12] # 12 hex chars
  86. # Format: idx_{hash}_{suffix} - guaranteed to fit
  87. # Maximum: idx_ (4) + hash (12) + _ (1) + suffix (variable) = 17 + suffix
  88. shortened_name = f"idx_{table_hash}_{index_suffix}"
  89. return shortened_name
  90. def _timing_details_suffix(**details: Any) -> str:
  91. parts = [f"{key}={value}" for key, value in details.items()]
  92. return f" {' '.join(parts)}" if parts else ""
  93. def _dollar_quote(s: str, tag_prefix: str = "AGE") -> str:
  94. """
  95. Generate a PostgreSQL dollar-quoted string with a unique tag.
  96. PostgreSQL dollar-quoting uses $tag$ as delimiters. If the content contains
  97. the same delimiter (e.g., $$ or $AGE1$), it will break the query.
  98. This function finds a unique tag that doesn't conflict with the content.
  99. Args:
  100. s: The string to quote
  101. tag_prefix: Prefix for generating unique tags (default: "AGE")
  102. Returns:
  103. The dollar-quoted string with a unique tag, e.g., $AGE1$content$AGE1$
  104. Example:
  105. >>> _dollar_quote("hello")
  106. '$AGE1$hello$AGE1$'
  107. >>> _dollar_quote("$AGE1$ test")
  108. '$AGE2$$AGE1$ test$AGE2$'
  109. >>> _dollar_quote("$$$") # Content with dollar signs
  110. '$AGE1$$$$AGE1$'
  111. """
  112. s = "" if s is None else str(s)
  113. for i in itertools.count(1):
  114. tag = f"{tag_prefix}{i}"
  115. wrapper = f"${tag}$"
  116. if wrapper not in s:
  117. return f"{wrapper}{s}{wrapper}"
  118. class PostgreSQLDB:
  119. def __init__(self, config: dict[str, Any], **kwargs: Any):
  120. self.host = config["host"]
  121. self.port = config["port"]
  122. self.user = config["user"]
  123. self.password = config["password"]
  124. self.database = config["database"]
  125. self.workspace = config["workspace"]
  126. self.max = int(config["max_connections"])
  127. self.increment = 1
  128. self.pool: Pool | None = None
  129. # SSL configuration
  130. self.ssl_mode = config.get("ssl_mode")
  131. self.ssl_cert = config.get("ssl_cert")
  132. self.ssl_key = config.get("ssl_key")
  133. self.ssl_root_cert = config.get("ssl_root_cert")
  134. self.ssl_crl = config.get("ssl_crl")
  135. # Vector configuration
  136. _ev = config.get("enable_vector", True)
  137. self.enable_vector = (
  138. _ev
  139. if isinstance(_ev, bool)
  140. else str(_ev).lower() in ("true", "1", "yes", "on")
  141. ) # True for backward compatibility, can be set to False to disable vector features
  142. self.vector_index_type = config.get("vector_index_type")
  143. self.hnsw_m = config.get("hnsw_m")
  144. self.hnsw_ef = config.get("hnsw_ef")
  145. self.ivfflat_lists = config.get("ivfflat_lists")
  146. self.vchordrq_build_options = config.get("vchordrq_build_options")
  147. self.vchordrq_probes = config.get("vchordrq_probes")
  148. self.vchordrq_epsilon = config.get("vchordrq_epsilon")
  149. # Server settings
  150. self.server_settings = config.get("server_settings")
  151. # Statement LRU cache size (keep as-is, allow None for optional configuration)
  152. self.statement_cache_size = config.get("statement_cache_size")
  153. if self.user is None or self.password is None or self.database is None:
  154. raise ValueError("Missing database user, password, or database")
  155. # Guard concurrent pool resets
  156. self._pool_reconnect_lock = asyncio.Lock()
  157. self._transient_exceptions = (
  158. asyncio.TimeoutError,
  159. TimeoutError,
  160. ConnectionError,
  161. OSError,
  162. asyncpg.exceptions.InterfaceError,
  163. asyncpg.exceptions.TooManyConnectionsError,
  164. asyncpg.exceptions.CannotConnectNowError,
  165. asyncpg.exceptions.PostgresConnectionError,
  166. asyncpg.exceptions.ConnectionDoesNotExistError,
  167. asyncpg.exceptions.ConnectionFailureError,
  168. )
  169. # Connection retry configuration
  170. self.connection_retry_attempts = config["connection_retry_attempts"]
  171. self.connection_retry_backoff = config["connection_retry_backoff"]
  172. self.connection_retry_backoff_max = max(
  173. self.connection_retry_backoff,
  174. config["connection_retry_backoff_max"],
  175. )
  176. self.pool_close_timeout = config["pool_close_timeout"]
  177. logger.info(
  178. "PostgreSQL, Retry config: attempts=%s, backoff=%.1fs, backoff_max=%.1fs, pool_close_timeout=%.1fs",
  179. self.connection_retry_attempts,
  180. self.connection_retry_backoff,
  181. self.connection_retry_backoff_max,
  182. self.pool_close_timeout,
  183. )
  184. def _create_ssl_context(self) -> ssl.SSLContext | None:
  185. """Create SSL context based on configuration parameters."""
  186. if not self.ssl_mode:
  187. return None
  188. ssl_mode = self.ssl_mode.lower()
  189. # For simple modes that don't require custom context
  190. if ssl_mode in ["disable", "allow", "prefer", "require"]:
  191. if ssl_mode == "disable":
  192. return None
  193. elif ssl_mode in ["require", "prefer", "allow"]:
  194. # Return None for simple SSL requirement, handled in initdb
  195. return None
  196. # For modes that require certificate verification
  197. if ssl_mode in ["verify-ca", "verify-full"]:
  198. try:
  199. context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
  200. # Configure certificate verification
  201. if ssl_mode == "verify-ca":
  202. context.check_hostname = False
  203. elif ssl_mode == "verify-full":
  204. context.check_hostname = True
  205. # Load root certificate if provided
  206. if self.ssl_root_cert:
  207. if os.path.exists(self.ssl_root_cert):
  208. context.load_verify_locations(cafile=self.ssl_root_cert)
  209. logger.info(
  210. f"PostgreSQL, Loaded SSL root certificate: {self.ssl_root_cert}"
  211. )
  212. else:
  213. logger.warning(
  214. f"PostgreSQL, SSL root certificate file not found: {self.ssl_root_cert}"
  215. )
  216. # Load client certificate and key if provided
  217. if self.ssl_cert and self.ssl_key:
  218. if os.path.exists(self.ssl_cert) and os.path.exists(self.ssl_key):
  219. context.load_cert_chain(self.ssl_cert, self.ssl_key)
  220. logger.info(
  221. f"PostgreSQL, Loaded SSL client certificate: {self.ssl_cert}"
  222. )
  223. else:
  224. logger.warning(
  225. "PostgreSQL, SSL client certificate or key file not found"
  226. )
  227. # Load certificate revocation list if provided
  228. if self.ssl_crl:
  229. if os.path.exists(self.ssl_crl):
  230. context.load_verify_locations(crlfile=self.ssl_crl)
  231. logger.info(f"PostgreSQL, Loaded SSL CRL: {self.ssl_crl}")
  232. else:
  233. logger.warning(
  234. f"PostgreSQL, SSL CRL file not found: {self.ssl_crl}"
  235. )
  236. return context
  237. except Exception as e:
  238. logger.error(f"PostgreSQL, Failed to create SSL context: {e}")
  239. raise ValueError(f"SSL configuration error: {e}")
  240. # Unknown SSL mode
  241. logger.warning(f"PostgreSQL, Unknown SSL mode: {ssl_mode}, SSL disabled")
  242. return None
  243. async def initdb(self):
  244. # Prepare connection parameters
  245. connection_params = {
  246. "user": self.user,
  247. "password": self.password,
  248. "database": self.database,
  249. "host": self.host,
  250. "port": self.port,
  251. "min_size": 1,
  252. "max_size": self.max,
  253. }
  254. # Only add statement_cache_size if it's configured
  255. if self.statement_cache_size is not None:
  256. connection_params["statement_cache_size"] = int(self.statement_cache_size)
  257. logger.info(
  258. f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}"
  259. )
  260. # Add SSL configuration if provided
  261. ssl_context = self._create_ssl_context()
  262. if ssl_context is not None:
  263. connection_params["ssl"] = ssl_context
  264. logger.info("PostgreSQL, SSL configuration applied")
  265. elif self.ssl_mode:
  266. # Handle simple SSL modes without custom context
  267. if self.ssl_mode.lower() in ["require", "prefer"]:
  268. connection_params["ssl"] = True
  269. elif self.ssl_mode.lower() == "disable":
  270. connection_params["ssl"] = False
  271. logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
  272. # Add server settings if provided
  273. if self.server_settings:
  274. try:
  275. settings = {}
  276. # The format is expected to be a query string, e.g., "key1=value1&key2=value2"
  277. pairs = self.server_settings.split("&")
  278. for pair in pairs:
  279. if "=" in pair:
  280. key, value = pair.split("=", 1)
  281. settings[key] = value
  282. if settings:
  283. connection_params["server_settings"] = settings
  284. logger.info(f"PostgreSQL, Server settings applied: {settings}")
  285. except Exception as e:
  286. logger.warning(
  287. f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}"
  288. )
  289. wait_strategy = (
  290. wait_exponential(
  291. multiplier=self.connection_retry_backoff,
  292. min=self.connection_retry_backoff,
  293. max=self.connection_retry_backoff_max,
  294. )
  295. if self.connection_retry_backoff > 0
  296. else wait_fixed(0)
  297. )
  298. async def _init_connection(connection: asyncpg.Connection) -> None:
  299. """Initialize each new connection with pgvector codec and VCHORDRQ session params.
  300. Called once per physical connection creation (not on pool reuse).
  301. register_vector is a Python-level codec registration that survives
  302. asyncpg's RESET ALL; VCHORDRQ GUCs do not — they are re-applied in
  303. _reset_connection after each pool release.
  304. """
  305. if self.enable_vector:
  306. await register_vector(connection)
  307. if self.enable_vector and self.vector_index_type == "VCHORDRQ":
  308. await self.configure_vchordrq(connection)
  309. async def _reset_connection(connection: asyncpg.Connection) -> None:
  310. """Run the default asyncpg cleanup, then re-apply VCHORDRQ session GUCs.
  311. When a custom reset= callback is registered with create_pool(), asyncpg
  312. calls Connection._reset() (private — clears listeners and rolls back open
  313. transactions if any) and then this function. It does NOT call the public
  314. Connection.reset(), which is the method that calls _reset() and then
  315. executes the cleanup query returned by get_reset_query() — the exact SQL
  316. depends on detected server capabilities and typically includes
  317. pg_advisory_unlock_all(), CLOSE ALL, UNLISTEN *, and RESET ALL.
  318. We must therefore run that cleanup ourselves via get_reset_query() before
  319. restoring VCHORDRQ GUCs. Skipping this step leaks session state across
  320. pool checkouts — for example configure_age() sets search_path and that
  321. modified path would persist into the next non-AGE connection checkout.
  322. register_vector is NOT repeated here: it is a Python-side encoder/decoder
  323. registration on the asyncpg Connection object and is unaffected by RESET ALL.
  324. Note that set_type_codec() clears the statement cache, which is naturally
  325. repopulated on subsequent queries.
  326. """
  327. try:
  328. # Run the default cleanup that asyncpg would otherwise handle.
  329. reset_query = connection.get_reset_query()
  330. if reset_query:
  331. await connection.execute(reset_query)
  332. except Exception as e:
  333. logger.error(
  334. f"[{self.workspace}] Pool reset cleanup query failed — connection "
  335. f"will be terminated and removed from pool: {e}"
  336. )
  337. raise
  338. # RESET ALL clears session GUCs; restore VCHORDRQ values afterward.
  339. if self.enable_vector and self.vector_index_type == "VCHORDRQ":
  340. try:
  341. await self.configure_vchordrq(connection)
  342. except asyncpg.exceptions.UndefinedObjectError:
  343. logger.error(
  344. f"[{self.workspace}] VCHORDRQ extension is not installed. "
  345. "Install the extension or set vector_index_type to a supported value. "
  346. "Connection will be terminated and removed from pool."
  347. )
  348. raise
  349. except asyncpg.exceptions.InvalidParameterValueError as e:
  350. logger.error(
  351. f"[{self.workspace}] Invalid VCHORDRQ GUC parameter — "
  352. f"check vchordrq_probes and vchordrq_epsilon config. "
  353. f"Connection will be terminated: {e}"
  354. )
  355. raise
  356. except Exception as e:
  357. logger.error(
  358. f"[{self.workspace}] VCHORDRQ session configuration failed "
  359. f"after pool reset — connection will be terminated: {e}"
  360. )
  361. raise
  362. async def _create_pool_once() -> None:
  363. # STEP 1: Bootstrap - ensure vector extension exists BEFORE pool creation.
  364. # On a fresh database, register_vector() in _init_connection will fail
  365. # if the vector extension doesn't exist yet, because the 'vector' type
  366. # won't be found in pg_catalog. We must create the extension first
  367. # using a standalone bootstrap connection.
  368. # Skip this step if vector support is not enabled.
  369. if self.enable_vector:
  370. bootstrap_conn = await asyncpg.connect(
  371. user=self.user,
  372. password=self.password,
  373. database=self.database,
  374. host=self.host,
  375. port=self.port,
  376. ssl=connection_params.get("ssl"),
  377. )
  378. try:
  379. await self.configure_vector_extension(bootstrap_conn)
  380. finally:
  381. await bootstrap_conn.close()
  382. # STEP 2: Now safe to create pool with register_vector callback.
  383. # The vector extension is guaranteed to exist at this point (if enabled).
  384. pool = await asyncpg.create_pool(
  385. **connection_params,
  386. init=_init_connection, # register pgvector codec on new connections
  387. reset=_reset_connection, # re-apply VCHORDRQ GUCs after RESET ALL
  388. ) # type: ignore
  389. self.pool = pool
  390. try:
  391. async for attempt in AsyncRetrying(
  392. stop=stop_after_attempt(self.connection_retry_attempts),
  393. retry=retry_if_exception_type(self._transient_exceptions),
  394. wait=wait_strategy,
  395. before_sleep=self._before_sleep,
  396. reraise=True,
  397. ):
  398. with attempt:
  399. await _create_pool_once()
  400. ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL"
  401. logger.info(
  402. f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database} {ssl_status}"
  403. )
  404. except Exception as e:
  405. logger.error(
  406. f"PostgreSQL, Failed to connect database at {self.host}:{self.port}/{self.database}, Got:{e}"
  407. )
  408. raise
  409. async def _ensure_pool(self) -> None:
  410. """Ensure the connection pool is initialised."""
  411. if self.pool is None:
  412. async with self._pool_reconnect_lock:
  413. if self.pool is None:
  414. await self.initdb()
  415. async def _reset_pool(self) -> None:
  416. async with self._pool_reconnect_lock:
  417. if self.pool is not None:
  418. try:
  419. await asyncio.wait_for(
  420. self.pool.close(), timeout=self.pool_close_timeout
  421. )
  422. except asyncio.TimeoutError:
  423. logger.error(
  424. "PostgreSQL, Timed out closing connection pool after %.2fs",
  425. self.pool_close_timeout,
  426. )
  427. except Exception as close_error: # pragma: no cover - defensive logging
  428. logger.warning(
  429. f"PostgreSQL, Failed to close existing connection pool cleanly: {close_error!r}"
  430. )
  431. self.pool = None
  432. async def _before_sleep(self, retry_state: RetryCallState) -> None:
  433. """Hook invoked by tenacity before sleeping between retries."""
  434. exc = retry_state.outcome.exception() if retry_state.outcome else None
  435. logger.warning(
  436. "PostgreSQL transient connection issue on attempt %s/%s: %r",
  437. retry_state.attempt_number,
  438. self.connection_retry_attempts,
  439. exc,
  440. )
  441. await self._reset_pool()
  442. async def _run_with_retry(
  443. self,
  444. operation: Callable[[asyncpg.Connection], Awaitable[T]],
  445. *,
  446. with_age: bool = False,
  447. graph_name: str | None = None,
  448. timing_label: str | None = None,
  449. ) -> T:
  450. """
  451. Execute a database operation with automatic retry for transient failures.
  452. Args:
  453. operation: Async callable that receives an active connection.
  454. with_age: Whether to configure Apache AGE on the connection.
  455. graph_name: AGE graph name; required when with_age is True.
  456. Returns:
  457. The result returned by the operation.
  458. Raises:
  459. Exception: Propagates the last error if all retry attempts fail or a non-transient error occurs.
  460. """
  461. wait_strategy = (
  462. wait_exponential(
  463. multiplier=self.connection_retry_backoff,
  464. min=self.connection_retry_backoff,
  465. max=self.connection_retry_backoff_max,
  466. )
  467. if self.connection_retry_backoff > 0
  468. else wait_fixed(0)
  469. )
  470. async for attempt in AsyncRetrying(
  471. stop=stop_after_attempt(self.connection_retry_attempts),
  472. retry=retry_if_exception_type(self._transient_exceptions),
  473. wait=wait_strategy,
  474. before_sleep=self._before_sleep,
  475. reraise=True,
  476. ):
  477. with attempt:
  478. await self._ensure_pool()
  479. assert self.pool is not None
  480. if timing_label:
  481. pool_snapshot_before = self._get_pool_snapshot()
  482. performance_timing_log(
  483. "[%s] pool.acquire waiting %s",
  484. timing_label,
  485. pool_snapshot_before,
  486. )
  487. acquire_start = time.perf_counter()
  488. async with self.pool.acquire() as connection: # type: ignore[arg-type]
  489. acquire_elapsed = time.perf_counter() - acquire_start
  490. if timing_label:
  491. pool_snapshot_after = self._get_pool_snapshot()
  492. performance_timing_log(
  493. "[%s] pool.acquire completed in %.4fs %s",
  494. timing_label,
  495. acquire_elapsed,
  496. pool_snapshot_after,
  497. )
  498. if with_age and graph_name:
  499. await self.configure_age(connection, graph_name)
  500. elif with_age and not graph_name:
  501. raise ValueError("Graph name is required when with_age is True")
  502. return await operation(connection)
  503. def _get_pool_snapshot(self) -> str:
  504. """Best-effort snapshot of asyncpg pool state for diagnostics.
  505. Uses asyncpg private attributes defensively; if a field is unavailable in the
  506. installed asyncpg version, return '?' for that metric instead of failing.
  507. """
  508. pool = self.pool
  509. if pool is None:
  510. return "pool_state=uninitialized"
  511. holders = getattr(pool, "_holders", None)
  512. queue = getattr(pool, "_queue", None)
  513. max_size = getattr(pool, "_maxsize", None)
  514. min_size = getattr(pool, "_minsize", None)
  515. total_holders = len(holders) if holders is not None else "?"
  516. idle_count: int | str = "?"
  517. acquired_count: int | str = "?"
  518. if holders is not None:
  519. idle_count = 0
  520. acquired_count = 0
  521. for holder in holders:
  522. # asyncpg holder uses _in_use Future/Event-like marker; treat present value as acquired
  523. in_use_marker = getattr(holder, "_in_use", None)
  524. if in_use_marker:
  525. acquired_count += 1
  526. else:
  527. idle_count += 1
  528. waiting_count: int | str = "?"
  529. if queue is not None:
  530. getters = getattr(queue, "_getters", None)
  531. if getters is not None:
  532. waiting_count = len(getters)
  533. return (
  534. f"pool_state[min={min_size}, max={max_size}, holders={total_holders}, "
  535. f"acquired={acquired_count}, idle={idle_count}, waiting={waiting_count}]"
  536. )
  537. async def configure_vector_extension(self, connection: asyncpg.Connection) -> None:
  538. """Create VECTOR extension if it doesn't exist for vector similarity operations.
  539. When vector_index_type is HNSW_HALFVEC, validates that pgvector >= 0.7.0
  540. (required for halfvec support) and raises RuntimeError if older.
  541. """
  542. try:
  543. await connection.execute("CREATE EXTENSION IF NOT EXISTS vector") # type: ignore
  544. logger.info("PostgreSQL, VECTOR extension enabled")
  545. except Exception as e:
  546. logger.warning(f"Could not create VECTOR extension: {e}")
  547. # Don't raise - let the system continue without vector extension
  548. return
  549. if getattr(self, "vector_index_type", None) == "HNSW_HALFVEC":
  550. row = await connection.fetchrow(
  551. "SELECT extversion FROM pg_extension WHERE extname = 'vector'"
  552. )
  553. if not row or not row["extversion"]:
  554. raise RuntimeError(
  555. "POSTGRES_VECTOR_INDEX_TYPE=HNSW_HALFVEC requires the pgvector "
  556. "extension. Ensure it is installed and CREATE EXTENSION vector succeeded."
  557. )
  558. raw_version = row["extversion"]
  559. try:
  560. parts = [int(p) for p in str(raw_version).split(".")[:3]]
  561. while len(parts) < 3:
  562. parts.append(0)
  563. version_tuple = (parts[0], parts[1], parts[2])
  564. except (ValueError, IndexError):
  565. raise RuntimeError(
  566. f"Could not parse pgvector version {raw_version!r}. "
  567. "HNSW_HALFVEC requires pgvector >= 0.7.0."
  568. ) from None
  569. if version_tuple < (0, 7, 0):
  570. raise RuntimeError(
  571. f"POSTGRES_VECTOR_INDEX_TYPE=HNSW_HALFVEC requires pgvector >= 0.7.0, "
  572. f"but installed version is {raw_version}. Upgrade the pgvector extension "
  573. "or use a different index type (e.g. HNSW with embeddings <= 2000 dimensions)."
  574. )
  575. @staticmethod
  576. async def configure_age_extension(connection: asyncpg.Connection) -> None:
  577. """Create AGE extension if it doesn't exist for graph operations."""
  578. try:
  579. await connection.execute("CREATE EXTENSION IF NOT EXISTS AGE CASCADE") # type: ignore
  580. logger.info("PostgreSQL, AGE extension enabled")
  581. except Exception as e:
  582. logger.warning(f"Could not create AGE extension: {e}")
  583. # Don't raise - let the system continue without AGE extension
  584. @staticmethod
  585. async def configure_age(connection: asyncpg.Connection, graph_name: str) -> None:
  586. """Set the Apache AGE environment and creates a graph if it does not exist.
  587. This method:
  588. - Sets the PostgreSQL `search_path` to include `ag_catalog`, ensuring that Apache AGE functions can be used without specifying the schema.
  589. - Attempts to create a new graph with the provided `graph_name` if it does not already exist.
  590. - Silently ignores errors related to the graph already existing.
  591. """
  592. try:
  593. await connection.execute( # type: ignore
  594. 'SET search_path = ag_catalog, "$user", public'
  595. )
  596. await connection.execute( # type: ignore
  597. f"select create_graph('{graph_name}')"
  598. )
  599. except (
  600. asyncpg.exceptions.InvalidSchemaNameError,
  601. asyncpg.exceptions.UniqueViolationError,
  602. ):
  603. pass
  604. async def configure_vchordrq(self, connection: asyncpg.Connection) -> None:
  605. """Configure VCHORDRQ extension for vector similarity search.
  606. Raises:
  607. asyncpg.exceptions.UndefinedObjectError: If VCHORDRQ extension is not installed
  608. asyncpg.exceptions.InvalidParameterValueError: If parameter value is invalid
  609. Note:
  610. This method does not catch exceptions. Configuration errors will fail-fast,
  611. while transient connection errors will be retried by _run_with_retry.
  612. """
  613. # Handle probes parameter - only set if non-empty value is provided
  614. if self.vchordrq_probes and str(self.vchordrq_probes).strip():
  615. await connection.execute(f"SET vchordrq.probes TO '{self.vchordrq_probes}'")
  616. logger.debug(f"PostgreSQL, VCHORDRQ probes set to: {self.vchordrq_probes}")
  617. # Handle epsilon parameter independently - check for None to allow 0.0 as valid value
  618. if self.vchordrq_epsilon is not None:
  619. await connection.execute(f"SET vchordrq.epsilon TO {self.vchordrq_epsilon}")
  620. logger.debug(
  621. f"PostgreSQL, VCHORDRQ epsilon set to: {self.vchordrq_epsilon}"
  622. )
  623. async def _migrate_llm_cache_schema(self):
  624. """Migrate LLM cache schema: add new columns and remove deprecated mode field"""
  625. try:
  626. # Check if all columns exist
  627. check_columns_sql = """
  628. SELECT column_name
  629. FROM information_schema.columns
  630. WHERE table_name = 'lightrag_llm_cache'
  631. AND column_name IN ('chunk_id', 'cache_type', 'queryparam', 'mode')
  632. """
  633. existing_columns = await self.query(check_columns_sql, multirows=True)
  634. existing_column_names = (
  635. {col["column_name"] for col in existing_columns}
  636. if existing_columns
  637. else set()
  638. )
  639. # Add missing chunk_id column
  640. if "chunk_id" not in existing_column_names:
  641. logger.info("Adding chunk_id column to LIGHTRAG_LLM_CACHE table")
  642. add_chunk_id_sql = """
  643. ALTER TABLE LIGHTRAG_LLM_CACHE
  644. ADD COLUMN chunk_id VARCHAR(255) NULL
  645. """
  646. await self.execute(add_chunk_id_sql)
  647. logger.info(
  648. "Successfully added chunk_id column to LIGHTRAG_LLM_CACHE table"
  649. )
  650. else:
  651. logger.info(
  652. "chunk_id column already exists in LIGHTRAG_LLM_CACHE table"
  653. )
  654. # Add missing cache_type column
  655. if "cache_type" not in existing_column_names:
  656. logger.info("Adding cache_type column to LIGHTRAG_LLM_CACHE table")
  657. add_cache_type_sql = """
  658. ALTER TABLE LIGHTRAG_LLM_CACHE
  659. ADD COLUMN cache_type VARCHAR(32) NULL
  660. """
  661. await self.execute(add_cache_type_sql)
  662. logger.info(
  663. "Successfully added cache_type column to LIGHTRAG_LLM_CACHE table"
  664. )
  665. # Migrate existing data using optimized regex pattern
  666. logger.info(
  667. "Migrating existing LLM cache data to populate cache_type field (optimized)"
  668. )
  669. optimized_update_sql = """
  670. UPDATE LIGHTRAG_LLM_CACHE
  671. SET cache_type = CASE
  672. WHEN id ~ '^[^:]+:[^:]+:' THEN split_part(id, ':', 2)
  673. ELSE 'extract'
  674. END
  675. WHERE cache_type IS NULL
  676. """
  677. await self.execute(optimized_update_sql)
  678. logger.info("Successfully migrated existing LLM cache data")
  679. else:
  680. logger.info(
  681. "cache_type column already exists in LIGHTRAG_LLM_CACHE table"
  682. )
  683. # Add missing queryparam column
  684. if "queryparam" not in existing_column_names:
  685. logger.info("Adding queryparam column to LIGHTRAG_LLM_CACHE table")
  686. add_queryparam_sql = """
  687. ALTER TABLE LIGHTRAG_LLM_CACHE
  688. ADD COLUMN queryparam JSONB NULL
  689. """
  690. await self.execute(add_queryparam_sql)
  691. logger.info(
  692. "Successfully added queryparam column to LIGHTRAG_LLM_CACHE table"
  693. )
  694. else:
  695. logger.info(
  696. "queryparam column already exists in LIGHTRAG_LLM_CACHE table"
  697. )
  698. # Remove deprecated mode field if it exists
  699. if "mode" in existing_column_names:
  700. logger.info(
  701. "Removing deprecated mode column from LIGHTRAG_LLM_CACHE table"
  702. )
  703. # First, drop the primary key constraint that includes mode
  704. drop_pk_sql = """
  705. ALTER TABLE LIGHTRAG_LLM_CACHE
  706. DROP CONSTRAINT IF EXISTS LIGHTRAG_LLM_CACHE_PK
  707. """
  708. await self.execute(drop_pk_sql)
  709. logger.info("Dropped old primary key constraint")
  710. # Drop the mode column
  711. drop_mode_sql = """
  712. ALTER TABLE LIGHTRAG_LLM_CACHE
  713. DROP COLUMN mode
  714. """
  715. await self.execute(drop_mode_sql)
  716. logger.info(
  717. "Successfully removed mode column from LIGHTRAG_LLM_CACHE table"
  718. )
  719. # Create new primary key constraint without mode
  720. add_pk_sql = """
  721. ALTER TABLE LIGHTRAG_LLM_CACHE
  722. ADD CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, id)
  723. """
  724. await self.execute(add_pk_sql)
  725. logger.info("Created new primary key constraint (workspace, id)")
  726. else:
  727. logger.info("mode column does not exist in LIGHTRAG_LLM_CACHE table")
  728. except Exception as e:
  729. logger.warning(f"Failed to migrate LLM cache schema: {e}")
  730. async def _migrate_timestamp_columns(self):
  731. """Migrate timestamp columns in tables to witimezone-free types, assuming original data is in UTC time"""
  732. # Tables and columns that need migration
  733. tables_to_migrate = {
  734. "LIGHTRAG_VDB_ENTITY": ["create_time", "update_time"],
  735. "LIGHTRAG_VDB_RELATION": ["create_time", "update_time"],
  736. "LIGHTRAG_DOC_CHUNKS": ["create_time", "update_time"],
  737. "LIGHTRAG_DOC_STATUS": ["created_at", "updated_at"],
  738. }
  739. try:
  740. # Filter out tables that don't exist (e.g., legacy vector tables may not exist)
  741. existing_tables = {}
  742. for table_name, columns in tables_to_migrate.items():
  743. if await self.check_table_exists(table_name):
  744. existing_tables[table_name] = columns
  745. else:
  746. logger.debug(
  747. f"Table {table_name} does not exist, skipping timestamp migration"
  748. )
  749. # Skip if no tables to migrate
  750. if not existing_tables:
  751. logger.debug("No tables found for timestamp migration")
  752. return
  753. # Use filtered tables for migration
  754. tables_to_migrate = existing_tables
  755. # Optimization: Batch check all columns in one query instead of 8 separate queries
  756. table_names_lower = [t.lower() for t in tables_to_migrate.keys()]
  757. all_column_names = list(
  758. set(col for cols in tables_to_migrate.values() for col in cols)
  759. )
  760. check_all_columns_sql = """
  761. SELECT table_name, column_name, data_type
  762. FROM information_schema.columns
  763. WHERE table_name = ANY($1)
  764. AND column_name = ANY($2)
  765. """
  766. all_columns_result = await self.query(
  767. check_all_columns_sql,
  768. [table_names_lower, all_column_names],
  769. multirows=True,
  770. )
  771. # Build lookup dict: (table_name, column_name) -> data_type
  772. column_types = {}
  773. if all_columns_result:
  774. column_types = {
  775. (row["table_name"].upper(), row["column_name"]): row["data_type"]
  776. for row in all_columns_result
  777. }
  778. # Now iterate and migrate only what's needed
  779. for table_name, columns in tables_to_migrate.items():
  780. for column_name in columns:
  781. try:
  782. data_type = column_types.get((table_name, column_name))
  783. if not data_type:
  784. logger.warning(
  785. f"Column {table_name}.{column_name} does not exist, skipping migration"
  786. )
  787. continue
  788. # Check column type
  789. if data_type == "timestamp without time zone":
  790. logger.debug(
  791. f"Column {table_name}.{column_name} is already witimezone-free, no migration needed"
  792. )
  793. continue
  794. # Execute migration, explicitly specifying UTC timezone for interpreting original data
  795. logger.info(
  796. f"Migrating {table_name}.{column_name} from {data_type} to TIMESTAMP(0) type"
  797. )
  798. migration_sql = f"""
  799. ALTER TABLE {table_name}
  800. ALTER COLUMN {column_name} TYPE TIMESTAMP(0),
  801. ALTER COLUMN {column_name} SET DEFAULT CURRENT_TIMESTAMP
  802. """
  803. await self.execute(migration_sql)
  804. logger.info(
  805. f"Successfully migrated {table_name}.{column_name} to timezone-free type"
  806. )
  807. except Exception as e:
  808. # Log error but don't interrupt the process
  809. logger.warning(
  810. f"Failed to migrate {table_name}.{column_name}: {e}"
  811. )
  812. except Exception as e:
  813. logger.error(f"Failed to batch check timestamp columns: {e}")
  814. async def _migrate_doc_chunks_to_vdb_chunks(self):
  815. """
  816. Migrate data from LIGHTRAG_DOC_CHUNKS to LIGHTRAG_VDB_CHUNKS if specific conditions are met.
  817. This migration is intended for users who are upgrading and have an older table structure
  818. where LIGHTRAG_DOC_CHUNKS contained a `content_vector` column.
  819. """
  820. try:
  821. # 0. Check if both tables exist before proceeding
  822. vdb_chunks_exists = await self.check_table_exists("LIGHTRAG_VDB_CHUNKS")
  823. doc_chunks_exists = await self.check_table_exists("LIGHTRAG_DOC_CHUNKS")
  824. if not vdb_chunks_exists:
  825. logger.debug(
  826. "Skipping migration: LIGHTRAG_VDB_CHUNKS table does not exist"
  827. )
  828. return
  829. if not doc_chunks_exists:
  830. logger.debug(
  831. "Skipping migration: LIGHTRAG_DOC_CHUNKS table does not exist"
  832. )
  833. return
  834. # 1. Check if the new table LIGHTRAG_VDB_CHUNKS is empty
  835. vdb_chunks_count_sql = "SELECT COUNT(1) as count FROM LIGHTRAG_VDB_CHUNKS"
  836. vdb_chunks_count_result = await self.query(vdb_chunks_count_sql)
  837. if vdb_chunks_count_result and vdb_chunks_count_result["count"] > 0:
  838. logger.info(
  839. "Skipping migration: LIGHTRAG_VDB_CHUNKS already contains data."
  840. )
  841. return
  842. # 2. Check if `content_vector` column exists in the old table
  843. check_column_sql = """
  844. SELECT 1 FROM information_schema.columns
  845. WHERE table_name = 'lightrag_doc_chunks' AND column_name = 'content_vector'
  846. """
  847. column_exists = await self.query(check_column_sql)
  848. if not column_exists:
  849. logger.info(
  850. "Skipping migration: `content_vector` not found in LIGHTRAG_DOC_CHUNKS"
  851. )
  852. return
  853. # 3. Check if the old table LIGHTRAG_DOC_CHUNKS has data
  854. doc_chunks_count_sql = "SELECT COUNT(1) as count FROM LIGHTRAG_DOC_CHUNKS"
  855. doc_chunks_count_result = await self.query(doc_chunks_count_sql)
  856. if not doc_chunks_count_result or doc_chunks_count_result["count"] == 0:
  857. logger.info("Skipping migration: LIGHTRAG_DOC_CHUNKS is empty.")
  858. return
  859. # 4. Perform the migration
  860. logger.info(
  861. "Starting data migration from LIGHTRAG_DOC_CHUNKS to LIGHTRAG_VDB_CHUNKS..."
  862. )
  863. migration_sql = """
  864. INSERT INTO LIGHTRAG_VDB_CHUNKS (
  865. id, workspace, full_doc_id, chunk_order_index, tokens, content,
  866. content_vector, file_path, create_time, update_time
  867. )
  868. SELECT
  869. id, workspace, full_doc_id, chunk_order_index, tokens, content,
  870. content_vector, file_path, create_time, update_time
  871. FROM LIGHTRAG_DOC_CHUNKS
  872. ON CONFLICT (workspace, id) DO NOTHING;
  873. """
  874. await self.execute(migration_sql)
  875. logger.info("Data migration to LIGHTRAG_VDB_CHUNKS completed successfully.")
  876. except Exception as e:
  877. logger.error(f"Failed during data migration to LIGHTRAG_VDB_CHUNKS: {e}")
  878. # Do not re-raise, to allow the application to start
  879. async def _check_llm_cache_needs_migration(self):
  880. """Check if LLM cache data needs migration by examining any record with old format"""
  881. try:
  882. # Optimized query: directly check for old format records without sorting
  883. check_sql = """
  884. SELECT 1 FROM LIGHTRAG_LLM_CACHE
  885. WHERE id NOT LIKE '%:%'
  886. LIMIT 1
  887. """
  888. result = await self.query(check_sql)
  889. # If any old format record exists, migration is needed
  890. return result is not None
  891. except Exception as e:
  892. logger.warning(f"Failed to check LLM cache migration status: {e}")
  893. return False
  894. async def _migrate_llm_cache_to_flattened_keys(self):
  895. """Optimized version: directly execute single UPDATE migration to migrate old format cache keys to flattened format"""
  896. try:
  897. # Check if migration is needed
  898. check_sql = """
  899. SELECT COUNT(*) as count FROM LIGHTRAG_LLM_CACHE
  900. WHERE id NOT LIKE '%:%'
  901. """
  902. result = await self.query(check_sql)
  903. if not result or result["count"] == 0:
  904. logger.info("No old format LLM cache data found, skipping migration")
  905. return
  906. old_count = result["count"]
  907. logger.info(f"Found {old_count} old format cache records")
  908. # Check potential primary key conflicts (optional but recommended)
  909. conflict_check_sql = """
  910. WITH new_ids AS (
  911. SELECT
  912. workspace,
  913. mode,
  914. id as old_id,
  915. mode || ':' ||
  916. CASE WHEN mode = 'default' THEN 'extract' ELSE 'unknown' END || ':' ||
  917. md5(original_prompt) as new_id
  918. FROM LIGHTRAG_LLM_CACHE
  919. WHERE id NOT LIKE '%:%'
  920. )
  921. SELECT COUNT(*) as conflicts
  922. FROM new_ids n1
  923. JOIN LIGHTRAG_LLM_CACHE existing
  924. ON existing.workspace = n1.workspace
  925. AND existing.mode = n1.mode
  926. AND existing.id = n1.new_id
  927. WHERE existing.id LIKE '%:%' -- Only check conflicts with existing new format records
  928. """
  929. conflict_result = await self.query(conflict_check_sql)
  930. if conflict_result and conflict_result["conflicts"] > 0:
  931. logger.warning(
  932. f"Found {conflict_result['conflicts']} potential ID conflicts with existing records"
  933. )
  934. # Can choose to continue or abort, here we choose to continue and log warning
  935. # Execute single UPDATE migration
  936. logger.info("Starting optimized LLM cache migration...")
  937. migration_sql = """
  938. UPDATE LIGHTRAG_LLM_CACHE
  939. SET
  940. id = mode || ':' ||
  941. CASE WHEN mode = 'default' THEN 'extract' ELSE 'unknown' END || ':' ||
  942. md5(original_prompt),
  943. cache_type = CASE WHEN mode = 'default' THEN 'extract' ELSE 'unknown' END,
  944. update_time = CURRENT_TIMESTAMP
  945. WHERE id NOT LIKE '%:%'
  946. """
  947. # Execute migration
  948. await self.execute(migration_sql)
  949. # Verify migration results
  950. verify_sql = """
  951. SELECT COUNT(*) as remaining_old FROM LIGHTRAG_LLM_CACHE
  952. WHERE id NOT LIKE '%:%'
  953. """
  954. verify_result = await self.query(verify_sql)
  955. remaining = verify_result["remaining_old"] if verify_result else -1
  956. if remaining == 0:
  957. logger.info(
  958. f"✅ Successfully migrated {old_count} LLM cache records to flattened format"
  959. )
  960. else:
  961. logger.warning(
  962. f"⚠️ Migration completed but {remaining} old format records remain"
  963. )
  964. except Exception as e:
  965. logger.error(f"Optimized LLM cache migration failed: {e}")
  966. raise
  967. async def _migrate_doc_status_add_chunks_list(self):
  968. """Add chunks_list column to LIGHTRAG_DOC_STATUS table if it doesn't exist"""
  969. try:
  970. # Check if chunks_list column exists
  971. check_column_sql = """
  972. SELECT column_name
  973. FROM information_schema.columns
  974. WHERE table_name = 'lightrag_doc_status'
  975. AND column_name = 'chunks_list'
  976. """
  977. column_info = await self.query(check_column_sql)
  978. if not column_info:
  979. logger.info("Adding chunks_list column to LIGHTRAG_DOC_STATUS table")
  980. add_column_sql = """
  981. ALTER TABLE LIGHTRAG_DOC_STATUS
  982. ADD COLUMN chunks_list JSONB NULL DEFAULT '[]'::jsonb
  983. """
  984. await self.execute(add_column_sql)
  985. logger.info(
  986. "Successfully added chunks_list column to LIGHTRAG_DOC_STATUS table"
  987. )
  988. else:
  989. logger.info(
  990. "chunks_list column already exists in LIGHTRAG_DOC_STATUS table"
  991. )
  992. except Exception as e:
  993. logger.warning(
  994. f"Failed to add chunks_list column to LIGHTRAG_DOC_STATUS: {e}"
  995. )
  996. async def _migrate_text_chunks_add_llm_cache_list(self):
  997. """Add llm_cache_list column to LIGHTRAG_DOC_CHUNKS table if it doesn't exist"""
  998. try:
  999. # Check if llm_cache_list column exists
  1000. check_column_sql = """
  1001. SELECT column_name
  1002. FROM information_schema.columns
  1003. WHERE table_name = 'lightrag_doc_chunks'
  1004. AND column_name = 'llm_cache_list'
  1005. """
  1006. column_info = await self.query(check_column_sql)
  1007. if not column_info:
  1008. logger.info("Adding llm_cache_list column to LIGHTRAG_DOC_CHUNKS table")
  1009. add_column_sql = """
  1010. ALTER TABLE LIGHTRAG_DOC_CHUNKS
  1011. ADD COLUMN llm_cache_list JSONB NULL DEFAULT '[]'::jsonb
  1012. """
  1013. await self.execute(add_column_sql)
  1014. logger.info(
  1015. "Successfully added llm_cache_list column to LIGHTRAG_DOC_CHUNKS table"
  1016. )
  1017. else:
  1018. logger.info(
  1019. "llm_cache_list column already exists in LIGHTRAG_DOC_CHUNKS table"
  1020. )
  1021. except Exception as e:
  1022. logger.warning(
  1023. f"Failed to add llm_cache_list column to LIGHTRAG_DOC_CHUNKS: {e}"
  1024. )
  1025. async def _migrate_doc_status_add_track_id(self):
  1026. """Add track_id column to LIGHTRAG_DOC_STATUS table if it doesn't exist and create index"""
  1027. try:
  1028. # Check if track_id column exists
  1029. check_column_sql = """
  1030. SELECT column_name
  1031. FROM information_schema.columns
  1032. WHERE table_name = 'lightrag_doc_status'
  1033. AND column_name = 'track_id'
  1034. """
  1035. column_info = await self.query(check_column_sql)
  1036. if not column_info:
  1037. logger.info("Adding track_id column to LIGHTRAG_DOC_STATUS table")
  1038. add_column_sql = """
  1039. ALTER TABLE LIGHTRAG_DOC_STATUS
  1040. ADD COLUMN track_id VARCHAR(255) NULL
  1041. """
  1042. await self.execute(add_column_sql)
  1043. logger.info(
  1044. "Successfully added track_id column to LIGHTRAG_DOC_STATUS table"
  1045. )
  1046. else:
  1047. logger.info(
  1048. "track_id column already exists in LIGHTRAG_DOC_STATUS table"
  1049. )
  1050. # Check if track_id index exists
  1051. check_index_sql = """
  1052. SELECT indexname
  1053. FROM pg_indexes
  1054. WHERE tablename = 'lightrag_doc_status'
  1055. AND indexname = 'idx_lightrag_doc_status_track_id'
  1056. """
  1057. index_info = await self.query(check_index_sql)
  1058. if not index_info:
  1059. logger.info(
  1060. "Creating index on track_id column for LIGHTRAG_DOC_STATUS table"
  1061. )
  1062. create_index_sql = """
  1063. CREATE INDEX idx_lightrag_doc_status_track_id ON LIGHTRAG_DOC_STATUS (track_id)
  1064. """
  1065. await self.execute(create_index_sql)
  1066. logger.info(
  1067. "Successfully created index on track_id column for LIGHTRAG_DOC_STATUS table"
  1068. )
  1069. else:
  1070. logger.info(
  1071. "Index on track_id column already exists for LIGHTRAG_DOC_STATUS table"
  1072. )
  1073. except Exception as e:
  1074. logger.warning(
  1075. f"Failed to add track_id column or index to LIGHTRAG_DOC_STATUS: {e}"
  1076. )
  1077. async def _migrate_doc_status_add_metadata_error_msg(self):
  1078. """Add metadata and error_msg columns to LIGHTRAG_DOC_STATUS table if they don't exist"""
  1079. try:
  1080. # Check if metadata column exists
  1081. check_metadata_sql = """
  1082. SELECT column_name
  1083. FROM information_schema.columns
  1084. WHERE table_name = 'lightrag_doc_status'
  1085. AND column_name = 'metadata'
  1086. """
  1087. metadata_info = await self.query(check_metadata_sql)
  1088. if not metadata_info:
  1089. logger.info("Adding metadata column to LIGHTRAG_DOC_STATUS table")
  1090. add_metadata_sql = """
  1091. ALTER TABLE LIGHTRAG_DOC_STATUS
  1092. ADD COLUMN metadata JSONB NULL DEFAULT '{}'::jsonb
  1093. """
  1094. await self.execute(add_metadata_sql)
  1095. logger.info(
  1096. "Successfully added metadata column to LIGHTRAG_DOC_STATUS table"
  1097. )
  1098. else:
  1099. logger.info(
  1100. "metadata column already exists in LIGHTRAG_DOC_STATUS table"
  1101. )
  1102. # Check if error_msg column exists
  1103. check_error_msg_sql = """
  1104. SELECT column_name
  1105. FROM information_schema.columns
  1106. WHERE table_name = 'lightrag_doc_status'
  1107. AND column_name = 'error_msg'
  1108. """
  1109. error_msg_info = await self.query(check_error_msg_sql)
  1110. if not error_msg_info:
  1111. logger.info("Adding error_msg column to LIGHTRAG_DOC_STATUS table")
  1112. add_error_msg_sql = """
  1113. ALTER TABLE LIGHTRAG_DOC_STATUS
  1114. ADD COLUMN error_msg TEXT NULL
  1115. """
  1116. await self.execute(add_error_msg_sql)
  1117. logger.info(
  1118. "Successfully added error_msg column to LIGHTRAG_DOC_STATUS table"
  1119. )
  1120. else:
  1121. logger.info(
  1122. "error_msg column already exists in LIGHTRAG_DOC_STATUS table"
  1123. )
  1124. except Exception as e:
  1125. logger.warning(
  1126. f"Failed to add metadata/error_msg columns to LIGHTRAG_DOC_STATUS: {e}"
  1127. )
  1128. async def _migrate_doc_full_add_pipeline_fields(self):
  1129. """Add pipeline-derived fields to LIGHTRAG_DOC_FULL if they don't exist.
  1130. Each ALTER is guarded individually so a single failure does not abort
  1131. the remaining columns; the migration is idempotent and retried on
  1132. every startup until all columns are present.
  1133. """
  1134. # content_hash uses TEXT (not VARCHAR(N)) so the column stays
  1135. # algorithm-agnostic; future SHA-512 / base64 hashes do not require a
  1136. # schema change. process_options is an opaque selector string emitted
  1137. # by sanitize_process_options() (e.g. "Fi").
  1138. columns_to_add = [
  1139. ("sidecar_location", "TEXT NULL"),
  1140. ("parse_format", "VARCHAR(32) NULL DEFAULT 'raw'"),
  1141. ("content_hash", "TEXT NULL"),
  1142. ("process_options", "TEXT NULL"),
  1143. ("chunk_options", "JSONB NULL DEFAULT '{}'::jsonb"),
  1144. ("parse_engine", "VARCHAR(32) NULL"),
  1145. ]
  1146. try:
  1147. existing = await self.query(
  1148. """
  1149. SELECT column_name
  1150. FROM information_schema.columns
  1151. WHERE table_name = 'lightrag_doc_full'
  1152. AND column_name = ANY($1)
  1153. """,
  1154. [[c for c, _ in columns_to_add]],
  1155. multirows=True,
  1156. )
  1157. existing_names = {row["column_name"] for row in (existing or [])}
  1158. except Exception as e:
  1159. logger.warning(
  1160. f"Failed to inspect LIGHTRAG_DOC_FULL columns for migration: {e}"
  1161. )
  1162. existing_names = set()
  1163. for col_name, col_type in columns_to_add:
  1164. if col_name in existing_names:
  1165. logger.debug(f"Column {col_name} already exists in LIGHTRAG_DOC_FULL")
  1166. continue
  1167. try:
  1168. alter_sql = (
  1169. f"ALTER TABLE LIGHTRAG_DOC_FULL ADD COLUMN {col_name} {col_type}"
  1170. )
  1171. logger.info(f"Adding {col_name} column to LIGHTRAG_DOC_FULL table")
  1172. await self.execute(alter_sql)
  1173. logger.info(
  1174. f"Successfully added {col_name} column to LIGHTRAG_DOC_FULL table"
  1175. )
  1176. except Exception as e:
  1177. logger.error(
  1178. f"Failed to add column {col_name} to LIGHTRAG_DOC_FULL: {e}"
  1179. )
  1180. async def _migrate_doc_status_add_content_hash(self):
  1181. """Add content_hash column to LIGHTRAG_DOC_STATUS table if it doesn't exist."""
  1182. try:
  1183. check_column_sql = """
  1184. SELECT column_name
  1185. FROM information_schema.columns
  1186. WHERE table_name = 'lightrag_doc_status'
  1187. AND column_name = 'content_hash'
  1188. """
  1189. column_info = await self.query(check_column_sql)
  1190. if not column_info:
  1191. logger.info("Adding content_hash column to LIGHTRAG_DOC_STATUS table")
  1192. # TEXT (not VARCHAR(N)) so the column is agnostic to the hash
  1193. # algorithm; today the pipeline writes 64-char SHA-256 hex.
  1194. await self.execute(
  1195. "ALTER TABLE LIGHTRAG_DOC_STATUS ADD COLUMN content_hash TEXT NULL"
  1196. )
  1197. logger.info(
  1198. "Successfully added content_hash column to LIGHTRAG_DOC_STATUS table"
  1199. )
  1200. else:
  1201. logger.debug(
  1202. "content_hash column already exists in LIGHTRAG_DOC_STATUS table"
  1203. )
  1204. except Exception as e:
  1205. logger.error(
  1206. f"Failed to add content_hash column to LIGHTRAG_DOC_STATUS: {e}"
  1207. )
  1208. try:
  1209. check_index_sql = """
  1210. SELECT indexname FROM pg_indexes
  1211. WHERE tablename = 'lightrag_doc_status'
  1212. AND indexname = 'idx_lightrag_doc_status_workspace_content_hash'
  1213. """
  1214. index_info = await self.query(check_index_sql)
  1215. if not index_info:
  1216. logger.info(
  1217. "Creating partial index idx_lightrag_doc_status_workspace_content_hash"
  1218. )
  1219. await self.execute(
  1220. """
  1221. CREATE INDEX IF NOT EXISTS idx_lightrag_doc_status_workspace_content_hash
  1222. ON LIGHTRAG_DOC_STATUS (workspace, content_hash)
  1223. WHERE content_hash IS NOT NULL AND content_hash <> ''
  1224. """
  1225. )
  1226. except Exception as e:
  1227. logger.error(
  1228. f"Failed to create partial content_hash index on LIGHTRAG_DOC_STATUS: {e}"
  1229. )
  1230. async def _migrate_text_chunks_add_heading_sidecar(self):
  1231. """Add heading and sidecar JSONB columns to LIGHTRAG_DOC_CHUNKS if missing."""
  1232. columns_to_add = [
  1233. ("heading", "JSONB NULL DEFAULT '{}'::jsonb"),
  1234. ("sidecar", "JSONB NULL DEFAULT '{}'::jsonb"),
  1235. ]
  1236. try:
  1237. existing = await self.query(
  1238. """
  1239. SELECT column_name
  1240. FROM information_schema.columns
  1241. WHERE table_name = 'lightrag_doc_chunks'
  1242. AND column_name = ANY($1)
  1243. """,
  1244. [[c for c, _ in columns_to_add]],
  1245. multirows=True,
  1246. )
  1247. existing_names = {row["column_name"] for row in (existing or [])}
  1248. except Exception as e:
  1249. logger.warning(
  1250. f"Failed to inspect LIGHTRAG_DOC_CHUNKS columns for migration: {e}"
  1251. )
  1252. existing_names = set()
  1253. for col_name, col_type in columns_to_add:
  1254. if col_name in existing_names:
  1255. logger.debug(f"Column {col_name} already exists in LIGHTRAG_DOC_CHUNKS")
  1256. continue
  1257. try:
  1258. alter_sql = (
  1259. f"ALTER TABLE LIGHTRAG_DOC_CHUNKS ADD COLUMN {col_name} {col_type}"
  1260. )
  1261. logger.info(f"Adding {col_name} column to LIGHTRAG_DOC_CHUNKS table")
  1262. await self.execute(alter_sql)
  1263. logger.info(
  1264. f"Successfully added {col_name} column to LIGHTRAG_DOC_CHUNKS table"
  1265. )
  1266. except Exception as e:
  1267. logger.error(
  1268. f"Failed to add column {col_name} to LIGHTRAG_DOC_CHUNKS: {e}"
  1269. )
  1270. async def _migrate_field_lengths(self):
  1271. """Migrate database field lengths: entity_name, source_id, target_id, and file_path"""
  1272. # Define the field changes needed
  1273. field_migrations = [
  1274. {
  1275. "table": "LIGHTRAG_VDB_ENTITY",
  1276. "column": "entity_name",
  1277. "old_type": "character varying(255)",
  1278. "new_type": "VARCHAR(512)",
  1279. "description": "entity_name from 255 to 512",
  1280. },
  1281. {
  1282. "table": "LIGHTRAG_VDB_RELATION",
  1283. "column": "source_id",
  1284. "old_type": "character varying(256)",
  1285. "new_type": "VARCHAR(512)",
  1286. "description": "source_id from 256 to 512",
  1287. },
  1288. {
  1289. "table": "LIGHTRAG_VDB_RELATION",
  1290. "column": "target_id",
  1291. "old_type": "character varying(256)",
  1292. "new_type": "VARCHAR(512)",
  1293. "description": "target_id from 256 to 512",
  1294. },
  1295. {
  1296. "table": "LIGHTRAG_DOC_CHUNKS",
  1297. "column": "file_path",
  1298. "old_type": "character varying(256)",
  1299. "new_type": "TEXT",
  1300. "description": "file_path to TEXT NULL",
  1301. },
  1302. {
  1303. "table": "LIGHTRAG_VDB_CHUNKS",
  1304. "column": "file_path",
  1305. "old_type": "character varying(256)",
  1306. "new_type": "TEXT",
  1307. "description": "file_path to TEXT NULL",
  1308. },
  1309. ]
  1310. try:
  1311. # Filter out tables that don't exist (e.g., legacy vector tables may not exist)
  1312. existing_migrations = []
  1313. for migration in field_migrations:
  1314. if await self.check_table_exists(migration["table"]):
  1315. existing_migrations.append(migration)
  1316. else:
  1317. logger.debug(
  1318. f"Table {migration['table']} does not exist, skipping field length migration for {migration['column']}"
  1319. )
  1320. # Skip if no migrations to process
  1321. if not existing_migrations:
  1322. logger.debug("No tables found for field length migration")
  1323. return
  1324. # Use filtered migrations for processing
  1325. field_migrations = existing_migrations
  1326. # Optimization: Batch check all columns in one query instead of 5 separate queries
  1327. unique_tables = list(set(m["table"].lower() for m in field_migrations))
  1328. unique_columns = list(set(m["column"] for m in field_migrations))
  1329. check_all_columns_sql = """
  1330. SELECT table_name, column_name, data_type, character_maximum_length, is_nullable
  1331. FROM information_schema.columns
  1332. WHERE table_name = ANY($1)
  1333. AND column_name = ANY($2)
  1334. """
  1335. all_columns_result = await self.query(
  1336. check_all_columns_sql, [unique_tables, unique_columns], multirows=True
  1337. )
  1338. # Build lookup dict: (table_name, column_name) -> column_info
  1339. column_info_map = {}
  1340. if all_columns_result:
  1341. column_info_map = {
  1342. (row["table_name"].upper(), row["column_name"]): row
  1343. for row in all_columns_result
  1344. }
  1345. # Now iterate and migrate only what's needed
  1346. for migration in field_migrations:
  1347. try:
  1348. column_info = column_info_map.get(
  1349. (migration["table"], migration["column"])
  1350. )
  1351. if not column_info:
  1352. logger.warning(
  1353. f"Column {migration['table']}.{migration['column']} does not exist, skipping migration"
  1354. )
  1355. continue
  1356. current_type = column_info.get("data_type", "").lower()
  1357. current_length = column_info.get("character_maximum_length")
  1358. # Check if migration is needed
  1359. needs_migration = False
  1360. if migration["column"] == "entity_name" and current_length == 255:
  1361. needs_migration = True
  1362. elif (
  1363. migration["column"] in ["source_id", "target_id"]
  1364. and current_length == 256
  1365. ):
  1366. needs_migration = True
  1367. elif (
  1368. migration["column"] == "file_path"
  1369. and current_type == "character varying"
  1370. ):
  1371. needs_migration = True
  1372. if needs_migration:
  1373. logger.info(
  1374. f"Migrating {migration['table']}.{migration['column']}: {migration['description']}"
  1375. )
  1376. # Execute the migration
  1377. alter_sql = f"""
  1378. ALTER TABLE {migration["table"]}
  1379. ALTER COLUMN {migration["column"]} TYPE {migration["new_type"]}
  1380. """
  1381. await self.execute(alter_sql)
  1382. logger.info(
  1383. f"Successfully migrated {migration['table']}.{migration['column']}"
  1384. )
  1385. else:
  1386. logger.debug(
  1387. f"Column {migration['table']}.{migration['column']} already has correct type, no migration needed"
  1388. )
  1389. except Exception as e:
  1390. # Log error but don't interrupt the process
  1391. logger.warning(
  1392. f"Failed to migrate {migration['table']}.{migration['column']}: {e}"
  1393. )
  1394. except Exception as e:
  1395. logger.error(f"Failed to batch check field lengths: {e}")
  1396. async def check_tables(self):
  1397. # Vector tables that should be skipped - they are created by PGVectorStorage.setup_table()
  1398. # with proper embedding model and dimension suffix for data isolation
  1399. vector_tables_to_skip = {
  1400. "LIGHTRAG_VDB_CHUNKS",
  1401. "LIGHTRAG_VDB_ENTITY",
  1402. "LIGHTRAG_VDB_RELATION",
  1403. }
  1404. # First create all tables (except vector tables)
  1405. for k, v in TABLES.items():
  1406. # Skip vector tables - they are created by PGVectorStorage.setup_table()
  1407. if k in vector_tables_to_skip:
  1408. continue
  1409. try:
  1410. await self.query(f"SELECT 1 FROM {k} LIMIT 1")
  1411. except Exception:
  1412. try:
  1413. logger.info(f"PostgreSQL, Try Creating table {k} in database")
  1414. await self.execute(v["ddl"])
  1415. logger.info(
  1416. f"PostgreSQL, Creation success table {k} in PostgreSQL database"
  1417. )
  1418. except Exception as e:
  1419. logger.error(
  1420. f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}"
  1421. )
  1422. raise e
  1423. # Batch check all indexes at once (optimization: single query instead of N queries)
  1424. try:
  1425. # Exclude vector tables from index creation since they are created by PGVectorStorage.setup_table()
  1426. table_names = [k for k in TABLES.keys() if k not in vector_tables_to_skip]
  1427. table_names_lower = [t.lower() for t in table_names]
  1428. # Get all existing indexes for our tables in one query
  1429. check_all_indexes_sql = """
  1430. SELECT indexname, tablename
  1431. FROM pg_indexes
  1432. WHERE tablename = ANY($1)
  1433. """
  1434. existing_indexes_result = await self.query(
  1435. check_all_indexes_sql, [table_names_lower], multirows=True
  1436. )
  1437. # Build a set of existing index names for fast lookup
  1438. existing_indexes = set()
  1439. if existing_indexes_result:
  1440. existing_indexes = {row["indexname"] for row in existing_indexes_result}
  1441. # Create missing indexes
  1442. for k in table_names:
  1443. # Create index for id column if missing
  1444. index_name = f"idx_{k.lower()}_id"
  1445. if index_name not in existing_indexes:
  1446. try:
  1447. create_index_sql = f"CREATE INDEX {index_name} ON {k}(id)"
  1448. logger.info(
  1449. f"PostgreSQL, Creating index {index_name} on table {k}"
  1450. )
  1451. await self.execute(create_index_sql)
  1452. except Exception as e:
  1453. logger.error(
  1454. f"PostgreSQL, Failed to create index {index_name}, Got: {e}"
  1455. )
  1456. # Create composite index for (workspace, id) if missing
  1457. composite_index_name = f"idx_{k.lower()}_workspace_id"
  1458. if composite_index_name not in existing_indexes:
  1459. try:
  1460. create_composite_index_sql = (
  1461. f"CREATE INDEX {composite_index_name} ON {k}(workspace, id)"
  1462. )
  1463. logger.info(
  1464. f"PostgreSQL, Creating composite index {composite_index_name} on table {k}"
  1465. )
  1466. await self.execute(create_composite_index_sql)
  1467. except Exception as e:
  1468. logger.error(
  1469. f"PostgreSQL, Failed to create composite index {composite_index_name}, Got: {e}"
  1470. )
  1471. except Exception as e:
  1472. logger.error(f"PostgreSQL, Failed to batch check/create indexes: {e}")
  1473. # NOTE: Vector index creation moved to PGVectorStorage.setup_table()
  1474. # Each vector storage instance creates its own index with correct embedding_dim
  1475. # After all tables are created, attempt to migrate timestamp fields
  1476. try:
  1477. await self._migrate_timestamp_columns()
  1478. except Exception as e:
  1479. logger.error(f"PostgreSQL, Failed to migrate timestamp columns: {e}")
  1480. # Don't throw an exception, allow the initialization process to continue
  1481. # Migrate LLM cache schema: add new columns and remove deprecated mode field
  1482. try:
  1483. await self._migrate_llm_cache_schema()
  1484. except Exception as e:
  1485. logger.error(f"PostgreSQL, Failed to migrate LLM cache schema: {e}")
  1486. # Don't throw an exception, allow the initialization process to continue
  1487. # Finally, attempt to migrate old doc chunks data if needed
  1488. try:
  1489. await self._migrate_doc_chunks_to_vdb_chunks()
  1490. except Exception as e:
  1491. logger.error(f"PostgreSQL, Failed to migrate doc_chunks to vdb_chunks: {e}")
  1492. # Check and migrate LLM cache to flattened keys if needed
  1493. try:
  1494. if await self._check_llm_cache_needs_migration():
  1495. await self._migrate_llm_cache_to_flattened_keys()
  1496. except Exception as e:
  1497. logger.error(f"PostgreSQL, LLM cache migration failed: {e}")
  1498. # Migrate doc status to add chunks_list field if needed
  1499. try:
  1500. await self._migrate_doc_status_add_chunks_list()
  1501. except Exception as e:
  1502. logger.error(
  1503. f"PostgreSQL, Failed to migrate doc status chunks_list field: {e}"
  1504. )
  1505. # Migrate text chunks to add llm_cache_list field if needed
  1506. try:
  1507. await self._migrate_text_chunks_add_llm_cache_list()
  1508. except Exception as e:
  1509. logger.error(
  1510. f"PostgreSQL, Failed to migrate text chunks llm_cache_list field: {e}"
  1511. )
  1512. # Migrate field lengths for entity_name, source_id, target_id, and file_path
  1513. try:
  1514. await self._migrate_field_lengths()
  1515. except Exception as e:
  1516. logger.error(f"PostgreSQL, Failed to migrate field lengths: {e}")
  1517. # Migrate doc status to add track_id field if needed
  1518. try:
  1519. await self._migrate_doc_status_add_track_id()
  1520. except Exception as e:
  1521. logger.error(
  1522. f"PostgreSQL, Failed to migrate doc status track_id field: {e}"
  1523. )
  1524. # Migrate doc status to add metadata and error_msg fields if needed
  1525. try:
  1526. await self._migrate_doc_status_add_metadata_error_msg()
  1527. except Exception as e:
  1528. logger.error(
  1529. f"PostgreSQL, Failed to migrate doc status metadata/error_msg fields: {e}"
  1530. )
  1531. # Create pagination optimization indexes for LIGHTRAG_DOC_STATUS
  1532. try:
  1533. await self._create_pagination_indexes()
  1534. except Exception as e:
  1535. logger.error(f"PostgreSQL, Failed to create pagination indexes: {e}")
  1536. # Migrate to ensure new tables LIGHTRAG_FULL_ENTITIES and LIGHTRAG_FULL_RELATIONS exist
  1537. try:
  1538. await self._migrate_create_full_entities_relations_tables()
  1539. except Exception as e:
  1540. logger.error(
  1541. f"PostgreSQL, Failed to create full entities/relations tables: {e}"
  1542. )
  1543. # Migrate LIGHTRAG_DOC_FULL to add pipeline-derived fields used by the
  1544. # JSON storage parity: sidecar_location / parse_format / content_hash /
  1545. # process_options / chunk_options / parse_engine
  1546. try:
  1547. await self._migrate_doc_full_add_pipeline_fields()
  1548. except Exception as e:
  1549. logger.error(
  1550. f"PostgreSQL, Failed to migrate LIGHTRAG_DOC_FULL pipeline fields: {e}"
  1551. )
  1552. # Migrate LIGHTRAG_DOC_STATUS to add content_hash column for content
  1553. # dedup queries
  1554. try:
  1555. await self._migrate_doc_status_add_content_hash()
  1556. except Exception as e:
  1557. logger.error(
  1558. f"PostgreSQL, Failed to migrate LIGHTRAG_DOC_STATUS content_hash field: {e}"
  1559. )
  1560. # Migrate LIGHTRAG_DOC_CHUNKS to add heading / sidecar JSONB columns
  1561. try:
  1562. await self._migrate_text_chunks_add_heading_sidecar()
  1563. except Exception as e:
  1564. logger.error(
  1565. f"PostgreSQL, Failed to migrate LIGHTRAG_DOC_CHUNKS heading/sidecar fields: {e}"
  1566. )
  1567. async def _migrate_create_full_entities_relations_tables(self):
  1568. """Create LIGHTRAG_FULL_ENTITIES and LIGHTRAG_FULL_RELATIONS tables if they don't exist"""
  1569. tables_to_check = [
  1570. {
  1571. "name": "LIGHTRAG_FULL_ENTITIES",
  1572. "ddl": TABLES["LIGHTRAG_FULL_ENTITIES"]["ddl"],
  1573. "description": "Full entities storage table",
  1574. },
  1575. {
  1576. "name": "LIGHTRAG_FULL_RELATIONS",
  1577. "ddl": TABLES["LIGHTRAG_FULL_RELATIONS"]["ddl"],
  1578. "description": "Full relations storage table",
  1579. },
  1580. ]
  1581. for table_info in tables_to_check:
  1582. table_name = table_info["name"]
  1583. try:
  1584. # Check if table exists
  1585. check_table_sql = """
  1586. SELECT table_name
  1587. FROM information_schema.tables
  1588. WHERE table_name = $1
  1589. AND table_schema = 'public'
  1590. """
  1591. params = {"table_name": table_name.lower()}
  1592. table_exists = await self.query(check_table_sql, list(params.values()))
  1593. if not table_exists:
  1594. logger.info(f"Creating table {table_name}")
  1595. await self.execute(table_info["ddl"])
  1596. logger.info(
  1597. f"Successfully created {table_info['description']}: {table_name}"
  1598. )
  1599. # Create basic indexes for the new table
  1600. try:
  1601. # Create index for id column
  1602. index_name = f"idx_{table_name.lower()}_id"
  1603. create_index_sql = (
  1604. f"CREATE INDEX {index_name} ON {table_name}(id)"
  1605. )
  1606. await self.execute(create_index_sql)
  1607. logger.info(f"Created index {index_name} on table {table_name}")
  1608. # Create composite index for (workspace, id) columns
  1609. composite_index_name = f"idx_{table_name.lower()}_workspace_id"
  1610. create_composite_index_sql = f"CREATE INDEX {composite_index_name} ON {table_name}(workspace, id)"
  1611. await self.execute(create_composite_index_sql)
  1612. logger.info(
  1613. f"Created composite index {composite_index_name} on table {table_name}"
  1614. )
  1615. except Exception as e:
  1616. logger.warning(
  1617. f"Failed to create indexes for table {table_name}: {e}"
  1618. )
  1619. else:
  1620. logger.debug(f"Table {table_name} already exists")
  1621. except Exception as e:
  1622. logger.error(f"Failed to create table {table_name}: {e}")
  1623. async def _create_pagination_indexes(self):
  1624. """Create indexes to optimize pagination queries for LIGHTRAG_DOC_STATUS"""
  1625. indexes = [
  1626. {
  1627. "name": "idx_lightrag_doc_status_workspace_status_updated_at",
  1628. "sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_lightrag_doc_status_workspace_status_updated_at ON LIGHTRAG_DOC_STATUS (workspace, status, updated_at DESC)",
  1629. "description": "Composite index for workspace + status + updated_at pagination",
  1630. },
  1631. {
  1632. "name": "idx_lightrag_doc_status_workspace_status_created_at",
  1633. "sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_lightrag_doc_status_workspace_status_created_at ON LIGHTRAG_DOC_STATUS (workspace, status, created_at DESC)",
  1634. "description": "Composite index for workspace + status + created_at pagination",
  1635. },
  1636. {
  1637. "name": "idx_lightrag_doc_status_workspace_updated_at",
  1638. "sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_lightrag_doc_status_workspace_updated_at ON LIGHTRAG_DOC_STATUS (workspace, updated_at DESC)",
  1639. "description": "Index for workspace + updated_at pagination (all statuses)",
  1640. },
  1641. {
  1642. "name": "idx_lightrag_doc_status_workspace_created_at",
  1643. "sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_lightrag_doc_status_workspace_created_at ON LIGHTRAG_DOC_STATUS (workspace, created_at DESC)",
  1644. "description": "Index for workspace + created_at pagination (all statuses)",
  1645. },
  1646. {
  1647. "name": "idx_lightrag_doc_status_workspace_id",
  1648. "sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_lightrag_doc_status_workspace_id ON LIGHTRAG_DOC_STATUS (workspace, id)",
  1649. "description": "Index for workspace + id sorting",
  1650. },
  1651. {
  1652. "name": "idx_lightrag_doc_status_workspace_file_path",
  1653. "sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_lightrag_doc_status_workspace_file_path ON LIGHTRAG_DOC_STATUS (workspace, file_path)",
  1654. "description": "Index for workspace + file_path sorting",
  1655. },
  1656. ]
  1657. # Fetch all existing index names in one query instead of N separate checks.
  1658. index_names = [idx["name"] for idx in indexes]
  1659. check_sql = """
  1660. SELECT indexname FROM pg_indexes
  1661. WHERE tablename = 'lightrag_doc_status'
  1662. AND indexname = ANY($1)
  1663. """
  1664. try:
  1665. rows = await self.query(check_sql, [index_names], multirows=True)
  1666. existing_names = {row["indexname"] for row in (rows or [])}
  1667. except asyncpg.PostgresError as e:
  1668. logger.warning(
  1669. f"[{self.workspace}] Failed to query existing pagination indexes "
  1670. f"({type(e).__name__}), will attempt to create all: {e}"
  1671. )
  1672. existing_names = set()
  1673. for index in indexes:
  1674. if index["name"] in existing_names:
  1675. logger.debug(f"Index already exists: {index['name']}")
  1676. continue
  1677. try:
  1678. logger.info(f"Creating pagination index: {index['description']}")
  1679. await self.execute(index["sql"])
  1680. logger.info(f"Successfully created index: {index['name']}")
  1681. except asyncpg.PostgresError as e:
  1682. logger.warning(
  1683. f"Failed to create index {index['name']} ({type(e).__name__}): {e}"
  1684. )
  1685. async def _create_vector_index(self, table_name: str, embedding_dim: int):
  1686. """
  1687. Create vector index for a specific table.
  1688. Args:
  1689. table_name: Name of the table to create index on
  1690. embedding_dim: Embedding dimension for the vector column
  1691. """
  1692. if not self.vector_index_type:
  1693. return
  1694. create_sql = {
  1695. "HNSW": f"""
  1696. CREATE INDEX {{vector_index_name}}
  1697. ON {{table_name}} USING hnsw (content_vector vector_cosine_ops)
  1698. WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef})
  1699. """,
  1700. "HNSW_HALFVEC": f"""
  1701. CREATE INDEX {{vector_index_name}}
  1702. ON {{table_name}} USING hnsw (content_vector halfvec_cosine_ops)
  1703. WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef})
  1704. """,
  1705. "IVFFLAT": f"""
  1706. CREATE INDEX {{vector_index_name}}
  1707. ON {{table_name}} USING ivfflat (content_vector vector_cosine_ops)
  1708. WITH (lists = {self.ivfflat_lists})
  1709. """,
  1710. "VCHORDRQ": f"""
  1711. CREATE INDEX {{vector_index_name}}
  1712. ON {{table_name}} USING vchordrq (content_vector vector_cosine_ops)
  1713. {f"WITH (options = $${self.vchordrq_build_options}$$)" if self.vchordrq_build_options else ""}
  1714. """,
  1715. }
  1716. if self.vector_index_type not in create_sql:
  1717. logger.warning(
  1718. f"Unsupported vector index type: {self.vector_index_type}. "
  1719. "Supported types: HNSW, HNSW_HALFVEC, IVFFLAT, VCHORDRQ"
  1720. )
  1721. return
  1722. k = table_name
  1723. # Use _safe_index_name to avoid PostgreSQL's 63-byte identifier truncation
  1724. index_suffix = f"{self.vector_index_type.lower()}_cosine"
  1725. vector_index_name = _safe_index_name(k, index_suffix)
  1726. check_vector_index_sql = f"""
  1727. SELECT 1 FROM pg_indexes
  1728. WHERE indexname = '{vector_index_name}' AND tablename = '{k.lower()}'
  1729. """
  1730. if self.vector_index_type == "HNSW_HALFVEC":
  1731. column_type = "HALFVEC"
  1732. else:
  1733. column_type = "VECTOR"
  1734. try:
  1735. vector_index_exists = await self.query(check_vector_index_sql)
  1736. if not vector_index_exists:
  1737. for suffix in _VECTOR_INDEX_SUFFIXES:
  1738. if suffix == index_suffix:
  1739. continue
  1740. old_name = _safe_index_name(k, suffix)
  1741. await self.execute(f"DROP INDEX IF EXISTS {old_name}")
  1742. alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE {column_type}({embedding_dim})"
  1743. await self.execute(alter_sql)
  1744. logger.debug(f"Ensured vector dimension for {k}")
  1745. logger.info(
  1746. f"Creating {self.vector_index_type} index {vector_index_name} on table {k}"
  1747. )
  1748. await self.execute(
  1749. create_sql[self.vector_index_type].format(
  1750. vector_index_name=vector_index_name, table_name=k
  1751. )
  1752. )
  1753. logger.info(
  1754. f"Successfully created vector index {vector_index_name} on table {k}"
  1755. )
  1756. else:
  1757. logger.info(
  1758. f"{self.vector_index_type} vector index {vector_index_name} already exists on table {k}"
  1759. )
  1760. except Exception as e:
  1761. logger.error(f"Failed to create vector index on table {k}, Got: {e}")
  1762. async def query(
  1763. self,
  1764. sql: str,
  1765. params: list[Any] | None = None,
  1766. multirows: bool = False,
  1767. with_age: bool = False,
  1768. graph_name: str | None = None,
  1769. timing_label: str | None = None,
  1770. ) -> dict[str, Any] | None | list[dict[str, Any]]:
  1771. async def _operation(connection: asyncpg.Connection) -> Any:
  1772. prepared_params = tuple(params) if params else ()
  1773. fetch_start = time.perf_counter()
  1774. if prepared_params:
  1775. rows = await connection.fetch(sql, *prepared_params)
  1776. else:
  1777. rows = await connection.fetch(sql)
  1778. fetch_elapsed = time.perf_counter() - fetch_start
  1779. if timing_label:
  1780. performance_timing_log(
  1781. "[%s] connection.fetch completed in %.4fs row_count=%s",
  1782. timing_label,
  1783. fetch_elapsed,
  1784. len(rows),
  1785. )
  1786. conversion_start = time.perf_counter()
  1787. if multirows:
  1788. if rows:
  1789. columns = [col for col in rows[0].keys()]
  1790. converted_rows = [dict(zip(columns, row)) for row in rows]
  1791. else:
  1792. converted_rows = []
  1793. if timing_label:
  1794. conversion_elapsed = time.perf_counter() - conversion_start
  1795. performance_timing_log(
  1796. "[%s] result conversion completed in %.4fs multirows=%s",
  1797. timing_label,
  1798. conversion_elapsed,
  1799. True,
  1800. )
  1801. return converted_rows
  1802. if rows:
  1803. columns = rows[0].keys()
  1804. converted_row = dict(zip(columns, rows[0]))
  1805. else:
  1806. converted_row = None
  1807. if timing_label:
  1808. conversion_elapsed = time.perf_counter() - conversion_start
  1809. performance_timing_log(
  1810. "[%s] result conversion completed in %.4fs multirows=%s",
  1811. timing_label,
  1812. conversion_elapsed,
  1813. False,
  1814. )
  1815. if converted_row is not None:
  1816. return converted_row
  1817. return None
  1818. try:
  1819. return await self._run_with_retry(
  1820. _operation,
  1821. with_age=with_age,
  1822. graph_name=graph_name,
  1823. timing_label=timing_label,
  1824. )
  1825. except Exception as e:
  1826. logger.error(f"PostgreSQL database, error:{e}")
  1827. raise
  1828. async def check_table_exists(self, table_name: str) -> bool:
  1829. """Check if a table exists in PostgreSQL database
  1830. Args:
  1831. table_name: Name of the table to check
  1832. Returns:
  1833. bool: True if table exists, False otherwise
  1834. """
  1835. query = """
  1836. SELECT EXISTS (
  1837. SELECT FROM information_schema.tables
  1838. WHERE table_name = $1
  1839. )
  1840. """
  1841. result = await self.query(query, [table_name.lower()])
  1842. return result.get("exists", False) if result else False
  1843. async def execute(
  1844. self,
  1845. sql: str,
  1846. data: dict[str, Any] | None = None,
  1847. upsert: bool = False,
  1848. ignore_if_exists: bool = False,
  1849. with_age: bool = False,
  1850. graph_name: str | None = None,
  1851. timing_label: str | None = None,
  1852. ):
  1853. async def _operation(connection: asyncpg.Connection) -> Any:
  1854. prepared_values = tuple(data.values()) if data else ()
  1855. execute_start = time.perf_counter()
  1856. try:
  1857. if not data:
  1858. result = await connection.execute(sql)
  1859. else:
  1860. result = await connection.execute(sql, *prepared_values)
  1861. except (
  1862. asyncpg.exceptions.UniqueViolationError,
  1863. asyncpg.exceptions.DuplicateTableError,
  1864. asyncpg.exceptions.DuplicateObjectError,
  1865. asyncpg.exceptions.InvalidSchemaNameError,
  1866. ) as e:
  1867. if ignore_if_exists:
  1868. logger.debug("PostgreSQL, ignoring duplicate during execute: %r", e)
  1869. result = None
  1870. elif upsert:
  1871. logger.info(
  1872. "PostgreSQL, duplicate detected but treated as upsert success: %r",
  1873. e,
  1874. )
  1875. result = None
  1876. else:
  1877. raise
  1878. except Exception:
  1879. if timing_label:
  1880. performance_timing_log(
  1881. "[%s] connection.execute failed after %.4fs",
  1882. timing_label,
  1883. time.perf_counter() - execute_start,
  1884. )
  1885. raise
  1886. if timing_label:
  1887. performance_timing_log(
  1888. "[%s] connection.execute completed in %.4fs result=%s",
  1889. timing_label,
  1890. time.perf_counter() - execute_start,
  1891. result,
  1892. )
  1893. return result
  1894. try:
  1895. await self._run_with_retry(
  1896. _operation,
  1897. with_age=with_age,
  1898. graph_name=graph_name,
  1899. timing_label=timing_label,
  1900. )
  1901. except Exception as e:
  1902. logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
  1903. raise
  1904. class ClientManager:
  1905. """Manage the process-wide PostgreSQL client pool shared by PG storages.
  1906. The first successful initialization defines the pool configuration for the
  1907. lifetime of the shared client. Reusing the pool with a different vector
  1908. storage setup is not supported and will raise a fail-fast error.
  1909. """
  1910. _instances: dict[str, Any] = {
  1911. "db": None,
  1912. "ref_count": 0,
  1913. "vector_signature": None,
  1914. }
  1915. _lock = asyncio.Lock()
  1916. @staticmethod
  1917. def get_config(vector_storage: str | None = None) -> dict[str, Any]:
  1918. config = configparser.ConfigParser()
  1919. config.read("config.ini", "utf-8")
  1920. return {
  1921. "host": os.environ.get(
  1922. "POSTGRES_HOST",
  1923. config.get("postgres", "host", fallback="localhost"),
  1924. ),
  1925. "port": os.environ.get(
  1926. "POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
  1927. ),
  1928. "user": os.environ.get(
  1929. "POSTGRES_USER", config.get("postgres", "user", fallback="postgres")
  1930. ),
  1931. "password": os.environ.get(
  1932. "POSTGRES_PASSWORD",
  1933. config.get("postgres", "password", fallback=None),
  1934. ),
  1935. "database": os.environ.get(
  1936. "POSTGRES_DATABASE",
  1937. config.get("postgres", "database", fallback="postgres"),
  1938. ),
  1939. "workspace": os.environ.get(
  1940. "POSTGRES_WORKSPACE",
  1941. config.get("postgres", "workspace", fallback=None),
  1942. ),
  1943. "max_connections": os.environ.get(
  1944. "POSTGRES_MAX_CONNECTIONS",
  1945. config.get("postgres", "max_connections", fallback=50),
  1946. ),
  1947. # SSL configuration
  1948. "ssl_mode": os.environ.get(
  1949. "POSTGRES_SSL_MODE",
  1950. config.get("postgres", "ssl_mode", fallback=None),
  1951. ),
  1952. "ssl_cert": os.environ.get(
  1953. "POSTGRES_SSL_CERT",
  1954. config.get("postgres", "ssl_cert", fallback=None),
  1955. ),
  1956. "ssl_key": os.environ.get(
  1957. "POSTGRES_SSL_KEY",
  1958. config.get("postgres", "ssl_key", fallback=None),
  1959. ),
  1960. "ssl_root_cert": os.environ.get(
  1961. "POSTGRES_SSL_ROOT_CERT",
  1962. config.get("postgres", "ssl_root_cert", fallback=None),
  1963. ),
  1964. "ssl_crl": os.environ.get(
  1965. "POSTGRES_SSL_CRL",
  1966. config.get("postgres", "ssl_crl", fallback=None),
  1967. ),
  1968. # Vector configuration: derived from the vector storage backend in use.
  1969. # PGVectorStorage requires pgvector; all other backends do not.
  1970. "enable_vector": vector_storage == "PGVectorStorage"
  1971. if vector_storage is not None
  1972. else True,
  1973. "vector_index_type": os.environ.get(
  1974. "POSTGRES_VECTOR_INDEX_TYPE",
  1975. config.get("postgres", "vector_index_type", fallback="HNSW"),
  1976. ),
  1977. "hnsw_m": int(
  1978. os.environ.get(
  1979. "POSTGRES_HNSW_M",
  1980. config.get("postgres", "hnsw_m", fallback="16"),
  1981. )
  1982. ),
  1983. "hnsw_ef": int(
  1984. os.environ.get(
  1985. "POSTGRES_HNSW_EF",
  1986. config.get("postgres", "hnsw_ef", fallback="64"),
  1987. )
  1988. ),
  1989. "ivfflat_lists": int(
  1990. os.environ.get(
  1991. "POSTGRES_IVFFLAT_LISTS",
  1992. config.get("postgres", "ivfflat_lists", fallback="100"),
  1993. )
  1994. ),
  1995. "vchordrq_build_options": os.environ.get(
  1996. "POSTGRES_VCHORDRQ_BUILD_OPTIONS",
  1997. config.get("postgres", "vchordrq_build_options", fallback=""),
  1998. ),
  1999. "vchordrq_probes": os.environ.get(
  2000. "POSTGRES_VCHORDRQ_PROBES",
  2001. config.get("postgres", "vchordrq_probes", fallback=""),
  2002. ),
  2003. "vchordrq_epsilon": float(
  2004. os.environ.get(
  2005. "POSTGRES_VCHORDRQ_EPSILON",
  2006. config.get("postgres", "vchordrq_epsilon", fallback="1.9"),
  2007. )
  2008. ),
  2009. # Server settings for Supabase
  2010. "server_settings": os.environ.get(
  2011. "POSTGRES_SERVER_SETTINGS",
  2012. config.get("postgres", "server_options", fallback=None),
  2013. ),
  2014. "statement_cache_size": os.environ.get(
  2015. "POSTGRES_STATEMENT_CACHE_SIZE",
  2016. config.get("postgres", "statement_cache_size", fallback=None),
  2017. ),
  2018. # Connection retry configuration
  2019. "connection_retry_attempts": min(
  2020. 100, # Increased from 10 to 100 for long-running operations
  2021. int(
  2022. os.environ.get(
  2023. "POSTGRES_CONNECTION_RETRIES",
  2024. config.get("postgres", "connection_retries", fallback=10),
  2025. )
  2026. ),
  2027. ),
  2028. "connection_retry_backoff": min(
  2029. 300.0, # Increased from 5.0 to 300.0 (5 minutes) for PG switchover scenarios
  2030. float(
  2031. os.environ.get(
  2032. "POSTGRES_CONNECTION_RETRY_BACKOFF",
  2033. config.get(
  2034. "postgres", "connection_retry_backoff", fallback=3.0
  2035. ),
  2036. )
  2037. ),
  2038. ),
  2039. "connection_retry_backoff_max": min(
  2040. 600.0, # Increased from 60.0 to 600.0 (10 minutes) for PG switchover scenarios
  2041. float(
  2042. os.environ.get(
  2043. "POSTGRES_CONNECTION_RETRY_BACKOFF_MAX",
  2044. config.get(
  2045. "postgres",
  2046. "connection_retry_backoff_max",
  2047. fallback=30.0,
  2048. ),
  2049. )
  2050. ),
  2051. ),
  2052. "pool_close_timeout": min(
  2053. 30.0,
  2054. float(
  2055. os.environ.get(
  2056. "POSTGRES_POOL_CLOSE_TIMEOUT",
  2057. config.get("postgres", "pool_close_timeout", fallback=5.0),
  2058. )
  2059. ),
  2060. ),
  2061. }
  2062. @classmethod
  2063. def _build_vector_signature(
  2064. cls, config: dict[str, Any], vector_storage: str | None
  2065. ) -> dict[str, Any]:
  2066. signature = {
  2067. "vector_storage": vector_storage,
  2068. "enable_vector": config["enable_vector"],
  2069. }
  2070. if config["enable_vector"]:
  2071. signature.update(
  2072. {
  2073. "vector_index_type": config["vector_index_type"],
  2074. "hnsw_m": config["hnsw_m"],
  2075. "hnsw_ef": config["hnsw_ef"],
  2076. "ivfflat_lists": config["ivfflat_lists"],
  2077. "vchordrq_build_options": config["vchordrq_build_options"],
  2078. "vchordrq_probes": config["vchordrq_probes"],
  2079. "vchordrq_epsilon": config["vchordrq_epsilon"],
  2080. }
  2081. )
  2082. return signature
  2083. @classmethod
  2084. def _assert_compatible_vector_signature(
  2085. cls, requested_signature: dict[str, Any]
  2086. ) -> None:
  2087. active_signature = cls._instances["vector_signature"]
  2088. if active_signature is None or active_signature == requested_signature:
  2089. return
  2090. raise RuntimeError(
  2091. "PostgreSQL client pool is process-wide and already initialized with "
  2092. f"vector settings {active_signature}. Received incompatible settings "
  2093. f"{requested_signature}. Multiple LightRAG instances with different "
  2094. "PostgreSQL/vector storage configurations are not supported in the "
  2095. "same process."
  2096. )
  2097. @classmethod
  2098. async def get_client(cls, vector_storage: str | None = None) -> PostgreSQLDB:
  2099. """Return the shared PostgreSQL client for all PG storages in this process.
  2100. The first caller fixes the vector-related pool configuration. Later calls
  2101. must provide a compatible vector storage setup or a RuntimeError is raised.
  2102. """
  2103. async with cls._lock:
  2104. config = ClientManager.get_config(vector_storage=vector_storage)
  2105. requested_signature = cls._build_vector_signature(config, vector_storage)
  2106. if cls._instances["db"] is None:
  2107. db = PostgreSQLDB(config)
  2108. await db.initdb()
  2109. await db.check_tables()
  2110. cls._instances["db"] = db
  2111. cls._instances["ref_count"] = 0
  2112. cls._instances["vector_signature"] = requested_signature
  2113. else:
  2114. cls._assert_compatible_vector_signature(requested_signature)
  2115. cls._instances["ref_count"] += 1
  2116. return cls._instances["db"]
  2117. @classmethod
  2118. async def release_client(cls, db: PostgreSQLDB):
  2119. async with cls._lock:
  2120. if db is not None:
  2121. if db is cls._instances["db"]:
  2122. cls._instances["ref_count"] -= 1
  2123. if cls._instances["ref_count"] == 0:
  2124. if db.pool is not None:
  2125. await db.pool.close()
  2126. logger.info("Closed PostgreSQL database connection pool")
  2127. cls._instances["db"] = None
  2128. cls._instances["vector_signature"] = None
  2129. else:
  2130. if db.pool is not None:
  2131. await db.pool.close()
  2132. @final
  2133. @dataclass
  2134. class PGKVStorage(BaseKVStorage):
  2135. db: PostgreSQLDB = field(default=None)
  2136. def __post_init__(self):
  2137. self._max_batch_size = 200 # DB batch size, independent of embedding batch size
  2138. async def initialize(self):
  2139. async with get_data_init_lock():
  2140. if self.db is None:
  2141. self.db = await ClientManager.get_client(
  2142. vector_storage=self.global_config.get("vector_storage")
  2143. )
  2144. # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
  2145. if self.db.workspace:
  2146. # Use PostgreSQLDB's workspace (highest priority)
  2147. logger.info(
  2148. f"Using PG_WORKSPACE environment variable: '{self.db.workspace}' (overriding '{self.workspace}/{self.namespace}')"
  2149. )
  2150. self.workspace = self.db.workspace
  2151. elif hasattr(self, "workspace") and self.workspace:
  2152. # Use storage class's workspace (medium priority)
  2153. pass
  2154. else:
  2155. # Use "default" for compatibility (lowest priority)
  2156. self.workspace = "default"
  2157. async def finalize(self):
  2158. if self.db is not None:
  2159. await ClientManager.release_client(self.db)
  2160. self.db = None
  2161. ################ QUERY METHODS ################
  2162. async def get_by_id(self, id: str) -> dict[str, Any] | None:
  2163. """Get data by id."""
  2164. sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
  2165. params = {"workspace": self.workspace, "id": id}
  2166. response = await self.db.query(sql, list(params.values()))
  2167. if response and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
  2168. # Parse llm_cache_list JSON string back to list
  2169. llm_cache_list = response.get("llm_cache_list", [])
  2170. if isinstance(llm_cache_list, str):
  2171. try:
  2172. llm_cache_list = json.loads(llm_cache_list)
  2173. except json.JSONDecodeError:
  2174. llm_cache_list = []
  2175. response["llm_cache_list"] = llm_cache_list
  2176. # Parse heading JSON string back to dict; normalize None/missing to {}
  2177. heading = response.get("heading")
  2178. if isinstance(heading, str):
  2179. try:
  2180. heading = json.loads(heading)
  2181. except json.JSONDecodeError:
  2182. heading = {}
  2183. if not isinstance(heading, dict):
  2184. heading = {}
  2185. response["heading"] = heading
  2186. # Parse sidecar JSON string back to dict; normalize None/missing to {}
  2187. sidecar = response.get("sidecar")
  2188. if isinstance(sidecar, str):
  2189. try:
  2190. sidecar = json.loads(sidecar)
  2191. except json.JSONDecodeError:
  2192. sidecar = {}
  2193. if not isinstance(sidecar, dict):
  2194. sidecar = {}
  2195. response["sidecar"] = sidecar
  2196. create_time = response.get("create_time", 0)
  2197. update_time = response.get("update_time", 0)
  2198. response["create_time"] = create_time
  2199. response["update_time"] = create_time if update_time == 0 else update_time
  2200. if response and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
  2201. # Parse chunk_options JSON string back to dict; normalize None/missing to {}
  2202. chunk_options = response.get("chunk_options")
  2203. if isinstance(chunk_options, str):
  2204. try:
  2205. chunk_options = json.loads(chunk_options)
  2206. except json.JSONDecodeError:
  2207. chunk_options = {}
  2208. if not isinstance(chunk_options, dict):
  2209. chunk_options = {}
  2210. response["chunk_options"] = chunk_options
  2211. # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
  2212. if response and is_namespace(
  2213. self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
  2214. ):
  2215. create_time = response.get("create_time", 0)
  2216. update_time = response.get("update_time", 0)
  2217. # Parse queryparam JSON string back to dict
  2218. queryparam = response.get("queryparam")
  2219. if isinstance(queryparam, str):
  2220. try:
  2221. queryparam = json.loads(queryparam)
  2222. except json.JSONDecodeError:
  2223. queryparam = None
  2224. # Map field names for compatibility (mode field removed)
  2225. response = {
  2226. **response,
  2227. "return": response.get("return_value", ""),
  2228. "cache_type": response.get("cache_type"),
  2229. "original_prompt": response.get("original_prompt", ""),
  2230. "chunk_id": response.get("chunk_id"),
  2231. "queryparam": queryparam,
  2232. "create_time": create_time,
  2233. "update_time": create_time if update_time == 0 else update_time,
  2234. }
  2235. # Special handling for FULL_ENTITIES namespace
  2236. if response and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
  2237. # Parse entity_names JSON string back to list
  2238. entity_names = response.get("entity_names", [])
  2239. if isinstance(entity_names, str):
  2240. try:
  2241. entity_names = json.loads(entity_names)
  2242. except json.JSONDecodeError:
  2243. entity_names = []
  2244. response["entity_names"] = entity_names
  2245. create_time = response.get("create_time", 0)
  2246. update_time = response.get("update_time", 0)
  2247. response["create_time"] = create_time
  2248. response["update_time"] = create_time if update_time == 0 else update_time
  2249. # Special handling for FULL_RELATIONS namespace
  2250. if response and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_RELATIONS):
  2251. # Parse relation_pairs JSON string back to list
  2252. relation_pairs = response.get("relation_pairs", [])
  2253. if isinstance(relation_pairs, str):
  2254. try:
  2255. relation_pairs = json.loads(relation_pairs)
  2256. except json.JSONDecodeError:
  2257. relation_pairs = []
  2258. response["relation_pairs"] = relation_pairs
  2259. create_time = response.get("create_time", 0)
  2260. update_time = response.get("update_time", 0)
  2261. response["create_time"] = create_time
  2262. response["update_time"] = create_time if update_time == 0 else update_time
  2263. # Special handling for ENTITY_CHUNKS namespace
  2264. if response and is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS):
  2265. # Parse chunk_ids JSON string back to list
  2266. chunk_ids = response.get("chunk_ids", [])
  2267. if isinstance(chunk_ids, str):
  2268. try:
  2269. chunk_ids = json.loads(chunk_ids)
  2270. except json.JSONDecodeError:
  2271. chunk_ids = []
  2272. response["chunk_ids"] = chunk_ids
  2273. create_time = response.get("create_time", 0)
  2274. update_time = response.get("update_time", 0)
  2275. response["create_time"] = create_time
  2276. response["update_time"] = create_time if update_time == 0 else update_time
  2277. # Special handling for RELATION_CHUNKS namespace
  2278. if response and is_namespace(
  2279. self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS
  2280. ):
  2281. # Parse chunk_ids JSON string back to list
  2282. chunk_ids = response.get("chunk_ids", [])
  2283. if isinstance(chunk_ids, str):
  2284. try:
  2285. chunk_ids = json.loads(chunk_ids)
  2286. except json.JSONDecodeError:
  2287. chunk_ids = []
  2288. response["chunk_ids"] = chunk_ids
  2289. create_time = response.get("create_time", 0)
  2290. update_time = response.get("update_time", 0)
  2291. response["create_time"] = create_time
  2292. response["update_time"] = create_time if update_time == 0 else update_time
  2293. return response if response else None
  2294. # Query by id
  2295. async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
  2296. """Get data by ids"""
  2297. if not ids:
  2298. return []
  2299. sql = SQL_TEMPLATES["get_by_ids_" + self.namespace]
  2300. params = {"workspace": self.workspace, "ids": ids}
  2301. results = await self.db.query(sql, list(params.values()), multirows=True)
  2302. def _order_results(
  2303. rows: list[dict[str, Any]] | None,
  2304. ) -> list[dict[str, Any] | None]:
  2305. """Preserve the caller requested ordering for bulk id lookups."""
  2306. if not rows:
  2307. return [None for _ in ids]
  2308. id_map: dict[str, dict[str, Any]] = {}
  2309. for row in rows:
  2310. if row is None:
  2311. continue
  2312. row_id = row.get("id")
  2313. if row_id is not None:
  2314. id_map[str(row_id)] = row
  2315. ordered: list[dict[str, Any] | None] = []
  2316. for requested_id in ids:
  2317. ordered.append(id_map.get(str(requested_id)))
  2318. return ordered
  2319. if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
  2320. # Parse llm_cache_list / heading / sidecar JSON strings for each result
  2321. for result in results:
  2322. llm_cache_list = result.get("llm_cache_list", [])
  2323. if isinstance(llm_cache_list, str):
  2324. try:
  2325. llm_cache_list = json.loads(llm_cache_list)
  2326. except json.JSONDecodeError:
  2327. llm_cache_list = []
  2328. result["llm_cache_list"] = llm_cache_list
  2329. heading = result.get("heading")
  2330. if isinstance(heading, str):
  2331. try:
  2332. heading = json.loads(heading)
  2333. except json.JSONDecodeError:
  2334. heading = {}
  2335. if not isinstance(heading, dict):
  2336. heading = {}
  2337. result["heading"] = heading
  2338. sidecar = result.get("sidecar")
  2339. if isinstance(sidecar, str):
  2340. try:
  2341. sidecar = json.loads(sidecar)
  2342. except json.JSONDecodeError:
  2343. sidecar = {}
  2344. if not isinstance(sidecar, dict):
  2345. sidecar = {}
  2346. result["sidecar"] = sidecar
  2347. create_time = result.get("create_time", 0)
  2348. update_time = result.get("update_time", 0)
  2349. result["create_time"] = create_time
  2350. result["update_time"] = create_time if update_time == 0 else update_time
  2351. if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
  2352. for result in results:
  2353. chunk_options = result.get("chunk_options")
  2354. if isinstance(chunk_options, str):
  2355. try:
  2356. chunk_options = json.loads(chunk_options)
  2357. except json.JSONDecodeError:
  2358. chunk_options = {}
  2359. if not isinstance(chunk_options, dict):
  2360. chunk_options = {}
  2361. result["chunk_options"] = chunk_options
  2362. # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
  2363. if results and is_namespace(
  2364. self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
  2365. ):
  2366. processed_results = []
  2367. for row in results:
  2368. create_time = row.get("create_time", 0)
  2369. update_time = row.get("update_time", 0)
  2370. # Parse queryparam JSON string back to dict
  2371. queryparam = row.get("queryparam")
  2372. if isinstance(queryparam, str):
  2373. try:
  2374. queryparam = json.loads(queryparam)
  2375. except json.JSONDecodeError:
  2376. queryparam = None
  2377. # Map field names for compatibility (mode field removed)
  2378. processed_row = {
  2379. **row,
  2380. "return": row.get("return_value", ""),
  2381. "cache_type": row.get("cache_type"),
  2382. "original_prompt": row.get("original_prompt", ""),
  2383. "chunk_id": row.get("chunk_id"),
  2384. "queryparam": queryparam,
  2385. "create_time": create_time,
  2386. "update_time": create_time if update_time == 0 else update_time,
  2387. }
  2388. processed_results.append(processed_row)
  2389. return _order_results(processed_results)
  2390. # Special handling for FULL_ENTITIES namespace
  2391. if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
  2392. for result in results:
  2393. # Parse entity_names JSON string back to list
  2394. entity_names = result.get("entity_names", [])
  2395. if isinstance(entity_names, str):
  2396. try:
  2397. entity_names = json.loads(entity_names)
  2398. except json.JSONDecodeError:
  2399. entity_names = []
  2400. result["entity_names"] = entity_names
  2401. create_time = result.get("create_time", 0)
  2402. update_time = result.get("update_time", 0)
  2403. result["create_time"] = create_time
  2404. result["update_time"] = create_time if update_time == 0 else update_time
  2405. # Special handling for FULL_RELATIONS namespace
  2406. if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_RELATIONS):
  2407. for result in results:
  2408. # Parse relation_pairs JSON string back to list
  2409. relation_pairs = result.get("relation_pairs", [])
  2410. if isinstance(relation_pairs, str):
  2411. try:
  2412. relation_pairs = json.loads(relation_pairs)
  2413. except json.JSONDecodeError:
  2414. relation_pairs = []
  2415. result["relation_pairs"] = relation_pairs
  2416. create_time = result.get("create_time", 0)
  2417. update_time = result.get("update_time", 0)
  2418. result["create_time"] = create_time
  2419. result["update_time"] = create_time if update_time == 0 else update_time
  2420. # Special handling for ENTITY_CHUNKS namespace
  2421. if results and is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS):
  2422. for result in results:
  2423. # Parse chunk_ids JSON string back to list
  2424. chunk_ids = result.get("chunk_ids", [])
  2425. if isinstance(chunk_ids, str):
  2426. try:
  2427. chunk_ids = json.loads(chunk_ids)
  2428. except json.JSONDecodeError:
  2429. chunk_ids = []
  2430. result["chunk_ids"] = chunk_ids
  2431. create_time = result.get("create_time", 0)
  2432. update_time = result.get("update_time", 0)
  2433. result["create_time"] = create_time
  2434. result["update_time"] = create_time if update_time == 0 else update_time
  2435. # Special handling for RELATION_CHUNKS namespace
  2436. if results and is_namespace(self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS):
  2437. for result in results:
  2438. # Parse chunk_ids JSON string back to list
  2439. chunk_ids = result.get("chunk_ids", [])
  2440. if isinstance(chunk_ids, str):
  2441. try:
  2442. chunk_ids = json.loads(chunk_ids)
  2443. except json.JSONDecodeError:
  2444. chunk_ids = []
  2445. result["chunk_ids"] = chunk_ids
  2446. create_time = result.get("create_time", 0)
  2447. update_time = result.get("update_time", 0)
  2448. result["create_time"] = create_time
  2449. result["update_time"] = create_time if update_time == 0 else update_time
  2450. return _order_results(results)
  2451. async def filter_keys(self, keys: set[str]) -> set[str]:
  2452. """Filter out duplicated content"""
  2453. if not keys:
  2454. return set()
  2455. table_name = namespace_to_table_name(self.namespace)
  2456. sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
  2457. params = {"workspace": self.workspace, "ids": list(keys)}
  2458. try:
  2459. res = await self.db.query(sql, list(params.values()), multirows=True)
  2460. if res:
  2461. exist_keys = [key["id"] for key in res]
  2462. else:
  2463. exist_keys = []
  2464. new_keys = set([s for s in keys if s not in exist_keys])
  2465. return new_keys
  2466. except Exception as e:
  2467. logger.error(
  2468. f"[{self.workspace}] PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
  2469. )
  2470. raise
  2471. ################ INSERT METHODS ################
  2472. async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
  2473. logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
  2474. if not data:
  2475. return
  2476. timing_label = f"{self.workspace} PGKVStorage.upsert[{self.namespace}]"
  2477. total_start = time.perf_counter()
  2478. performance_timing_log(
  2479. "[%s] start records=%s max_batch_size=%s",
  2480. timing_label,
  2481. len(data),
  2482. self._max_batch_size,
  2483. )
  2484. batch_values: list[tuple] = []
  2485. upsert_sql = ""
  2486. batch_values_build_start = time.perf_counter()
  2487. if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
  2488. upsert_sql = SQL_TEMPLATES["upsert_text_chunk"]
  2489. # Get current UTC time and convert to naive datetime for database storage
  2490. current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None)
  2491. for i, (k, v) in enumerate(data.items(), start=1):
  2492. # Tuple order must match SQL: (workspace, id, tokens, chunk_order_index,
  2493. # full_doc_id, content, file_path, llm_cache_list, heading, sidecar,
  2494. # create_time, update_time)
  2495. batch_values.append(
  2496. (
  2497. self.workspace,
  2498. k,
  2499. v["tokens"],
  2500. v["chunk_order_index"],
  2501. v["full_doc_id"],
  2502. v["content"],
  2503. v["file_path"],
  2504. json.dumps(v.get("llm_cache_list", [])),
  2505. json.dumps(v.get("heading") or {}),
  2506. json.dumps(v.get("sidecar") or {}),
  2507. current_time,
  2508. current_time,
  2509. )
  2510. )
  2511. await _cooperative_yield(i)
  2512. elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
  2513. upsert_sql = SQL_TEMPLATES["upsert_doc_full"]
  2514. for i, (k, v) in enumerate(data.items(), start=1):
  2515. # Tuple order must match SQL: (id, content, doc_name, workspace,
  2516. # sidecar_location, parse_format, content_hash, process_options,
  2517. # chunk_options, parse_engine)
  2518. #
  2519. # All pipeline-derived fields pass through untouched so the
  2520. # SQL-level COALESCE guard in upsert_doc_full can distinguish
  2521. # "caller did not supply" (None/'') from "caller supplied a
  2522. # real value". The 'raw' default for parse_format is provided
  2523. # by the column DDL on initial insert; do NOT default it here
  2524. # or the COALESCE guard never triggers on subsequent partial
  2525. # writes.
  2526. batch_values.append(
  2527. (
  2528. k,
  2529. v["content"],
  2530. v.get("file_path", ""),
  2531. self.workspace,
  2532. v.get("sidecar_location"),
  2533. v.get("parse_format"),
  2534. v.get("content_hash"),
  2535. v.get("process_options"),
  2536. json.dumps(v.get("chunk_options") or {}),
  2537. v.get("parse_engine"),
  2538. )
  2539. )
  2540. await _cooperative_yield(i)
  2541. elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
  2542. upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
  2543. for i, (k, v) in enumerate(data.items(), start=1):
  2544. # Tuple order must match SQL: (workspace, id, original_prompt, return_value,
  2545. # chunk_id, cache_type, queryparam)
  2546. batch_values.append(
  2547. (
  2548. self.workspace,
  2549. k,
  2550. v["original_prompt"],
  2551. v["return"],
  2552. v.get("chunk_id"),
  2553. v.get("cache_type", "extract"),
  2554. json.dumps(v.get("queryparam"))
  2555. if v.get("queryparam")
  2556. else None,
  2557. )
  2558. )
  2559. await _cooperative_yield(i)
  2560. elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
  2561. upsert_sql = SQL_TEMPLATES["upsert_full_entities"]
  2562. # Get current UTC time and convert to naive datetime for database storage
  2563. current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None)
  2564. for i, (k, v) in enumerate(data.items(), start=1):
  2565. # Tuple order must match SQL: (workspace, id, entity_names, count,
  2566. # create_time, update_time)
  2567. batch_values.append(
  2568. (
  2569. self.workspace,
  2570. k,
  2571. json.dumps(v["entity_names"]),
  2572. v["count"],
  2573. current_time,
  2574. current_time,
  2575. )
  2576. )
  2577. await _cooperative_yield(i)
  2578. elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_RELATIONS):
  2579. upsert_sql = SQL_TEMPLATES["upsert_full_relations"]
  2580. # Get current UTC time and convert to naive datetime for database storage
  2581. current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None)
  2582. for i, (k, v) in enumerate(data.items(), start=1):
  2583. # Tuple order must match SQL: (workspace, id, relation_pairs, count,
  2584. # create_time, update_time)
  2585. batch_values.append(
  2586. (
  2587. self.workspace,
  2588. k,
  2589. json.dumps(v["relation_pairs"]),
  2590. v["count"],
  2591. current_time,
  2592. current_time,
  2593. )
  2594. )
  2595. await _cooperative_yield(i)
  2596. elif is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS):
  2597. upsert_sql = SQL_TEMPLATES["upsert_entity_chunks"]
  2598. # Get current UTC time and convert to naive datetime for database storage
  2599. current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None)
  2600. for i, (k, v) in enumerate(data.items(), start=1):
  2601. # Tuple order must match SQL: (workspace, id, chunk_ids, count,
  2602. # create_time, update_time)
  2603. batch_values.append(
  2604. (
  2605. self.workspace,
  2606. k,
  2607. json.dumps(v["chunk_ids"]),
  2608. v["count"],
  2609. current_time,
  2610. current_time,
  2611. )
  2612. )
  2613. await _cooperative_yield(i)
  2614. elif is_namespace(self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS):
  2615. upsert_sql = SQL_TEMPLATES["upsert_relation_chunks"]
  2616. # Get current UTC time and convert to naive datetime for database storage
  2617. current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None)
  2618. for i, (k, v) in enumerate(data.items(), start=1):
  2619. # Tuple order must match SQL: (workspace, id, chunk_ids, count,
  2620. # create_time, update_time)
  2621. batch_values.append(
  2622. (
  2623. self.workspace,
  2624. k,
  2625. json.dumps(v["chunk_ids"]),
  2626. v["count"],
  2627. current_time,
  2628. current_time,
  2629. )
  2630. )
  2631. await _cooperative_yield(i)
  2632. else:
  2633. logger.error(f"Unknown namespace: {self.namespace}")
  2634. raise ValueError(f"Unknown namespace: {self.namespace}")
  2635. # upsert_sql is always set here; unknown namespace raises ValueError above
  2636. performance_timing_log(
  2637. "[%s] batch_values build completed in %.4fs records=%s%s",
  2638. timing_label,
  2639. time.perf_counter() - batch_values_build_start,
  2640. len(batch_values),
  2641. _timing_details_suffix(namespace=self.namespace),
  2642. )
  2643. if batch_values:
  2644. # Split into sub-batches to prevent database overload
  2645. num_batches = (
  2646. len(batch_values) + self._max_batch_size - 1
  2647. ) // self._max_batch_size
  2648. for batch_index, i in enumerate(
  2649. range(0, len(batch_values), self._max_batch_size), start=1
  2650. ):
  2651. sub_batch = batch_values[i : i + self._max_batch_size]
  2652. async def _batch_upsert(
  2653. connection: asyncpg.Connection,
  2654. _sql: str = upsert_sql,
  2655. _data: list[tuple] = sub_batch,
  2656. _batch_index: int = batch_index,
  2657. _num_batches: int = num_batches,
  2658. ) -> None:
  2659. execute_start = time.perf_counter()
  2660. await connection.executemany(_sql, _data)
  2661. performance_timing_log(
  2662. "[%s] sub-batch %s/%s executemany completed in %.4fs batch_size=%s",
  2663. timing_label,
  2664. _batch_index,
  2665. _num_batches,
  2666. time.perf_counter() - execute_start,
  2667. len(_data),
  2668. )
  2669. await self.db._run_with_retry(_batch_upsert, timing_label=timing_label)
  2670. logger.debug(
  2671. f"[{self.workspace}] Batch upserted {len(batch_values)} records to {self.namespace} "
  2672. f"in {num_batches} sub-batches"
  2673. )
  2674. performance_timing_log(
  2675. "[%s] total complete in %.4fs records=%s",
  2676. timing_label,
  2677. time.perf_counter() - total_start,
  2678. len(batch_values),
  2679. )
  2680. async def index_done_callback(self) -> None:
  2681. # PG handles persistence automatically
  2682. pass
  2683. async def is_empty(self) -> bool:
  2684. """Check if the storage is empty for the current workspace and namespace
  2685. Returns:
  2686. bool: True if storage is empty, False otherwise
  2687. """
  2688. table_name = namespace_to_table_name(self.namespace)
  2689. if not table_name:
  2690. logger.error(
  2691. f"[{self.workspace}] Unknown namespace for is_empty check: {self.namespace}"
  2692. )
  2693. return True
  2694. sql = f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE workspace=$1 LIMIT 1) as has_data"
  2695. try:
  2696. result = await self.db.query(sql, [self.workspace])
  2697. return not result.get("has_data", False) if result else True
  2698. except Exception as e:
  2699. logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}")
  2700. return True
  2701. async def delete(self, ids: list[str]) -> None:
  2702. """Delete specific records from storage by their IDs
  2703. Args:
  2704. ids (list[str]): List of document IDs to be deleted from storage
  2705. Returns:
  2706. None
  2707. """
  2708. if not ids:
  2709. return
  2710. table_name = namespace_to_table_name(self.namespace)
  2711. if not table_name:
  2712. logger.error(
  2713. f"[{self.workspace}] Unknown namespace for deletion: {self.namespace}"
  2714. )
  2715. return
  2716. delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
  2717. try:
  2718. await self.db.execute(delete_sql, {"workspace": self.workspace, "ids": ids})
  2719. logger.debug(
  2720. f"[{self.workspace}] Successfully deleted {len(ids)} records from {self.namespace}"
  2721. )
  2722. except Exception as e:
  2723. logger.error(
  2724. f"[{self.workspace}] Error while deleting records from {self.namespace}: {e}"
  2725. )
  2726. async def drop(self) -> dict[str, str]:
  2727. """Drop the storage"""
  2728. try:
  2729. table_name = namespace_to_table_name(self.namespace)
  2730. if not table_name:
  2731. return {
  2732. "status": "error",
  2733. "message": f"Unknown namespace: {self.namespace}",
  2734. }
  2735. drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
  2736. table_name=table_name
  2737. )
  2738. await self.db.execute(drop_sql, {"workspace": self.workspace})
  2739. return {"status": "success", "message": "data dropped"}
  2740. except Exception as e:
  2741. return {"status": "error", "message": str(e)}
  2742. @dataclass
  2743. class _PendingPGVectorDoc:
  2744. """Buffered PG vector upsert awaiting embedding and batched flush.
  2745. ``vector`` is stored as a numpy ndarray (typically float32 from the
  2746. embedding function) once embedded; pgvector's asyncpg codec accepts
  2747. ndarray directly so no per-flush conversion is needed.
  2748. """
  2749. item: dict[str, Any]
  2750. created_at: datetime.datetime
  2751. vector: np.ndarray | None = None
  2752. @final
  2753. @dataclass
  2754. class PGVectorStorage(BaseVectorStorage):
  2755. db: PostgreSQLDB | None = field(default=None)
  2756. def __post_init__(self):
  2757. self._validate_embedding_func()
  2758. self._max_batch_size = self.global_config["embedding_batch_num"]
  2759. config = self.global_config.get("vector_db_storage_cls_kwargs", {})
  2760. cosine_threshold = config.get("cosine_better_than_threshold")
  2761. if cosine_threshold is None:
  2762. raise ValueError(
  2763. "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
  2764. )
  2765. self.cosine_better_than_threshold = cosine_threshold
  2766. # Generate model suffix for table isolation
  2767. self.model_suffix = self._generate_collection_suffix()
  2768. # Get base table name
  2769. base_table = namespace_to_table_name(self.namespace)
  2770. if not base_table:
  2771. raise ValueError(f"Unknown namespace: {self.namespace}")
  2772. # New table name (with suffix)
  2773. # Ensure model_suffix is not empty before appending
  2774. if self.model_suffix:
  2775. self.table_name = f"{base_table}_{self.model_suffix}"
  2776. logger.info(f"PostgreSQL table: {self.table_name}")
  2777. else:
  2778. # Fallback: use base table name if model_suffix is unavailable
  2779. self.table_name = base_table
  2780. logger.warning(
  2781. f"PostgreSQL table: {self.table_name} missing suffix. Pls add model_name to embedding_func for proper workspace data isolation."
  2782. )
  2783. # Legacy table name (without suffix, for migration)
  2784. self.legacy_table_name = base_table
  2785. # Validate table name length (PostgreSQL identifier limit is 63 characters)
  2786. if len(self.table_name) > PG_MAX_IDENTIFIER_LENGTH:
  2787. raise ValueError(
  2788. f"PostgreSQL table name exceeds {PG_MAX_IDENTIFIER_LENGTH} character limit: '{self.table_name}' "
  2789. f"(length: {len(self.table_name)}). "
  2790. f"Consider using a shorter embedding model name or workspace name."
  2791. )
  2792. # Pending buffers: upsert() and delete() queue work here until
  2793. # _flush_pending_vector_ops() runs from index_done_callback() /
  2794. # finalize(). Mirrors OpenSearchVectorDBStorage / NanoVectorDBStorage.
  2795. self._pending_vector_docs: dict[str, _PendingPGVectorDoc] = {}
  2796. self._pending_vector_deletes: set[str] = set()
  2797. # Namespace-keyed lock; created in initialize() after workspace is final.
  2798. self._flush_lock = None
  2799. @staticmethod
  2800. async def _pg_create_table(
  2801. db: PostgreSQLDB, table_name: str, base_table: str, embedding_dim: int
  2802. ) -> None:
  2803. """Create a new vector table by replacing the table name in DDL template,
  2804. and create indexes on id and (workspace, id) columns.
  2805. Args:
  2806. db: PostgreSQLDB instance
  2807. table_name: Name of the new table to create
  2808. base_table: Base table name for DDL template lookup
  2809. embedding_dim: Embedding dimension for vector column
  2810. """
  2811. if base_table not in TABLES:
  2812. raise ValueError(f"No DDL template found for table: {base_table}")
  2813. ddl_template = TABLES[base_table]["ddl"]
  2814. # Determine vector column type based on configuration
  2815. # HALFVEC is used when HNSW_HALFVEC is selected
  2816. vector_type = "VECTOR"
  2817. if getattr(db, "vector_index_type", None) == "HNSW_HALFVEC":
  2818. vector_type = "HALFVEC"
  2819. # Replace embedding dimension placeholder if exists
  2820. ddl = ddl_template.replace(
  2821. "VECTOR(dimension)", f"{vector_type}({embedding_dim})"
  2822. )
  2823. # Replace table name
  2824. ddl = ddl.replace(base_table, table_name)
  2825. # Make creation idempotent to handle restarts and race conditions
  2826. ddl = ddl.replace("CREATE TABLE ", "CREATE TABLE IF NOT EXISTS ", 1)
  2827. await db.execute(ddl)
  2828. # Create indexes similar to check_tables() but with safe index names
  2829. # Create index for id column
  2830. id_index_name = _safe_index_name(table_name, "id")
  2831. try:
  2832. create_id_index_sql = (
  2833. f"CREATE INDEX IF NOT EXISTS {id_index_name} ON {table_name}(id)"
  2834. )
  2835. logger.info(
  2836. f"PostgreSQL, Creating index {id_index_name} on table {table_name}"
  2837. )
  2838. await db.execute(create_id_index_sql)
  2839. except Exception as e:
  2840. logger.error(
  2841. f"PostgreSQL, Failed to create index {id_index_name}, Got: {e}"
  2842. )
  2843. # Create composite index for (workspace, id)
  2844. workspace_id_index_name = _safe_index_name(table_name, "workspace_id")
  2845. try:
  2846. create_composite_index_sql = f"CREATE INDEX IF NOT EXISTS {workspace_id_index_name} ON {table_name}(workspace, id)"
  2847. logger.info(
  2848. f"PostgreSQL, Creating composite index {workspace_id_index_name} on table {table_name}"
  2849. )
  2850. await db.execute(create_composite_index_sql)
  2851. except Exception as e:
  2852. logger.error(
  2853. f"PostgreSQL, Failed to create composite index {workspace_id_index_name}, Got: {e}"
  2854. )
  2855. @staticmethod
  2856. async def _pg_migrate_workspace_data(
  2857. db: PostgreSQLDB,
  2858. legacy_table_name: str,
  2859. new_table_name: str,
  2860. workspace: str,
  2861. expected_count: int,
  2862. embedding_dim: int,
  2863. ) -> int:
  2864. """Migrate workspace data from legacy table to new table using batch insert.
  2865. This function uses asyncpg's executemany for efficient batch insertion,
  2866. reducing database round-trips from N to 1 per batch.
  2867. Uses keyset pagination (cursor-based) with ORDER BY id for stable ordering.
  2868. This ensures every legacy row is migrated exactly once, avoiding the
  2869. non-deterministic row ordering issues with OFFSET/LIMIT without ORDER BY.
  2870. Args:
  2871. db: PostgreSQLDB instance
  2872. legacy_table_name: Name of the legacy table to migrate from
  2873. new_table_name: Name of the new table to migrate to
  2874. workspace: Workspace to filter records for migration
  2875. expected_count: Expected number of records to migrate
  2876. embedding_dim: Embedding dimension for vector column
  2877. Returns:
  2878. Number of records migrated
  2879. """
  2880. migrated_count = 0
  2881. last_id: str | None = None
  2882. batch_size = 500
  2883. while True:
  2884. # Use keyset pagination with ORDER BY id for deterministic ordering
  2885. # This avoids OFFSET/LIMIT without ORDER BY which can skip or duplicate rows
  2886. if workspace:
  2887. if last_id is not None:
  2888. select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 AND id > $2 ORDER BY id LIMIT $3"
  2889. rows = await db.query(
  2890. select_query, [workspace, last_id, batch_size], multirows=True
  2891. )
  2892. else:
  2893. select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 ORDER BY id LIMIT $2"
  2894. rows = await db.query(
  2895. select_query, [workspace, batch_size], multirows=True
  2896. )
  2897. else:
  2898. if last_id is not None:
  2899. select_query = f"SELECT * FROM {legacy_table_name} WHERE id > $1 ORDER BY id LIMIT $2"
  2900. rows = await db.query(
  2901. select_query, [last_id, batch_size], multirows=True
  2902. )
  2903. else:
  2904. select_query = (
  2905. f"SELECT * FROM {legacy_table_name} ORDER BY id LIMIT $1"
  2906. )
  2907. rows = await db.query(select_query, [batch_size], multirows=True)
  2908. if not rows:
  2909. break
  2910. # Track the last ID for keyset pagination cursor
  2911. last_id = rows[-1]["id"]
  2912. # Batch insert optimization: use executemany instead of individual inserts
  2913. # Get column names from the first row
  2914. first_row = dict(rows[0])
  2915. columns = list(first_row.keys())
  2916. columns_str = ", ".join(columns)
  2917. placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))])
  2918. insert_query = f"""
  2919. INSERT INTO {new_table_name} ({columns_str})
  2920. VALUES ({placeholders})
  2921. ON CONFLICT (workspace, id) DO NOTHING
  2922. """
  2923. # Prepare batch data: convert rows to list of tuples
  2924. batch_values = []
  2925. for row in rows:
  2926. row_dict = dict(row)
  2927. # FIX: Parse vector strings from connections without register_vector codec.
  2928. # When pgvector codec is not registered on the read connection, vector
  2929. # columns are returned as text strings like "[0.1,0.2,...]" instead of
  2930. # lists/arrays. We need to convert these to numpy arrays before passing
  2931. # to executemany, which uses a connection WITH register_vector codec
  2932. # that expects list/tuple/ndarray types.
  2933. if "content_vector" in row_dict:
  2934. vec = row_dict["content_vector"]
  2935. if isinstance(vec, str):
  2936. # pgvector text format: "[0.1,0.2,0.3,...]"
  2937. vec = vec.strip("[]")
  2938. if vec:
  2939. row_dict["content_vector"] = np.array(
  2940. [float(x) for x in vec.split(",")], dtype=np.float32
  2941. )
  2942. else:
  2943. row_dict["content_vector"] = None
  2944. # Extract values in column order to match placeholders
  2945. values_tuple = tuple(row_dict[col] for col in columns)
  2946. batch_values.append(values_tuple)
  2947. # Use executemany for batch execution - significantly reduces DB round-trips
  2948. # Note: register_vector is already called on pool init, no need to call it again
  2949. async def _batch_insert(connection: asyncpg.Connection) -> None:
  2950. await connection.executemany(insert_query, batch_values)
  2951. await db._run_with_retry(_batch_insert)
  2952. migrated_count += len(rows)
  2953. workspace_info = f" for workspace '{workspace}'" if workspace else ""
  2954. logger.info(
  2955. f"PostgreSQL: {migrated_count}/{expected_count} records migrated{workspace_info}"
  2956. )
  2957. return migrated_count
  2958. @staticmethod
  2959. async def setup_table(
  2960. db: PostgreSQLDB,
  2961. table_name: str,
  2962. workspace: str,
  2963. embedding_dim: int,
  2964. legacy_table_name: str,
  2965. base_table: str,
  2966. ):
  2967. """
  2968. Setup PostgreSQL table with migration support from legacy tables.
  2969. Ensure final table has workspace isolation index.
  2970. Check vector dimension compatibility before new table creation.
  2971. Drop legacy table if it exists and is empty.
  2972. Only migrate data from legacy table to new table when new table first created and legacy table is not empty.
  2973. This function must be call ClientManager.get_client() to legacy table is migrated to latest schema.
  2974. Args:
  2975. db: PostgreSQLDB instance
  2976. table_name: Name of the new table
  2977. workspace: Workspace to filter records for migration
  2978. legacy_table_name: Name of the legacy table to check for migration
  2979. base_table: Base table name for DDL template lookup
  2980. embedding_dim: Embedding dimension for vector column
  2981. """
  2982. if not workspace:
  2983. raise ValueError("workspace must be provided")
  2984. new_table_exists = await db.check_table_exists(table_name)
  2985. legacy_exists = legacy_table_name and await db.check_table_exists(
  2986. legacy_table_name
  2987. )
  2988. # Case 1: Only new table exists or new table is the same as legacy table
  2989. # No data migration needed, ensuring index is created then return
  2990. if (new_table_exists and not legacy_exists) or (
  2991. new_table_exists and (table_name.lower() == legacy_table_name.lower())
  2992. ):
  2993. await db._create_vector_index(table_name, embedding_dim)
  2994. workspace_count_query = (
  2995. f"SELECT COUNT(*) as count FROM {table_name} WHERE workspace = $1"
  2996. )
  2997. workspace_count_result = await db.query(workspace_count_query, [workspace])
  2998. workspace_count = (
  2999. workspace_count_result.get("count", 0) if workspace_count_result else 0
  3000. )
  3001. if workspace_count == 0 and not (
  3002. table_name.lower() == legacy_table_name.lower()
  3003. ):
  3004. logger.warning(
  3005. f"PostgreSQL: workspace data in table '{table_name}' is empty. "
  3006. f"Ensure it is caused by new workspace setup and not an unexpected embedding model change."
  3007. )
  3008. return
  3009. legacy_count = None
  3010. if not new_table_exists:
  3011. # Check vector dimension compatibility before creating new table
  3012. if legacy_exists:
  3013. count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1"
  3014. count_result = await db.query(count_query, [workspace])
  3015. legacy_count = count_result.get("count", 0) if count_result else 0
  3016. if legacy_count > 0:
  3017. legacy_dim = None
  3018. try:
  3019. sample_query = f"SELECT content_vector FROM {legacy_table_name} WHERE workspace = $1 LIMIT 1"
  3020. sample_result = await db.query(sample_query, [workspace])
  3021. # Fix: Use 'is not None' instead of truthiness check to avoid
  3022. # NumPy array boolean ambiguity error
  3023. if (
  3024. sample_result
  3025. and sample_result.get("content_vector") is not None
  3026. ):
  3027. vector_data = sample_result["content_vector"]
  3028. # pgvector returns list directly, but may also return NumPy arrays
  3029. # when register_vector codec is active on the connection
  3030. if isinstance(vector_data, (list, tuple)):
  3031. legacy_dim = len(vector_data)
  3032. elif hasattr(vector_data, "__len__") and not isinstance(
  3033. vector_data, str
  3034. ):
  3035. # Handle NumPy arrays and other array-like objects
  3036. legacy_dim = len(vector_data)
  3037. elif hasattr(vector_data, "dimensions") and callable(
  3038. vector_data.dimensions
  3039. ):
  3040. # pgvector HalfVector / SparseVector expose dimensions()
  3041. legacy_dim = vector_data.dimensions()
  3042. elif isinstance(vector_data, str):
  3043. import json
  3044. vector_list = json.loads(vector_data)
  3045. legacy_dim = len(vector_list)
  3046. if legacy_dim and legacy_dim != embedding_dim:
  3047. logger.error(
  3048. f"PostgreSQL: Dimension mismatch detected! "
  3049. f"Legacy table '{legacy_table_name}' has {legacy_dim}d vectors, "
  3050. f"but new embedding model expects {embedding_dim}d."
  3051. )
  3052. raise DataMigrationError(
  3053. f"Dimension mismatch between legacy table '{legacy_table_name}' "
  3054. f"and new embedding model. Expected {embedding_dim}d but got {legacy_dim}d."
  3055. )
  3056. except DataMigrationError:
  3057. # Re-raise DataMigrationError as-is to preserve specific error messages
  3058. raise
  3059. except Exception as e:
  3060. raise DataMigrationError(
  3061. f"Could not verify legacy table vector dimension: {e}. "
  3062. f"Proceeding with caution..."
  3063. )
  3064. await PGVectorStorage._pg_create_table(
  3065. db, table_name, base_table, embedding_dim
  3066. )
  3067. logger.info(f"PostgreSQL: New table '{table_name}' created successfully")
  3068. if not legacy_exists:
  3069. await db._create_vector_index(table_name, embedding_dim)
  3070. logger.info(
  3071. "Ensure this new table creation is caused by new workspace setup and not an unexpected embedding model change."
  3072. )
  3073. return
  3074. # Ensure vector index is created
  3075. await db._create_vector_index(table_name, embedding_dim)
  3076. # Case 2: Legacy table exist
  3077. if legacy_exists:
  3078. workspace_info = f" for workspace '{workspace}'"
  3079. # Only drop legacy table if entire table is empty
  3080. total_count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
  3081. total_count_result = await db.query(total_count_query, [])
  3082. total_count = (
  3083. total_count_result.get("count", 0) if total_count_result else 0
  3084. )
  3085. if total_count == 0:
  3086. logger.info(
  3087. f"PostgreSQL: Empty legacy table '{legacy_table_name}' deleted successfully"
  3088. )
  3089. drop_query = f"DROP TABLE {legacy_table_name}"
  3090. await db.execute(drop_query, None)
  3091. return
  3092. # No data migration needed if legacy workspace is empty
  3093. if legacy_count is None:
  3094. count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1"
  3095. count_result = await db.query(count_query, [workspace])
  3096. legacy_count = count_result.get("count", 0) if count_result else 0
  3097. if legacy_count == 0:
  3098. logger.info(
  3099. f"PostgreSQL: No records{workspace_info} found in legacy table. "
  3100. f"No data migration needed."
  3101. )
  3102. return
  3103. new_count_query = (
  3104. f"SELECT COUNT(*) as count FROM {table_name} WHERE workspace = $1"
  3105. )
  3106. new_count_result = await db.query(new_count_query, [workspace])
  3107. new_table_workspace_count = (
  3108. new_count_result.get("count", 0) if new_count_result else 0
  3109. )
  3110. if new_table_workspace_count > 0:
  3111. logger.warning(
  3112. f"PostgreSQL: Both new and legacy collection have data. "
  3113. f"{legacy_count} records in {legacy_table_name} require manual deletion after migration verification."
  3114. )
  3115. return
  3116. # Case 3: Legacy has workspace data and new table is empty for workspace
  3117. logger.info(
  3118. f"PostgreSQL: Found legacy table '{legacy_table_name}' with {legacy_count} records{workspace_info}."
  3119. )
  3120. logger.info(
  3121. f"PostgreSQL: Migrating data from legacy table '{legacy_table_name}' to new table '{table_name}'"
  3122. )
  3123. try:
  3124. migrated_count = await PGVectorStorage._pg_migrate_workspace_data(
  3125. db,
  3126. legacy_table_name,
  3127. table_name,
  3128. workspace,
  3129. legacy_count,
  3130. embedding_dim,
  3131. )
  3132. if migrated_count != legacy_count:
  3133. logger.warning(
  3134. "PostgreSQL: Read %s legacy records%s during migration, expected %s.",
  3135. migrated_count,
  3136. workspace_info,
  3137. legacy_count,
  3138. )
  3139. new_count_result = await db.query(new_count_query, [workspace])
  3140. new_table_count_after = (
  3141. new_count_result.get("count", 0) if new_count_result else 0
  3142. )
  3143. inserted_count = new_table_count_after - new_table_workspace_count
  3144. if inserted_count != legacy_count:
  3145. error_msg = (
  3146. "PostgreSQL: Migration verification failed, "
  3147. f"expected {legacy_count} inserted records, got {inserted_count}."
  3148. )
  3149. logger.error(error_msg)
  3150. raise DataMigrationError(error_msg)
  3151. except DataMigrationError:
  3152. # Re-raise DataMigrationError as-is to preserve specific error messages
  3153. raise
  3154. except Exception as e:
  3155. logger.error(
  3156. f"PostgreSQL: Failed to migrate data from legacy table '{legacy_table_name}' to new table '{table_name}': {e}"
  3157. )
  3158. raise DataMigrationError(
  3159. f"Failed to migrate data from legacy table '{legacy_table_name}' to new table '{table_name}'"
  3160. ) from e
  3161. logger.info(
  3162. f"PostgreSQL: Migration from '{legacy_table_name}' to '{table_name}' completed successfully"
  3163. )
  3164. logger.warning(
  3165. "PostgreSQL: Manual deletion is required after data migration verification."
  3166. )
  3167. async def initialize(self):
  3168. async with get_data_init_lock():
  3169. if self.db is None:
  3170. self.db = await ClientManager.get_client(
  3171. vector_storage=self.global_config.get("vector_storage")
  3172. )
  3173. # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
  3174. if self.db.workspace:
  3175. # Use PostgreSQLDB's workspace (highest priority)
  3176. logger.info(
  3177. f"Using PG_WORKSPACE environment variable: '{self.db.workspace}' (overriding '{self.workspace}/{self.namespace}')"
  3178. )
  3179. self.workspace = self.db.workspace
  3180. elif hasattr(self, "workspace") and self.workspace:
  3181. # Use storage class's workspace (medium priority)
  3182. pass
  3183. else:
  3184. # Use "default" for compatibility (lowest priority)
  3185. self.workspace = "default"
  3186. # Setup table (create if not exists and handle migration)
  3187. await PGVectorStorage.setup_table(
  3188. self.db,
  3189. self.table_name,
  3190. self.workspace, # CRITICAL: Filter migration by workspace
  3191. embedding_dim=self.embedding_func.embedding_dim,
  3192. legacy_table_name=self.legacy_table_name,
  3193. base_table=self.legacy_table_name, # base_table for DDL template lookup
  3194. )
  3195. if self._flush_lock is None:
  3196. self._flush_lock = get_namespace_lock(
  3197. self.namespace, workspace=self.workspace
  3198. )
  3199. async def finalize(self):
  3200. """Flush pending vector ops then release the shared PG client.
  3201. Captures regular ``Exception`` from the flush so it can be re-raised
  3202. as a ``RuntimeError`` naming the unflushed buffer counts after the
  3203. client is released. ``BaseException`` (``CancelledError``,
  3204. ``KeyboardInterrupt``, ``SystemExit``) is intentionally NOT caught
  3205. so it can propagate through ``finally`` — the buffer-count reframing
  3206. below is skipped in that case (the propagating exception already
  3207. signals shutdown; conflating it with "left N pending" would be
  3208. misleading).
  3209. Idempotency:
  3210. Re-entry after a successful or failed first call is a no-op for
  3211. the flush (client is already released), but still raises if
  3212. buffers remain non-empty so the operator sees the data-loss
  3213. signal again.
  3214. """
  3215. if self.db is None:
  3216. pending_docs = len(self._pending_vector_docs)
  3217. pending_deletes = len(self._pending_vector_deletes)
  3218. if pending_docs or pending_deletes:
  3219. raise RuntimeError(
  3220. f"[{self.workspace}] PGVectorStorage.finalize() re-entry: "
  3221. f"client already released; {pending_docs} pending upserts "
  3222. f"and {pending_deletes} pending deletes cannot be flushed"
  3223. )
  3224. return
  3225. flush_error: Exception | None = None
  3226. try:
  3227. try:
  3228. await self._flush_pending_vector_ops()
  3229. except Exception as e:
  3230. flush_error = e
  3231. finally:
  3232. if self.db is not None:
  3233. await ClientManager.release_client(self.db)
  3234. self.db = None
  3235. pending_docs = len(self._pending_vector_docs)
  3236. pending_deletes = len(self._pending_vector_deletes)
  3237. if flush_error is not None:
  3238. raise RuntimeError(
  3239. f"[{self.workspace}] PGVectorStorage.finalize() flush raised; "
  3240. f"{pending_docs} pending upserts and {pending_deletes} pending "
  3241. f"deletes were left buffered (client released, data lost)"
  3242. ) from flush_error
  3243. if pending_docs or pending_deletes:
  3244. raise RuntimeError(
  3245. f"[{self.workspace}] PGVectorStorage.finalize() left "
  3246. f"{pending_docs} pending upserts and {pending_deletes} "
  3247. f"pending deletes buffered after final flush attempt"
  3248. )
  3249. def _upsert_chunks(
  3250. self, item: dict[str, Any], current_time: datetime.datetime
  3251. ) -> tuple[str, tuple[Any, ...]]:
  3252. """Prepare upsert data for chunks.
  3253. Returns:
  3254. Tuple of (SQL template, values tuple for executemany)
  3255. """
  3256. try:
  3257. upsert_sql = SQL_TEMPLATES["upsert_chunk"].format(
  3258. table_name=self.table_name
  3259. )
  3260. # Return tuple in the exact order of SQL parameters ($1, $2, ...)
  3261. values: tuple[Any, ...] = (
  3262. self.workspace, # $1
  3263. item["__id__"], # $2
  3264. item["tokens"], # $3
  3265. item["chunk_order_index"], # $4
  3266. item["full_doc_id"], # $5
  3267. item["content"], # $6
  3268. item["__vector__"], # $7 - numpy array, handled by pgvector codec
  3269. item["file_path"], # $8
  3270. current_time, # $9
  3271. current_time, # $10
  3272. )
  3273. except Exception as e:
  3274. logger.error(
  3275. f"[{self.workspace}] Error to prepare upsert,\nerror: {e}\nitem: {item}"
  3276. )
  3277. raise
  3278. return upsert_sql, values
  3279. def _upsert_entities(
  3280. self, item: dict[str, Any], current_time: datetime.datetime
  3281. ) -> tuple[str, tuple[Any, ...]]:
  3282. """Prepare upsert data for entities.
  3283. Returns:
  3284. Tuple of (SQL template, values tuple for executemany)
  3285. """
  3286. upsert_sql = SQL_TEMPLATES["upsert_entity"].format(table_name=self.table_name)
  3287. source_id = item["source_id"]
  3288. if isinstance(source_id, str) and "<SEP>" in source_id:
  3289. chunk_ids = source_id.split("<SEP>")
  3290. else:
  3291. chunk_ids = [source_id]
  3292. # Return tuple in the exact order of SQL parameters ($1, $2, ...)
  3293. values: tuple[Any, ...] = (
  3294. self.workspace, # $1
  3295. item["__id__"], # $2
  3296. item["entity_name"], # $3
  3297. item["content"], # $4
  3298. item["__vector__"], # $5 - numpy array, handled by pgvector codec
  3299. chunk_ids, # $6
  3300. item.get("file_path", None), # $7
  3301. current_time, # $8
  3302. current_time, # $9
  3303. )
  3304. return upsert_sql, values
  3305. def _upsert_relationships(
  3306. self, item: dict[str, Any], current_time: datetime.datetime
  3307. ) -> tuple[str, tuple[Any, ...]]:
  3308. """Prepare upsert data for relationships.
  3309. Returns:
  3310. Tuple of (SQL template, values tuple for executemany)
  3311. """
  3312. upsert_sql = SQL_TEMPLATES["upsert_relationship"].format(
  3313. table_name=self.table_name
  3314. )
  3315. source_id = item["source_id"]
  3316. if isinstance(source_id, str) and "<SEP>" in source_id:
  3317. chunk_ids = source_id.split("<SEP>")
  3318. else:
  3319. chunk_ids = [source_id]
  3320. # Return tuple in the exact order of SQL parameters ($1, $2, ...)
  3321. values: tuple[Any, ...] = (
  3322. self.workspace, # $1
  3323. item["__id__"], # $2
  3324. item["src_id"], # $3
  3325. item["tgt_id"], # $4
  3326. item["content"], # $5
  3327. item["__vector__"], # $6 - numpy array, handled by pgvector codec
  3328. chunk_ids, # $7
  3329. item.get("file_path", None), # $8
  3330. current_time, # $9
  3331. current_time, # $10
  3332. )
  3333. return upsert_sql, values
  3334. async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
  3335. """Buffer vector docs for embedding and batched flush.
  3336. Correctness premise:
  3337. LightRAG's pipeline is the normal write path for graph/vector
  3338. mutations and guarantees a single writer process per workspace.
  3339. This storage follows the same deferred-embedding contract as
  3340. OpenSearchVectorDBStorage: the pending buffer is process-local.
  3341. Committed PG rows are immediately visible across workers, but
  3342. *buffered* writes are not — readers in other workers will not
  3343. see them until the writing worker calls index_done_callback().
  3344. Non-pipeline writers must provide equivalent single-writer
  3345. serialization and must flush explicitly before depending on
  3346. reads from another worker.
  3347. Memory expectation:
  3348. Pending docs (raw ``content`` strings, plus cached float32
  3349. vectors once embedded) accumulate in process memory until the
  3350. next ``index_done_callback()`` / ``finalize()``. This matches
  3351. the OpenSearch/Nano/Faiss contract. Callers performing very
  3352. large ingests should flush periodically (every N upserts) to
  3353. cap working-set size.
  3354. """
  3355. if not data:
  3356. return
  3357. logger.debug(
  3358. f"[{self.workspace}] Buffering {len(data)} vectors for {self.namespace}"
  3359. )
  3360. # Build pending docs outside the lock; UTC naive datetime mirrors
  3361. # the previous direct-write code path (the _upsert_* helpers feed
  3362. # this straight into asyncpg as a timestamp).
  3363. current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None)
  3364. pending_docs: list[tuple[str, _PendingPGVectorDoc]] = []
  3365. for i, (k, v) in enumerate(data.items(), start=1):
  3366. pending_docs.append(
  3367. (
  3368. k,
  3369. _PendingPGVectorDoc(
  3370. item={"__id__": k, **v},
  3371. created_at=current_time,
  3372. ),
  3373. )
  3374. )
  3375. await _cooperative_yield(i)
  3376. async with self._flush_lock:
  3377. for doc_id, pending_doc in pending_docs:
  3378. # Invariant: a later upsert wins over an earlier delete; the
  3379. # unconditional dict assignment also discards any cached
  3380. # stale vector from a prior upsert of the same id.
  3381. self._pending_vector_deletes.discard(doc_id)
  3382. self._pending_vector_docs[doc_id] = pending_doc
  3383. async def _flush_pending_vector_ops(self) -> None:
  3384. """Flush buffered PG vector upserts and deletes in one transaction.
  3385. Concurrency:
  3386. All buffer reads/writes and destructive server mutations on
  3387. this storage run under ``self._flush_lock``. Embedding stays
  3388. inside that lock so a destructive operation cannot interleave
  3389. between embedding and the PG write in the same process.
  3390. Failure handling:
  3391. PG cannot expose per-document statuses, so flush is
  3392. all-or-nothing:
  3393. * If embedding fails the buffers stay intact (next flush
  3394. retries; cached vectors are reused).
  3395. * If ``_run_with_retry`` raises the transaction rolls back
  3396. and the buffers stay intact. Cached vectors stay attached
  3397. to pending docs so the next flush does not re-embed.
  3398. * On success both buffers are cleared.
  3399. Post-finalize / pre-initialize:
  3400. Calling this after ``finalize()`` (``self.db is None``) or
  3401. before ``initialize()`` (``self._flush_lock is None``) with a
  3402. non-empty buffer raises ``RuntimeError`` — silently dropping
  3403. buffered writes would defeat the data-loss visibility that
  3404. ``finalize()`` provides. An empty-buffer call is a no-op.
  3405. """
  3406. if self._flush_lock is None:
  3407. pending_docs = len(self._pending_vector_docs)
  3408. pending_deletes = len(self._pending_vector_deletes)
  3409. if pending_docs or pending_deletes:
  3410. raise RuntimeError(
  3411. f"[{self.workspace}] PGVectorStorage._flush_pending_vector_ops "
  3412. f"called before initialize(); {pending_docs} pending upserts "
  3413. f"and {pending_deletes} pending deletes cannot be flushed"
  3414. )
  3415. return
  3416. async with self._flush_lock:
  3417. if not self._pending_vector_docs and not self._pending_vector_deletes:
  3418. return
  3419. if self.db is None:
  3420. pending_docs = len(self._pending_vector_docs)
  3421. pending_deletes = len(self._pending_vector_deletes)
  3422. raise RuntimeError(
  3423. f"[{self.workspace}] PGVectorStorage._flush_pending_vector_ops "
  3424. f"called after client release; {pending_docs} pending upserts "
  3425. f"and {pending_deletes} pending deletes cannot be flushed"
  3426. )
  3427. timing_label = f"{self.workspace} PGVectorStorage.flush[{self.namespace}]"
  3428. total_start = time.perf_counter()
  3429. performance_timing_log(
  3430. "[%s] start upserts=%s deletes=%s max_batch_size=%s",
  3431. timing_label,
  3432. len(self._pending_vector_docs),
  3433. len(self._pending_vector_deletes),
  3434. self._max_batch_size,
  3435. )
  3436. # --- Embedding phase ---------------------------------------------
  3437. docs_to_embed = [
  3438. (doc_id, pending_doc)
  3439. for doc_id, pending_doc in self._pending_vector_docs.items()
  3440. if pending_doc.vector is None
  3441. ]
  3442. if docs_to_embed:
  3443. contents = [
  3444. pending_doc.item["content"] for _, pending_doc in docs_to_embed
  3445. ]
  3446. batches = [
  3447. contents[i : i + self._max_batch_size]
  3448. for i in range(0, len(contents), self._max_batch_size)
  3449. ]
  3450. logger.info(
  3451. f"[{self.workspace}] {self.namespace} flush: embedding "
  3452. f"{len(docs_to_embed)} vectors in {len(batches)} batch(es) "
  3453. f"(batch_num={self._max_batch_size})"
  3454. )
  3455. embedding_start = time.perf_counter()
  3456. try:
  3457. embeddings_list = await asyncio.gather(
  3458. *[
  3459. self.embedding_func(batch, context="document")
  3460. for batch in batches
  3461. ]
  3462. )
  3463. except Exception as e:
  3464. logger.error(
  3465. f"[{self.workspace}] Error embedding pending vector ops "
  3466. f"(upserts={len(docs_to_embed)}): {e}"
  3467. )
  3468. raise
  3469. performance_timing_log(
  3470. "[%s] embedding completed in %.4fs docs=%s batches=%s",
  3471. timing_label,
  3472. time.perf_counter() - embedding_start,
  3473. len(docs_to_embed),
  3474. len(batches),
  3475. )
  3476. embeddings = np.concatenate(embeddings_list)
  3477. # Explicit check: a count mismatch under `python -O` would
  3478. # silently truncate via zip(), mispairing vectors with docs.
  3479. if len(embeddings) != len(docs_to_embed):
  3480. raise RuntimeError(
  3481. f"[{self.workspace}] Embedding count mismatch: "
  3482. f"expected {len(docs_to_embed)}, got {len(embeddings)}"
  3483. )
  3484. for i, ((_, pending_doc), embedding) in enumerate(
  3485. zip(docs_to_embed, embeddings), start=1
  3486. ):
  3487. pending_doc.vector = embedding
  3488. await _cooperative_yield(i)
  3489. # --- Build batch tuples ------------------------------------------
  3490. if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS):
  3491. build_tuple = self._upsert_chunks
  3492. elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES):
  3493. build_tuple = self._upsert_entities
  3494. elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS):
  3495. build_tuple = self._upsert_relationships
  3496. else:
  3497. raise ValueError(f"{self.namespace} is not supported")
  3498. batch_values: list[tuple[Any, ...]] = []
  3499. upsert_sql: str | None = None
  3500. for i, (doc_id, pending_doc) in enumerate(
  3501. self._pending_vector_docs.items(), start=1
  3502. ):
  3503. if pending_doc.vector is None:
  3504. # Should not happen: every pending doc was embedded above
  3505. # or had a cached vector from a previous lazy embed.
  3506. raise RuntimeError(
  3507. f"[{self.workspace}] Pending vector for id={doc_id} "
  3508. f"missing after embedding phase"
  3509. )
  3510. # Coerce to float32 ndarray if not already (defensive; the
  3511. # embedding func typically returns float32 but a custom
  3512. # provider may return float64 — pgvector wants float32).
  3513. item = dict(pending_doc.item)
  3514. vector = pending_doc.vector
  3515. if not isinstance(vector, np.ndarray) or vector.dtype != np.float32:
  3516. vector = np.asarray(vector, dtype=np.float32)
  3517. item["__vector__"] = vector
  3518. upsert_sql, values = build_tuple(item, pending_doc.created_at)
  3519. batch_values.append(values)
  3520. await _cooperative_yield(i)
  3521. pending_delete_ids = list(self._pending_vector_deletes)
  3522. # --- Persistence -------------------------------------------------
  3523. async def _flush_batch(connection: asyncpg.Connection) -> None:
  3524. async with connection.transaction():
  3525. if batch_values and upsert_sql:
  3526. execute_start = time.perf_counter()
  3527. await connection.executemany(upsert_sql, batch_values)
  3528. performance_timing_log(
  3529. "[%s] executemany completed in %.4fs batch_size=%s",
  3530. timing_label,
  3531. time.perf_counter() - execute_start,
  3532. len(batch_values),
  3533. )
  3534. if pending_delete_ids:
  3535. delete_sql = (
  3536. f"DELETE FROM {self.table_name} "
  3537. "WHERE workspace=$1 AND id = ANY($2)"
  3538. )
  3539. await connection.execute(
  3540. delete_sql, self.workspace, pending_delete_ids
  3541. )
  3542. try:
  3543. await self.db._run_with_retry(_flush_batch, timing_label=timing_label)
  3544. except Exception as e:
  3545. logger.error(
  3546. f"[{self.workspace}] Error flushing vector ops "
  3547. f"(upserts={len(batch_values)}, "
  3548. f"deletes={len(pending_delete_ids)}): {e}"
  3549. )
  3550. raise
  3551. # Success: clear committed buffers. Cached vectors live on
  3552. # those records and are GC'd with them.
  3553. self._pending_vector_docs.clear()
  3554. self._pending_vector_deletes.clear()
  3555. performance_timing_log(
  3556. "[%s] total complete in %.4fs upserts=%s deletes=%s",
  3557. timing_label,
  3558. time.perf_counter() - total_start,
  3559. len(batch_values),
  3560. len(pending_delete_ids),
  3561. )
  3562. #################### query method ###############
  3563. async def query(
  3564. self, query: str, top_k: int, query_embedding: list[float] = None
  3565. ) -> list[dict[str, Any]]:
  3566. if query_embedding is not None:
  3567. embedding = query_embedding
  3568. else:
  3569. embeddings = await self.embedding_func(
  3570. [query], context="query", _priority=5
  3571. ) # higher priority for query
  3572. embedding = embeddings[0]
  3573. # Use positional $4 parameter instead of string-interpolated literal.
  3574. # asyncpg sends the embedding via register_vector binary codec, avoiding
  3575. # per-query text serialization and PostgreSQL text-to-vector parsing.
  3576. vector_cast = (
  3577. "halfvec"
  3578. if getattr(self.db, "vector_index_type", None) == "HNSW_HALFVEC"
  3579. else "vector"
  3580. )
  3581. sql = SQL_TEMPLATES[self.namespace].format(
  3582. table_name=self.table_name, vector_cast=vector_cast
  3583. )
  3584. params = {
  3585. "workspace": self.workspace,
  3586. "closer_than_threshold": 1 - self.cosine_better_than_threshold,
  3587. "top_k": top_k,
  3588. "embedding": embedding,
  3589. }
  3590. results = await self.db.query(sql, params=list(params.values()), multirows=True)
  3591. return results
  3592. async def index_done_callback(self) -> None:
  3593. await self._flush_pending_vector_ops()
  3594. async def delete(self, ids: list[str]) -> None:
  3595. """Buffer vector deletes for batched flush.
  3596. A delete cancels any pending upsert for the same id. The actual PG
  3597. delete is performed by ``_flush_pending_vector_ops`` during the next
  3598. ``index_done_callback`` / ``finalize`` call.
  3599. """
  3600. if not ids:
  3601. return
  3602. if isinstance(ids, set):
  3603. ids = list(ids)
  3604. async with self._flush_lock:
  3605. for doc_id in ids:
  3606. self._pending_vector_docs.pop(doc_id, None)
  3607. self._pending_vector_deletes.add(doc_id)
  3608. logger.debug(
  3609. f"[{self.workspace}] Buffered delete for {len(ids)} vectors in {self.namespace}"
  3610. )
  3611. async def delete_entity(self, entity_name: str) -> None:
  3612. """Delete an entity vector by entity name.
  3613. Runs the SQL predicate delete (``WHERE entity_name=$2``) immediately
  3614. under ``_flush_lock`` so it cannot interleave with a flush of the
  3615. same namespace, and — only after the SQL succeeds — prunes the
  3616. matching pending docs and any pending delete that would otherwise
  3617. re-fire. If the SQL raises, the buffer is left untouched so a
  3618. subsequent retry can still observe the pending state instead of
  3619. silently losing it, and the exception is logged and re-raised so
  3620. the caller (e.g. ``adelete_by_entity``) short-circuits before
  3621. ``_persist_graph_updates()`` flushes those preserved pending
  3622. upserts back into the table. Matches the cross-backend contract
  3623. documented on the Qdrant / Milvus / Mongo implementations: "server-
  3624. side failures are re-raised; the caller decides whether to retry."
  3625. The SQL predicate is kept (rather than ``self.delete([ent_id])``) as
  3626. a safety net for legacy rows whose ``id`` may not equal
  3627. ``compute_mdhash_id(entity_name, prefix="ent-")``.
  3628. Raises:
  3629. RuntimeError: if called before ``initialize()`` (``_flush_lock``
  3630. is still ``None``). Silently dropping a destructive intent
  3631. would defeat the data-loss visibility that the rest of this
  3632. storage enforces; the caller must initialize first.
  3633. """
  3634. if self._flush_lock is None:
  3635. raise RuntimeError(
  3636. f"[{self.workspace}] PGVectorStorage.delete_entity called before "
  3637. f"initialize(); call initialize_storages() on the LightRAG instance "
  3638. f"before issuing destructive operations"
  3639. )
  3640. entity_id = compute_mdhash_id(entity_name, prefix="ent-")
  3641. def _prune_pending() -> None:
  3642. # Drop any pending upsert keyed by hash id or matching
  3643. # entity_name in the buffered payload (relationship docs
  3644. # have no entity_name; the lookup is a harmless no-op).
  3645. self._pending_vector_docs.pop(entity_id, None)
  3646. for buffered_id in [
  3647. k
  3648. for k, v in self._pending_vector_docs.items()
  3649. if v.item.get("entity_name") == entity_name
  3650. ]:
  3651. self._pending_vector_docs.pop(buffered_id, None)
  3652. # Drop any redundant pending delete; the SQL above covered it.
  3653. self._pending_vector_deletes.discard(entity_id)
  3654. try:
  3655. async with self._flush_lock:
  3656. if self.db is None:
  3657. # Storage already finalized; buffer is the only state
  3658. # left, so apply the delete intent there.
  3659. _prune_pending()
  3660. return
  3661. delete_sql = (
  3662. f"DELETE FROM {self.table_name} "
  3663. "WHERE workspace=$1 AND entity_name=$2"
  3664. )
  3665. await self.db.execute(
  3666. delete_sql,
  3667. {"workspace": self.workspace, "entity_name": entity_name},
  3668. )
  3669. # SQL succeeded — safe to prune buffer. If it had raised,
  3670. # we'd skip this so the pending state remains for retry.
  3671. _prune_pending()
  3672. logger.debug(
  3673. f"[{self.workspace}] Successfully deleted entity {entity_name}"
  3674. )
  3675. except Exception as e:
  3676. # Re-raise so the caller can short-circuit and skip the
  3677. # subsequent flush; otherwise the pending upsert we just
  3678. # preserved would be persisted back, undoing the delete.
  3679. logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}")
  3680. raise
  3681. async def delete_entity_relation(self, entity_name: str) -> None:
  3682. """Delete all relation vectors where ``entity_name`` is src or tgt.
  3683. Predicate-based; runs immediately. The whole method holds
  3684. ``_flush_lock`` so it cannot interleave with a flush of buffered
  3685. relation upserts.
  3686. Buffer semantics — post-prune with caller short-circuit contract:
  3687. Any pending relation upsert whose ``src_id`` or ``tgt_id``
  3688. matches ``entity_name`` is pruned from ``_pending_vector_docs``
  3689. **only after** the SQL predicate delete succeeds. On SQL
  3690. failure the pending docs are left intact and the exception is
  3691. re-raised. This avoids silently dropping buffered relation
  3692. vectors that the user never told us to discard.
  3693. Correctness relies on the caller short-circuiting before it
  3694. can trigger ``index_done_callback`` and flush those preserved
  3695. pending upserts back into the table (which would undo the
  3696. delete intent on a partial server-side delete). The single
  3697. in-tree caller ``adelete_by_entity`` in ``utils_graph.py``
  3698. honors this: its ``except`` clause skips both ``delete_node``
  3699. and ``_persist_graph_updates``, so on failure both the graph
  3700. and the pending vector buffer stay consistent with the
  3701. "delete never happened" state and the operation converges on
  3702. the next retry. Callers that need to rename or re-link the
  3703. entity must re-issue the relation upserts after a successful
  3704. call.
  3705. Raises:
  3706. RuntimeError: if called before ``initialize()`` (``_flush_lock``
  3707. is still ``None``). Silently dropping a destructive intent
  3708. would defeat the data-loss visibility that the rest of this
  3709. storage enforces; the caller must initialize first.
  3710. """
  3711. if self._flush_lock is None:
  3712. raise RuntimeError(
  3713. f"[{self.workspace}] PGVectorStorage.delete_entity_relation called "
  3714. f"before initialize(); call initialize_storages() on the LightRAG "
  3715. f"instance before issuing destructive operations"
  3716. )
  3717. def _prune_pending() -> None:
  3718. for buffered_id in [
  3719. k
  3720. for k, v in self._pending_vector_docs.items()
  3721. if v.item.get("src_id") == entity_name
  3722. or v.item.get("tgt_id") == entity_name
  3723. ]:
  3724. self._pending_vector_docs.pop(buffered_id, None)
  3725. try:
  3726. async with self._flush_lock:
  3727. if self.db is None:
  3728. # Storage already finalized; buffer is the only state
  3729. # left, so apply the delete intent there.
  3730. _prune_pending()
  3731. return
  3732. delete_sql = (
  3733. f"DELETE FROM {self.table_name} "
  3734. "WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)"
  3735. )
  3736. await self.db.execute(
  3737. delete_sql,
  3738. {"workspace": self.workspace, "entity_name": entity_name},
  3739. )
  3740. # SQL succeeded — safe to prune pending relation docs. If
  3741. # it had raised, we'd skip this so the pending state
  3742. # remains for retry on the next call.
  3743. _prune_pending()
  3744. logger.debug(
  3745. f"[{self.workspace}] Successfully deleted relations for entity {entity_name}"
  3746. )
  3747. except Exception as e:
  3748. logger.error(
  3749. f"[{self.workspace}] Error deleting relations for entity {entity_name}: {e}"
  3750. )
  3751. raise
  3752. async def get_by_id(self, id: str) -> dict[str, Any] | None:
  3753. """Get vector data by its ID with read-your-writes against the buffer.
  3754. ``__vector__`` and ``__id__`` are stripped from buffered results to
  3755. match the other vector backends; callers needing embeddings must use
  3756. ``get_vectors_by_ids``.
  3757. Response shape:
  3758. Buffered hits return ``{"id", "content", <payload fields>,
  3759. "created_at"}`` only — no embedding column. SQL-fallback hits
  3760. return the full row including ``content_vector`` (and any
  3761. namespace-specific columns such as ``entity_name`` or
  3762. ``source_id``). Callers that only read documented payload
  3763. fields (``content``, ``id``, ``created_at``) are unaffected;
  3764. consumers iterating over all keys must tolerate both shapes.
  3765. """
  3766. async with self._flush_lock:
  3767. if id in self._pending_vector_deletes:
  3768. return None
  3769. pending = self._pending_vector_docs.get(id)
  3770. if pending is not None:
  3771. doc = {
  3772. k: v
  3773. for k, v in pending.item.items()
  3774. if k not in ("__id__", "__vector__")
  3775. }
  3776. doc["id"] = id
  3777. doc["created_at"] = int(pending.created_at.timestamp())
  3778. return doc
  3779. query = (
  3780. f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at "
  3781. f"FROM {self.table_name} WHERE workspace=$1 AND id=$2"
  3782. )
  3783. try:
  3784. result = await self.db.query(query, [self.workspace, id])
  3785. if result:
  3786. return dict(result)
  3787. return None
  3788. except Exception as e:
  3789. logger.error(
  3790. f"[{self.workspace}] Error retrieving vector data for ID {id}: {e}"
  3791. )
  3792. return None
  3793. async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
  3794. """Get multiple vector docs by ID, preserving caller order.
  3795. Pending deletes return ``None`` in their slot. Pending upserts are
  3796. served from the buffer; remaining ids fall through to a single
  3797. parameterized ``id = ANY($2)`` SQL query (replacing the previous
  3798. string-built ``IN (...)`` form).
  3799. Response shape: same buffered-vs-SQL inconsistency as
  3800. ``get_by_id`` — see that docstring for details.
  3801. """
  3802. if not ids:
  3803. return []
  3804. buffered: dict[str, dict[str, Any] | None] = {}
  3805. remaining: list[str] = []
  3806. async with self._flush_lock:
  3807. for doc_id in ids:
  3808. if doc_id in self._pending_vector_deletes:
  3809. buffered[doc_id] = None
  3810. continue
  3811. pending = self._pending_vector_docs.get(doc_id)
  3812. if pending is not None:
  3813. doc = {
  3814. k: v
  3815. for k, v in pending.item.items()
  3816. if k not in ("__id__", "__vector__")
  3817. }
  3818. doc["id"] = doc_id
  3819. doc["created_at"] = int(pending.created_at.timestamp())
  3820. buffered[doc_id] = doc
  3821. continue
  3822. remaining.append(doc_id)
  3823. id_map: dict[str, dict[str, Any]] = {}
  3824. if remaining:
  3825. query = (
  3826. f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at "
  3827. f"FROM {self.table_name} WHERE workspace=$1 AND id = ANY($2)"
  3828. )
  3829. try:
  3830. results = await self.db.query(
  3831. query, [self.workspace, remaining], multirows=True
  3832. )
  3833. for record in results or []:
  3834. if record is None:
  3835. continue
  3836. record_dict = dict(record)
  3837. row_id = record_dict.get("id")
  3838. if row_id is not None:
  3839. id_map[str(row_id)] = record_dict
  3840. except Exception as e:
  3841. logger.error(
  3842. f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
  3843. )
  3844. return []
  3845. ordered_results: list[dict[str, Any] | None] = []
  3846. for requested_id in ids:
  3847. if requested_id in buffered:
  3848. ordered_results.append(buffered[requested_id])
  3849. else:
  3850. ordered_results.append(id_map.get(str(requested_id)))
  3851. return ordered_results
  3852. async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
  3853. """Get vector embeddings by ID, with read-your-writes against the buffer.
  3854. Lazily embeds pending docs whose vector has not been computed yet,
  3855. caches the result on the pending record (so the next flush reuses
  3856. it), and falls through to a parameterized SQL query for ids not in
  3857. the buffer.
  3858. Embedding I/O runs *outside* ``_flush_lock`` so a slow embedding
  3859. provider cannot block concurrent ``upsert`` / ``delete`` / read
  3860. calls on this storage. The lock is re-acquired briefly to cache
  3861. the result, and the pending record's identity is re-checked
  3862. first: if a concurrent ``upsert`` / ``delete`` / ``drop`` replaced
  3863. or removed the record during the embedding window, that ID is
  3864. dropped from the response entirely — we neither cache the stale
  3865. vector on the new/missing record nor return it to the caller, so
  3866. callers cannot observe an embedding that no longer matches the
  3867. current buffer state. Affected callers should treat the missing
  3868. key the same as the existing "id was deleted before the call"
  3869. case and retry if needed.
  3870. """
  3871. if not ids:
  3872. return {}
  3873. result: dict[str, list[float]] = {}
  3874. remaining: list[str] = []
  3875. docs_to_embed: list[tuple[str, _PendingPGVectorDoc]] = []
  3876. async with self._flush_lock:
  3877. for doc_id in ids:
  3878. if doc_id in self._pending_vector_deletes:
  3879. continue
  3880. pending = self._pending_vector_docs.get(doc_id)
  3881. if pending is not None:
  3882. if pending.vector is None:
  3883. docs_to_embed.append((doc_id, pending))
  3884. else:
  3885. result[doc_id] = pending.vector.tolist()
  3886. continue
  3887. remaining.append(doc_id)
  3888. if docs_to_embed:
  3889. contents = [pending_doc.item["content"] for _, pending_doc in docs_to_embed]
  3890. batches = [
  3891. contents[i : i + self._max_batch_size]
  3892. for i in range(0, len(contents), self._max_batch_size)
  3893. ]
  3894. try:
  3895. embeddings_list = await asyncio.gather(
  3896. *[
  3897. self.embedding_func(batch, context="document")
  3898. for batch in batches
  3899. ]
  3900. )
  3901. except Exception as e:
  3902. logger.error(
  3903. f"[{self.workspace}] Error lazily embedding pending vectors "
  3904. f"(upserts={len(docs_to_embed)}): {e}"
  3905. )
  3906. raise
  3907. embeddings = np.concatenate(embeddings_list)
  3908. if len(embeddings) != len(docs_to_embed):
  3909. raise RuntimeError(
  3910. f"[{self.workspace}] Embedding count mismatch: "
  3911. f"expected {len(docs_to_embed)}, got {len(embeddings)}"
  3912. )
  3913. # Re-acquire the lock just long enough to cache results on
  3914. # the same record. The identity check gates BOTH the cache
  3915. # write and the response entry: if the pending record was
  3916. # swapped or removed during the embedding window (concurrent
  3917. # upsert / delete / drop), the just-computed vector no longer
  3918. # matches the current buffer state for this id, so we drop it
  3919. # from the response rather than return a stale embedding.
  3920. async with self._flush_lock:
  3921. for i, ((doc_id, original_pending), embedding) in enumerate(
  3922. zip(docs_to_embed, embeddings), start=1
  3923. ):
  3924. current = self._pending_vector_docs.get(doc_id)
  3925. if current is original_pending:
  3926. current.vector = embedding
  3927. result[doc_id] = embedding.tolist()
  3928. await _cooperative_yield(i)
  3929. if not remaining:
  3930. return result
  3931. query = (
  3932. f"SELECT id, content_vector FROM {self.table_name} "
  3933. f"WHERE workspace=$1 AND id = ANY($2)"
  3934. )
  3935. try:
  3936. results = await self.db.query(
  3937. query, [self.workspace, remaining], multirows=True
  3938. )
  3939. for row in results or []:
  3940. if not row or "content_vector" not in row or "id" not in row:
  3941. continue
  3942. vector_data = row["content_vector"]
  3943. try:
  3944. if isinstance(vector_data, (list, tuple)):
  3945. result[row["id"]] = list(vector_data)
  3946. elif isinstance(vector_data, str):
  3947. parsed = json.loads(vector_data)
  3948. if isinstance(parsed, list):
  3949. result[row["id"]] = parsed
  3950. elif hasattr(vector_data, "tolist"):
  3951. result[row["id"]] = vector_data.tolist()
  3952. elif hasattr(vector_data, "to_list") and callable(
  3953. vector_data.to_list
  3954. ):
  3955. result[row["id"]] = vector_data.to_list()
  3956. except (json.JSONDecodeError, TypeError) as e:
  3957. logger.warning(
  3958. f"[{self.workspace}] Failed to parse vector data for ID {row['id']}: {e}"
  3959. )
  3960. except Exception as e:
  3961. logger.error(f"[{self.workspace}] Error getting vectors: {e}")
  3962. return result
  3963. async def drop(self) -> dict[str, str]:
  3964. """Drop all rows scoped to this storage's workspace.
  3965. The underlying table is shared across workspaces and is NOT
  3966. dropped — this method issues ``DELETE FROM <table> WHERE
  3967. workspace=$1`` and clears the pending buffers (queued
  3968. upserts/deletes against rows that are about to disappear are
  3969. meaningless).
  3970. Concurrency contract:
  3971. ``_flush_lock`` guards same-process flush / upsert / delete
  3972. races only. Cross-worker buffered writes are NOT covered —
  3973. another worker's pending buffer can flush stale rows back
  3974. into the table immediately after this call returns. Callers
  3975. running inside the LightRAG framework MUST hold
  3976. ``pipeline_status["destructive_busy"] = True`` (acquired
  3977. atomically via ``_acquire_destructive_busy``) for the entire
  3978. duration of the drop; the ``/documents/clear`` endpoint
  3979. already does this before invoking ``drop()`` on every
  3980. storage. Direct callers (tests, ops scripts, debugging) are
  3981. responsible for ensuring no other writer is touching this
  3982. workspace.
  3983. Returns:
  3984. ``{"status": "success" | "error", "message": ...}``. Unlike
  3985. ``delete()`` / ``delete_entity()`` / ``delete_entity_relation()``
  3986. which re-raise on failure, ``drop()`` swallows the exception
  3987. into the return dict — callers MUST inspect ``status`` to
  3988. detect failure. The exception is also logged at ``error``
  3989. level so a missed status check still leaves a trail.
  3990. """
  3991. try:
  3992. async with self._flush_lock:
  3993. self._pending_vector_docs.clear()
  3994. self._pending_vector_deletes.clear()
  3995. drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
  3996. table_name=self.table_name
  3997. )
  3998. await self.db.execute(drop_sql, {"workspace": self.workspace})
  3999. return {"status": "success", "message": "data dropped"}
  4000. except Exception as e:
  4001. logger.error(
  4002. f"[{self.workspace}] Error dropping vector storage "
  4003. f"{self.namespace}: {e}"
  4004. )
  4005. return {"status": "error", "message": str(e)}
  4006. def _parse_doc_status_datetime(
  4007. dt_str: Any,
  4008. context: str = "",
  4009. ) -> datetime.datetime | None:
  4010. """Convert a datetime value to a naive UTC datetime for database storage.
  4011. Accepts `datetime.datetime` objects, `datetime.date` objects, or ISO-format
  4012. strings. Returns None on failure (which may trigger a NOT NULL constraint
  4013. violation if the column does not allow nulls).
  4014. The optional context string (e.g. "[workspace] doc <id> created_at") is
  4015. included in the error log to help locate the offending record.
  4016. """
  4017. if dt_str is None:
  4018. return None
  4019. if isinstance(dt_str, datetime.datetime):
  4020. if dt_str.tzinfo is None:
  4021. dt_str = dt_str.replace(tzinfo=timezone.utc)
  4022. return dt_str.astimezone(timezone.utc).replace(tzinfo=None)
  4023. if isinstance(dt_str, datetime.date):
  4024. return datetime.datetime(
  4025. dt_str.year, dt_str.month, dt_str.day, tzinfo=timezone.utc
  4026. ).replace(tzinfo=None)
  4027. try:
  4028. dt = datetime.datetime.fromisoformat(dt_str)
  4029. if dt.tzinfo is None:
  4030. dt = dt.replace(tzinfo=timezone.utc)
  4031. return dt.astimezone(timezone.utc).replace(tzinfo=None)
  4032. except (ValueError, TypeError):
  4033. logger.error(
  4034. f"Unable to parse doc status datetime string"
  4035. f"{f' ({context})' if context else ''}: {dt_str!r}"
  4036. )
  4037. return None
  4038. @final
  4039. @dataclass
  4040. class PGDocStatusStorage(DocStatusStorage):
  4041. db: PostgreSQLDB = field(default=None)
  4042. def _format_datetime_with_timezone(self, dt):
  4043. """Convert datetime to ISO format string with timezone info"""
  4044. if dt is None:
  4045. return None
  4046. # If no timezone info, assume it's UTC time (as stored in database)
  4047. if dt.tzinfo is None:
  4048. dt = dt.replace(tzinfo=timezone.utc)
  4049. # If datetime already has timezone info, keep it as is
  4050. return dt.isoformat()
  4051. async def initialize(self):
  4052. async with get_data_init_lock():
  4053. if self.db is None:
  4054. self.db = await ClientManager.get_client(
  4055. vector_storage=self.global_config.get("vector_storage")
  4056. )
  4057. # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
  4058. if self.db.workspace:
  4059. # Use PostgreSQLDB's workspace (highest priority)
  4060. logger.info(
  4061. f"Using PG_WORKSPACE environment variable: '{self.db.workspace}' (overriding '{self.workspace}/{self.namespace}')"
  4062. )
  4063. self.workspace = self.db.workspace
  4064. elif hasattr(self, "workspace") and self.workspace:
  4065. # Use storage class's workspace (medium priority)
  4066. pass
  4067. else:
  4068. # Use "default" for compatibility (lowest priority)
  4069. self.workspace = "default"
  4070. # NOTE: Table creation is handled by PostgreSQLDB.initdb() during initialization
  4071. # No need to create table here as it's already created in the TABLES dict
  4072. async def finalize(self):
  4073. if self.db is not None:
  4074. await ClientManager.release_client(self.db)
  4075. self.db = None
  4076. async def filter_keys(self, keys: set[str]) -> set[str]:
  4077. """Filter out duplicated content"""
  4078. if not keys:
  4079. return set()
  4080. table_name = namespace_to_table_name(self.namespace)
  4081. sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
  4082. params = {"workspace": self.workspace, "ids": list(keys)}
  4083. try:
  4084. res = await self.db.query(sql, list(params.values()), multirows=True)
  4085. if res:
  4086. exist_keys = [key["id"] for key in res]
  4087. else:
  4088. exist_keys = []
  4089. new_keys = set([s for s in keys if s not in exist_keys])
  4090. # print(f"keys: {keys}")
  4091. # print(f"new_keys: {new_keys}")
  4092. return new_keys
  4093. except Exception as e:
  4094. logger.error(
  4095. f"[{self.workspace}] PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
  4096. )
  4097. raise
  4098. async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
  4099. sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
  4100. params = {"workspace": self.workspace, "id": id}
  4101. result = await self.db.query(sql, list(params.values()), True)
  4102. if result is None or result == []:
  4103. return None
  4104. else:
  4105. # Parse chunks_list JSON string back to list
  4106. chunks_list = result[0].get("chunks_list", [])
  4107. if isinstance(chunks_list, str):
  4108. try:
  4109. chunks_list = json.loads(chunks_list)
  4110. except json.JSONDecodeError:
  4111. chunks_list = []
  4112. # Parse metadata JSON string back to dict
  4113. metadata = result[0].get("metadata", {})
  4114. if isinstance(metadata, str):
  4115. try:
  4116. metadata = json.loads(metadata)
  4117. except json.JSONDecodeError:
  4118. metadata = {}
  4119. # Convert datetime objects to ISO format strings with timezone info
  4120. created_at = self._format_datetime_with_timezone(result[0]["created_at"])
  4121. updated_at = self._format_datetime_with_timezone(result[0]["updated_at"])
  4122. return dict(
  4123. content_length=result[0]["content_length"],
  4124. content_summary=result[0]["content_summary"],
  4125. status=result[0]["status"],
  4126. chunks_count=result[0]["chunks_count"],
  4127. created_at=created_at,
  4128. updated_at=updated_at,
  4129. file_path=result[0]["file_path"],
  4130. chunks_list=chunks_list,
  4131. metadata=metadata,
  4132. error_msg=result[0].get("error_msg"),
  4133. track_id=result[0].get("track_id"),
  4134. content_hash=result[0].get("content_hash"),
  4135. )
  4136. async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
  4137. """Get doc_chunks data by multiple IDs."""
  4138. if not ids:
  4139. return []
  4140. sql = "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2)"
  4141. params = {"workspace": self.workspace, "ids": ids}
  4142. results = await self.db.query(sql, list(params.values()), True)
  4143. if not results:
  4144. return []
  4145. processed_map: dict[str, dict[str, Any]] = {}
  4146. for row in results:
  4147. # Parse chunks_list JSON string back to list
  4148. chunks_list = row.get("chunks_list", [])
  4149. if isinstance(chunks_list, str):
  4150. try:
  4151. chunks_list = json.loads(chunks_list)
  4152. except json.JSONDecodeError:
  4153. chunks_list = []
  4154. # Parse metadata JSON string back to dict
  4155. metadata = row.get("metadata", {})
  4156. if isinstance(metadata, str):
  4157. try:
  4158. metadata = json.loads(metadata)
  4159. except json.JSONDecodeError:
  4160. metadata = {}
  4161. # Convert datetime objects to ISO format strings with timezone info
  4162. created_at = self._format_datetime_with_timezone(row["created_at"])
  4163. updated_at = self._format_datetime_with_timezone(row["updated_at"])
  4164. processed_map[str(row.get("id"))] = {
  4165. "content_length": row["content_length"],
  4166. "content_summary": row["content_summary"],
  4167. "status": row["status"],
  4168. "chunks_count": row["chunks_count"],
  4169. "created_at": created_at,
  4170. "updated_at": updated_at,
  4171. "file_path": row["file_path"],
  4172. "chunks_list": chunks_list,
  4173. "metadata": metadata,
  4174. "error_msg": row.get("error_msg"),
  4175. "track_id": row.get("track_id"),
  4176. "content_hash": row.get("content_hash"),
  4177. }
  4178. ordered_results: list[dict[str, Any] | None] = []
  4179. for requested_id in ids:
  4180. ordered_results.append(processed_map.get(str(requested_id)))
  4181. return ordered_results
  4182. async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]:
  4183. """Get document by file path
  4184. Args:
  4185. file_path: The file path to search for
  4186. Returns:
  4187. Union[dict[str, Any], None]: Document data if found, None otherwise
  4188. Returns the same format as get_by_id method
  4189. """
  4190. sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and file_path=$2"
  4191. params = {"workspace": self.workspace, "file_path": file_path}
  4192. result = await self.db.query(sql, list(params.values()), True)
  4193. if result is None or result == []:
  4194. return None
  4195. else:
  4196. # Parse chunks_list JSON string back to list
  4197. chunks_list = result[0].get("chunks_list", [])
  4198. if isinstance(chunks_list, str):
  4199. try:
  4200. chunks_list = json.loads(chunks_list)
  4201. except json.JSONDecodeError:
  4202. chunks_list = []
  4203. # Parse metadata JSON string back to dict
  4204. metadata = result[0].get("metadata", {})
  4205. if isinstance(metadata, str):
  4206. try:
  4207. metadata = json.loads(metadata)
  4208. except json.JSONDecodeError:
  4209. metadata = {}
  4210. # Convert datetime objects to ISO format strings with timezone info
  4211. created_at = self._format_datetime_with_timezone(result[0]["created_at"])
  4212. updated_at = self._format_datetime_with_timezone(result[0]["updated_at"])
  4213. return dict(
  4214. content_length=result[0]["content_length"],
  4215. content_summary=result[0]["content_summary"],
  4216. status=result[0]["status"],
  4217. chunks_count=result[0]["chunks_count"],
  4218. created_at=created_at,
  4219. updated_at=updated_at,
  4220. file_path=result[0]["file_path"],
  4221. chunks_list=chunks_list,
  4222. metadata=metadata,
  4223. error_msg=result[0].get("error_msg"),
  4224. track_id=result[0].get("track_id"),
  4225. content_hash=result[0].get("content_hash"),
  4226. )
  4227. async def get_doc_by_file_basename(
  4228. self, basename: str
  4229. ) -> tuple[str, dict[str, Any]] | None:
  4230. """PG-native override of basename-based document lookup.
  4231. Replaces the base-class full-table scan with a database-level query on
  4232. the canonical ``file_path`` column. The caller is responsible for
  4233. passing an already-canonical basename; storage performs an exact match
  4234. only.
  4235. """
  4236. if not basename:
  4237. return None
  4238. if basename == "unknown_source":
  4239. return None
  4240. sql = (
  4241. "SELECT * FROM LIGHTRAG_DOC_STATUS "
  4242. "WHERE workspace=$1 AND file_path = $2 "
  4243. "ORDER BY created_at ASC, id ASC LIMIT 1"
  4244. )
  4245. params = [self.workspace, basename]
  4246. result = await self.db.query(sql, params, True)
  4247. if not result:
  4248. return None
  4249. row = result[0]
  4250. chunks_list = row.get("chunks_list", [])
  4251. if isinstance(chunks_list, str):
  4252. try:
  4253. chunks_list = json.loads(chunks_list)
  4254. except json.JSONDecodeError:
  4255. chunks_list = []
  4256. metadata = row.get("metadata", {})
  4257. if isinstance(metadata, str):
  4258. try:
  4259. metadata = json.loads(metadata)
  4260. except json.JSONDecodeError:
  4261. metadata = {}
  4262. created_at = self._format_datetime_with_timezone(row["created_at"])
  4263. updated_at = self._format_datetime_with_timezone(row["updated_at"])
  4264. doc = dict(
  4265. content_length=row["content_length"],
  4266. content_summary=row["content_summary"],
  4267. status=row["status"],
  4268. chunks_count=row["chunks_count"],
  4269. created_at=created_at,
  4270. updated_at=updated_at,
  4271. file_path=row["file_path"],
  4272. chunks_list=chunks_list,
  4273. metadata=metadata,
  4274. error_msg=row.get("error_msg"),
  4275. track_id=row.get("track_id"),
  4276. content_hash=row.get("content_hash"),
  4277. )
  4278. return str(row["id"]), doc
  4279. async def get_doc_by_content_hash(
  4280. self, content_hash: str
  4281. ) -> tuple[str, dict[str, Any]] | None:
  4282. """PG-native override of content-hash document lookup.
  4283. Replaces the base-class full-table scan with an indexed query on
  4284. ``workspace + content_hash``. Empty strings are treated as a miss
  4285. to align with the partial-index predicate.
  4286. """
  4287. if not content_hash:
  4288. return None
  4289. sql = (
  4290. "SELECT * FROM LIGHTRAG_DOC_STATUS "
  4291. "WHERE workspace=$1 AND content_hash=$2 "
  4292. "ORDER BY created_at ASC, id ASC LIMIT 1"
  4293. )
  4294. result = await self.db.query(sql, [self.workspace, content_hash], True)
  4295. if not result:
  4296. return None
  4297. row = result[0]
  4298. chunks_list = row.get("chunks_list", [])
  4299. if isinstance(chunks_list, str):
  4300. try:
  4301. chunks_list = json.loads(chunks_list)
  4302. except json.JSONDecodeError:
  4303. chunks_list = []
  4304. metadata = row.get("metadata", {})
  4305. if isinstance(metadata, str):
  4306. try:
  4307. metadata = json.loads(metadata)
  4308. except json.JSONDecodeError:
  4309. metadata = {}
  4310. created_at = self._format_datetime_with_timezone(row["created_at"])
  4311. updated_at = self._format_datetime_with_timezone(row["updated_at"])
  4312. doc = dict(
  4313. content_length=row["content_length"],
  4314. content_summary=row["content_summary"],
  4315. status=row["status"],
  4316. chunks_count=row["chunks_count"],
  4317. created_at=created_at,
  4318. updated_at=updated_at,
  4319. file_path=row["file_path"],
  4320. chunks_list=chunks_list,
  4321. metadata=metadata,
  4322. error_msg=row.get("error_msg"),
  4323. track_id=row.get("track_id"),
  4324. content_hash=row.get("content_hash"),
  4325. )
  4326. return str(row["id"]), doc
  4327. async def get_status_counts(self) -> dict[str, int]:
  4328. """Get counts of documents in each status"""
  4329. sql = """SELECT status as "status", COUNT(1) as "count"
  4330. FROM LIGHTRAG_DOC_STATUS
  4331. where workspace=$1 GROUP BY STATUS
  4332. """
  4333. params = {"workspace": self.workspace}
  4334. result = await self.db.query(sql, list(params.values()), True)
  4335. counts = {}
  4336. for doc in result:
  4337. counts[doc["status"]] = doc["count"]
  4338. return counts
  4339. async def get_docs_by_status(
  4340. self, status: DocStatus
  4341. ) -> dict[str, DocProcessingStatus]:
  4342. """all documents with a specific status"""
  4343. sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
  4344. params = {"workspace": self.workspace, "status": status.value}
  4345. result = await self.db.query(sql, list(params.values()), True)
  4346. docs_by_status = {}
  4347. for element in result:
  4348. # Parse chunks_list JSON string back to list
  4349. chunks_list = element.get("chunks_list", [])
  4350. if isinstance(chunks_list, str):
  4351. try:
  4352. chunks_list = json.loads(chunks_list)
  4353. except json.JSONDecodeError:
  4354. chunks_list = []
  4355. # Parse metadata JSON string back to dict
  4356. metadata = element.get("metadata", {})
  4357. if isinstance(metadata, str):
  4358. try:
  4359. metadata = json.loads(metadata)
  4360. except json.JSONDecodeError:
  4361. metadata = {}
  4362. # Ensure metadata is a dict
  4363. if not isinstance(metadata, dict):
  4364. metadata = {}
  4365. # Safe handling for file_path
  4366. file_path = element.get("file_path")
  4367. if file_path is None:
  4368. file_path = "no-file-path"
  4369. # Convert datetime objects to ISO format strings with timezone info
  4370. created_at = self._format_datetime_with_timezone(element["created_at"])
  4371. updated_at = self._format_datetime_with_timezone(element["updated_at"])
  4372. docs_by_status[element["id"]] = DocProcessingStatus(
  4373. content_summary=element["content_summary"],
  4374. content_length=element["content_length"],
  4375. status=element["status"],
  4376. created_at=created_at,
  4377. updated_at=updated_at,
  4378. chunks_count=element["chunks_count"],
  4379. file_path=file_path,
  4380. chunks_list=chunks_list,
  4381. metadata=metadata,
  4382. error_msg=element.get("error_msg"),
  4383. track_id=element.get("track_id"),
  4384. content_hash=element.get("content_hash"),
  4385. )
  4386. return docs_by_status
  4387. async def get_docs_by_statuses(
  4388. self, statuses: list[DocStatus]
  4389. ) -> dict[str, DocProcessingStatus]:
  4390. """Fetch documents matching any of the given statuses in a single query.
  4391. Replaces multiple sequential/parallel get_docs_by_status() calls when the
  4392. caller needs documents across several statuses (e.g. PROCESSING + FAILED + PENDING).
  4393. Uses a single ANY($2) query instead of N separate round-trips.
  4394. """
  4395. if not statuses:
  4396. return {}
  4397. status_values = [s.value for s in statuses]
  4398. sql = (
  4399. "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND status = ANY($2)"
  4400. )
  4401. result = await self.db.query(
  4402. sql, [self.workspace, status_values], multirows=True
  4403. )
  4404. docs: dict[str, DocProcessingStatus] = {}
  4405. for element in result or []:
  4406. try:
  4407. chunks_list = element.get("chunks_list", [])
  4408. if isinstance(chunks_list, str):
  4409. try:
  4410. chunks_list = json.loads(chunks_list)
  4411. except json.JSONDecodeError:
  4412. chunks_list = []
  4413. metadata = element.get("metadata", {})
  4414. if isinstance(metadata, str):
  4415. try:
  4416. metadata = json.loads(metadata)
  4417. except json.JSONDecodeError:
  4418. metadata = {}
  4419. if not isinstance(metadata, dict):
  4420. metadata = {}
  4421. file_path = element.get("file_path") or "no-file-path"
  4422. docs[element["id"]] = DocProcessingStatus(
  4423. content_summary=element["content_summary"],
  4424. content_length=element["content_length"],
  4425. status=element["status"],
  4426. created_at=self._format_datetime_with_timezone(
  4427. element["created_at"]
  4428. ),
  4429. updated_at=self._format_datetime_with_timezone(
  4430. element["updated_at"]
  4431. ),
  4432. chunks_count=element["chunks_count"],
  4433. file_path=file_path,
  4434. chunks_list=chunks_list,
  4435. metadata=metadata,
  4436. error_msg=element.get("error_msg"),
  4437. track_id=element.get("track_id"),
  4438. content_hash=element.get("content_hash"),
  4439. )
  4440. except (KeyError, TypeError) as e:
  4441. doc_id_hint = element.get("id", "<unknown>") if element else "<unknown>"
  4442. logger.error(
  4443. f"[{self.workspace}] Skipping document '{doc_id_hint}' — "
  4444. f"required field missing or wrong type while parsing DB row: {e!r}"
  4445. )
  4446. continue
  4447. return docs
  4448. async def get_docs_by_track_id(
  4449. self, track_id: str
  4450. ) -> dict[str, DocProcessingStatus]:
  4451. """Get all documents with a specific track_id"""
  4452. sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and track_id=$2"
  4453. params = {"workspace": self.workspace, "track_id": track_id}
  4454. result = await self.db.query(sql, list(params.values()), True)
  4455. docs_by_track_id = {}
  4456. for element in result:
  4457. # Parse chunks_list JSON string back to list
  4458. chunks_list = element.get("chunks_list", [])
  4459. if isinstance(chunks_list, str):
  4460. try:
  4461. chunks_list = json.loads(chunks_list)
  4462. except json.JSONDecodeError:
  4463. chunks_list = []
  4464. # Parse metadata JSON string back to dict
  4465. metadata = element.get("metadata", {})
  4466. if isinstance(metadata, str):
  4467. try:
  4468. metadata = json.loads(metadata)
  4469. except json.JSONDecodeError:
  4470. metadata = {}
  4471. # Ensure metadata is a dict
  4472. if not isinstance(metadata, dict):
  4473. metadata = {}
  4474. # Safe handling for file_path
  4475. file_path = element.get("file_path")
  4476. if file_path is None:
  4477. file_path = "no-file-path"
  4478. # Convert datetime objects to ISO format strings with timezone info
  4479. created_at = self._format_datetime_with_timezone(element["created_at"])
  4480. updated_at = self._format_datetime_with_timezone(element["updated_at"])
  4481. docs_by_track_id[element["id"]] = DocProcessingStatus(
  4482. content_summary=element["content_summary"],
  4483. content_length=element["content_length"],
  4484. status=element["status"],
  4485. created_at=created_at,
  4486. updated_at=updated_at,
  4487. chunks_count=element["chunks_count"],
  4488. file_path=file_path,
  4489. chunks_list=chunks_list,
  4490. track_id=element.get("track_id"),
  4491. metadata=metadata,
  4492. error_msg=element.get("error_msg"),
  4493. content_hash=element.get("content_hash"),
  4494. )
  4495. return docs_by_track_id
  4496. async def get_docs_paginated(
  4497. self,
  4498. status_filter: DocStatus | None = None,
  4499. status_filters: list[DocStatus] | None = None,
  4500. page: int = 1,
  4501. page_size: int = 50,
  4502. sort_field: str = "updated_at",
  4503. sort_direction: str = "desc",
  4504. ) -> tuple[list[tuple[str, DocProcessingStatus]], int]:
  4505. """Get documents with pagination support
  4506. Args:
  4507. status_filter: Filter by document status, None for all statuses
  4508. page: Page number (1-based)
  4509. page_size: Number of documents per page (10-200)
  4510. sort_field: Field to sort by ('created_at', 'updated_at', 'id')
  4511. sort_direction: Sort direction ('asc' or 'desc')
  4512. Returns:
  4513. Tuple of (list of (doc_id, DocProcessingStatus) tuples, total_count)
  4514. """
  4515. start = time.perf_counter()
  4516. status_filter_values = self.resolve_status_filter_values(
  4517. status_filter=status_filter,
  4518. status_filters=status_filters,
  4519. )
  4520. status_filter_value = status_filter.value if status_filter is not None else None
  4521. performance_timing_log(
  4522. "[%s] PGDocStatusStorage.get_docs_paginated start status_filter=%s page=%s page_size=%s sort_field=%s sort_direction=%s",
  4523. self.workspace,
  4524. status_filter_value,
  4525. page,
  4526. page_size,
  4527. sort_field,
  4528. sort_direction,
  4529. )
  4530. # Validate parameters
  4531. if page < 1:
  4532. page = 1
  4533. if page_size < 10:
  4534. page_size = 10
  4535. elif page_size > 200:
  4536. page_size = 200
  4537. # Whitelist validation for sort_field to prevent SQL injection
  4538. allowed_sort_fields = {"created_at", "updated_at", "id", "file_path"}
  4539. if sort_field not in allowed_sort_fields:
  4540. sort_field = "updated_at"
  4541. # Whitelist validation for sort_direction to prevent SQL injection
  4542. if sort_direction.lower() not in ["asc", "desc"]:
  4543. sort_direction = "desc"
  4544. else:
  4545. sort_direction = sort_direction.lower()
  4546. # Calculate offset
  4547. offset = (page - 1) * page_size
  4548. # Build parameterized query components
  4549. params = {"workspace": self.workspace}
  4550. param_count = 1
  4551. # Build WHERE clause with parameterized query
  4552. if status_filter_values is not None:
  4553. param_count += 1
  4554. where_clause = "WHERE workspace=$1 AND status = ANY($2)"
  4555. params["status_filters"] = sorted(status_filter_values)
  4556. else:
  4557. where_clause = "WHERE workspace=$1"
  4558. # Build ORDER BY clause using validated whitelist values.
  4559. # NULLS LAST is applied in both the inner paged CTE and the outer query so
  4560. # that the LIMIT/OFFSET slice boundary and the display order are identical.
  4561. # Without it, DESC defaults to NULLS FIRST: nulls land on earlier pages but
  4562. # are re-sorted to the end by the outer ORDER BY, dropping non-null rows.
  4563. order_clause = f"ORDER BY {sort_field} {sort_direction.upper()} NULLS LAST"
  4564. # Two-CTE query: total count + page data in a single round-trip.
  4565. #
  4566. # COUNT(*) OVER () was replaced because when the LIMIT/OFFSET clause yields
  4567. # no rows (out-of-range page), there are no result rows to carry the window
  4568. # function value — so total_count would not appear in the output at all,
  4569. # making it impossible to distinguish "0 matching documents" from "non-empty
  4570. # result set, page is past the end".
  4571. #
  4572. # The LEFT JOIN pattern fixes this: the `total` CTE always produces exactly
  4573. # one row (the aggregate count over the full WHERE clause), and the outer
  4574. # LEFT JOIN emits that one row even when `paged` is empty. Python then
  4575. # skips rows where id IS NULL (the empty-page sentinel).
  4576. #
  4577. # chunks_list is intentionally excluded from the paged CTE SELECT list:
  4578. # DocStatusResponse does not expose it, so transferring the full JSONB array
  4579. # would be pure overhead. The chunks_list=[] in the constructor below is
  4580. # intentional — see the paged CTE column list above.
  4581. params["limit"] = page_size
  4582. params["offset"] = offset
  4583. cte_sql = f"""
  4584. WITH total AS (
  4585. SELECT COUNT(*) AS _total_count
  4586. FROM LIGHTRAG_DOC_STATUS
  4587. {where_clause}
  4588. ),
  4589. paged AS (
  4590. SELECT id, workspace, content_summary, content_length, chunks_count,
  4591. status, file_path, track_id, metadata, error_msg, content_hash,
  4592. created_at, updated_at
  4593. FROM LIGHTRAG_DOC_STATUS
  4594. {where_clause}
  4595. {order_clause}
  4596. LIMIT ${param_count + 1} OFFSET ${param_count + 2}
  4597. )
  4598. SELECT p.*, t._total_count
  4599. FROM total t
  4600. LEFT JOIN paged p ON true
  4601. ORDER BY p.{sort_field} {sort_direction.upper()} NULLS LAST
  4602. """
  4603. query_timing_label = f"{self.workspace} PGDocStatusStorage.get_docs_paginated"
  4604. result = await self.db.query(
  4605. cte_sql,
  4606. list(params.values()),
  4607. True,
  4608. timing_label=query_timing_label,
  4609. )
  4610. total_count = result[0]["_total_count"] if result else 0
  4611. # Convert to (doc_id, DocProcessingStatus) tuples
  4612. documents = []
  4613. for element in result:
  4614. if element["id"] is None:
  4615. # Empty-page sentinel row from LEFT JOIN when paged has no rows.
  4616. continue
  4617. doc_id = element["id"]
  4618. # Parse metadata JSON string back to dict
  4619. metadata = element.get("metadata", {})
  4620. if isinstance(metadata, str):
  4621. try:
  4622. metadata = json.loads(metadata)
  4623. except json.JSONDecodeError:
  4624. metadata = {}
  4625. # Convert datetime objects to ISO format strings with timezone info
  4626. created_at = self._format_datetime_with_timezone(element["created_at"])
  4627. updated_at = self._format_datetime_with_timezone(element["updated_at"])
  4628. doc_status = DocProcessingStatus(
  4629. content_summary=element["content_summary"],
  4630. content_length=element["content_length"],
  4631. status=element["status"],
  4632. created_at=created_at,
  4633. updated_at=updated_at,
  4634. chunks_count=element["chunks_count"],
  4635. file_path=element["file_path"],
  4636. chunks_list=[], # not fetched: unused by pagination response
  4637. track_id=element.get("track_id"),
  4638. metadata=metadata,
  4639. error_msg=element.get("error_msg"),
  4640. content_hash=element.get("content_hash"),
  4641. )
  4642. documents.append((doc_id, doc_status))
  4643. elapsed = time.perf_counter() - start
  4644. performance_timing_log(
  4645. "[%s] PGDocStatusStorage.get_docs_paginated completed in %.4fs returned_rows=%s total_count=%s status_filter=%s page=%s page_size=%s sort_field=%s sort_direction=%s",
  4646. self.workspace,
  4647. elapsed,
  4648. len(documents),
  4649. total_count,
  4650. status_filter_value,
  4651. page,
  4652. page_size,
  4653. sort_field,
  4654. sort_direction,
  4655. )
  4656. return documents, total_count
  4657. async def get_all_status_counts(self) -> dict[str, int]:
  4658. """Get counts of documents in each status for all documents
  4659. Returns:
  4660. Dictionary mapping status names to counts, including 'all' field
  4661. """
  4662. start = time.perf_counter()
  4663. performance_timing_log(
  4664. "[%s] PGDocStatusStorage.get_all_status_counts start", self.workspace
  4665. )
  4666. sql = """
  4667. SELECT status, COUNT(*) as count
  4668. FROM LIGHTRAG_DOC_STATUS
  4669. WHERE workspace=$1
  4670. GROUP BY status
  4671. """
  4672. params = {"workspace": self.workspace}
  4673. query_timing_label = (
  4674. f"{self.workspace} PGDocStatusStorage.get_all_status_counts"
  4675. )
  4676. result = await self.db.query(
  4677. sql,
  4678. list(params.values()),
  4679. True,
  4680. timing_label=query_timing_label,
  4681. )
  4682. counts = {}
  4683. total_count = 0
  4684. for row in result:
  4685. counts[row["status"]] = row["count"]
  4686. total_count += row["count"]
  4687. # Add 'all' field with total count
  4688. counts["all"] = total_count
  4689. elapsed = time.perf_counter() - start
  4690. performance_timing_log(
  4691. "[%s] PGDocStatusStorage.get_all_status_counts completed in %.4fs counts=%s",
  4692. self.workspace,
  4693. elapsed,
  4694. counts,
  4695. )
  4696. return counts
  4697. async def index_done_callback(self) -> None:
  4698. # PG handles persistence automatically
  4699. pass
  4700. async def is_empty(self) -> bool:
  4701. """Check if the storage is empty for the current workspace and namespace
  4702. Returns:
  4703. bool: True if storage is empty, False otherwise
  4704. """
  4705. table_name = namespace_to_table_name(self.namespace)
  4706. if not table_name:
  4707. logger.error(
  4708. f"[{self.workspace}] Unknown namespace for is_empty check: {self.namespace}"
  4709. )
  4710. return True
  4711. sql = f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE workspace=$1 LIMIT 1) as has_data"
  4712. try:
  4713. result = await self.db.query(sql, [self.workspace])
  4714. return not result.get("has_data", False) if result else True
  4715. except Exception as e:
  4716. logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}")
  4717. return True
  4718. async def delete(self, ids: list[str]) -> None:
  4719. """Delete specific records from storage by their IDs
  4720. Args:
  4721. ids (list[str]): List of document IDs to be deleted from storage
  4722. Returns:
  4723. None
  4724. """
  4725. if not ids:
  4726. return
  4727. table_name = namespace_to_table_name(self.namespace)
  4728. if not table_name:
  4729. logger.error(
  4730. f"[{self.workspace}] Unknown namespace for deletion: {self.namespace}"
  4731. )
  4732. return
  4733. delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
  4734. try:
  4735. await self.db.execute(delete_sql, {"workspace": self.workspace, "ids": ids})
  4736. logger.debug(
  4737. f"[{self.workspace}] Successfully deleted {len(ids)} records from {self.namespace}"
  4738. )
  4739. except Exception as e:
  4740. logger.error(
  4741. f"[{self.workspace}] Error while deleting records from {self.namespace}: {e}"
  4742. )
  4743. async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
  4744. """Update or insert document status
  4745. Args:
  4746. data: dictionary of document IDs and their status data
  4747. """
  4748. logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
  4749. if not data:
  4750. return
  4751. timing_label = f"{self.workspace} PGDocStatusStorage.upsert"
  4752. total_start = time.perf_counter()
  4753. performance_timing_log(
  4754. "[%s] start records=%s",
  4755. timing_label,
  4756. len(data),
  4757. )
  4758. # NOTE: content_hash uses COALESCE(NULLIF(...,''), existing) rather than
  4759. # a straight EXCLUDED overwrite. This gives write-once-after-set
  4760. # semantics: once a non-empty content_hash is recorded, subsequent
  4761. # upserts that omit it (or pass '' / NULL) will NOT clear it. Required
  4762. # because pipeline state transitions (e.g. processing -> processed)
  4763. # reuse the existing DocProcessingStatus payload without re-supplying
  4764. # the hash, while _persist_parsed_full_docs patches the hash in a
  4765. # separate upsert.
  4766. #
  4767. # This is a deliberate behavioral divergence from JsonDocStatusStorage,
  4768. # which overwrites unconditionally. No caller today wants to clear a
  4769. # content_hash, so the divergence is invisible — but if that ever
  4770. # changes, this guard must be revisited.
  4771. sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content_summary,content_length,chunks_count,status,file_path,chunks_list,track_id,metadata,error_msg,content_hash,created_at,updated_at)
  4772. values($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14)
  4773. on conflict(id,workspace) do update set
  4774. content_summary = EXCLUDED.content_summary,
  4775. content_length = EXCLUDED.content_length,
  4776. chunks_count = EXCLUDED.chunks_count,
  4777. status = EXCLUDED.status,
  4778. file_path = EXCLUDED.file_path,
  4779. chunks_list = EXCLUDED.chunks_list,
  4780. track_id = EXCLUDED.track_id,
  4781. metadata = EXCLUDED.metadata,
  4782. error_msg = EXCLUDED.error_msg,
  4783. content_hash = COALESCE(
  4784. NULLIF(EXCLUDED.content_hash, ''),
  4785. LIGHTRAG_DOC_STATUS.content_hash
  4786. ),
  4787. created_at = EXCLUDED.created_at,
  4788. updated_at = EXCLUDED.updated_at"""
  4789. # Tuple order must match SQL: (workspace, id, content_summary, content_length,
  4790. # chunks_count, status, file_path, chunks_list, track_id, metadata,
  4791. # error_msg, content_hash, created_at, updated_at)
  4792. batch: list[tuple] = []
  4793. skipped: list[str] = []
  4794. batch_build_start = time.perf_counter()
  4795. for i, (k, v) in enumerate(data.items(), start=1):
  4796. try:
  4797. batch.append(
  4798. (
  4799. self.workspace,
  4800. k,
  4801. v["content_summary"],
  4802. v["content_length"],
  4803. v.get("chunks_count", -1),
  4804. v["status"],
  4805. v["file_path"],
  4806. json.dumps(v.get("chunks_list", [])),
  4807. v.get("track_id"),
  4808. json.dumps(v.get("metadata", {})),
  4809. v.get("error_msg"),
  4810. v.get("content_hash"),
  4811. _parse_doc_status_datetime(
  4812. v.get("created_at"),
  4813. f"[{self.workspace}] doc {k} created_at",
  4814. ),
  4815. _parse_doc_status_datetime(
  4816. v.get("updated_at"),
  4817. f"[{self.workspace}] doc {k} updated_at",
  4818. ),
  4819. )
  4820. )
  4821. except (KeyError, TypeError, ValueError) as e:
  4822. logger.error(
  4823. f"[{self.workspace}] Skipping document '{k}' in batch upsert — "
  4824. f"invalid or missing field: {e!r}"
  4825. )
  4826. skipped.append(k)
  4827. await _cooperative_yield(i)
  4828. if skipped:
  4829. logger.warning(
  4830. f"[{self.workspace}] {len(skipped)} document(s) skipped in batch upsert: {skipped}"
  4831. )
  4832. performance_timing_log(
  4833. "[%s] batch validation/assembly completed in %.4fs valid_count=%s skipped_count=%s",
  4834. timing_label,
  4835. time.perf_counter() - batch_build_start,
  4836. len(batch),
  4837. len(skipped),
  4838. )
  4839. async def _batch_upsert(
  4840. connection: asyncpg.Connection,
  4841. _sql: str = sql,
  4842. _data: list[tuple] = batch,
  4843. ) -> None:
  4844. execute_start = time.perf_counter()
  4845. async with connection.transaction():
  4846. await connection.executemany(_sql, _data)
  4847. performance_timing_log(
  4848. "[%s] transaction + executemany completed in %.4fs batch_size=%s",
  4849. timing_label,
  4850. time.perf_counter() - execute_start,
  4851. len(_data),
  4852. )
  4853. await self.db._run_with_retry(_batch_upsert, timing_label=timing_label)
  4854. logger.debug(
  4855. f"[{self.workspace}] Batch upserted {len(batch)} records to {self.namespace}"
  4856. )
  4857. performance_timing_log(
  4858. "[%s] total complete in %.4fs valid_count=%s skipped_count=%s",
  4859. timing_label,
  4860. time.perf_counter() - total_start,
  4861. len(batch),
  4862. len(skipped),
  4863. )
  4864. async def drop(self) -> dict[str, str]:
  4865. """Drop the storage"""
  4866. try:
  4867. table_name = namespace_to_table_name(self.namespace)
  4868. if not table_name:
  4869. return {
  4870. "status": "error",
  4871. "message": f"Unknown namespace: {self.namespace}",
  4872. }
  4873. drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
  4874. table_name=table_name
  4875. )
  4876. await self.db.execute(drop_sql, {"workspace": self.workspace})
  4877. return {"status": "success", "message": "data dropped"}
  4878. except Exception as e:
  4879. return {"status": "error", "message": str(e)}
  4880. class PGGraphQueryException(Exception):
  4881. """Exception for the AGE queries."""
  4882. def __init__(self, exception: Union[str, dict[str, Any]]) -> None:
  4883. if isinstance(exception, dict):
  4884. self.message = exception["message"] if "message" in exception else "unknown"
  4885. self.details = exception["details"] if "details" in exception else "unknown"
  4886. else:
  4887. self.message = exception
  4888. self.details = "unknown"
  4889. def get_message(self) -> str:
  4890. return self.message
  4891. def get_details(self) -> Any:
  4892. return self.details
  4893. def _is_transient_graph_write_error(exc: BaseException) -> bool:
  4894. """Return True when a PGGraphQueryException wraps a transient write-time error.
  4895. The inner _run_with_retry already handles connection-level transient errors
  4896. (pool reset, TCP failures, etc.). This predicate covers query-level transient
  4897. errors that survive the connection layer and surface as PGGraphQueryException:
  4898. deadlocks, serialization conflicts, and lock-acquisition timeouts that can
  4899. occur under concurrent document ingestion.
  4900. """
  4901. if not isinstance(exc, PGGraphQueryException):
  4902. return False
  4903. cause = exc.__cause__
  4904. if cause is None:
  4905. return False
  4906. return isinstance(
  4907. cause,
  4908. (
  4909. asyncpg.exceptions.DeadlockDetectedError,
  4910. asyncpg.exceptions.SerializationError,
  4911. asyncpg.exceptions.LockNotAvailableError,
  4912. asyncpg.exceptions.QueryCanceledError,
  4913. ),
  4914. )
  4915. @final
  4916. @dataclass
  4917. class PGGraphStorage(BaseGraphStorage):
  4918. def __post_init__(self):
  4919. # Graph name will be dynamically generated in initialize() based on workspace
  4920. self.db: PostgreSQLDB | None = None
  4921. def _get_workspace_graph_name(self) -> str:
  4922. """
  4923. Generate graph name based on workspace and namespace for data isolation.
  4924. Rules:
  4925. - If workspace is empty or "default": graph_name = namespace
  4926. - If workspace has other value: graph_name = workspace_namespace
  4927. Args:
  4928. None
  4929. Returns:
  4930. str: The graph name for the current workspace
  4931. """
  4932. workspace = self.workspace
  4933. namespace = self.namespace
  4934. if workspace and workspace.strip() and workspace.strip().lower() != "default":
  4935. # Ensure names comply with PostgreSQL identifier specifications
  4936. safe_workspace = re.sub(r"[^a-zA-Z0-9_]", "_", workspace.strip())
  4937. safe_namespace = re.sub(r"[^a-zA-Z0-9_]", "_", namespace)
  4938. return f"{safe_workspace}_{safe_namespace}"
  4939. else:
  4940. # When the workspace is "default", use the namespace directly (for backward compatibility with legacy implementations)
  4941. return re.sub(r"[^a-zA-Z0-9_]", "_", namespace)
  4942. @staticmethod
  4943. def _normalize_node_id(node_id: str) -> str:
  4944. """
  4945. Normalize node ID to ensure special characters are properly handled in Cypher queries.
  4946. Used by write paths that still embed entity IDs in Cypher strings
  4947. (delete_node, remove_nodes, remove_edges). The upsert paths now use
  4948. parameterized Cypher instead.
  4949. Within a Cypher double-quoted string the only recognised escape
  4950. sequences are ``\\"`` and ``\\\\``. We also strip null bytes which
  4951. could truncate the string in some PostgreSQL/AGE code paths.
  4952. Args:
  4953. node_id: The original node ID
  4954. Returns:
  4955. Normalized node ID suitable for embedding in a Cypher double-quoted string
  4956. """
  4957. # Strip null bytes that could truncate the string
  4958. normalized_id = node_id.replace("\x00", "")
  4959. # Escape backslashes first (order matters)
  4960. normalized_id = normalized_id.replace("\\", "\\\\")
  4961. # Escape double quotes
  4962. normalized_id = normalized_id.replace('"', '\\"')
  4963. return normalized_id
  4964. async def initialize(self):
  4965. async with get_data_init_lock():
  4966. if self.db is None:
  4967. self.db = await ClientManager.get_client(
  4968. vector_storage=self.global_config.get("vector_storage")
  4969. )
  4970. # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
  4971. if self.db.workspace:
  4972. # Use PostgreSQLDB's workspace (highest priority)
  4973. logger.info(
  4974. f"Using PG_WORKSPACE environment variable: '{self.db.workspace}' (overriding '{self.workspace}/{self.namespace}')"
  4975. )
  4976. self.workspace = self.db.workspace
  4977. elif hasattr(self, "workspace") and self.workspace:
  4978. # Use storage class's workspace (medium priority)
  4979. pass
  4980. else:
  4981. # Use "default" for compatibility (lowest priority)
  4982. self.workspace = "default"
  4983. # Dynamically generate graph name based on workspace
  4984. self.graph_name = self._get_workspace_graph_name()
  4985. # Log the graph initialization for debugging
  4986. logger.info(
  4987. f"[{self.workspace}] PostgreSQL Graph initialized: graph_name='{self.graph_name}'"
  4988. )
  4989. # Create AGE extension and configure graph environment once at initialization
  4990. # Use _run_with_retry so transient connection errors are retried and pool=None
  4991. # is handled safely (unlike a bare pool.acquire() call).
  4992. async def _do_configure_age_extension(
  4993. connection: asyncpg.Connection,
  4994. ) -> None:
  4995. await PostgreSQLDB.configure_age_extension(connection)
  4996. await self.db._run_with_retry(_do_configure_age_extension)
  4997. # Execute each statement separately and ignore errors
  4998. queries = [
  4999. f"SELECT create_graph('{self.graph_name}')",
  5000. f"SELECT create_vlabel('{self.graph_name}', 'base');",
  5001. f"SELECT create_elabel('{self.graph_name}', 'DIRECTED');",
  5002. # f'CREATE INDEX CONCURRENTLY vertex_p_idx ON {self.graph_name}."_ag_label_vertex" (id)',
  5003. f'CREATE INDEX CONCURRENTLY vertex_idx_node_id ON {self.graph_name}."_ag_label_vertex" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))',
  5004. # f'CREATE INDEX CONCURRENTLY edge_p_idx ON {self.graph_name}."_ag_label_edge" (id)',
  5005. f'CREATE INDEX CONCURRENTLY edge_sid_idx ON {self.graph_name}."_ag_label_edge" (start_id)',
  5006. f'CREATE INDEX CONCURRENTLY edge_eid_idx ON {self.graph_name}."_ag_label_edge" (end_id)',
  5007. f'CREATE INDEX CONCURRENTLY edge_seid_idx ON {self.graph_name}."_ag_label_edge" (start_id,end_id)',
  5008. f'CREATE INDEX CONCURRENTLY directed_p_idx ON {self.graph_name}."DIRECTED" (id)',
  5009. f'CREATE INDEX CONCURRENTLY directed_eid_idx ON {self.graph_name}."DIRECTED" (end_id)',
  5010. f'CREATE INDEX CONCURRENTLY directed_sid_idx ON {self.graph_name}."DIRECTED" (start_id)',
  5011. f'CREATE INDEX CONCURRENTLY directed_seid_idx ON {self.graph_name}."DIRECTED" (start_id,end_id)',
  5012. f'CREATE INDEX CONCURRENTLY entity_p_idx ON {self.graph_name}."base" (id)',
  5013. f'CREATE INDEX CONCURRENTLY entity_idx_node_id ON {self.graph_name}."base" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))',
  5014. f'CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON {self.graph_name}."base" using gin(properties)',
  5015. f'ALTER TABLE {self.graph_name}."DIRECTED" CLUSTER ON directed_sid_idx',
  5016. ]
  5017. for query in queries:
  5018. # Use the new flag to silently ignore "already exists" errors
  5019. # at the source, preventing log spam.
  5020. await self.db.execute(
  5021. query,
  5022. upsert=True,
  5023. ignore_if_exists=True, # Pass the new flag
  5024. with_age=True,
  5025. graph_name=self.graph_name,
  5026. )
  5027. async def finalize(self):
  5028. if self.db is not None:
  5029. await ClientManager.release_client(self.db)
  5030. self.db = None
  5031. async def index_done_callback(self) -> None:
  5032. # PG handles persistence automatically
  5033. pass
  5034. @staticmethod
  5035. def _record_to_dict(record: asyncpg.Record) -> dict[str, Any]:
  5036. """
  5037. Convert a record returned from an age query to a dictionary
  5038. Args:
  5039. record (): a record from an age query result
  5040. Returns:
  5041. dict[str, Any]: a dictionary representation of the record where
  5042. the dictionary key is the field name and the value is the
  5043. value converted to a python type
  5044. """
  5045. @staticmethod
  5046. def parse_agtype_string(agtype_str: str) -> tuple[str, str]:
  5047. """
  5048. Parse agtype string precisely, separating JSON content and type identifier
  5049. Args:
  5050. agtype_str: String like '{"json": "content"}::vertex'
  5051. Returns:
  5052. (json_content, type_identifier)
  5053. """
  5054. if not isinstance(agtype_str, str) or "::" not in agtype_str:
  5055. return agtype_str, ""
  5056. # Find the last :: from the right, which is the start of type identifier
  5057. last_double_colon = agtype_str.rfind("::")
  5058. if last_double_colon == -1:
  5059. return agtype_str, ""
  5060. # Separate JSON content and type identifier
  5061. json_content = agtype_str[:last_double_colon]
  5062. type_identifier = agtype_str[last_double_colon + 2 :]
  5063. return json_content, type_identifier
  5064. @staticmethod
  5065. def safe_json_parse(json_str: str, context: str = "") -> dict:
  5066. """
  5067. Safe JSON parsing with simplified error logging
  5068. """
  5069. try:
  5070. return json.loads(json_str)
  5071. except json.JSONDecodeError as e:
  5072. logger.error(f"JSON parsing failed ({context}): {e}")
  5073. logger.error(f"Raw data (first 100 chars): {repr(json_str[:100])}")
  5074. logger.error(f"Error position: line {e.lineno}, column {e.colno}")
  5075. return None
  5076. # result holder
  5077. d = {}
  5078. # prebuild a mapping of vertex_id to vertex mappings to be used
  5079. # later to build edges
  5080. vertices = {}
  5081. # First pass: preprocess vertices
  5082. for k in record.keys():
  5083. v = record[k]
  5084. if isinstance(v, str) and "::" in v:
  5085. if v.startswith("[") and v.endswith("]"):
  5086. # Handle vertex arrays
  5087. json_content, type_id = parse_agtype_string(v)
  5088. if type_id == "vertex":
  5089. vertexes = safe_json_parse(
  5090. json_content, f"vertices array for {k}"
  5091. )
  5092. if vertexes:
  5093. for vertex in vertexes:
  5094. vertices[vertex["id"]] = vertex.get("properties")
  5095. else:
  5096. # Handle single vertex
  5097. json_content, type_id = parse_agtype_string(v)
  5098. if type_id == "vertex":
  5099. vertex = safe_json_parse(json_content, f"single vertex for {k}")
  5100. if vertex:
  5101. vertices[vertex["id"]] = vertex.get("properties")
  5102. # Second pass: process all fields
  5103. for k in record.keys():
  5104. v = record[k]
  5105. if isinstance(v, str) and "::" in v:
  5106. if v.startswith("[") and v.endswith("]"):
  5107. # Handle array types
  5108. json_content, type_id = parse_agtype_string(v)
  5109. if type_id in ["vertex", "edge"]:
  5110. parsed_data = safe_json_parse(
  5111. json_content, f"array {type_id} for field {k}"
  5112. )
  5113. d[k] = parsed_data if parsed_data is not None else None
  5114. else:
  5115. logger.warning(f"Unknown array type: {type_id}")
  5116. d[k] = None
  5117. else:
  5118. # Handle single objects
  5119. json_content, type_id = parse_agtype_string(v)
  5120. if type_id in ["vertex", "edge"]:
  5121. parsed_data = safe_json_parse(
  5122. json_content, f"single {type_id} for field {k}"
  5123. )
  5124. d[k] = parsed_data if parsed_data is not None else None
  5125. else:
  5126. # May be other types of agtype data, keep as is
  5127. d[k] = v
  5128. else:
  5129. d[k] = v # Keep as string
  5130. return d
  5131. @staticmethod
  5132. def _format_properties(
  5133. properties: dict[str, Any], _id: Union[str, None] = None
  5134. ) -> str:
  5135. """
  5136. Convert a dictionary of properties to a string representation that
  5137. can be used in a cypher query insert/merge statement.
  5138. Args:
  5139. properties (dict[str,str]): a dictionary containing node/edge properties
  5140. _id (Union[str, None]): the id of the node or None if none exists
  5141. Returns:
  5142. str: the properties dictionary as a properly formatted string
  5143. """
  5144. props = []
  5145. # Wrap property keys in backticks and escape embedded backticks to
  5146. # preserve the Cypher structure when property names contain specials.
  5147. for k, v in properties.items():
  5148. safe_key = str(k).replace("`", "``")
  5149. prop = f"`{safe_key}`: {json.dumps(v, ensure_ascii=False)}"
  5150. props.append(prop)
  5151. if _id is not None and "id" not in properties:
  5152. props.append(
  5153. f"id: {json.dumps(_id, ensure_ascii=False)}"
  5154. if isinstance(_id, str)
  5155. else f"id: {_id}"
  5156. )
  5157. return "{" + ", ".join(props) + "}"
  5158. async def _query(
  5159. self,
  5160. query: str,
  5161. readonly: bool = True,
  5162. upsert: bool = False,
  5163. params: dict[str, Any] | None = None,
  5164. timing_label: str | None = None,
  5165. ) -> list[dict[str, Any]]:
  5166. """
  5167. Query the graph by taking a cypher query, converting it to an
  5168. age compatible query, executing it and converting the result
  5169. Args:
  5170. query (str): a cypher query to be executed
  5171. readonly (bool): if True, uses db.query; if False, uses db.execute.
  5172. Both paths support the ``params`` argument.
  5173. upsert (bool): passed through to db.execute for write operations.
  5174. params (dict | None): AGE agtype parameters for parameterized Cypher
  5175. (e.g. ``{"params": json.dumps({"entity_id": "..."})}``).
  5176. Honoured for both read and write paths.
  5177. timing_label (str | None): optional label for performance logging.
  5178. Returns:
  5179. list[dict[str, Any]]: a list of dictionaries containing the result set
  5180. """
  5181. try:
  5182. if readonly:
  5183. data = await self.db.query(
  5184. query,
  5185. list(params.values()) if params else None,
  5186. multirows=True,
  5187. with_age=True,
  5188. graph_name=self.graph_name,
  5189. timing_label=timing_label,
  5190. )
  5191. else:
  5192. age_execute_start = time.perf_counter()
  5193. data = await self.db.execute(
  5194. query,
  5195. data=params,
  5196. upsert=upsert,
  5197. with_age=True,
  5198. graph_name=self.graph_name,
  5199. timing_label=timing_label,
  5200. )
  5201. if timing_label:
  5202. performance_timing_log(
  5203. "[%s] AGE execute completed in %.4fs",
  5204. timing_label,
  5205. time.perf_counter() - age_execute_start,
  5206. )
  5207. except Exception as e:
  5208. if timing_label and not readonly:
  5209. performance_timing_log(
  5210. "[%s] AGE execute failed after %.4fs",
  5211. timing_label,
  5212. time.perf_counter() - age_execute_start,
  5213. )
  5214. raise PGGraphQueryException(
  5215. {
  5216. "message": f"Error executing graph query: {query}",
  5217. "wrapped": query,
  5218. "detail": repr(e),
  5219. "error_type": e.__class__.__name__,
  5220. }
  5221. ) from e
  5222. if data is None:
  5223. result = []
  5224. # decode records
  5225. else:
  5226. result = [self._record_to_dict(d) for d in data]
  5227. return result
  5228. async def has_node(self, node_id: str) -> bool:
  5229. query = f"""
  5230. SELECT EXISTS (
  5231. SELECT 1
  5232. FROM {self.graph_name}.base
  5233. WHERE ag_catalog.agtype_access_operator(
  5234. VARIADIC ARRAY[properties, '"entity_id"'::agtype]
  5235. ) = (to_json($1::text)::text)::agtype
  5236. LIMIT 1
  5237. ) AS node_exists;
  5238. """
  5239. params = {"node_id": node_id}
  5240. row = (await self._query(query, params=params))[0]
  5241. return bool(row["node_exists"])
  5242. async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
  5243. query = f"""
  5244. WITH a AS (
  5245. SELECT id AS vid
  5246. FROM {self.graph_name}.base
  5247. WHERE ag_catalog.agtype_access_operator(
  5248. VARIADIC ARRAY[properties, '"entity_id"'::agtype]
  5249. ) = (to_json($1::text)::text)::agtype
  5250. ),
  5251. b AS (
  5252. SELECT id AS vid
  5253. FROM {self.graph_name}.base
  5254. WHERE ag_catalog.agtype_access_operator(
  5255. VARIADIC ARRAY[properties, '"entity_id"'::agtype]
  5256. ) = (to_json($2::text)::text)::agtype
  5257. )
  5258. SELECT EXISTS (
  5259. SELECT 1
  5260. FROM {self.graph_name}."DIRECTED" d
  5261. JOIN a ON d.start_id = a.vid
  5262. JOIN b ON d.end_id = b.vid
  5263. LIMIT 1
  5264. )
  5265. OR EXISTS (
  5266. SELECT 1
  5267. FROM {self.graph_name}."DIRECTED" d
  5268. JOIN a ON d.end_id = a.vid
  5269. JOIN b ON d.start_id = b.vid
  5270. LIMIT 1
  5271. ) AS edge_exists;
  5272. """
  5273. params = {
  5274. "source_node_id": source_node_id,
  5275. "target_node_id": target_node_id,
  5276. }
  5277. row = (await self._query(query, params=params))[0]
  5278. return bool(row["edge_exists"])
  5279. async def get_node(self, node_id: str) -> dict[str, str] | None:
  5280. """Get node by its label identifier, return only node properties"""
  5281. result = await self.get_nodes_batch(node_ids=[node_id])
  5282. if result and node_id in result:
  5283. return result[node_id]
  5284. return None
  5285. async def node_degree(self, node_id: str) -> int:
  5286. result = await self.node_degrees_batch(node_ids=[node_id])
  5287. if result and node_id in result:
  5288. return result[node_id]
  5289. return 0
  5290. async def edge_degree(self, src_id: str, tgt_id: str) -> int:
  5291. result = await self.edge_degrees_batch(edges=[(src_id, tgt_id)])
  5292. if result and (src_id, tgt_id) in result:
  5293. return result[(src_id, tgt_id)]
  5294. return 0
  5295. async def get_edge(
  5296. self, source_node_id: str, target_node_id: str
  5297. ) -> dict[str, str] | None:
  5298. """Get edge properties between two nodes"""
  5299. result = await self.get_edges_batch(
  5300. [{"src": source_node_id, "tgt": target_node_id}]
  5301. )
  5302. if result and (source_node_id, target_node_id) in result:
  5303. return result[(source_node_id, target_node_id)]
  5304. return None
  5305. async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
  5306. """
  5307. Retrieves all edges (relationships) for a particular node identified by its label.
  5308. :return: list of dictionaries containing edge information
  5309. """
  5310. cypher_query = """MATCH (n:base {entity_id: $entity_id})
  5311. OPTIONAL MATCH (n)-[]-(connected:base)
  5312. RETURN n.entity_id AS source_id, connected.entity_id AS connected_id"""
  5313. query = f"SELECT * FROM cypher({_dollar_quote(self.graph_name)}::name, {_dollar_quote(cypher_query)}::cstring, $1::agtype) AS (source_id text, connected_id text)"
  5314. pg_params = {
  5315. "params": json.dumps({"entity_id": source_node_id}, ensure_ascii=False)
  5316. }
  5317. results = await self._query(query, params=pg_params)
  5318. edges = []
  5319. for record in results:
  5320. source_id = record["source_id"]
  5321. connected_id = record["connected_id"]
  5322. if source_id and connected_id:
  5323. edges.append((source_id, connected_id))
  5324. return edges
  5325. @retry(
  5326. stop=stop_after_attempt(3),
  5327. wait=wait_exponential(multiplier=1, min=4, max=10),
  5328. retry=retry_if_exception(_is_transient_graph_write_error),
  5329. reraise=True,
  5330. )
  5331. async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
  5332. """
  5333. Upsert a node in the Neo4j database.
  5334. Args:
  5335. node_id: The unique identifier for the node (used as label)
  5336. node_data: Dictionary of node properties
  5337. """
  5338. if "entity_id" not in node_data:
  5339. raise ValueError(
  5340. "PostgreSQL: node properties must contain an 'entity_id' field"
  5341. )
  5342. # AGE supports binding scalar values in Cypher parameters here, but not
  5343. # using a bound agtype object on ``SET n += $props`` (verified on AGE 1.5.0).
  5344. # Keep the node ID parameterized and inline a safely escaped property map literal.
  5345. node_props = {k: v for k, v in node_data.items() if k != "entity_id"}
  5346. props_literal = self._format_properties(node_props)
  5347. cypher_query = f"""MERGE (n:base {{entity_id: $entity_id}})
  5348. SET n += {props_literal}
  5349. RETURN n"""
  5350. query = (
  5351. f"SELECT * FROM cypher("
  5352. f"{_dollar_quote(self.graph_name)}::name, "
  5353. f"{_dollar_quote(cypher_query)}::cstring, "
  5354. f"$1::agtype) AS (n agtype)"
  5355. )
  5356. pg_params = {
  5357. "params": json.dumps(
  5358. {"entity_id": node_id},
  5359. ensure_ascii=False,
  5360. )
  5361. }
  5362. timing_label = f"{self.workspace} PGGraphStorage.upsert_node"
  5363. total_start = time.perf_counter()
  5364. performance_timing_log(
  5365. "[%s] start node_id=%s",
  5366. timing_label,
  5367. node_id,
  5368. )
  5369. try:
  5370. await self._query(
  5371. query,
  5372. readonly=False,
  5373. upsert=True,
  5374. params=pg_params,
  5375. timing_label=timing_label,
  5376. )
  5377. performance_timing_log(
  5378. "[%s] total complete in %.4fs node_id=%s",
  5379. timing_label,
  5380. time.perf_counter() - total_start,
  5381. node_id,
  5382. )
  5383. except Exception:
  5384. performance_timing_log(
  5385. "[%s] total failed after %.4fs node_id=%s",
  5386. timing_label,
  5387. time.perf_counter() - total_start,
  5388. node_id,
  5389. )
  5390. logger.error(
  5391. f"[{self.workspace}] POSTGRES, upsert_node error on node_id: `{node_id}`"
  5392. )
  5393. raise
  5394. @retry(
  5395. stop=stop_after_attempt(3),
  5396. wait=wait_exponential(multiplier=1, min=4, max=10),
  5397. retry=retry_if_exception(_is_transient_graph_write_error),
  5398. reraise=True,
  5399. )
  5400. async def upsert_edge(
  5401. self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
  5402. ) -> None:
  5403. """
  5404. Upsert an edge and its properties between two nodes identified by their labels.
  5405. Args:
  5406. source_node_id (str): Label of the source node (used as identifier)
  5407. target_node_id (str): Label of the target node (used as identifier)
  5408. edge_data (dict): dictionary of properties to set on the edge
  5409. """
  5410. # AGE does not support binding a full agtype map in ``SET r += $props``
  5411. # (verified on AGE 1.5.0), and the inlined literal form ``SET r += {map}``
  5412. # is also silently ignored for edges (though it works for nodes). Individual
  5413. # ``SET r.key = value`` assignments run without error but also do not persist.
  5414. # The only reliable way to write edge properties in AGE is to inline them
  5415. # directly in a CREATE clause. We use OPTIONAL MATCH to delete any existing
  5416. # edge first so the operation remains idempotent.
  5417. #
  5418. # Concurrency: OPTIONAL MATCH + DELETE + CREATE is not atomic against other
  5419. # writers — two transactions upserting the same pair could both observe no
  5420. # existing edge and both CREATE one, leaving duplicate DIRECTED rows that
  5421. # inflate degree counts and duplicate relations. We serialise per logical
  5422. # edge with a transaction-scoped advisory lock keyed on
  5423. # (graph_name, ordered (src_id, tgt_id)) so:
  5424. # - {A,B} and {B,A} collide on the same lock (the OPTIONAL MATCH is
  5425. # undirected), and
  5426. # - the same (A,B) pair in different AGE graphs / workspaces does NOT
  5427. # collide. pg_advisory_xact_lock is database-wide, and we don't want
  5428. # independent tenants to serialise each other's ingestion.
  5429. # AGE refuses to plan a join against a cypher() call that contains a
  5430. # CREATE clause ("cypher create clause cannot be rescanned"), so we cannot
  5431. # use a CTE for the lock. Instead we open an explicit transaction and run
  5432. # two statements on the same connection: the lock acquisition first, then
  5433. # the cypher upsert. The lock is released when the transaction commits.
  5434. props_literal = self._format_properties(edge_data) if edge_data else "{}"
  5435. cypher_query = f"""MATCH (source:base {{entity_id: $src_id}})
  5436. WITH source
  5437. MATCH (target:base {{entity_id: $tgt_id}})
  5438. WITH source, target
  5439. OPTIONAL MATCH (source)-[old:DIRECTED]-(target)
  5440. DELETE old
  5441. WITH source, target
  5442. CREATE (source)-[r:DIRECTED {props_literal}]->(target)
  5443. RETURN r"""
  5444. lock_sql = (
  5445. "SELECT pg_advisory_xact_lock("
  5446. " hashtextextended("
  5447. " $1::text || E'\\x01' ||"
  5448. " LEAST($2::text, $3::text) || E'\\x01' || GREATEST($2::text, $3::text),"
  5449. " 0"
  5450. " )"
  5451. ")"
  5452. )
  5453. cypher_sql = (
  5454. f"SELECT r FROM cypher("
  5455. f"{_dollar_quote(self.graph_name)}::name, "
  5456. f"{_dollar_quote(cypher_query)}::cstring, "
  5457. f"$1::agtype) AS (r agtype)"
  5458. )
  5459. params_json = json.dumps(
  5460. {"src_id": source_node_id, "tgt_id": target_node_id},
  5461. ensure_ascii=False,
  5462. )
  5463. timing_label = f"{self.workspace} PGGraphStorage.upsert_edge"
  5464. total_start = time.perf_counter()
  5465. performance_timing_log(
  5466. "[%s] start source_node_id=%s target_node_id=%s",
  5467. timing_label,
  5468. source_node_id,
  5469. target_node_id,
  5470. )
  5471. async def _operation(connection: asyncpg.Connection) -> None:
  5472. async with connection.transaction():
  5473. await connection.execute(
  5474. lock_sql, self.graph_name, source_node_id, target_node_id
  5475. )
  5476. await connection.execute(cypher_sql, params_json)
  5477. try:
  5478. await self.db._run_with_retry(
  5479. _operation,
  5480. with_age=True,
  5481. graph_name=self.graph_name,
  5482. timing_label=timing_label,
  5483. )
  5484. performance_timing_log(
  5485. "[%s] total complete in %.4fs source_node_id=%s target_node_id=%s",
  5486. timing_label,
  5487. time.perf_counter() - total_start,
  5488. source_node_id,
  5489. target_node_id,
  5490. )
  5491. except Exception as e:
  5492. performance_timing_log(
  5493. "[%s] total failed after %.4fs source_node_id=%s target_node_id=%s",
  5494. timing_label,
  5495. time.perf_counter() - total_start,
  5496. source_node_id,
  5497. target_node_id,
  5498. )
  5499. logger.error(
  5500. f"[{self.workspace}] POSTGRES, upsert_edge error on edge: `{source_node_id}`-`{target_node_id}`"
  5501. )
  5502. # Re-raise as PGGraphQueryException so the outer @retry's
  5503. # _is_transient_graph_write_error predicate can inspect __cause__ and
  5504. # retry on DeadlockDetectedError / SerializationError /
  5505. # LockNotAvailableError / QueryCanceledError — mirrors what _query
  5506. # does for upsert_node and the rest of the AGE write paths. Without
  5507. # this wrapping, query-level transient errors from connection.execute
  5508. # would surface as raw asyncpg exceptions, fail isinstance() in the
  5509. # predicate, and skip retries.
  5510. if isinstance(e, PGGraphQueryException):
  5511. raise
  5512. raise PGGraphQueryException(
  5513. {
  5514. "message": (
  5515. f"Error executing graph upsert_edge: "
  5516. f"`{source_node_id}`-`{target_node_id}`"
  5517. ),
  5518. "wrapped": cypher_sql,
  5519. "detail": repr(e),
  5520. "error_type": e.__class__.__name__,
  5521. }
  5522. ) from e
  5523. async def upsert_nodes_batch(self, nodes: list[tuple[str, dict[str, str]]]) -> None:
  5524. """Batch insert/update multiple nodes while preserving input-order semantics.
  5525. PostgreSQL/AGE write paths embed properties directly in Cypher strings and do not
  5526. yet support parameterized UNWIND. Deduplicating by node ID first preserves the
  5527. last-write-wins behaviour of the historical serial fallback.
  5528. Args:
  5529. nodes: List of (node_id, node_data) tuples.
  5530. """
  5531. if not nodes:
  5532. return
  5533. deduped_nodes: dict[str, dict[str, str]] = {}
  5534. for node_id, node_data in nodes:
  5535. deduped_nodes.pop(node_id, None)
  5536. deduped_nodes[node_id] = node_data
  5537. for node_id, node_data in deduped_nodes.items():
  5538. await self.upsert_node(node_id, node_data=node_data)
  5539. async def has_nodes_batch(self, node_ids: list[str]) -> set[str]:
  5540. """Check existence of multiple nodes using a single array-based SQL query.
  5541. Args:
  5542. node_ids: List of node IDs to check.
  5543. Returns:
  5544. Set of node_ids that exist in the graph.
  5545. """
  5546. if not node_ids:
  5547. return set()
  5548. result = await self.get_nodes_batch(node_ids)
  5549. return set(result.keys())
  5550. async def upsert_edges_batch(
  5551. self, edges: list[tuple[str, str, dict[str, str]]]
  5552. ) -> None:
  5553. """Batch insert/update multiple edges while preserving input-order semantics.
  5554. PostgreSQL/AGE relationships are undirected (`MERGE (source)-[r:DIRECTED]-(target)`),
  5555. so batches containing reciprocal duplicates must retain the last update for each
  5556. endpoint pair to match the historical serial fallback.
  5557. Args:
  5558. edges: List of (source_node_id, target_node_id, edge_data) tuples.
  5559. """
  5560. if not edges:
  5561. return
  5562. deduped_edges: dict[tuple[str, str], tuple[str, str, dict[str, str]]] = {}
  5563. for src, tgt, edge_data in edges:
  5564. edge_key = tuple(sorted((src, tgt)))
  5565. deduped_edges.pop(edge_key, None)
  5566. deduped_edges[edge_key] = (src, tgt, edge_data)
  5567. # Iterate in canonical (LEAST, GREATEST) order rather than dict
  5568. # insertion order. upsert_edge opens an independent transaction per
  5569. # call and releases the advisory lock on commit, so this is not a
  5570. # deadlock fix — but a deterministic iteration order makes logs and
  5571. # replays reproducible across callers, and matches the dedup key
  5572. # already used above.
  5573. for edge_key in sorted(deduped_edges):
  5574. src, tgt, edge_data = deduped_edges[edge_key]
  5575. await self.upsert_edge(src, tgt, edge_data=edge_data)
  5576. async def delete_node(self, node_id: str) -> None:
  5577. """
  5578. Delete a node from the graph.
  5579. Args:
  5580. node_id (str): The ID of the node to delete.
  5581. """
  5582. label = self._normalize_node_id(node_id)
  5583. # Build Cypher query with dynamic dollar-quoting to handle entity_id containing $ sequences
  5584. cypher_query = f"""MATCH (n:base {{entity_id: "{label}"}})
  5585. DETACH DELETE n"""
  5586. query = f"SELECT * FROM cypher({_dollar_quote(self.graph_name)}, {_dollar_quote(cypher_query)}) AS (n agtype)"
  5587. try:
  5588. await self._query(query, readonly=False)
  5589. except Exception as e:
  5590. logger.error(f"[{self.workspace}] Error during node deletion: {e}")
  5591. raise
  5592. async def remove_nodes(self, node_ids: list[str]) -> None:
  5593. """
  5594. Remove multiple nodes from the graph.
  5595. Args:
  5596. node_ids (list[str]): A list of node IDs to remove.
  5597. """
  5598. node_ids_normalized = [self._normalize_node_id(node_id) for node_id in node_ids]
  5599. node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids_normalized])
  5600. # Build Cypher query with dynamic dollar-quoting to handle entity_id containing $ sequences
  5601. cypher_query = f"""MATCH (n:base)
  5602. WHERE n.entity_id IN [{node_id_list}]
  5603. DETACH DELETE n"""
  5604. query = f"SELECT * FROM cypher({_dollar_quote(self.graph_name)}, {_dollar_quote(cypher_query)}) AS (n agtype)"
  5605. try:
  5606. await self._query(query, readonly=False)
  5607. except Exception as e:
  5608. logger.error(f"[{self.workspace}] Error during node removal: {e}")
  5609. raise
  5610. async def remove_edges(self, edges: list[tuple[str, str]]) -> None:
  5611. """
  5612. Remove multiple edges from the graph.
  5613. Args:
  5614. edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
  5615. """
  5616. for source, target in edges:
  5617. src_label = self._normalize_node_id(source)
  5618. tgt_label = self._normalize_node_id(target)
  5619. # Build Cypher query with dynamic dollar-quoting to handle entity_id containing $ sequences
  5620. cypher_query = f"""MATCH (a:base {{entity_id: "{src_label}"}})-[r]-(b:base {{entity_id: "{tgt_label}"}})
  5621. DELETE r"""
  5622. query = f"SELECT * FROM cypher({_dollar_quote(self.graph_name)}, {_dollar_quote(cypher_query)}) AS (r agtype)"
  5623. try:
  5624. await self._query(query, readonly=False)
  5625. logger.debug(
  5626. f"[{self.workspace}] Deleted edge from '{source}' to '{target}'"
  5627. )
  5628. except Exception as e:
  5629. logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}")
  5630. raise
  5631. async def get_nodes_batch(
  5632. self, node_ids: list[str], batch_size: int = 1000
  5633. ) -> dict[str, dict]:
  5634. """
  5635. Retrieve multiple nodes in one query using UNWIND.
  5636. Args:
  5637. node_ids: List of node entity IDs to fetch.
  5638. batch_size: Batch size for the query
  5639. Returns:
  5640. A dictionary mapping each node_id to its node data (or None if not found).
  5641. """
  5642. if not node_ids:
  5643. return {}
  5644. seen: set[str] = set()
  5645. unique_ids: list[str] = []
  5646. lookup: dict[str, str] = {}
  5647. requested: set[str] = set()
  5648. for nid in node_ids:
  5649. if nid not in seen:
  5650. seen.add(nid)
  5651. unique_ids.append(nid)
  5652. requested.add(nid)
  5653. lookup[nid] = nid
  5654. lookup[self._normalize_node_id(nid)] = nid
  5655. # Build result dictionary
  5656. nodes_dict = {}
  5657. for i in range(0, len(unique_ids), batch_size):
  5658. batch = unique_ids[i : i + batch_size]
  5659. query = f"""
  5660. WITH input(v, ord) AS (
  5661. SELECT v, ord
  5662. FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord)
  5663. ),
  5664. ids(node_id, ord) AS (
  5665. SELECT (to_json(v)::text)::agtype AS node_id, ord
  5666. FROM input
  5667. )
  5668. SELECT i.node_id::text AS node_id,
  5669. b.properties
  5670. FROM {self.graph_name}.base AS b
  5671. JOIN ids i
  5672. ON ag_catalog.agtype_access_operator(
  5673. VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]
  5674. ) = i.node_id
  5675. ORDER BY i.ord;
  5676. """
  5677. results = await self._query(query, params={"ids": batch})
  5678. for result in results:
  5679. if result["node_id"] and result["properties"]:
  5680. node_dict = result["properties"]
  5681. # Process string result, parse it to JSON dictionary
  5682. if isinstance(node_dict, str):
  5683. try:
  5684. node_dict = json.loads(node_dict)
  5685. except json.JSONDecodeError:
  5686. logger.warning(
  5687. f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
  5688. )
  5689. node_key = result["node_id"]
  5690. original_key = lookup.get(node_key)
  5691. if original_key is None:
  5692. logger.warning(
  5693. f"[{self.workspace}] Node {node_key} not found in lookup map"
  5694. )
  5695. original_key = node_key
  5696. if original_key in requested:
  5697. nodes_dict[original_key] = node_dict
  5698. return nodes_dict
  5699. async def node_degrees_batch(
  5700. self, node_ids: list[str], batch_size: int = 500
  5701. ) -> dict[str, int]:
  5702. """
  5703. Retrieve the degree for multiple nodes in a single query using UNWIND.
  5704. Calculates the total degree by counting distinct relationships.
  5705. Uses separate queries for outgoing and incoming edges.
  5706. Args:
  5707. node_ids: List of node labels (entity_id values) to look up.
  5708. batch_size: Batch size for the query
  5709. Returns:
  5710. A dictionary mapping each node_id to its degree (total number of relationships).
  5711. If a node is not found, its degree will be set to 0.
  5712. """
  5713. if not node_ids:
  5714. return {}
  5715. seen: set[str] = set()
  5716. unique_ids: list[str] = []
  5717. lookup: dict[str, str] = {}
  5718. requested: set[str] = set()
  5719. for nid in node_ids:
  5720. if nid not in seen:
  5721. seen.add(nid)
  5722. unique_ids.append(nid)
  5723. requested.add(nid)
  5724. lookup[nid] = nid
  5725. lookup[self._normalize_node_id(nid)] = nid
  5726. out_degrees = {}
  5727. in_degrees = {}
  5728. for i in range(0, len(unique_ids), batch_size):
  5729. batch = unique_ids[i : i + batch_size]
  5730. query = f"""
  5731. WITH input(v, ord) AS (
  5732. SELECT v, ord
  5733. FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord)
  5734. ),
  5735. ids(node_id, ord) AS (
  5736. SELECT (to_json(v)::text)::agtype AS node_id, ord
  5737. FROM input
  5738. ),
  5739. vids AS (
  5740. SELECT b.id AS vid, i.node_id, i.ord
  5741. FROM {self.graph_name}.base AS b
  5742. JOIN ids i
  5743. ON ag_catalog.agtype_access_operator(
  5744. VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]
  5745. ) = i.node_id
  5746. ),
  5747. deg_out AS (
  5748. SELECT d.start_id AS vid, COUNT(*)::bigint AS out_degree
  5749. FROM {self.graph_name}."DIRECTED" AS d
  5750. JOIN vids v ON v.vid = d.start_id
  5751. GROUP BY d.start_id
  5752. ),
  5753. deg_in AS (
  5754. SELECT d.end_id AS vid, COUNT(*)::bigint AS in_degree
  5755. FROM {self.graph_name}."DIRECTED" AS d
  5756. JOIN vids v ON v.vid = d.end_id
  5757. GROUP BY d.end_id
  5758. )
  5759. SELECT v.node_id::text AS node_id,
  5760. COALESCE(o.out_degree, 0) AS out_degree,
  5761. COALESCE(n.in_degree, 0) AS in_degree
  5762. FROM vids v
  5763. LEFT JOIN deg_out o ON o.vid = v.vid
  5764. LEFT JOIN deg_in n ON n.vid = v.vid
  5765. ORDER BY v.ord;
  5766. """
  5767. combined_results = await self._query(query, params={"ids": batch})
  5768. for row in combined_results:
  5769. node_id = row["node_id"]
  5770. if not node_id:
  5771. continue
  5772. node_key = node_id
  5773. original_key = lookup.get(node_key)
  5774. if original_key is None:
  5775. logger.warning(
  5776. f"[{self.workspace}] Node {node_key} not found in lookup map"
  5777. )
  5778. original_key = node_key
  5779. if original_key in requested:
  5780. out_degrees[original_key] = int(row.get("out_degree", 0) or 0)
  5781. in_degrees[original_key] = int(row.get("in_degree", 0) or 0)
  5782. degrees_dict = {}
  5783. for node_id in node_ids:
  5784. out_degree = out_degrees.get(node_id, 0)
  5785. in_degree = in_degrees.get(node_id, 0)
  5786. degrees_dict[node_id] = out_degree + in_degree
  5787. return degrees_dict
  5788. async def edge_degrees_batch(
  5789. self, edges: list[tuple[str, str]]
  5790. ) -> dict[tuple[str, str], int]:
  5791. """
  5792. Calculate the combined degree for each edge (sum of the source and target node degrees)
  5793. in batch using the already implemented node_degrees_batch.
  5794. Args:
  5795. edges: List of (source_node_id, target_node_id) tuples
  5796. Returns:
  5797. Dictionary mapping edge tuples to their combined degrees
  5798. """
  5799. if not edges:
  5800. return {}
  5801. # Use node_degrees_batch to get all node degrees efficiently
  5802. all_nodes = set()
  5803. for src, tgt in edges:
  5804. all_nodes.add(src)
  5805. all_nodes.add(tgt)
  5806. node_degrees = await self.node_degrees_batch(list(all_nodes))
  5807. # Calculate edge degrees
  5808. edge_degrees_dict = {}
  5809. for src, tgt in edges:
  5810. src_degree = node_degrees.get(src, 0)
  5811. tgt_degree = node_degrees.get(tgt, 0)
  5812. edge_degrees_dict[(src, tgt)] = src_degree + tgt_degree
  5813. return edge_degrees_dict
  5814. async def get_edges_batch(
  5815. self, pairs: list[dict[str, str]], batch_size: int = 500
  5816. ) -> dict[tuple[str, str], dict]:
  5817. """
  5818. Retrieve edge properties for multiple (src, tgt) pairs in one query.
  5819. Get forward and backward edges separately and merge them before return
  5820. Args:
  5821. pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
  5822. batch_size: Batch size for the query
  5823. Returns:
  5824. A dictionary mapping (src, tgt) tuples to their edge properties.
  5825. """
  5826. if not pairs:
  5827. return {}
  5828. seen = set()
  5829. uniq_pairs: list[dict[str, str]] = []
  5830. for p in pairs:
  5831. s = self._normalize_node_id(p["src"])
  5832. t = self._normalize_node_id(p["tgt"])
  5833. key = (s, t)
  5834. if s and t and key not in seen:
  5835. seen.add(key)
  5836. uniq_pairs.append(p)
  5837. edges_dict: dict[tuple[str, str], dict] = {}
  5838. for i in range(0, len(uniq_pairs), batch_size):
  5839. batch = uniq_pairs[i : i + batch_size]
  5840. pairs = [{"src": p["src"], "tgt": p["tgt"]} for p in batch]
  5841. forward_cypher = """
  5842. UNWIND $pairs AS p
  5843. WITH p.src AS src_eid, p.tgt AS tgt_eid
  5844. MATCH (a:base {entity_id: src_eid})
  5845. MATCH (b:base {entity_id: tgt_eid})
  5846. MATCH (a)-[r]->(b)
  5847. RETURN src_eid AS source, tgt_eid AS target, properties(r) AS edge_properties"""
  5848. backward_cypher = """
  5849. UNWIND $pairs AS p
  5850. WITH p.src AS src_eid, p.tgt AS tgt_eid
  5851. MATCH (a:base {entity_id: src_eid})
  5852. MATCH (b:base {entity_id: tgt_eid})
  5853. MATCH (a)<-[r]-(b)
  5854. RETURN src_eid AS source, tgt_eid AS target, properties(r) AS edge_properties"""
  5855. sql_fwd = f"""
  5856. SELECT * FROM cypher({_dollar_quote(self.graph_name)}::name,
  5857. {_dollar_quote(forward_cypher)}::cstring,
  5858. $1::agtype)
  5859. AS (source text, target text, edge_properties agtype)
  5860. """
  5861. sql_bwd = f"""
  5862. SELECT * FROM cypher({_dollar_quote(self.graph_name)}::name,
  5863. {_dollar_quote(backward_cypher)}::cstring,
  5864. $1::agtype)
  5865. AS (source text, target text, edge_properties agtype)
  5866. """
  5867. pg_params = {"params": json.dumps({"pairs": pairs}, ensure_ascii=False)}
  5868. forward_results = await self._query(sql_fwd, params=pg_params)
  5869. backward_results = await self._query(sql_bwd, params=pg_params)
  5870. for result in forward_results:
  5871. if result["source"] and result["target"] and result["edge_properties"]:
  5872. edge_props = result["edge_properties"]
  5873. # Process string result, parse it to JSON dictionary
  5874. if isinstance(edge_props, str):
  5875. try:
  5876. edge_props = json.loads(edge_props)
  5877. except json.JSONDecodeError:
  5878. logger.warning(
  5879. f"[{self.workspace}]Failed to parse edge properties string: {edge_props}"
  5880. )
  5881. continue
  5882. edges_dict[(result["source"], result["target"])] = edge_props
  5883. for result in backward_results:
  5884. if result["source"] and result["target"] and result["edge_properties"]:
  5885. edge_props = result["edge_properties"]
  5886. # Process string result, parse it to JSON dictionary
  5887. if isinstance(edge_props, str):
  5888. try:
  5889. edge_props = json.loads(edge_props)
  5890. except json.JSONDecodeError:
  5891. logger.warning(
  5892. f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
  5893. )
  5894. continue
  5895. edges_dict[(result["source"], result["target"])] = edge_props
  5896. return edges_dict
  5897. async def get_nodes_edges_batch(
  5898. self, node_ids: list[str], batch_size: int = 500
  5899. ) -> dict[str, list[tuple[str, str]]]:
  5900. """
  5901. Get all edges (both outgoing and incoming) for multiple nodes in a single batch operation.
  5902. Args:
  5903. node_ids: List of node IDs to get edges for
  5904. batch_size: Batch size for the query
  5905. Returns:
  5906. Dictionary mapping node IDs to lists of (source, target) edge tuples
  5907. """
  5908. if not node_ids:
  5909. return {}
  5910. seen = set()
  5911. unique_ids: list[str] = []
  5912. for nid in node_ids:
  5913. if nid and nid not in seen:
  5914. seen.add(nid)
  5915. unique_ids.append(nid)
  5916. edges_norm: dict[str, list[tuple[str, str]]] = {n: [] for n in unique_ids}
  5917. for i in range(0, len(unique_ids), batch_size):
  5918. batch = unique_ids[i : i + batch_size]
  5919. pg_params = {"params": json.dumps({"node_ids": batch}, ensure_ascii=False)}
  5920. outgoing_cypher = """UNWIND $node_ids AS node_id
  5921. MATCH (n:base {entity_id: node_id})
  5922. OPTIONAL MATCH (n:base)-[]->(connected:base)
  5923. RETURN node_id, connected.entity_id AS connected_id"""
  5924. incoming_cypher = """UNWIND $node_ids AS node_id
  5925. MATCH (n:base {entity_id: node_id})
  5926. OPTIONAL MATCH (n:base)<-[]-(connected:base)
  5927. RETURN node_id, connected.entity_id AS connected_id"""
  5928. outgoing_query = f"SELECT * FROM cypher({_dollar_quote(self.graph_name)}::name, {_dollar_quote(outgoing_cypher)}::cstring, $1::agtype) AS (node_id text, connected_id text)"
  5929. incoming_query = f"SELECT * FROM cypher({_dollar_quote(self.graph_name)}::name, {_dollar_quote(incoming_cypher)}::cstring, $1::agtype) AS (node_id text, connected_id text)"
  5930. outgoing_results = await self._query(outgoing_query, params=pg_params)
  5931. incoming_results = await self._query(incoming_query, params=pg_params)
  5932. for result in outgoing_results:
  5933. if result["node_id"] and result["connected_id"]:
  5934. edges_norm[result["node_id"]].append(
  5935. (result["node_id"], result["connected_id"])
  5936. )
  5937. for result in incoming_results:
  5938. if result["node_id"] and result["connected_id"]:
  5939. edges_norm[result["node_id"]].append(
  5940. (result["connected_id"], result["node_id"])
  5941. )
  5942. out: dict[str, list[tuple[str, str]]] = {}
  5943. for orig in node_ids:
  5944. out[orig] = edges_norm.get(orig, [])
  5945. return out
  5946. async def get_all_labels(self) -> list[str]:
  5947. """
  5948. Get all labels(node IDs, entity names) in the graph.
  5949. Returns:
  5950. list[str]: A list of all labels in the graph.
  5951. """
  5952. query = (
  5953. """SELECT * FROM cypher('%s', $$
  5954. MATCH (n:base)
  5955. WHERE n.entity_id IS NOT NULL
  5956. RETURN DISTINCT n.entity_id AS label
  5957. ORDER BY n.entity_id
  5958. $$) AS (label text)"""
  5959. % self.graph_name
  5960. )
  5961. results = await self._query(query)
  5962. labels = []
  5963. for result in results:
  5964. if result and isinstance(result, dict) and "label" in result:
  5965. labels.append(result["label"])
  5966. return labels
  5967. async def _bfs_subgraph(
  5968. self, node_label: str, max_depth: int, max_nodes: int
  5969. ) -> KnowledgeGraph:
  5970. """
  5971. Implements a true breadth-first search algorithm for subgraph retrieval.
  5972. This method is used as a fallback when the standard Cypher query is too slow
  5973. or when we need to guarantee BFS ordering.
  5974. Args:
  5975. node_label: Label of the starting node
  5976. max_depth: Maximum depth of the subgraph
  5977. max_nodes: Maximum number of nodes to return
  5978. Returns:
  5979. KnowledgeGraph object containing nodes and edges
  5980. """
  5981. from collections import deque
  5982. result = KnowledgeGraph()
  5983. visited_nodes = set()
  5984. visited_node_ids = set()
  5985. visited_edges = set()
  5986. visited_edge_pairs = set()
  5987. # Get starting node data
  5988. label = self._normalize_node_id(node_label)
  5989. # Build Cypher query with dynamic dollar-quoting to handle entity_id containing $ sequences
  5990. cypher_query = f"""MATCH (n:base {{entity_id: "{label}"}})
  5991. RETURN id(n) as node_id, n"""
  5992. query = f"SELECT * FROM cypher({_dollar_quote(self.graph_name)}, {_dollar_quote(cypher_query)}) AS (node_id bigint, n agtype)"
  5993. node_result = await self._query(query)
  5994. if not node_result or not node_result[0].get("n"):
  5995. return result
  5996. # Create initial KnowledgeGraphNode
  5997. start_node_data = node_result[0]["n"]
  5998. entity_id = start_node_data["properties"]["entity_id"]
  5999. internal_id = str(start_node_data["id"])
  6000. start_node = KnowledgeGraphNode(
  6001. id=internal_id,
  6002. labels=[entity_id],
  6003. properties=start_node_data["properties"],
  6004. )
  6005. # Initialize BFS queue, each element is a tuple of (node, depth)
  6006. queue = deque([(start_node, 0)])
  6007. visited_nodes.add(entity_id)
  6008. visited_node_ids.add(internal_id)
  6009. result.nodes.append(start_node)
  6010. result.is_truncated = False
  6011. # BFS search main loop
  6012. while queue:
  6013. # Get all nodes at the current depth
  6014. current_level_nodes = []
  6015. current_depth = None
  6016. # Determine current depth
  6017. if queue:
  6018. current_depth = queue[0][1]
  6019. # Extract all nodes at current depth from the queue
  6020. while queue and queue[0][1] == current_depth:
  6021. node, depth = queue.popleft()
  6022. if depth > max_depth:
  6023. continue
  6024. current_level_nodes.append(node)
  6025. if not current_level_nodes:
  6026. continue
  6027. # Check depth limit
  6028. if current_depth > max_depth:
  6029. continue
  6030. # Prepare node IDs list
  6031. node_ids = [node.labels[0] for node in current_level_nodes]
  6032. formatted_ids = ", ".join(
  6033. [f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids]
  6034. )
  6035. # Build Cypher queries with dynamic dollar-quoting to handle entity_id containing $ sequences
  6036. outgoing_cypher = f"""UNWIND [{formatted_ids}] AS node_id
  6037. MATCH (n:base {{entity_id: node_id}})
  6038. OPTIONAL MATCH (n)-[r]->(neighbor:base)
  6039. RETURN node_id AS current_id,
  6040. id(n) AS current_internal_id,
  6041. id(neighbor) AS neighbor_internal_id,
  6042. neighbor.entity_id AS neighbor_id,
  6043. id(r) AS edge_id,
  6044. r,
  6045. neighbor,
  6046. true AS is_outgoing"""
  6047. incoming_cypher = f"""UNWIND [{formatted_ids}] AS node_id
  6048. MATCH (n:base {{entity_id: node_id}})
  6049. OPTIONAL MATCH (n)<-[r]-(neighbor:base)
  6050. RETURN node_id AS current_id,
  6051. id(n) AS current_internal_id,
  6052. id(neighbor) AS neighbor_internal_id,
  6053. neighbor.entity_id AS neighbor_id,
  6054. id(r) AS edge_id,
  6055. r,
  6056. neighbor,
  6057. false AS is_outgoing"""
  6058. outgoing_query = f"SELECT * FROM cypher({_dollar_quote(self.graph_name)}, {_dollar_quote(outgoing_cypher)}) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"
  6059. incoming_query = f"SELECT * FROM cypher({_dollar_quote(self.graph_name)}, {_dollar_quote(incoming_cypher)}) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"
  6060. # Execute queries
  6061. outgoing_results = await self._query(outgoing_query)
  6062. incoming_results = await self._query(incoming_query)
  6063. # Combine results
  6064. neighbors = outgoing_results + incoming_results
  6065. # Create mapping from node ID to node object
  6066. node_map = {node.labels[0]: node for node in current_level_nodes}
  6067. # Process all results in a single loop
  6068. for record in neighbors:
  6069. if not record.get("neighbor") or not record.get("r"):
  6070. continue
  6071. # Get current node information
  6072. current_entity_id = record["current_id"]
  6073. current_node = node_map[current_entity_id]
  6074. # Get neighbor node information
  6075. neighbor_entity_id = record["neighbor_id"]
  6076. neighbor_internal_id = str(record["neighbor_internal_id"])
  6077. is_outgoing = record["is_outgoing"]
  6078. # Determine edge direction
  6079. if is_outgoing:
  6080. source_id = current_node.id
  6081. target_id = neighbor_internal_id
  6082. else:
  6083. source_id = neighbor_internal_id
  6084. target_id = current_node.id
  6085. if not neighbor_entity_id:
  6086. continue
  6087. # Get edge and node information
  6088. b_node = record["neighbor"]
  6089. rel = record["r"]
  6090. edge_id = str(record["edge_id"])
  6091. # Create neighbor node object
  6092. neighbor_node = KnowledgeGraphNode(
  6093. id=neighbor_internal_id,
  6094. labels=[neighbor_entity_id],
  6095. properties=b_node["properties"],
  6096. )
  6097. # Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
  6098. sorted_pair = tuple(sorted([current_entity_id, neighbor_entity_id]))
  6099. # Create edge object
  6100. edge = KnowledgeGraphEdge(
  6101. id=edge_id,
  6102. type=rel["label"],
  6103. source=source_id,
  6104. target=target_id,
  6105. properties=rel["properties"],
  6106. )
  6107. if neighbor_internal_id in visited_node_ids:
  6108. # Add backward edge if neighbor node is already visited
  6109. if (
  6110. edge_id not in visited_edges
  6111. and sorted_pair not in visited_edge_pairs
  6112. ):
  6113. result.edges.append(edge)
  6114. visited_edges.add(edge_id)
  6115. visited_edge_pairs.add(sorted_pair)
  6116. else:
  6117. if len(visited_node_ids) < max_nodes and current_depth < max_depth:
  6118. # Add new node to result and queue
  6119. result.nodes.append(neighbor_node)
  6120. visited_nodes.add(neighbor_entity_id)
  6121. visited_node_ids.add(neighbor_internal_id)
  6122. # Add node to queue with incremented depth
  6123. queue.append((neighbor_node, current_depth + 1))
  6124. # Add forward edge
  6125. if (
  6126. edge_id not in visited_edges
  6127. and sorted_pair not in visited_edge_pairs
  6128. ):
  6129. result.edges.append(edge)
  6130. visited_edges.add(edge_id)
  6131. visited_edge_pairs.add(sorted_pair)
  6132. else:
  6133. if current_depth < max_depth:
  6134. result.is_truncated = True
  6135. return result
  6136. async def get_knowledge_graph(
  6137. self,
  6138. node_label: str,
  6139. max_depth: int = 3,
  6140. max_nodes: int = None,
  6141. ) -> KnowledgeGraph:
  6142. """
  6143. Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
  6144. Args:
  6145. node_label: Label of the starting node, * means all nodes
  6146. max_depth: Maximum depth of the subgraph, Defaults to 3
  6147. max_nodes: Maximum nodes to return, Defaults to global_config max_graph_nodes
  6148. Returns:
  6149. KnowledgeGraph object containing nodes and edges, with an is_truncated flag
  6150. indicating whether the graph was truncated due to max_nodes limit
  6151. """
  6152. # Use global_config max_graph_nodes as default if max_nodes is None
  6153. if max_nodes is None:
  6154. max_nodes = self.global_config.get("max_graph_nodes", 1000)
  6155. else:
  6156. # Limit max_nodes to not exceed global_config max_graph_nodes
  6157. max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
  6158. kg = KnowledgeGraph()
  6159. # Handle wildcard query - get all nodes
  6160. if node_label == "*":
  6161. # First check total node count to determine if graph should be truncated
  6162. count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
  6163. MATCH (n:base)
  6164. RETURN count(distinct n) AS total_nodes
  6165. $$) AS (total_nodes bigint)"""
  6166. count_result = await self._query(count_query)
  6167. total_nodes = count_result[0]["total_nodes"] if count_result else 0
  6168. is_truncated = total_nodes > max_nodes
  6169. # Get max_nodes with highest degrees
  6170. query_nodes = f"""SELECT * FROM cypher('{self.graph_name}', $$
  6171. MATCH (n:base)
  6172. OPTIONAL MATCH (n)-[r]->()
  6173. RETURN id(n) as node_id, count(r) as degree
  6174. $$) AS (node_id BIGINT, degree BIGINT)
  6175. ORDER BY degree DESC
  6176. LIMIT {max_nodes}"""
  6177. node_results = await self._query(query_nodes)
  6178. node_ids = [str(result["node_id"]) for result in node_results]
  6179. logger.info(
  6180. f"[{self.workspace}] Total nodes: {total_nodes}, Selected nodes: {len(node_ids)}"
  6181. )
  6182. if node_ids:
  6183. formatted_ids = ", ".join(node_ids)
  6184. # Construct batch query for subgraph within max_nodes
  6185. query = f"""SELECT * FROM cypher('{self.graph_name}', $$
  6186. WITH [{formatted_ids}] AS node_ids
  6187. MATCH (a)
  6188. WHERE id(a) IN node_ids
  6189. OPTIONAL MATCH (a)-[r]->(b)
  6190. WHERE id(b) IN node_ids
  6191. RETURN a, r, b
  6192. $$) AS (a AGTYPE, r AGTYPE, b AGTYPE)"""
  6193. results = await self._query(query)
  6194. # Process query results, deduplicate nodes and edges
  6195. nodes_dict = {}
  6196. edges_dict = {}
  6197. for result in results:
  6198. # Process node a
  6199. if result.get("a") and isinstance(result["a"], dict):
  6200. node_a = result["a"]
  6201. node_id = str(node_a["id"])
  6202. if node_id not in nodes_dict and "properties" in node_a:
  6203. nodes_dict[node_id] = KnowledgeGraphNode(
  6204. id=node_id,
  6205. labels=[node_a["properties"]["entity_id"]],
  6206. properties=node_a["properties"],
  6207. )
  6208. # Process node b
  6209. if result.get("b") and isinstance(result["b"], dict):
  6210. node_b = result["b"]
  6211. node_id = str(node_b["id"])
  6212. if node_id not in nodes_dict and "properties" in node_b:
  6213. nodes_dict[node_id] = KnowledgeGraphNode(
  6214. id=node_id,
  6215. labels=[node_b["properties"]["entity_id"]],
  6216. properties=node_b["properties"],
  6217. )
  6218. # Process edge r
  6219. if result.get("r") and isinstance(result["r"], dict):
  6220. edge = result["r"]
  6221. edge_id = str(edge["id"])
  6222. if edge_id not in edges_dict:
  6223. edges_dict[edge_id] = KnowledgeGraphEdge(
  6224. id=edge_id,
  6225. type=edge["label"],
  6226. source=str(edge["start_id"]),
  6227. target=str(edge["end_id"]),
  6228. properties=edge["properties"],
  6229. )
  6230. kg = KnowledgeGraph(
  6231. nodes=list(nodes_dict.values()),
  6232. edges=list(edges_dict.values()),
  6233. is_truncated=is_truncated,
  6234. )
  6235. else:
  6236. # For single node query, use BFS algorithm
  6237. kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
  6238. logger.info(
  6239. f"[{self.workspace}] Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
  6240. )
  6241. else:
  6242. # For non-wildcard queries, use the BFS algorithm
  6243. kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
  6244. logger.info(
  6245. f"[{self.workspace}] Subgraph query for '{node_label}' successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
  6246. )
  6247. return kg
  6248. async def get_all_nodes(self) -> list[dict]:
  6249. """Get all nodes in the graph.
  6250. Returns:
  6251. A list of all nodes, where each node is a dictionary of its properties
  6252. """
  6253. # Use native SQL to avoid Cypher wrapper overhead
  6254. # Original: SELECT * FROM cypher(...) with MATCH (n:base)
  6255. # Optimized: Direct table access for better performance
  6256. query = f"""
  6257. SELECT properties
  6258. FROM {self.graph_name}.base
  6259. """
  6260. results = await self._query(query)
  6261. nodes = []
  6262. for result in results:
  6263. if result.get("properties"):
  6264. node_dict = result["properties"]
  6265. # Process string result, parse it to JSON dictionary
  6266. if isinstance(node_dict, str):
  6267. try:
  6268. node_dict = json.loads(node_dict)
  6269. except json.JSONDecodeError:
  6270. logger.warning(
  6271. f"[{self.workspace}] Failed to parse node string: {node_dict}"
  6272. )
  6273. continue
  6274. # Add node id (entity_id) to the dictionary for easier access
  6275. node_dict["id"] = node_dict.get("entity_id")
  6276. nodes.append(node_dict)
  6277. return nodes
  6278. async def get_all_edges(self) -> list[dict]:
  6279. """Get all edges in the graph.
  6280. Returns:
  6281. A list of all edges, where each edge is a dictionary of its properties
  6282. (If 2 directional edges exist between the same pair of nodes, deduplication must be handled by the caller)
  6283. """
  6284. # Use native SQL to avoid Cartesian product (N×N) in Cypher MATCH
  6285. # Original Cypher: MATCH (a:base)-[r]-(b:base) creates ~50 billion row combinations
  6286. # Optimized: Start from edges table, join to nodes only to get entity_id
  6287. # Performance: O(E) instead of O(N²), ~50,000x faster for large graphs
  6288. query = f"""
  6289. SELECT DISTINCT
  6290. (ag_catalog.agtype_access_operator(VARIADIC ARRAY[a.properties, '"entity_id"'::agtype]))::text AS source,
  6291. (ag_catalog.agtype_access_operator(VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]))::text AS target,
  6292. r.properties
  6293. FROM {self.graph_name}."DIRECTED" r
  6294. JOIN {self.graph_name}.base a ON r.start_id = a.id
  6295. JOIN {self.graph_name}.base b ON r.end_id = b.id
  6296. """
  6297. results = await self._query(query)
  6298. edges = []
  6299. for result in results:
  6300. edge_properties = result["properties"]
  6301. # Process string result, parse it to JSON dictionary
  6302. if isinstance(edge_properties, str):
  6303. try:
  6304. edge_properties = json.loads(edge_properties)
  6305. except json.JSONDecodeError:
  6306. logger.warning(
  6307. f"[{self.workspace}] Failed to parse edge properties string: {edge_properties}"
  6308. )
  6309. edge_properties = {}
  6310. edge_properties["source"] = result["source"]
  6311. edge_properties["target"] = result["target"]
  6312. edges.append(edge_properties)
  6313. return edges
  6314. async def get_popular_labels(self, limit: int = 300) -> list[str]:
  6315. """Get popular labels by node degree (most connected entities) using native SQL for performance."""
  6316. try:
  6317. # Native SQL query to calculate node degrees directly from AGE's underlying tables
  6318. # This is significantly faster than using the cypher() function wrapper
  6319. query = f"""
  6320. WITH node_degrees AS (
  6321. SELECT
  6322. node_id,
  6323. COUNT(*) AS degree
  6324. FROM (
  6325. SELECT start_id AS node_id FROM {self.graph_name}._ag_label_edge
  6326. UNION ALL
  6327. SELECT end_id AS node_id FROM {self.graph_name}._ag_label_edge
  6328. ) AS all_edges
  6329. GROUP BY node_id
  6330. )
  6331. SELECT
  6332. (ag_catalog.agtype_access_operator(VARIADIC ARRAY[v.properties, '"entity_id"'::agtype]))::text AS label
  6333. FROM
  6334. node_degrees d
  6335. JOIN
  6336. {self.graph_name}._ag_label_vertex v ON d.node_id = v.id
  6337. WHERE
  6338. ag_catalog.agtype_access_operator(VARIADIC ARRAY[v.properties, '"entity_id"'::agtype]) IS NOT NULL
  6339. ORDER BY
  6340. d.degree DESC,
  6341. label ASC
  6342. LIMIT $1;
  6343. """
  6344. results = await self._query(query, params={"limit": limit})
  6345. labels = [
  6346. result["label"] for result in results if result and "label" in result
  6347. ]
  6348. logger.debug(
  6349. f"[{self.workspace}] Retrieved {len(labels)} popular labels (limit: {limit})"
  6350. )
  6351. return labels
  6352. except Exception as e:
  6353. logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}")
  6354. return []
  6355. async def search_labels(self, query: str, limit: int = 50) -> list[str]:
  6356. """Search labels with fuzzy matching using native, parameterized SQL for performance and security."""
  6357. query_lower = query.lower().strip()
  6358. if not query_lower:
  6359. return []
  6360. try:
  6361. # Re-implementing with the correct agtype access operator and full scoring logic.
  6362. sql_query = f"""
  6363. WITH ranked_labels AS (
  6364. SELECT
  6365. (ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '"entity_id"'::agtype]))::text AS label,
  6366. LOWER((ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '"entity_id"'::agtype]))::text) AS label_lower
  6367. FROM
  6368. {self.graph_name}._ag_label_vertex
  6369. WHERE
  6370. ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '"entity_id"'::agtype]) IS NOT NULL
  6371. AND LOWER((ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '"entity_id"'::agtype]))::text) ILIKE $1
  6372. )
  6373. SELECT
  6374. label
  6375. FROM (
  6376. SELECT
  6377. label,
  6378. CASE
  6379. WHEN label_lower = $2 THEN 1000
  6380. WHEN label_lower LIKE $3 THEN 500
  6381. ELSE (100 - LENGTH(label))
  6382. END +
  6383. CASE
  6384. WHEN label_lower LIKE $4 OR label_lower LIKE $5 THEN 50
  6385. ELSE 0
  6386. END AS score
  6387. FROM
  6388. ranked_labels
  6389. ) AS scored_labels
  6390. ORDER BY
  6391. score DESC,
  6392. label ASC
  6393. LIMIT $6;
  6394. """
  6395. params = (
  6396. f"%{query_lower}%", # For the main ILIKE clause ($1)
  6397. query_lower, # For exact match ($2)
  6398. f"{query_lower}%", # For prefix match ($3)
  6399. f"% {query_lower}%", # For word boundary (space) ($4)
  6400. f"%_{query_lower}%", # For word boundary (underscore) ($5)
  6401. limit, # For LIMIT ($6)
  6402. )
  6403. results = await self._query(sql_query, params=dict(enumerate(params, 1)))
  6404. labels = [
  6405. result["label"] for result in results if result and "label" in result
  6406. ]
  6407. logger.debug(
  6408. f"[{self.workspace}] Search query '{query}' returned {len(labels)} results (limit: {limit})"
  6409. )
  6410. return labels
  6411. except Exception as e:
  6412. logger.error(
  6413. f"[{self.workspace}] Error searching labels with query '{query}': {str(e)}"
  6414. )
  6415. return []
  6416. async def drop(self) -> dict[str, str]:
  6417. """Drop the storage"""
  6418. try:
  6419. drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
  6420. MATCH (n)
  6421. DETACH DELETE n
  6422. $$) AS (result agtype)"""
  6423. await self._query(drop_query, readonly=False)
  6424. return {
  6425. "status": "success",
  6426. "message": f"workspace '{self.workspace}' graph data dropped",
  6427. }
  6428. except Exception as e:
  6429. logger.error(f"[{self.workspace}] Error dropping graph: {e}")
  6430. return {"status": "error", "message": str(e)}
  6431. # Note: Order matters! More specific namespaces (e.g., "full_entities") must come before
  6432. # more general ones (e.g., "entities") because is_namespace() uses endswith() matching
  6433. NAMESPACE_TABLE_MAP = {
  6434. NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
  6435. NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
  6436. NameSpace.KV_STORE_FULL_ENTITIES: "LIGHTRAG_FULL_ENTITIES",
  6437. NameSpace.KV_STORE_FULL_RELATIONS: "LIGHTRAG_FULL_RELATIONS",
  6438. NameSpace.KV_STORE_ENTITY_CHUNKS: "LIGHTRAG_ENTITY_CHUNKS",
  6439. NameSpace.KV_STORE_RELATION_CHUNKS: "LIGHTRAG_RELATION_CHUNKS",
  6440. NameSpace.KV_STORE_LLM_RESPONSE_CACHE: "LIGHTRAG_LLM_CACHE",
  6441. NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_VDB_CHUNKS",
  6442. NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY",
  6443. NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_VDB_RELATION",
  6444. NameSpace.DOC_STATUS: "LIGHTRAG_DOC_STATUS",
  6445. }
  6446. def namespace_to_table_name(namespace: str) -> str:
  6447. for k, v in NAMESPACE_TABLE_MAP.items():
  6448. if is_namespace(namespace, k):
  6449. return v
  6450. TABLES = {
  6451. "LIGHTRAG_DOC_FULL": {
  6452. "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
  6453. id VARCHAR(255),
  6454. workspace VARCHAR(255),
  6455. doc_name VARCHAR(1024),
  6456. content TEXT,
  6457. meta JSONB,
  6458. sidecar_location TEXT NULL,
  6459. parse_format VARCHAR(32) NULL DEFAULT 'raw',
  6460. -- content_hash is TEXT (not VARCHAR(N)) so the column is
  6461. -- agnostic to the hash algorithm. Today's pipeline writes
  6462. -- 64-char SHA-256 hex; future algos (SHA-512, base64) do
  6463. -- not require a schema change.
  6464. content_hash TEXT NULL,
  6465. -- process_options is an opaque selector string emitted by
  6466. -- sanitize_process_options() (e.g. "Fi").
  6467. process_options TEXT NULL,
  6468. chunk_options JSONB NULL DEFAULT '{}'::jsonb,
  6469. parse_engine VARCHAR(32) NULL,
  6470. create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
  6471. update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
  6472. CONSTRAINT LIGHTRAG_DOC_FULL_PK PRIMARY KEY (workspace, id)
  6473. )"""
  6474. },
  6475. "LIGHTRAG_DOC_CHUNKS": {
  6476. "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
  6477. id VARCHAR(255),
  6478. workspace VARCHAR(255),
  6479. full_doc_id VARCHAR(256),
  6480. chunk_order_index INTEGER,
  6481. tokens INTEGER,
  6482. content TEXT,
  6483. file_path TEXT NULL,
  6484. llm_cache_list JSONB NULL DEFAULT '[]'::jsonb,
  6485. heading JSONB NULL DEFAULT '{}'::jsonb,
  6486. sidecar JSONB NULL DEFAULT '{}'::jsonb,
  6487. create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
  6488. update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
  6489. CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
  6490. )"""
  6491. },
  6492. "LIGHTRAG_VDB_CHUNKS": {
  6493. "ddl": """CREATE TABLE LIGHTRAG_VDB_CHUNKS (
  6494. id VARCHAR(255),
  6495. workspace VARCHAR(255),
  6496. full_doc_id VARCHAR(256),
  6497. chunk_order_index INTEGER,
  6498. tokens INTEGER,
  6499. content TEXT,
  6500. content_vector VECTOR(dimension),
  6501. file_path TEXT NULL,
  6502. create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
  6503. update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
  6504. CONSTRAINT LIGHTRAG_VDB_CHUNKS_PK PRIMARY KEY (workspace, id)
  6505. )"""
  6506. },
  6507. "LIGHTRAG_VDB_ENTITY": {
  6508. "ddl": """CREATE TABLE LIGHTRAG_VDB_ENTITY (
  6509. id VARCHAR(255),
  6510. workspace VARCHAR(255),
  6511. entity_name VARCHAR(512),
  6512. content TEXT,
  6513. content_vector VECTOR(dimension),
  6514. create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
  6515. update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
  6516. chunk_ids VARCHAR(255)[] NULL,
  6517. file_path TEXT NULL,
  6518. CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
  6519. )"""
  6520. },
  6521. "LIGHTRAG_VDB_RELATION": {
  6522. "ddl": """CREATE TABLE LIGHTRAG_VDB_RELATION (
  6523. id VARCHAR(255),
  6524. workspace VARCHAR(255),
  6525. source_id VARCHAR(512),
  6526. target_id VARCHAR(512),
  6527. content TEXT,
  6528. content_vector VECTOR(dimension),
  6529. create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
  6530. update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
  6531. chunk_ids VARCHAR(255)[] NULL,
  6532. file_path TEXT NULL,
  6533. CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
  6534. )"""
  6535. },
  6536. "LIGHTRAG_LLM_CACHE": {
  6537. "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
  6538. workspace varchar(255) NOT NULL,
  6539. id varchar(255) NOT NULL,
  6540. original_prompt TEXT,
  6541. return_value TEXT,
  6542. chunk_id VARCHAR(255) NULL,
  6543. cache_type VARCHAR(32),
  6544. queryparam JSONB NULL,
  6545. create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  6546. update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  6547. CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, id)
  6548. )"""
  6549. },
  6550. "LIGHTRAG_DOC_STATUS": {
  6551. "ddl": """CREATE TABLE LIGHTRAG_DOC_STATUS (
  6552. workspace varchar(255) NOT NULL,
  6553. id varchar(255) NOT NULL,
  6554. content_summary varchar(255) NULL,
  6555. content_length int4 NULL,
  6556. chunks_count int4 NULL,
  6557. status varchar(64) NULL,
  6558. file_path TEXT NULL,
  6559. chunks_list JSONB NULL DEFAULT '[]'::jsonb,
  6560. track_id varchar(255) NULL,
  6561. metadata JSONB NULL DEFAULT '{}'::jsonb,
  6562. error_msg TEXT NULL,
  6563. -- content_hash is TEXT (not VARCHAR(N)) so the column is
  6564. -- agnostic to the hash algorithm. Today's pipeline writes
  6565. -- 64-char SHA-256 hex; future algos (SHA-512, base64) do
  6566. -- not require a schema change.
  6567. content_hash TEXT NULL,
  6568. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  6569. updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  6570. CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id)
  6571. )"""
  6572. },
  6573. "LIGHTRAG_FULL_ENTITIES": {
  6574. "ddl": """CREATE TABLE LIGHTRAG_FULL_ENTITIES (
  6575. id VARCHAR(255),
  6576. workspace VARCHAR(255),
  6577. entity_names JSONB,
  6578. count INTEGER,
  6579. create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
  6580. update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
  6581. CONSTRAINT LIGHTRAG_FULL_ENTITIES_PK PRIMARY KEY (workspace, id)
  6582. )"""
  6583. },
  6584. "LIGHTRAG_FULL_RELATIONS": {
  6585. "ddl": """CREATE TABLE LIGHTRAG_FULL_RELATIONS (
  6586. id VARCHAR(255),
  6587. workspace VARCHAR(255),
  6588. relation_pairs JSONB,
  6589. count INTEGER,
  6590. create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
  6591. update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
  6592. CONSTRAINT LIGHTRAG_FULL_RELATIONS_PK PRIMARY KEY (workspace, id)
  6593. )"""
  6594. },
  6595. "LIGHTRAG_ENTITY_CHUNKS": {
  6596. "ddl": """CREATE TABLE LIGHTRAG_ENTITY_CHUNKS (
  6597. id VARCHAR(512),
  6598. workspace VARCHAR(255),
  6599. chunk_ids JSONB,
  6600. count INTEGER,
  6601. create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
  6602. update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
  6603. CONSTRAINT LIGHTRAG_ENTITY_CHUNKS_PK PRIMARY KEY (workspace, id)
  6604. )"""
  6605. },
  6606. "LIGHTRAG_RELATION_CHUNKS": {
  6607. "ddl": """CREATE TABLE LIGHTRAG_RELATION_CHUNKS (
  6608. id VARCHAR(512),
  6609. workspace VARCHAR(255),
  6610. chunk_ids JSONB,
  6611. count INTEGER,
  6612. create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
  6613. update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
  6614. CONSTRAINT LIGHTRAG_RELATION_CHUNKS_PK PRIMARY KEY (workspace, id)
  6615. )"""
  6616. },
  6617. }
  6618. SQL_TEMPLATES = {
  6619. # SQL for KVStorage
  6620. "get_by_id_full_docs": """SELECT id, COALESCE(content, '') as content,
  6621. COALESCE(doc_name, '') as file_path,
  6622. sidecar_location,
  6623. parse_format,
  6624. content_hash,
  6625. process_options,
  6626. COALESCE(chunk_options, '{}'::jsonb) as chunk_options,
  6627. parse_engine
  6628. FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2
  6629. """,
  6630. "get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
  6631. chunk_order_index, full_doc_id, file_path,
  6632. COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
  6633. COALESCE(heading, '{}'::jsonb) as heading,
  6634. COALESCE(sidecar, '{}'::jsonb) as sidecar,
  6635. EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
  6636. EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
  6637. FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
  6638. """,
  6639. "get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam,
  6640. EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
  6641. EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
  6642. FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
  6643. """,
  6644. "get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content,
  6645. COALESCE(doc_name, '') as file_path,
  6646. sidecar_location,
  6647. parse_format,
  6648. content_hash,
  6649. process_options,
  6650. COALESCE(chunk_options, '{}'::jsonb) as chunk_options,
  6651. parse_engine
  6652. FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id = ANY($2)
  6653. """,
  6654. "get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
  6655. chunk_order_index, full_doc_id, file_path,
  6656. COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
  6657. COALESCE(heading, '{}'::jsonb) as heading,
  6658. COALESCE(sidecar, '{}'::jsonb) as sidecar,
  6659. EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
  6660. EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
  6661. FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id = ANY($2)
  6662. """,
  6663. "get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam,
  6664. EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
  6665. EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
  6666. FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id = ANY($2)
  6667. """,
  6668. "get_by_id_full_entities": """SELECT id, entity_names, count,
  6669. EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
  6670. EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
  6671. FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id=$2
  6672. """,
  6673. "get_by_id_full_relations": """SELECT id, relation_pairs, count,
  6674. EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
  6675. EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
  6676. FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id=$2
  6677. """,
  6678. "get_by_ids_full_entities": """SELECT id, entity_names, count,
  6679. EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
  6680. EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
  6681. FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id = ANY($2)
  6682. """,
  6683. "get_by_ids_full_relations": """SELECT id, relation_pairs, count,
  6684. EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
  6685. EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
  6686. FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id = ANY($2)
  6687. """,
  6688. "get_by_id_entity_chunks": """SELECT id, chunk_ids, count,
  6689. EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
  6690. EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
  6691. FROM LIGHTRAG_ENTITY_CHUNKS WHERE workspace=$1 AND id=$2
  6692. """,
  6693. "get_by_id_relation_chunks": """SELECT id, chunk_ids, count,
  6694. EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
  6695. EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
  6696. FROM LIGHTRAG_RELATION_CHUNKS WHERE workspace=$1 AND id=$2
  6697. """,
  6698. "get_by_ids_entity_chunks": """SELECT id, chunk_ids, count,
  6699. EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
  6700. EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
  6701. FROM LIGHTRAG_ENTITY_CHUNKS WHERE workspace=$1 AND id = ANY($2)
  6702. """,
  6703. "get_by_ids_relation_chunks": """SELECT id, chunk_ids, count,
  6704. EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
  6705. EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
  6706. FROM LIGHTRAG_RELATION_CHUNKS WHERE workspace=$1 AND id = ANY($2)
  6707. """,
  6708. "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
  6709. # Pipeline-derived columns (sidecar_location / parse_format / content_hash /
  6710. # process_options / chunk_options / parse_engine) are guarded with COALESCE
  6711. # so a partial upsert (e.g. a caller writing only ``content`` + ``doc_name``)
  6712. # does not silently overwrite metadata recorded by _persist_parsed_full_docs.
  6713. # ``content`` and ``doc_name`` themselves are always overwritten — they are
  6714. # the primary payload, never a candidate for preservation.
  6715. # For the string columns we use NULLIF('', ...) so that an empty string from
  6716. # a default-bearing caller is treated as "no value, preserve existing".
  6717. # For chunk_options (JSONB) we treat NULL or the empty-object literal as
  6718. # "no value, preserve existing".
  6719. "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace,
  6720. sidecar_location, parse_format, content_hash,
  6721. process_options, chunk_options, parse_engine)
  6722. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
  6723. ON CONFLICT (workspace,id) DO UPDATE
  6724. SET content = EXCLUDED.content,
  6725. doc_name = EXCLUDED.doc_name,
  6726. sidecar_location = COALESCE(
  6727. NULLIF(EXCLUDED.sidecar_location, ''),
  6728. LIGHTRAG_DOC_FULL.sidecar_location
  6729. ),
  6730. parse_format = COALESCE(
  6731. NULLIF(EXCLUDED.parse_format, ''),
  6732. LIGHTRAG_DOC_FULL.parse_format
  6733. ),
  6734. content_hash = COALESCE(
  6735. NULLIF(EXCLUDED.content_hash, ''),
  6736. LIGHTRAG_DOC_FULL.content_hash
  6737. ),
  6738. process_options = COALESCE(
  6739. NULLIF(EXCLUDED.process_options, ''),
  6740. LIGHTRAG_DOC_FULL.process_options
  6741. ),
  6742. chunk_options = CASE
  6743. WHEN EXCLUDED.chunk_options IS NULL
  6744. OR EXCLUDED.chunk_options = '{}'::jsonb
  6745. THEN LIGHTRAG_DOC_FULL.chunk_options
  6746. ELSE EXCLUDED.chunk_options
  6747. END,
  6748. parse_engine = COALESCE(
  6749. NULLIF(EXCLUDED.parse_engine, ''),
  6750. LIGHTRAG_DOC_FULL.parse_engine
  6751. ),
  6752. update_time = CURRENT_TIMESTAMP
  6753. """,
  6754. "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,chunk_id,cache_type,queryparam)
  6755. VALUES ($1, $2, $3, $4, $5, $6, $7)
  6756. ON CONFLICT (workspace,id) DO UPDATE
  6757. SET original_prompt = EXCLUDED.original_prompt,
  6758. return_value=EXCLUDED.return_value,
  6759. chunk_id=EXCLUDED.chunk_id,
  6760. cache_type=EXCLUDED.cache_type,
  6761. queryparam=EXCLUDED.queryparam,
  6762. update_time = CURRENT_TIMESTAMP
  6763. """,
  6764. "upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
  6765. chunk_order_index, full_doc_id, content, file_path, llm_cache_list,
  6766. heading, sidecar, create_time, update_time)
  6767. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
  6768. ON CONFLICT (workspace,id) DO UPDATE
  6769. SET tokens=EXCLUDED.tokens,
  6770. chunk_order_index=EXCLUDED.chunk_order_index,
  6771. full_doc_id=EXCLUDED.full_doc_id,
  6772. content = EXCLUDED.content,
  6773. file_path=EXCLUDED.file_path,
  6774. llm_cache_list=EXCLUDED.llm_cache_list,
  6775. heading=EXCLUDED.heading,
  6776. sidecar=EXCLUDED.sidecar,
  6777. update_time = EXCLUDED.update_time
  6778. """,
  6779. "upsert_full_entities": """INSERT INTO LIGHTRAG_FULL_ENTITIES (workspace, id, entity_names, count,
  6780. create_time, update_time)
  6781. VALUES ($1, $2, $3, $4, $5, $6)
  6782. ON CONFLICT (workspace,id) DO UPDATE
  6783. SET entity_names=EXCLUDED.entity_names,
  6784. count=EXCLUDED.count,
  6785. update_time = EXCLUDED.update_time
  6786. """,
  6787. "upsert_full_relations": """INSERT INTO LIGHTRAG_FULL_RELATIONS (workspace, id, relation_pairs, count,
  6788. create_time, update_time)
  6789. VALUES ($1, $2, $3, $4, $5, $6)
  6790. ON CONFLICT (workspace,id) DO UPDATE
  6791. SET relation_pairs=EXCLUDED.relation_pairs,
  6792. count=EXCLUDED.count,
  6793. update_time = EXCLUDED.update_time
  6794. """,
  6795. "upsert_entity_chunks": """INSERT INTO LIGHTRAG_ENTITY_CHUNKS (workspace, id, chunk_ids, count,
  6796. create_time, update_time)
  6797. VALUES ($1, $2, $3, $4, $5, $6)
  6798. ON CONFLICT (workspace,id) DO UPDATE
  6799. SET chunk_ids=EXCLUDED.chunk_ids,
  6800. count=EXCLUDED.count,
  6801. update_time = EXCLUDED.update_time
  6802. """,
  6803. "upsert_relation_chunks": """INSERT INTO LIGHTRAG_RELATION_CHUNKS (workspace, id, chunk_ids, count,
  6804. create_time, update_time)
  6805. VALUES ($1, $2, $3, $4, $5, $6)
  6806. ON CONFLICT (workspace,id) DO UPDATE
  6807. SET chunk_ids=EXCLUDED.chunk_ids,
  6808. count=EXCLUDED.count,
  6809. update_time = EXCLUDED.update_time
  6810. """,
  6811. # SQL for VectorStorage
  6812. "upsert_chunk": """INSERT INTO {table_name} (workspace, id, tokens,
  6813. chunk_order_index, full_doc_id, content, content_vector, file_path,
  6814. create_time, update_time)
  6815. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
  6816. ON CONFLICT (workspace,id) DO UPDATE
  6817. SET tokens=EXCLUDED.tokens,
  6818. chunk_order_index=EXCLUDED.chunk_order_index,
  6819. full_doc_id=EXCLUDED.full_doc_id,
  6820. content = EXCLUDED.content,
  6821. content_vector=EXCLUDED.content_vector,
  6822. file_path=EXCLUDED.file_path,
  6823. update_time = EXCLUDED.update_time
  6824. """,
  6825. "upsert_entity": """INSERT INTO {table_name} (workspace, id, entity_name, content,
  6826. content_vector, chunk_ids, file_path, create_time, update_time)
  6827. VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9)
  6828. ON CONFLICT (workspace,id) DO UPDATE
  6829. SET entity_name=EXCLUDED.entity_name,
  6830. content=EXCLUDED.content,
  6831. content_vector=EXCLUDED.content_vector,
  6832. chunk_ids=EXCLUDED.chunk_ids,
  6833. file_path=EXCLUDED.file_path,
  6834. update_time=EXCLUDED.update_time
  6835. """,
  6836. "upsert_relationship": """INSERT INTO {table_name} (workspace, id, source_id,
  6837. target_id, content, content_vector, chunk_ids, file_path, create_time, update_time)
  6838. VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8, $9, $10)
  6839. ON CONFLICT (workspace,id) DO UPDATE
  6840. SET source_id=EXCLUDED.source_id,
  6841. target_id=EXCLUDED.target_id,
  6842. content=EXCLUDED.content,
  6843. content_vector=EXCLUDED.content_vector,
  6844. chunk_ids=EXCLUDED.chunk_ids,
  6845. file_path=EXCLUDED.file_path,
  6846. update_time = EXCLUDED.update_time
  6847. """,
  6848. "relationships": """
  6849. SELECT source_id AS src_id,
  6850. target_id AS tgt_id,
  6851. EXTRACT(EPOCH FROM create_time)::BIGINT AS created_at
  6852. FROM {table_name}
  6853. WHERE workspace = $1
  6854. AND content_vector <=> $4::{vector_cast} < $2
  6855. ORDER BY content_vector <=> $4::{vector_cast}
  6856. LIMIT $3;
  6857. """,
  6858. "entities": """
  6859. SELECT entity_name,
  6860. EXTRACT(EPOCH FROM create_time)::BIGINT AS created_at
  6861. FROM {table_name}
  6862. WHERE workspace = $1
  6863. AND content_vector <=> $4::{vector_cast} < $2
  6864. ORDER BY content_vector <=> $4::{vector_cast}
  6865. LIMIT $3;
  6866. """,
  6867. "chunks": """
  6868. SELECT id,
  6869. content,
  6870. file_path,
  6871. EXTRACT(EPOCH FROM create_time)::BIGINT AS created_at
  6872. FROM {table_name}
  6873. WHERE workspace = $1
  6874. AND content_vector <=> $4::{vector_cast} < $2
  6875. ORDER BY content_vector <=> $4::{vector_cast}
  6876. LIMIT $3;
  6877. """,
  6878. # DROP tables
  6879. "drop_specifiy_table_workspace": """
  6880. DELETE FROM {table_name} WHERE workspace=$1
  6881. """,
  6882. }