| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412441344144415441644174418441944204421442244234424442544264427442844294430443144324433443444354436443744384439444044414442444344444445444644474448444944504451445244534454445544564457445844594460446144624463446444654466446744684469447044714472447344744475447644774478447944804481448244834484448544864487448844894490449144924493449444954496449744984499450045014502450345044505450645074508450945104511451245134514451545164517451845194520452145224523452445254526452745284529453045314532453345344535453645374538453945404541454245434544454545464547454845494550455145524553455445554556455745584559456045614562456345644565456645674568456945704571457245734574457545764577457845794580458145824583458445854586458745884589459045914592459345944595459645974598459946004601460246034604460546064607460846094610461146124613461446154616461746184619462046214622462346244625462646274628462946304631463246334634463546364637463846394640464146424643464446454646464746484649465046514652465346544655465646574658465946604661466246634664466546664667466846694670467146724673467446754676467746784679468046814682468346844685468646874688468946904691469246934694469546964697469846994700470147024703470447054706470747084709471047114712471347144715471647174718471947204721472247234724472547264727472847294730473147324733473447354736473747384739474047414742474347444745474647474748474947504751475247534754475547564757475847594760476147624763476447654766476747684769477047714772477347744775477647774778477947804781478247834784478547864787478847894790479147924793479447954796479747984799480048014802480348044805480648074808480948104811481248134814481548164817481848194820482148224823482448254826482748284829483048314832483348344835483648374838483948404841484248434844484548464847484848494850485148524853485448554856485748584859486048614862486348644865486648674868486948704871487248734874487548764877487848794880488148824883488448854886488748884889489048914892489348944895489648974898489949004901490249034904490549064907490849094910491149124913491449154916491749184919492049214922492349244925492649274928492949304931493249334934493549364937493849394940494149424943494449454946494749484949495049514952495349544955495649574958495949604961496249634964496549664967496849694970497149724973497449754976497749784979498049814982498349844985498649874988498949904991499249934994499549964997499849995000500150025003500450055006500750085009501050115012501350145015501650175018501950205021502250235024502550265027502850295030503150325033503450355036503750385039504050415042504350445045504650475048504950505051505250535054505550565057505850595060506150625063506450655066506750685069507050715072507350745075507650775078507950805081508250835084508550865087508850895090509150925093509450955096509750985099510051015102510351045105510651075108510951105111511251135114511551165117511851195120512151225123512451255126512751285129513051315132513351345135513651375138513951405141514251435144514551465147514851495150515151525153515451555156515751585159516051615162516351645165516651675168516951705171517251735174517551765177517851795180518151825183518451855186518751885189519051915192519351945195519651975198519952005201520252035204520552065207520852095210521152125213521452155216521752185219522052215222522352245225522652275228522952305231523252335234523552365237523852395240524152425243524452455246524752485249525052515252525352545255525652575258525952605261526252635264526552665267526852695270527152725273527452755276527752785279528052815282528352845285528652875288528952905291529252935294529552965297529852995300530153025303530453055306530753085309531053115312531353145315531653175318531953205321532253235324532553265327532853295330533153325333533453355336533753385339534053415342534353445345534653475348534953505351535253535354535553565357535853595360536153625363536453655366536753685369537053715372537353745375537653775378537953805381538253835384538553865387538853895390539153925393539453955396539753985399540054015402540354045405540654075408540954105411541254135414541554165417541854195420542154225423542454255426542754285429543054315432543354345435543654375438543954405441544254435444544554465447544854495450545154525453545454555456545754585459546054615462546354645465546654675468546954705471547254735474547554765477547854795480548154825483548454855486548754885489549054915492549354945495549654975498549955005501550255035504550555065507550855095510551155125513551455155516551755185519552055215522552355245525552655275528552955305531553255335534553555365537553855395540554155425543554455455546554755485549555055515552555355545555555655575558555955605561556255635564556555665567556855695570557155725573557455755576557755785579558055815582558355845585558655875588558955905591559255935594559555965597559855995600560156025603560456055606560756085609561056115612561356145615561656175618561956205621562256235624562556265627562856295630563156325633563456355636563756385639564056415642564356445645564656475648564956505651565256535654565556565657565856595660566156625663566456655666566756685669567056715672567356745675567656775678567956805681568256835684568556865687568856895690569156925693569456955696569756985699570057015702570357045705570657075708570957105711571257135714571557165717571857195720572157225723572457255726572757285729573057315732573357345735573657375738573957405741574257435744574557465747574857495750575157525753575457555756575757585759576057615762576357645765576657675768576957705771577257735774577557765777577857795780578157825783578457855786578757885789579057915792579357945795579657975798579958005801580258035804580558065807580858095810581158125813581458155816581758185819582058215822582358245825582658275828582958305831583258335834583558365837583858395840584158425843584458455846584758485849585058515852585358545855585658575858585958605861586258635864586558665867586858695870587158725873587458755876587758785879588058815882588358845885588658875888588958905891589258935894589558965897589858995900590159025903590459055906590759085909591059115912591359145915591659175918591959205921592259235924592559265927592859295930593159325933593459355936593759385939594059415942594359445945594659475948594959505951595259535954595559565957595859595960596159625963596459655966596759685969597059715972597359745975597659775978597959805981598259835984598559865987598859895990599159925993599459955996599759985999600060016002600360046005600660076008600960106011601260136014601560166017601860196020602160226023602460256026602760286029603060316032603360346035603660376038603960406041604260436044604560466047604860496050605160526053605460556056605760586059606060616062606360646065606660676068606960706071607260736074607560766077607860796080608160826083608460856086608760886089609060916092609360946095609660976098609961006101610261036104610561066107610861096110611161126113611461156116611761186119612061216122612361246125612661276128612961306131613261336134613561366137613861396140614161426143614461456146614761486149615061516152615361546155615661576158615961606161616261636164616561666167616861696170617161726173617461756176617761786179618061816182618361846185618661876188618961906191619261936194619561966197619861996200620162026203620462056206620762086209621062116212621362146215621662176218621962206221622262236224622562266227622862296230623162326233623462356236623762386239624062416242624362446245624662476248624962506251625262536254625562566257625862596260626162626263626462656266626762686269627062716272627362746275627662776278627962806281628262836284628562866287628862896290629162926293629462956296629762986299630063016302630363046305630663076308630963106311631263136314631563166317631863196320632163226323632463256326632763286329633063316332633363346335633663376338633963406341634263436344634563466347634863496350635163526353635463556356635763586359636063616362636363646365636663676368636963706371637263736374637563766377637863796380638163826383638463856386638763886389639063916392639363946395639663976398639964006401640264036404640564066407640864096410641164126413641464156416641764186419642064216422642364246425642664276428642964306431643264336434643564366437643864396440644164426443644464456446644764486449645064516452645364546455645664576458645964606461646264636464646564666467646864696470647164726473647464756476647764786479648064816482648364846485648664876488648964906491649264936494649564966497649864996500650165026503650465056506650765086509651065116512651365146515651665176518651965206521652265236524652565266527652865296530653165326533653465356536653765386539654065416542654365446545654665476548654965506551655265536554655565566557655865596560656165626563656465656566656765686569657065716572657365746575657665776578657965806581658265836584658565866587658865896590659165926593659465956596659765986599660066016602660366046605660666076608660966106611661266136614661566166617661866196620662166226623662466256626662766286629663066316632663366346635663666376638663966406641664266436644664566466647664866496650665166526653665466556656665766586659666066616662666366646665666666676668666966706671667266736674667566766677667866796680668166826683668466856686668766886689669066916692669366946695669666976698669967006701670267036704670567066707670867096710671167126713671467156716671767186719672067216722672367246725672667276728672967306731673267336734673567366737673867396740674167426743674467456746674767486749675067516752675367546755675667576758675967606761676267636764676567666767676867696770677167726773677467756776677767786779678067816782678367846785678667876788678967906791679267936794679567966797679867996800680168026803680468056806680768086809681068116812681368146815681668176818681968206821682268236824682568266827682868296830683168326833683468356836683768386839684068416842684368446845684668476848684968506851685268536854685568566857685868596860686168626863686468656866686768686869687068716872687368746875687668776878687968806881688268836884688568866887688868896890689168926893689468956896689768986899690069016902690369046905690669076908690969106911691269136914691569166917691869196920692169226923692469256926692769286929693069316932693369346935693669376938693969406941694269436944694569466947694869496950695169526953695469556956695769586959696069616962696369646965696669676968696969706971697269736974697569766977697869796980698169826983698469856986698769886989699069916992699369946995699669976998699970007001700270037004700570067007700870097010701170127013701470157016701770187019702070217022702370247025702670277028702970307031703270337034703570367037703870397040704170427043704470457046704770487049705070517052705370547055705670577058705970607061706270637064706570667067706870697070707170727073707470757076707770787079708070817082708370847085708670877088708970907091709270937094709570967097709870997100710171027103710471057106710771087109711071117112711371147115711671177118711971207121712271237124712571267127712871297130713171327133713471357136713771387139714071417142714371447145714671477148714971507151715271537154715571567157715871597160716171627163716471657166716771687169717071717172717371747175717671777178717971807181718271837184718571867187718871897190719171927193719471957196719771987199720072017202720372047205720672077208720972107211721272137214721572167217721872197220722172227223722472257226722772287229723072317232723372347235723672377238723972407241724272437244724572467247724872497250725172527253725472557256725772587259726072617262726372647265726672677268726972707271727272737274727572767277727872797280728172827283728472857286728772887289729072917292729372947295729672977298729973007301730273037304730573067307730873097310731173127313731473157316731773187319732073217322732373247325732673277328732973307331733273337334733573367337733873397340734173427343734473457346734773487349735073517352735373547355735673577358735973607361736273637364736573667367736873697370737173727373737473757376737773787379738073817382738373847385738673877388738973907391739273937394739573967397739873997400740174027403740474057406740774087409741074117412741374147415741674177418741974207421742274237424742574267427742874297430743174327433743474357436743774387439744074417442744374447445744674477448744974507451745274537454745574567457745874597460746174627463746474657466746774687469747074717472747374747475747674777478747974807481748274837484748574867487748874897490749174927493749474957496749774987499750075017502750375047505750675077508750975107511751275137514751575167517751875197520752175227523752475257526752775287529753075317532753375347535753675377538753975407541754275437544754575467547754875497550755175527553755475557556755775587559756075617562756375647565756675677568756975707571757275737574757575767577757875797580758175827583758475857586758775887589759075917592759375947595759675977598759976007601760276037604760576067607760876097610761176127613761476157616761776187619762076217622762376247625762676277628762976307631763276337634763576367637763876397640764176427643764476457646764776487649765076517652765376547655765676577658765976607661766276637664766576667667766876697670767176727673767476757676767776787679768076817682768376847685768676877688768976907691 |
- import asyncio
- import time
- import hashlib
- import json
- import os
- import re
- import datetime
- from datetime import timezone
- from dataclasses import dataclass, field
- from typing import Any, Awaitable, Callable, TypeVar, Union, final
- import numpy as np
- import configparser
- import ssl
- import itertools
- from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
- from tenacity import (
- AsyncRetrying,
- RetryCallState,
- retry,
- retry_if_exception,
- retry_if_exception_type,
- stop_after_attempt,
- wait_exponential,
- wait_fixed,
- )
- from ..base import (
- BaseGraphStorage,
- BaseKVStorage,
- BaseVectorStorage,
- DocProcessingStatus,
- DocStatus,
- DocStatusStorage,
- )
- from ..exceptions import DataMigrationError
- from ..namespace import NameSpace, is_namespace
- from ..utils import (
- logger,
- compute_mdhash_id,
- _cooperative_yield,
- performance_timing_log,
- )
- from ..kg.shared_storage import get_data_init_lock, get_namespace_lock
- import pipmaster as pm
- if not pm.is_installed("asyncpg"):
- pm.install("asyncpg")
- if not pm.is_installed("pgvector"):
- pm.install("pgvector")
- import asyncpg # type: ignore
- from asyncpg import Pool # type: ignore
- from pgvector.asyncpg import register_vector # type: ignore
- from dotenv import load_dotenv
- # use the .env that is inside the current folder
- # allows to use different .env file for each lightrag instance
- # the OS environment variables take precedence over the .env file
- load_dotenv(dotenv_path=".env", override=False)
- T = TypeVar("T")
- # PostgreSQL identifier length limit (in bytes)
- PG_MAX_IDENTIFIER_LENGTH = 63
- # All known vector index suffixes, used to drop conflicting indexes when switching types
- _VECTOR_INDEX_SUFFIXES = [
- "hnsw_cosine",
- "hnsw_halfvec_cosine",
- "ivfflat_cosine",
- "vchordrq_cosine",
- ]
- def _safe_index_name(table_name: str, index_suffix: str) -> str:
- """
- Generate a PostgreSQL-safe index name that won't be truncated.
- PostgreSQL silently truncates identifiers to 63 bytes. This function
- ensures index names stay within that limit by hashing long table names.
- Args:
- table_name: The table name (may be long with model suffix)
- index_suffix: The index type suffix (e.g., 'hnsw_cosine', 'id', 'workspace_id')
- Returns:
- A deterministic index name that fits within 63 bytes
- """
- # Construct the full index name
- full_name = f"idx_{table_name.lower()}_{index_suffix}"
- # If it fits within the limit, use it as-is
- if len(full_name.encode("utf-8")) <= PG_MAX_IDENTIFIER_LENGTH:
- return full_name
- # Otherwise, hash the table name to create a shorter unique identifier
- # Keep 'idx_' prefix and suffix readable, hash the middle
- hash_input = table_name.lower().encode("utf-8")
- table_hash = hashlib.md5(hash_input).hexdigest()[:12] # 12 hex chars
- # Format: idx_{hash}_{suffix} - guaranteed to fit
- # Maximum: idx_ (4) + hash (12) + _ (1) + suffix (variable) = 17 + suffix
- shortened_name = f"idx_{table_hash}_{index_suffix}"
- return shortened_name
- def _timing_details_suffix(**details: Any) -> str:
- parts = [f"{key}={value}" for key, value in details.items()]
- return f" {' '.join(parts)}" if parts else ""
- def _dollar_quote(s: str, tag_prefix: str = "AGE") -> str:
- """
- Generate a PostgreSQL dollar-quoted string with a unique tag.
- PostgreSQL dollar-quoting uses $tag$ as delimiters. If the content contains
- the same delimiter (e.g., $$ or $AGE1$), it will break the query.
- This function finds a unique tag that doesn't conflict with the content.
- Args:
- s: The string to quote
- tag_prefix: Prefix for generating unique tags (default: "AGE")
- Returns:
- The dollar-quoted string with a unique tag, e.g., $AGE1$content$AGE1$
- Example:
- >>> _dollar_quote("hello")
- '$AGE1$hello$AGE1$'
- >>> _dollar_quote("$AGE1$ test")
- '$AGE2$$AGE1$ test$AGE2$'
- >>> _dollar_quote("$$$") # Content with dollar signs
- '$AGE1$$$$AGE1$'
- """
- s = "" if s is None else str(s)
- for i in itertools.count(1):
- tag = f"{tag_prefix}{i}"
- wrapper = f"${tag}$"
- if wrapper not in s:
- return f"{wrapper}{s}{wrapper}"
- class PostgreSQLDB:
- def __init__(self, config: dict[str, Any], **kwargs: Any):
- self.host = config["host"]
- self.port = config["port"]
- self.user = config["user"]
- self.password = config["password"]
- self.database = config["database"]
- self.workspace = config["workspace"]
- self.max = int(config["max_connections"])
- self.increment = 1
- self.pool: Pool | None = None
- # SSL configuration
- self.ssl_mode = config.get("ssl_mode")
- self.ssl_cert = config.get("ssl_cert")
- self.ssl_key = config.get("ssl_key")
- self.ssl_root_cert = config.get("ssl_root_cert")
- self.ssl_crl = config.get("ssl_crl")
- # Vector configuration
- _ev = config.get("enable_vector", True)
- self.enable_vector = (
- _ev
- if isinstance(_ev, bool)
- else str(_ev).lower() in ("true", "1", "yes", "on")
- ) # True for backward compatibility, can be set to False to disable vector features
- self.vector_index_type = config.get("vector_index_type")
- self.hnsw_m = config.get("hnsw_m")
- self.hnsw_ef = config.get("hnsw_ef")
- self.ivfflat_lists = config.get("ivfflat_lists")
- self.vchordrq_build_options = config.get("vchordrq_build_options")
- self.vchordrq_probes = config.get("vchordrq_probes")
- self.vchordrq_epsilon = config.get("vchordrq_epsilon")
- # Server settings
- self.server_settings = config.get("server_settings")
- # Statement LRU cache size (keep as-is, allow None for optional configuration)
- self.statement_cache_size = config.get("statement_cache_size")
- if self.user is None or self.password is None or self.database is None:
- raise ValueError("Missing database user, password, or database")
- # Guard concurrent pool resets
- self._pool_reconnect_lock = asyncio.Lock()
- self._transient_exceptions = (
- asyncio.TimeoutError,
- TimeoutError,
- ConnectionError,
- OSError,
- asyncpg.exceptions.InterfaceError,
- asyncpg.exceptions.TooManyConnectionsError,
- asyncpg.exceptions.CannotConnectNowError,
- asyncpg.exceptions.PostgresConnectionError,
- asyncpg.exceptions.ConnectionDoesNotExistError,
- asyncpg.exceptions.ConnectionFailureError,
- )
- # Connection retry configuration
- self.connection_retry_attempts = config["connection_retry_attempts"]
- self.connection_retry_backoff = config["connection_retry_backoff"]
- self.connection_retry_backoff_max = max(
- self.connection_retry_backoff,
- config["connection_retry_backoff_max"],
- )
- self.pool_close_timeout = config["pool_close_timeout"]
- logger.info(
- "PostgreSQL, Retry config: attempts=%s, backoff=%.1fs, backoff_max=%.1fs, pool_close_timeout=%.1fs",
- self.connection_retry_attempts,
- self.connection_retry_backoff,
- self.connection_retry_backoff_max,
- self.pool_close_timeout,
- )
- def _create_ssl_context(self) -> ssl.SSLContext | None:
- """Create SSL context based on configuration parameters."""
- if not self.ssl_mode:
- return None
- ssl_mode = self.ssl_mode.lower()
- # For simple modes that don't require custom context
- if ssl_mode in ["disable", "allow", "prefer", "require"]:
- if ssl_mode == "disable":
- return None
- elif ssl_mode in ["require", "prefer", "allow"]:
- # Return None for simple SSL requirement, handled in initdb
- return None
- # For modes that require certificate verification
- if ssl_mode in ["verify-ca", "verify-full"]:
- try:
- context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
- # Configure certificate verification
- if ssl_mode == "verify-ca":
- context.check_hostname = False
- elif ssl_mode == "verify-full":
- context.check_hostname = True
- # Load root certificate if provided
- if self.ssl_root_cert:
- if os.path.exists(self.ssl_root_cert):
- context.load_verify_locations(cafile=self.ssl_root_cert)
- logger.info(
- f"PostgreSQL, Loaded SSL root certificate: {self.ssl_root_cert}"
- )
- else:
- logger.warning(
- f"PostgreSQL, SSL root certificate file not found: {self.ssl_root_cert}"
- )
- # Load client certificate and key if provided
- if self.ssl_cert and self.ssl_key:
- if os.path.exists(self.ssl_cert) and os.path.exists(self.ssl_key):
- context.load_cert_chain(self.ssl_cert, self.ssl_key)
- logger.info(
- f"PostgreSQL, Loaded SSL client certificate: {self.ssl_cert}"
- )
- else:
- logger.warning(
- "PostgreSQL, SSL client certificate or key file not found"
- )
- # Load certificate revocation list if provided
- if self.ssl_crl:
- if os.path.exists(self.ssl_crl):
- context.load_verify_locations(crlfile=self.ssl_crl)
- logger.info(f"PostgreSQL, Loaded SSL CRL: {self.ssl_crl}")
- else:
- logger.warning(
- f"PostgreSQL, SSL CRL file not found: {self.ssl_crl}"
- )
- return context
- except Exception as e:
- logger.error(f"PostgreSQL, Failed to create SSL context: {e}")
- raise ValueError(f"SSL configuration error: {e}")
- # Unknown SSL mode
- logger.warning(f"PostgreSQL, Unknown SSL mode: {ssl_mode}, SSL disabled")
- return None
- async def initdb(self):
- # Prepare connection parameters
- connection_params = {
- "user": self.user,
- "password": self.password,
- "database": self.database,
- "host": self.host,
- "port": self.port,
- "min_size": 1,
- "max_size": self.max,
- }
- # Only add statement_cache_size if it's configured
- if self.statement_cache_size is not None:
- connection_params["statement_cache_size"] = int(self.statement_cache_size)
- logger.info(
- f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}"
- )
- # Add SSL configuration if provided
- ssl_context = self._create_ssl_context()
- if ssl_context is not None:
- connection_params["ssl"] = ssl_context
- logger.info("PostgreSQL, SSL configuration applied")
- elif self.ssl_mode:
- # Handle simple SSL modes without custom context
- if self.ssl_mode.lower() in ["require", "prefer"]:
- connection_params["ssl"] = True
- elif self.ssl_mode.lower() == "disable":
- connection_params["ssl"] = False
- logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
- # Add server settings if provided
- if self.server_settings:
- try:
- settings = {}
- # The format is expected to be a query string, e.g., "key1=value1&key2=value2"
- pairs = self.server_settings.split("&")
- for pair in pairs:
- if "=" in pair:
- key, value = pair.split("=", 1)
- settings[key] = value
- if settings:
- connection_params["server_settings"] = settings
- logger.info(f"PostgreSQL, Server settings applied: {settings}")
- except Exception as e:
- logger.warning(
- f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}"
- )
- wait_strategy = (
- wait_exponential(
- multiplier=self.connection_retry_backoff,
- min=self.connection_retry_backoff,
- max=self.connection_retry_backoff_max,
- )
- if self.connection_retry_backoff > 0
- else wait_fixed(0)
- )
- async def _init_connection(connection: asyncpg.Connection) -> None:
- """Initialize each new connection with pgvector codec and VCHORDRQ session params.
- Called once per physical connection creation (not on pool reuse).
- register_vector is a Python-level codec registration that survives
- asyncpg's RESET ALL; VCHORDRQ GUCs do not — they are re-applied in
- _reset_connection after each pool release.
- """
- if self.enable_vector:
- await register_vector(connection)
- if self.enable_vector and self.vector_index_type == "VCHORDRQ":
- await self.configure_vchordrq(connection)
- async def _reset_connection(connection: asyncpg.Connection) -> None:
- """Run the default asyncpg cleanup, then re-apply VCHORDRQ session GUCs.
- When a custom reset= callback is registered with create_pool(), asyncpg
- calls Connection._reset() (private — clears listeners and rolls back open
- transactions if any) and then this function. It does NOT call the public
- Connection.reset(), which is the method that calls _reset() and then
- executes the cleanup query returned by get_reset_query() — the exact SQL
- depends on detected server capabilities and typically includes
- pg_advisory_unlock_all(), CLOSE ALL, UNLISTEN *, and RESET ALL.
- We must therefore run that cleanup ourselves via get_reset_query() before
- restoring VCHORDRQ GUCs. Skipping this step leaks session state across
- pool checkouts — for example configure_age() sets search_path and that
- modified path would persist into the next non-AGE connection checkout.
- register_vector is NOT repeated here: it is a Python-side encoder/decoder
- registration on the asyncpg Connection object and is unaffected by RESET ALL.
- Note that set_type_codec() clears the statement cache, which is naturally
- repopulated on subsequent queries.
- """
- try:
- # Run the default cleanup that asyncpg would otherwise handle.
- reset_query = connection.get_reset_query()
- if reset_query:
- await connection.execute(reset_query)
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Pool reset cleanup query failed — connection "
- f"will be terminated and removed from pool: {e}"
- )
- raise
- # RESET ALL clears session GUCs; restore VCHORDRQ values afterward.
- if self.enable_vector and self.vector_index_type == "VCHORDRQ":
- try:
- await self.configure_vchordrq(connection)
- except asyncpg.exceptions.UndefinedObjectError:
- logger.error(
- f"[{self.workspace}] VCHORDRQ extension is not installed. "
- "Install the extension or set vector_index_type to a supported value. "
- "Connection will be terminated and removed from pool."
- )
- raise
- except asyncpg.exceptions.InvalidParameterValueError as e:
- logger.error(
- f"[{self.workspace}] Invalid VCHORDRQ GUC parameter — "
- f"check vchordrq_probes and vchordrq_epsilon config. "
- f"Connection will be terminated: {e}"
- )
- raise
- except Exception as e:
- logger.error(
- f"[{self.workspace}] VCHORDRQ session configuration failed "
- f"after pool reset — connection will be terminated: {e}"
- )
- raise
- async def _create_pool_once() -> None:
- # STEP 1: Bootstrap - ensure vector extension exists BEFORE pool creation.
- # On a fresh database, register_vector() in _init_connection will fail
- # if the vector extension doesn't exist yet, because the 'vector' type
- # won't be found in pg_catalog. We must create the extension first
- # using a standalone bootstrap connection.
- # Skip this step if vector support is not enabled.
- if self.enable_vector:
- bootstrap_conn = await asyncpg.connect(
- user=self.user,
- password=self.password,
- database=self.database,
- host=self.host,
- port=self.port,
- ssl=connection_params.get("ssl"),
- )
- try:
- await self.configure_vector_extension(bootstrap_conn)
- finally:
- await bootstrap_conn.close()
- # STEP 2: Now safe to create pool with register_vector callback.
- # The vector extension is guaranteed to exist at this point (if enabled).
- pool = await asyncpg.create_pool(
- **connection_params,
- init=_init_connection, # register pgvector codec on new connections
- reset=_reset_connection, # re-apply VCHORDRQ GUCs after RESET ALL
- ) # type: ignore
- self.pool = pool
- try:
- async for attempt in AsyncRetrying(
- stop=stop_after_attempt(self.connection_retry_attempts),
- retry=retry_if_exception_type(self._transient_exceptions),
- wait=wait_strategy,
- before_sleep=self._before_sleep,
- reraise=True,
- ):
- with attempt:
- await _create_pool_once()
- ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL"
- logger.info(
- f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database} {ssl_status}"
- )
- except Exception as e:
- logger.error(
- f"PostgreSQL, Failed to connect database at {self.host}:{self.port}/{self.database}, Got:{e}"
- )
- raise
- async def _ensure_pool(self) -> None:
- """Ensure the connection pool is initialised."""
- if self.pool is None:
- async with self._pool_reconnect_lock:
- if self.pool is None:
- await self.initdb()
- async def _reset_pool(self) -> None:
- async with self._pool_reconnect_lock:
- if self.pool is not None:
- try:
- await asyncio.wait_for(
- self.pool.close(), timeout=self.pool_close_timeout
- )
- except asyncio.TimeoutError:
- logger.error(
- "PostgreSQL, Timed out closing connection pool after %.2fs",
- self.pool_close_timeout,
- )
- except Exception as close_error: # pragma: no cover - defensive logging
- logger.warning(
- f"PostgreSQL, Failed to close existing connection pool cleanly: {close_error!r}"
- )
- self.pool = None
- async def _before_sleep(self, retry_state: RetryCallState) -> None:
- """Hook invoked by tenacity before sleeping between retries."""
- exc = retry_state.outcome.exception() if retry_state.outcome else None
- logger.warning(
- "PostgreSQL transient connection issue on attempt %s/%s: %r",
- retry_state.attempt_number,
- self.connection_retry_attempts,
- exc,
- )
- await self._reset_pool()
- async def _run_with_retry(
- self,
- operation: Callable[[asyncpg.Connection], Awaitable[T]],
- *,
- with_age: bool = False,
- graph_name: str | None = None,
- timing_label: str | None = None,
- ) -> T:
- """
- Execute a database operation with automatic retry for transient failures.
- Args:
- operation: Async callable that receives an active connection.
- with_age: Whether to configure Apache AGE on the connection.
- graph_name: AGE graph name; required when with_age is True.
- Returns:
- The result returned by the operation.
- Raises:
- Exception: Propagates the last error if all retry attempts fail or a non-transient error occurs.
- """
- wait_strategy = (
- wait_exponential(
- multiplier=self.connection_retry_backoff,
- min=self.connection_retry_backoff,
- max=self.connection_retry_backoff_max,
- )
- if self.connection_retry_backoff > 0
- else wait_fixed(0)
- )
- async for attempt in AsyncRetrying(
- stop=stop_after_attempt(self.connection_retry_attempts),
- retry=retry_if_exception_type(self._transient_exceptions),
- wait=wait_strategy,
- before_sleep=self._before_sleep,
- reraise=True,
- ):
- with attempt:
- await self._ensure_pool()
- assert self.pool is not None
- if timing_label:
- pool_snapshot_before = self._get_pool_snapshot()
- performance_timing_log(
- "[%s] pool.acquire waiting %s",
- timing_label,
- pool_snapshot_before,
- )
- acquire_start = time.perf_counter()
- async with self.pool.acquire() as connection: # type: ignore[arg-type]
- acquire_elapsed = time.perf_counter() - acquire_start
- if timing_label:
- pool_snapshot_after = self._get_pool_snapshot()
- performance_timing_log(
- "[%s] pool.acquire completed in %.4fs %s",
- timing_label,
- acquire_elapsed,
- pool_snapshot_after,
- )
- if with_age and graph_name:
- await self.configure_age(connection, graph_name)
- elif with_age and not graph_name:
- raise ValueError("Graph name is required when with_age is True")
- return await operation(connection)
- def _get_pool_snapshot(self) -> str:
- """Best-effort snapshot of asyncpg pool state for diagnostics.
- Uses asyncpg private attributes defensively; if a field is unavailable in the
- installed asyncpg version, return '?' for that metric instead of failing.
- """
- pool = self.pool
- if pool is None:
- return "pool_state=uninitialized"
- holders = getattr(pool, "_holders", None)
- queue = getattr(pool, "_queue", None)
- max_size = getattr(pool, "_maxsize", None)
- min_size = getattr(pool, "_minsize", None)
- total_holders = len(holders) if holders is not None else "?"
- idle_count: int | str = "?"
- acquired_count: int | str = "?"
- if holders is not None:
- idle_count = 0
- acquired_count = 0
- for holder in holders:
- # asyncpg holder uses _in_use Future/Event-like marker; treat present value as acquired
- in_use_marker = getattr(holder, "_in_use", None)
- if in_use_marker:
- acquired_count += 1
- else:
- idle_count += 1
- waiting_count: int | str = "?"
- if queue is not None:
- getters = getattr(queue, "_getters", None)
- if getters is not None:
- waiting_count = len(getters)
- return (
- f"pool_state[min={min_size}, max={max_size}, holders={total_holders}, "
- f"acquired={acquired_count}, idle={idle_count}, waiting={waiting_count}]"
- )
- async def configure_vector_extension(self, connection: asyncpg.Connection) -> None:
- """Create VECTOR extension if it doesn't exist for vector similarity operations.
- When vector_index_type is HNSW_HALFVEC, validates that pgvector >= 0.7.0
- (required for halfvec support) and raises RuntimeError if older.
- """
- try:
- await connection.execute("CREATE EXTENSION IF NOT EXISTS vector") # type: ignore
- logger.info("PostgreSQL, VECTOR extension enabled")
- except Exception as e:
- logger.warning(f"Could not create VECTOR extension: {e}")
- # Don't raise - let the system continue without vector extension
- return
- if getattr(self, "vector_index_type", None) == "HNSW_HALFVEC":
- row = await connection.fetchrow(
- "SELECT extversion FROM pg_extension WHERE extname = 'vector'"
- )
- if not row or not row["extversion"]:
- raise RuntimeError(
- "POSTGRES_VECTOR_INDEX_TYPE=HNSW_HALFVEC requires the pgvector "
- "extension. Ensure it is installed and CREATE EXTENSION vector succeeded."
- )
- raw_version = row["extversion"]
- try:
- parts = [int(p) for p in str(raw_version).split(".")[:3]]
- while len(parts) < 3:
- parts.append(0)
- version_tuple = (parts[0], parts[1], parts[2])
- except (ValueError, IndexError):
- raise RuntimeError(
- f"Could not parse pgvector version {raw_version!r}. "
- "HNSW_HALFVEC requires pgvector >= 0.7.0."
- ) from None
- if version_tuple < (0, 7, 0):
- raise RuntimeError(
- f"POSTGRES_VECTOR_INDEX_TYPE=HNSW_HALFVEC requires pgvector >= 0.7.0, "
- f"but installed version is {raw_version}. Upgrade the pgvector extension "
- "or use a different index type (e.g. HNSW with embeddings <= 2000 dimensions)."
- )
- @staticmethod
- async def configure_age_extension(connection: asyncpg.Connection) -> None:
- """Create AGE extension if it doesn't exist for graph operations."""
- try:
- await connection.execute("CREATE EXTENSION IF NOT EXISTS AGE CASCADE") # type: ignore
- logger.info("PostgreSQL, AGE extension enabled")
- except Exception as e:
- logger.warning(f"Could not create AGE extension: {e}")
- # Don't raise - let the system continue without AGE extension
- @staticmethod
- async def configure_age(connection: asyncpg.Connection, graph_name: str) -> None:
- """Set the Apache AGE environment and creates a graph if it does not exist.
- This method:
- - Sets the PostgreSQL `search_path` to include `ag_catalog`, ensuring that Apache AGE functions can be used without specifying the schema.
- - Attempts to create a new graph with the provided `graph_name` if it does not already exist.
- - Silently ignores errors related to the graph already existing.
- """
- try:
- await connection.execute( # type: ignore
- 'SET search_path = ag_catalog, "$user", public'
- )
- await connection.execute( # type: ignore
- f"select create_graph('{graph_name}')"
- )
- except (
- asyncpg.exceptions.InvalidSchemaNameError,
- asyncpg.exceptions.UniqueViolationError,
- ):
- pass
- async def configure_vchordrq(self, connection: asyncpg.Connection) -> None:
- """Configure VCHORDRQ extension for vector similarity search.
- Raises:
- asyncpg.exceptions.UndefinedObjectError: If VCHORDRQ extension is not installed
- asyncpg.exceptions.InvalidParameterValueError: If parameter value is invalid
- Note:
- This method does not catch exceptions. Configuration errors will fail-fast,
- while transient connection errors will be retried by _run_with_retry.
- """
- # Handle probes parameter - only set if non-empty value is provided
- if self.vchordrq_probes and str(self.vchordrq_probes).strip():
- await connection.execute(f"SET vchordrq.probes TO '{self.vchordrq_probes}'")
- logger.debug(f"PostgreSQL, VCHORDRQ probes set to: {self.vchordrq_probes}")
- # Handle epsilon parameter independently - check for None to allow 0.0 as valid value
- if self.vchordrq_epsilon is not None:
- await connection.execute(f"SET vchordrq.epsilon TO {self.vchordrq_epsilon}")
- logger.debug(
- f"PostgreSQL, VCHORDRQ epsilon set to: {self.vchordrq_epsilon}"
- )
- async def _migrate_llm_cache_schema(self):
- """Migrate LLM cache schema: add new columns and remove deprecated mode field"""
- try:
- # Check if all columns exist
- check_columns_sql = """
- SELECT column_name
- FROM information_schema.columns
- WHERE table_name = 'lightrag_llm_cache'
- AND column_name IN ('chunk_id', 'cache_type', 'queryparam', 'mode')
- """
- existing_columns = await self.query(check_columns_sql, multirows=True)
- existing_column_names = (
- {col["column_name"] for col in existing_columns}
- if existing_columns
- else set()
- )
- # Add missing chunk_id column
- if "chunk_id" not in existing_column_names:
- logger.info("Adding chunk_id column to LIGHTRAG_LLM_CACHE table")
- add_chunk_id_sql = """
- ALTER TABLE LIGHTRAG_LLM_CACHE
- ADD COLUMN chunk_id VARCHAR(255) NULL
- """
- await self.execute(add_chunk_id_sql)
- logger.info(
- "Successfully added chunk_id column to LIGHTRAG_LLM_CACHE table"
- )
- else:
- logger.info(
- "chunk_id column already exists in LIGHTRAG_LLM_CACHE table"
- )
- # Add missing cache_type column
- if "cache_type" not in existing_column_names:
- logger.info("Adding cache_type column to LIGHTRAG_LLM_CACHE table")
- add_cache_type_sql = """
- ALTER TABLE LIGHTRAG_LLM_CACHE
- ADD COLUMN cache_type VARCHAR(32) NULL
- """
- await self.execute(add_cache_type_sql)
- logger.info(
- "Successfully added cache_type column to LIGHTRAG_LLM_CACHE table"
- )
- # Migrate existing data using optimized regex pattern
- logger.info(
- "Migrating existing LLM cache data to populate cache_type field (optimized)"
- )
- optimized_update_sql = """
- UPDATE LIGHTRAG_LLM_CACHE
- SET cache_type = CASE
- WHEN id ~ '^[^:]+:[^:]+:' THEN split_part(id, ':', 2)
- ELSE 'extract'
- END
- WHERE cache_type IS NULL
- """
- await self.execute(optimized_update_sql)
- logger.info("Successfully migrated existing LLM cache data")
- else:
- logger.info(
- "cache_type column already exists in LIGHTRAG_LLM_CACHE table"
- )
- # Add missing queryparam column
- if "queryparam" not in existing_column_names:
- logger.info("Adding queryparam column to LIGHTRAG_LLM_CACHE table")
- add_queryparam_sql = """
- ALTER TABLE LIGHTRAG_LLM_CACHE
- ADD COLUMN queryparam JSONB NULL
- """
- await self.execute(add_queryparam_sql)
- logger.info(
- "Successfully added queryparam column to LIGHTRAG_LLM_CACHE table"
- )
- else:
- logger.info(
- "queryparam column already exists in LIGHTRAG_LLM_CACHE table"
- )
- # Remove deprecated mode field if it exists
- if "mode" in existing_column_names:
- logger.info(
- "Removing deprecated mode column from LIGHTRAG_LLM_CACHE table"
- )
- # First, drop the primary key constraint that includes mode
- drop_pk_sql = """
- ALTER TABLE LIGHTRAG_LLM_CACHE
- DROP CONSTRAINT IF EXISTS LIGHTRAG_LLM_CACHE_PK
- """
- await self.execute(drop_pk_sql)
- logger.info("Dropped old primary key constraint")
- # Drop the mode column
- drop_mode_sql = """
- ALTER TABLE LIGHTRAG_LLM_CACHE
- DROP COLUMN mode
- """
- await self.execute(drop_mode_sql)
- logger.info(
- "Successfully removed mode column from LIGHTRAG_LLM_CACHE table"
- )
- # Create new primary key constraint without mode
- add_pk_sql = """
- ALTER TABLE LIGHTRAG_LLM_CACHE
- ADD CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, id)
- """
- await self.execute(add_pk_sql)
- logger.info("Created new primary key constraint (workspace, id)")
- else:
- logger.info("mode column does not exist in LIGHTRAG_LLM_CACHE table")
- except Exception as e:
- logger.warning(f"Failed to migrate LLM cache schema: {e}")
- async def _migrate_timestamp_columns(self):
- """Migrate timestamp columns in tables to witimezone-free types, assuming original data is in UTC time"""
- # Tables and columns that need migration
- tables_to_migrate = {
- "LIGHTRAG_VDB_ENTITY": ["create_time", "update_time"],
- "LIGHTRAG_VDB_RELATION": ["create_time", "update_time"],
- "LIGHTRAG_DOC_CHUNKS": ["create_time", "update_time"],
- "LIGHTRAG_DOC_STATUS": ["created_at", "updated_at"],
- }
- try:
- # Filter out tables that don't exist (e.g., legacy vector tables may not exist)
- existing_tables = {}
- for table_name, columns in tables_to_migrate.items():
- if await self.check_table_exists(table_name):
- existing_tables[table_name] = columns
- else:
- logger.debug(
- f"Table {table_name} does not exist, skipping timestamp migration"
- )
- # Skip if no tables to migrate
- if not existing_tables:
- logger.debug("No tables found for timestamp migration")
- return
- # Use filtered tables for migration
- tables_to_migrate = existing_tables
- # Optimization: Batch check all columns in one query instead of 8 separate queries
- table_names_lower = [t.lower() for t in tables_to_migrate.keys()]
- all_column_names = list(
- set(col for cols in tables_to_migrate.values() for col in cols)
- )
- check_all_columns_sql = """
- SELECT table_name, column_name, data_type
- FROM information_schema.columns
- WHERE table_name = ANY($1)
- AND column_name = ANY($2)
- """
- all_columns_result = await self.query(
- check_all_columns_sql,
- [table_names_lower, all_column_names],
- multirows=True,
- )
- # Build lookup dict: (table_name, column_name) -> data_type
- column_types = {}
- if all_columns_result:
- column_types = {
- (row["table_name"].upper(), row["column_name"]): row["data_type"]
- for row in all_columns_result
- }
- # Now iterate and migrate only what's needed
- for table_name, columns in tables_to_migrate.items():
- for column_name in columns:
- try:
- data_type = column_types.get((table_name, column_name))
- if not data_type:
- logger.warning(
- f"Column {table_name}.{column_name} does not exist, skipping migration"
- )
- continue
- # Check column type
- if data_type == "timestamp without time zone":
- logger.debug(
- f"Column {table_name}.{column_name} is already witimezone-free, no migration needed"
- )
- continue
- # Execute migration, explicitly specifying UTC timezone for interpreting original data
- logger.info(
- f"Migrating {table_name}.{column_name} from {data_type} to TIMESTAMP(0) type"
- )
- migration_sql = f"""
- ALTER TABLE {table_name}
- ALTER COLUMN {column_name} TYPE TIMESTAMP(0),
- ALTER COLUMN {column_name} SET DEFAULT CURRENT_TIMESTAMP
- """
- await self.execute(migration_sql)
- logger.info(
- f"Successfully migrated {table_name}.{column_name} to timezone-free type"
- )
- except Exception as e:
- # Log error but don't interrupt the process
- logger.warning(
- f"Failed to migrate {table_name}.{column_name}: {e}"
- )
- except Exception as e:
- logger.error(f"Failed to batch check timestamp columns: {e}")
- async def _migrate_doc_chunks_to_vdb_chunks(self):
- """
- Migrate data from LIGHTRAG_DOC_CHUNKS to LIGHTRAG_VDB_CHUNKS if specific conditions are met.
- This migration is intended for users who are upgrading and have an older table structure
- where LIGHTRAG_DOC_CHUNKS contained a `content_vector` column.
- """
- try:
- # 0. Check if both tables exist before proceeding
- vdb_chunks_exists = await self.check_table_exists("LIGHTRAG_VDB_CHUNKS")
- doc_chunks_exists = await self.check_table_exists("LIGHTRAG_DOC_CHUNKS")
- if not vdb_chunks_exists:
- logger.debug(
- "Skipping migration: LIGHTRAG_VDB_CHUNKS table does not exist"
- )
- return
- if not doc_chunks_exists:
- logger.debug(
- "Skipping migration: LIGHTRAG_DOC_CHUNKS table does not exist"
- )
- return
- # 1. Check if the new table LIGHTRAG_VDB_CHUNKS is empty
- vdb_chunks_count_sql = "SELECT COUNT(1) as count FROM LIGHTRAG_VDB_CHUNKS"
- vdb_chunks_count_result = await self.query(vdb_chunks_count_sql)
- if vdb_chunks_count_result and vdb_chunks_count_result["count"] > 0:
- logger.info(
- "Skipping migration: LIGHTRAG_VDB_CHUNKS already contains data."
- )
- return
- # 2. Check if `content_vector` column exists in the old table
- check_column_sql = """
- SELECT 1 FROM information_schema.columns
- WHERE table_name = 'lightrag_doc_chunks' AND column_name = 'content_vector'
- """
- column_exists = await self.query(check_column_sql)
- if not column_exists:
- logger.info(
- "Skipping migration: `content_vector` not found in LIGHTRAG_DOC_CHUNKS"
- )
- return
- # 3. Check if the old table LIGHTRAG_DOC_CHUNKS has data
- doc_chunks_count_sql = "SELECT COUNT(1) as count FROM LIGHTRAG_DOC_CHUNKS"
- doc_chunks_count_result = await self.query(doc_chunks_count_sql)
- if not doc_chunks_count_result or doc_chunks_count_result["count"] == 0:
- logger.info("Skipping migration: LIGHTRAG_DOC_CHUNKS is empty.")
- return
- # 4. Perform the migration
- logger.info(
- "Starting data migration from LIGHTRAG_DOC_CHUNKS to LIGHTRAG_VDB_CHUNKS..."
- )
- migration_sql = """
- INSERT INTO LIGHTRAG_VDB_CHUNKS (
- id, workspace, full_doc_id, chunk_order_index, tokens, content,
- content_vector, file_path, create_time, update_time
- )
- SELECT
- id, workspace, full_doc_id, chunk_order_index, tokens, content,
- content_vector, file_path, create_time, update_time
- FROM LIGHTRAG_DOC_CHUNKS
- ON CONFLICT (workspace, id) DO NOTHING;
- """
- await self.execute(migration_sql)
- logger.info("Data migration to LIGHTRAG_VDB_CHUNKS completed successfully.")
- except Exception as e:
- logger.error(f"Failed during data migration to LIGHTRAG_VDB_CHUNKS: {e}")
- # Do not re-raise, to allow the application to start
- async def _check_llm_cache_needs_migration(self):
- """Check if LLM cache data needs migration by examining any record with old format"""
- try:
- # Optimized query: directly check for old format records without sorting
- check_sql = """
- SELECT 1 FROM LIGHTRAG_LLM_CACHE
- WHERE id NOT LIKE '%:%'
- LIMIT 1
- """
- result = await self.query(check_sql)
- # If any old format record exists, migration is needed
- return result is not None
- except Exception as e:
- logger.warning(f"Failed to check LLM cache migration status: {e}")
- return False
- async def _migrate_llm_cache_to_flattened_keys(self):
- """Optimized version: directly execute single UPDATE migration to migrate old format cache keys to flattened format"""
- try:
- # Check if migration is needed
- check_sql = """
- SELECT COUNT(*) as count FROM LIGHTRAG_LLM_CACHE
- WHERE id NOT LIKE '%:%'
- """
- result = await self.query(check_sql)
- if not result or result["count"] == 0:
- logger.info("No old format LLM cache data found, skipping migration")
- return
- old_count = result["count"]
- logger.info(f"Found {old_count} old format cache records")
- # Check potential primary key conflicts (optional but recommended)
- conflict_check_sql = """
- WITH new_ids AS (
- SELECT
- workspace,
- mode,
- id as old_id,
- mode || ':' ||
- CASE WHEN mode = 'default' THEN 'extract' ELSE 'unknown' END || ':' ||
- md5(original_prompt) as new_id
- FROM LIGHTRAG_LLM_CACHE
- WHERE id NOT LIKE '%:%'
- )
- SELECT COUNT(*) as conflicts
- FROM new_ids n1
- JOIN LIGHTRAG_LLM_CACHE existing
- ON existing.workspace = n1.workspace
- AND existing.mode = n1.mode
- AND existing.id = n1.new_id
- WHERE existing.id LIKE '%:%' -- Only check conflicts with existing new format records
- """
- conflict_result = await self.query(conflict_check_sql)
- if conflict_result and conflict_result["conflicts"] > 0:
- logger.warning(
- f"Found {conflict_result['conflicts']} potential ID conflicts with existing records"
- )
- # Can choose to continue or abort, here we choose to continue and log warning
- # Execute single UPDATE migration
- logger.info("Starting optimized LLM cache migration...")
- migration_sql = """
- UPDATE LIGHTRAG_LLM_CACHE
- SET
- id = mode || ':' ||
- CASE WHEN mode = 'default' THEN 'extract' ELSE 'unknown' END || ':' ||
- md5(original_prompt),
- cache_type = CASE WHEN mode = 'default' THEN 'extract' ELSE 'unknown' END,
- update_time = CURRENT_TIMESTAMP
- WHERE id NOT LIKE '%:%'
- """
- # Execute migration
- await self.execute(migration_sql)
- # Verify migration results
- verify_sql = """
- SELECT COUNT(*) as remaining_old FROM LIGHTRAG_LLM_CACHE
- WHERE id NOT LIKE '%:%'
- """
- verify_result = await self.query(verify_sql)
- remaining = verify_result["remaining_old"] if verify_result else -1
- if remaining == 0:
- logger.info(
- f"✅ Successfully migrated {old_count} LLM cache records to flattened format"
- )
- else:
- logger.warning(
- f"⚠️ Migration completed but {remaining} old format records remain"
- )
- except Exception as e:
- logger.error(f"Optimized LLM cache migration failed: {e}")
- raise
- async def _migrate_doc_status_add_chunks_list(self):
- """Add chunks_list column to LIGHTRAG_DOC_STATUS table if it doesn't exist"""
- try:
- # Check if chunks_list column exists
- check_column_sql = """
- SELECT column_name
- FROM information_schema.columns
- WHERE table_name = 'lightrag_doc_status'
- AND column_name = 'chunks_list'
- """
- column_info = await self.query(check_column_sql)
- if not column_info:
- logger.info("Adding chunks_list column to LIGHTRAG_DOC_STATUS table")
- add_column_sql = """
- ALTER TABLE LIGHTRAG_DOC_STATUS
- ADD COLUMN chunks_list JSONB NULL DEFAULT '[]'::jsonb
- """
- await self.execute(add_column_sql)
- logger.info(
- "Successfully added chunks_list column to LIGHTRAG_DOC_STATUS table"
- )
- else:
- logger.info(
- "chunks_list column already exists in LIGHTRAG_DOC_STATUS table"
- )
- except Exception as e:
- logger.warning(
- f"Failed to add chunks_list column to LIGHTRAG_DOC_STATUS: {e}"
- )
- async def _migrate_text_chunks_add_llm_cache_list(self):
- """Add llm_cache_list column to LIGHTRAG_DOC_CHUNKS table if it doesn't exist"""
- try:
- # Check if llm_cache_list column exists
- check_column_sql = """
- SELECT column_name
- FROM information_schema.columns
- WHERE table_name = 'lightrag_doc_chunks'
- AND column_name = 'llm_cache_list'
- """
- column_info = await self.query(check_column_sql)
- if not column_info:
- logger.info("Adding llm_cache_list column to LIGHTRAG_DOC_CHUNKS table")
- add_column_sql = """
- ALTER TABLE LIGHTRAG_DOC_CHUNKS
- ADD COLUMN llm_cache_list JSONB NULL DEFAULT '[]'::jsonb
- """
- await self.execute(add_column_sql)
- logger.info(
- "Successfully added llm_cache_list column to LIGHTRAG_DOC_CHUNKS table"
- )
- else:
- logger.info(
- "llm_cache_list column already exists in LIGHTRAG_DOC_CHUNKS table"
- )
- except Exception as e:
- logger.warning(
- f"Failed to add llm_cache_list column to LIGHTRAG_DOC_CHUNKS: {e}"
- )
- async def _migrate_doc_status_add_track_id(self):
- """Add track_id column to LIGHTRAG_DOC_STATUS table if it doesn't exist and create index"""
- try:
- # Check if track_id column exists
- check_column_sql = """
- SELECT column_name
- FROM information_schema.columns
- WHERE table_name = 'lightrag_doc_status'
- AND column_name = 'track_id'
- """
- column_info = await self.query(check_column_sql)
- if not column_info:
- logger.info("Adding track_id column to LIGHTRAG_DOC_STATUS table")
- add_column_sql = """
- ALTER TABLE LIGHTRAG_DOC_STATUS
- ADD COLUMN track_id VARCHAR(255) NULL
- """
- await self.execute(add_column_sql)
- logger.info(
- "Successfully added track_id column to LIGHTRAG_DOC_STATUS table"
- )
- else:
- logger.info(
- "track_id column already exists in LIGHTRAG_DOC_STATUS table"
- )
- # Check if track_id index exists
- check_index_sql = """
- SELECT indexname
- FROM pg_indexes
- WHERE tablename = 'lightrag_doc_status'
- AND indexname = 'idx_lightrag_doc_status_track_id'
- """
- index_info = await self.query(check_index_sql)
- if not index_info:
- logger.info(
- "Creating index on track_id column for LIGHTRAG_DOC_STATUS table"
- )
- create_index_sql = """
- CREATE INDEX idx_lightrag_doc_status_track_id ON LIGHTRAG_DOC_STATUS (track_id)
- """
- await self.execute(create_index_sql)
- logger.info(
- "Successfully created index on track_id column for LIGHTRAG_DOC_STATUS table"
- )
- else:
- logger.info(
- "Index on track_id column already exists for LIGHTRAG_DOC_STATUS table"
- )
- except Exception as e:
- logger.warning(
- f"Failed to add track_id column or index to LIGHTRAG_DOC_STATUS: {e}"
- )
- async def _migrate_doc_status_add_metadata_error_msg(self):
- """Add metadata and error_msg columns to LIGHTRAG_DOC_STATUS table if they don't exist"""
- try:
- # Check if metadata column exists
- check_metadata_sql = """
- SELECT column_name
- FROM information_schema.columns
- WHERE table_name = 'lightrag_doc_status'
- AND column_name = 'metadata'
- """
- metadata_info = await self.query(check_metadata_sql)
- if not metadata_info:
- logger.info("Adding metadata column to LIGHTRAG_DOC_STATUS table")
- add_metadata_sql = """
- ALTER TABLE LIGHTRAG_DOC_STATUS
- ADD COLUMN metadata JSONB NULL DEFAULT '{}'::jsonb
- """
- await self.execute(add_metadata_sql)
- logger.info(
- "Successfully added metadata column to LIGHTRAG_DOC_STATUS table"
- )
- else:
- logger.info(
- "metadata column already exists in LIGHTRAG_DOC_STATUS table"
- )
- # Check if error_msg column exists
- check_error_msg_sql = """
- SELECT column_name
- FROM information_schema.columns
- WHERE table_name = 'lightrag_doc_status'
- AND column_name = 'error_msg'
- """
- error_msg_info = await self.query(check_error_msg_sql)
- if not error_msg_info:
- logger.info("Adding error_msg column to LIGHTRAG_DOC_STATUS table")
- add_error_msg_sql = """
- ALTER TABLE LIGHTRAG_DOC_STATUS
- ADD COLUMN error_msg TEXT NULL
- """
- await self.execute(add_error_msg_sql)
- logger.info(
- "Successfully added error_msg column to LIGHTRAG_DOC_STATUS table"
- )
- else:
- logger.info(
- "error_msg column already exists in LIGHTRAG_DOC_STATUS table"
- )
- except Exception as e:
- logger.warning(
- f"Failed to add metadata/error_msg columns to LIGHTRAG_DOC_STATUS: {e}"
- )
- async def _migrate_doc_full_add_pipeline_fields(self):
- """Add pipeline-derived fields to LIGHTRAG_DOC_FULL if they don't exist.
- Each ALTER is guarded individually so a single failure does not abort
- the remaining columns; the migration is idempotent and retried on
- every startup until all columns are present.
- """
- # content_hash uses TEXT (not VARCHAR(N)) so the column stays
- # algorithm-agnostic; future SHA-512 / base64 hashes do not require a
- # schema change. process_options is an opaque selector string emitted
- # by sanitize_process_options() (e.g. "Fi").
- columns_to_add = [
- ("sidecar_location", "TEXT NULL"),
- ("parse_format", "VARCHAR(32) NULL DEFAULT 'raw'"),
- ("content_hash", "TEXT NULL"),
- ("process_options", "TEXT NULL"),
- ("chunk_options", "JSONB NULL DEFAULT '{}'::jsonb"),
- ("parse_engine", "VARCHAR(32) NULL"),
- ]
- try:
- existing = await self.query(
- """
- SELECT column_name
- FROM information_schema.columns
- WHERE table_name = 'lightrag_doc_full'
- AND column_name = ANY($1)
- """,
- [[c for c, _ in columns_to_add]],
- multirows=True,
- )
- existing_names = {row["column_name"] for row in (existing or [])}
- except Exception as e:
- logger.warning(
- f"Failed to inspect LIGHTRAG_DOC_FULL columns for migration: {e}"
- )
- existing_names = set()
- for col_name, col_type in columns_to_add:
- if col_name in existing_names:
- logger.debug(f"Column {col_name} already exists in LIGHTRAG_DOC_FULL")
- continue
- try:
- alter_sql = (
- f"ALTER TABLE LIGHTRAG_DOC_FULL ADD COLUMN {col_name} {col_type}"
- )
- logger.info(f"Adding {col_name} column to LIGHTRAG_DOC_FULL table")
- await self.execute(alter_sql)
- logger.info(
- f"Successfully added {col_name} column to LIGHTRAG_DOC_FULL table"
- )
- except Exception as e:
- logger.error(
- f"Failed to add column {col_name} to LIGHTRAG_DOC_FULL: {e}"
- )
- async def _migrate_doc_status_add_content_hash(self):
- """Add content_hash column to LIGHTRAG_DOC_STATUS table if it doesn't exist."""
- try:
- check_column_sql = """
- SELECT column_name
- FROM information_schema.columns
- WHERE table_name = 'lightrag_doc_status'
- AND column_name = 'content_hash'
- """
- column_info = await self.query(check_column_sql)
- if not column_info:
- logger.info("Adding content_hash column to LIGHTRAG_DOC_STATUS table")
- # TEXT (not VARCHAR(N)) so the column is agnostic to the hash
- # algorithm; today the pipeline writes 64-char SHA-256 hex.
- await self.execute(
- "ALTER TABLE LIGHTRAG_DOC_STATUS ADD COLUMN content_hash TEXT NULL"
- )
- logger.info(
- "Successfully added content_hash column to LIGHTRAG_DOC_STATUS table"
- )
- else:
- logger.debug(
- "content_hash column already exists in LIGHTRAG_DOC_STATUS table"
- )
- except Exception as e:
- logger.error(
- f"Failed to add content_hash column to LIGHTRAG_DOC_STATUS: {e}"
- )
- try:
- check_index_sql = """
- SELECT indexname FROM pg_indexes
- WHERE tablename = 'lightrag_doc_status'
- AND indexname = 'idx_lightrag_doc_status_workspace_content_hash'
- """
- index_info = await self.query(check_index_sql)
- if not index_info:
- logger.info(
- "Creating partial index idx_lightrag_doc_status_workspace_content_hash"
- )
- await self.execute(
- """
- CREATE INDEX IF NOT EXISTS idx_lightrag_doc_status_workspace_content_hash
- ON LIGHTRAG_DOC_STATUS (workspace, content_hash)
- WHERE content_hash IS NOT NULL AND content_hash <> ''
- """
- )
- except Exception as e:
- logger.error(
- f"Failed to create partial content_hash index on LIGHTRAG_DOC_STATUS: {e}"
- )
- async def _migrate_text_chunks_add_heading_sidecar(self):
- """Add heading and sidecar JSONB columns to LIGHTRAG_DOC_CHUNKS if missing."""
- columns_to_add = [
- ("heading", "JSONB NULL DEFAULT '{}'::jsonb"),
- ("sidecar", "JSONB NULL DEFAULT '{}'::jsonb"),
- ]
- try:
- existing = await self.query(
- """
- SELECT column_name
- FROM information_schema.columns
- WHERE table_name = 'lightrag_doc_chunks'
- AND column_name = ANY($1)
- """,
- [[c for c, _ in columns_to_add]],
- multirows=True,
- )
- existing_names = {row["column_name"] for row in (existing or [])}
- except Exception as e:
- logger.warning(
- f"Failed to inspect LIGHTRAG_DOC_CHUNKS columns for migration: {e}"
- )
- existing_names = set()
- for col_name, col_type in columns_to_add:
- if col_name in existing_names:
- logger.debug(f"Column {col_name} already exists in LIGHTRAG_DOC_CHUNKS")
- continue
- try:
- alter_sql = (
- f"ALTER TABLE LIGHTRAG_DOC_CHUNKS ADD COLUMN {col_name} {col_type}"
- )
- logger.info(f"Adding {col_name} column to LIGHTRAG_DOC_CHUNKS table")
- await self.execute(alter_sql)
- logger.info(
- f"Successfully added {col_name} column to LIGHTRAG_DOC_CHUNKS table"
- )
- except Exception as e:
- logger.error(
- f"Failed to add column {col_name} to LIGHTRAG_DOC_CHUNKS: {e}"
- )
- async def _migrate_field_lengths(self):
- """Migrate database field lengths: entity_name, source_id, target_id, and file_path"""
- # Define the field changes needed
- field_migrations = [
- {
- "table": "LIGHTRAG_VDB_ENTITY",
- "column": "entity_name",
- "old_type": "character varying(255)",
- "new_type": "VARCHAR(512)",
- "description": "entity_name from 255 to 512",
- },
- {
- "table": "LIGHTRAG_VDB_RELATION",
- "column": "source_id",
- "old_type": "character varying(256)",
- "new_type": "VARCHAR(512)",
- "description": "source_id from 256 to 512",
- },
- {
- "table": "LIGHTRAG_VDB_RELATION",
- "column": "target_id",
- "old_type": "character varying(256)",
- "new_type": "VARCHAR(512)",
- "description": "target_id from 256 to 512",
- },
- {
- "table": "LIGHTRAG_DOC_CHUNKS",
- "column": "file_path",
- "old_type": "character varying(256)",
- "new_type": "TEXT",
- "description": "file_path to TEXT NULL",
- },
- {
- "table": "LIGHTRAG_VDB_CHUNKS",
- "column": "file_path",
- "old_type": "character varying(256)",
- "new_type": "TEXT",
- "description": "file_path to TEXT NULL",
- },
- ]
- try:
- # Filter out tables that don't exist (e.g., legacy vector tables may not exist)
- existing_migrations = []
- for migration in field_migrations:
- if await self.check_table_exists(migration["table"]):
- existing_migrations.append(migration)
- else:
- logger.debug(
- f"Table {migration['table']} does not exist, skipping field length migration for {migration['column']}"
- )
- # Skip if no migrations to process
- if not existing_migrations:
- logger.debug("No tables found for field length migration")
- return
- # Use filtered migrations for processing
- field_migrations = existing_migrations
- # Optimization: Batch check all columns in one query instead of 5 separate queries
- unique_tables = list(set(m["table"].lower() for m in field_migrations))
- unique_columns = list(set(m["column"] for m in field_migrations))
- check_all_columns_sql = """
- SELECT table_name, column_name, data_type, character_maximum_length, is_nullable
- FROM information_schema.columns
- WHERE table_name = ANY($1)
- AND column_name = ANY($2)
- """
- all_columns_result = await self.query(
- check_all_columns_sql, [unique_tables, unique_columns], multirows=True
- )
- # Build lookup dict: (table_name, column_name) -> column_info
- column_info_map = {}
- if all_columns_result:
- column_info_map = {
- (row["table_name"].upper(), row["column_name"]): row
- for row in all_columns_result
- }
- # Now iterate and migrate only what's needed
- for migration in field_migrations:
- try:
- column_info = column_info_map.get(
- (migration["table"], migration["column"])
- )
- if not column_info:
- logger.warning(
- f"Column {migration['table']}.{migration['column']} does not exist, skipping migration"
- )
- continue
- current_type = column_info.get("data_type", "").lower()
- current_length = column_info.get("character_maximum_length")
- # Check if migration is needed
- needs_migration = False
- if migration["column"] == "entity_name" and current_length == 255:
- needs_migration = True
- elif (
- migration["column"] in ["source_id", "target_id"]
- and current_length == 256
- ):
- needs_migration = True
- elif (
- migration["column"] == "file_path"
- and current_type == "character varying"
- ):
- needs_migration = True
- if needs_migration:
- logger.info(
- f"Migrating {migration['table']}.{migration['column']}: {migration['description']}"
- )
- # Execute the migration
- alter_sql = f"""
- ALTER TABLE {migration["table"]}
- ALTER COLUMN {migration["column"]} TYPE {migration["new_type"]}
- """
- await self.execute(alter_sql)
- logger.info(
- f"Successfully migrated {migration['table']}.{migration['column']}"
- )
- else:
- logger.debug(
- f"Column {migration['table']}.{migration['column']} already has correct type, no migration needed"
- )
- except Exception as e:
- # Log error but don't interrupt the process
- logger.warning(
- f"Failed to migrate {migration['table']}.{migration['column']}: {e}"
- )
- except Exception as e:
- logger.error(f"Failed to batch check field lengths: {e}")
- async def check_tables(self):
- # Vector tables that should be skipped - they are created by PGVectorStorage.setup_table()
- # with proper embedding model and dimension suffix for data isolation
- vector_tables_to_skip = {
- "LIGHTRAG_VDB_CHUNKS",
- "LIGHTRAG_VDB_ENTITY",
- "LIGHTRAG_VDB_RELATION",
- }
- # First create all tables (except vector tables)
- for k, v in TABLES.items():
- # Skip vector tables - they are created by PGVectorStorage.setup_table()
- if k in vector_tables_to_skip:
- continue
- try:
- await self.query(f"SELECT 1 FROM {k} LIMIT 1")
- except Exception:
- try:
- logger.info(f"PostgreSQL, Try Creating table {k} in database")
- await self.execute(v["ddl"])
- logger.info(
- f"PostgreSQL, Creation success table {k} in PostgreSQL database"
- )
- except Exception as e:
- logger.error(
- f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}"
- )
- raise e
- # Batch check all indexes at once (optimization: single query instead of N queries)
- try:
- # Exclude vector tables from index creation since they are created by PGVectorStorage.setup_table()
- table_names = [k for k in TABLES.keys() if k not in vector_tables_to_skip]
- table_names_lower = [t.lower() for t in table_names]
- # Get all existing indexes for our tables in one query
- check_all_indexes_sql = """
- SELECT indexname, tablename
- FROM pg_indexes
- WHERE tablename = ANY($1)
- """
- existing_indexes_result = await self.query(
- check_all_indexes_sql, [table_names_lower], multirows=True
- )
- # Build a set of existing index names for fast lookup
- existing_indexes = set()
- if existing_indexes_result:
- existing_indexes = {row["indexname"] for row in existing_indexes_result}
- # Create missing indexes
- for k in table_names:
- # Create index for id column if missing
- index_name = f"idx_{k.lower()}_id"
- if index_name not in existing_indexes:
- try:
- create_index_sql = f"CREATE INDEX {index_name} ON {k}(id)"
- logger.info(
- f"PostgreSQL, Creating index {index_name} on table {k}"
- )
- await self.execute(create_index_sql)
- except Exception as e:
- logger.error(
- f"PostgreSQL, Failed to create index {index_name}, Got: {e}"
- )
- # Create composite index for (workspace, id) if missing
- composite_index_name = f"idx_{k.lower()}_workspace_id"
- if composite_index_name not in existing_indexes:
- try:
- create_composite_index_sql = (
- f"CREATE INDEX {composite_index_name} ON {k}(workspace, id)"
- )
- logger.info(
- f"PostgreSQL, Creating composite index {composite_index_name} on table {k}"
- )
- await self.execute(create_composite_index_sql)
- except Exception as e:
- logger.error(
- f"PostgreSQL, Failed to create composite index {composite_index_name}, Got: {e}"
- )
- except Exception as e:
- logger.error(f"PostgreSQL, Failed to batch check/create indexes: {e}")
- # NOTE: Vector index creation moved to PGVectorStorage.setup_table()
- # Each vector storage instance creates its own index with correct embedding_dim
- # After all tables are created, attempt to migrate timestamp fields
- try:
- await self._migrate_timestamp_columns()
- except Exception as e:
- logger.error(f"PostgreSQL, Failed to migrate timestamp columns: {e}")
- # Don't throw an exception, allow the initialization process to continue
- # Migrate LLM cache schema: add new columns and remove deprecated mode field
- try:
- await self._migrate_llm_cache_schema()
- except Exception as e:
- logger.error(f"PostgreSQL, Failed to migrate LLM cache schema: {e}")
- # Don't throw an exception, allow the initialization process to continue
- # Finally, attempt to migrate old doc chunks data if needed
- try:
- await self._migrate_doc_chunks_to_vdb_chunks()
- except Exception as e:
- logger.error(f"PostgreSQL, Failed to migrate doc_chunks to vdb_chunks: {e}")
- # Check and migrate LLM cache to flattened keys if needed
- try:
- if await self._check_llm_cache_needs_migration():
- await self._migrate_llm_cache_to_flattened_keys()
- except Exception as e:
- logger.error(f"PostgreSQL, LLM cache migration failed: {e}")
- # Migrate doc status to add chunks_list field if needed
- try:
- await self._migrate_doc_status_add_chunks_list()
- except Exception as e:
- logger.error(
- f"PostgreSQL, Failed to migrate doc status chunks_list field: {e}"
- )
- # Migrate text chunks to add llm_cache_list field if needed
- try:
- await self._migrate_text_chunks_add_llm_cache_list()
- except Exception as e:
- logger.error(
- f"PostgreSQL, Failed to migrate text chunks llm_cache_list field: {e}"
- )
- # Migrate field lengths for entity_name, source_id, target_id, and file_path
- try:
- await self._migrate_field_lengths()
- except Exception as e:
- logger.error(f"PostgreSQL, Failed to migrate field lengths: {e}")
- # Migrate doc status to add track_id field if needed
- try:
- await self._migrate_doc_status_add_track_id()
- except Exception as e:
- logger.error(
- f"PostgreSQL, Failed to migrate doc status track_id field: {e}"
- )
- # Migrate doc status to add metadata and error_msg fields if needed
- try:
- await self._migrate_doc_status_add_metadata_error_msg()
- except Exception as e:
- logger.error(
- f"PostgreSQL, Failed to migrate doc status metadata/error_msg fields: {e}"
- )
- # Create pagination optimization indexes for LIGHTRAG_DOC_STATUS
- try:
- await self._create_pagination_indexes()
- except Exception as e:
- logger.error(f"PostgreSQL, Failed to create pagination indexes: {e}")
- # Migrate to ensure new tables LIGHTRAG_FULL_ENTITIES and LIGHTRAG_FULL_RELATIONS exist
- try:
- await self._migrate_create_full_entities_relations_tables()
- except Exception as e:
- logger.error(
- f"PostgreSQL, Failed to create full entities/relations tables: {e}"
- )
- # Migrate LIGHTRAG_DOC_FULL to add pipeline-derived fields used by the
- # JSON storage parity: sidecar_location / parse_format / content_hash /
- # process_options / chunk_options / parse_engine
- try:
- await self._migrate_doc_full_add_pipeline_fields()
- except Exception as e:
- logger.error(
- f"PostgreSQL, Failed to migrate LIGHTRAG_DOC_FULL pipeline fields: {e}"
- )
- # Migrate LIGHTRAG_DOC_STATUS to add content_hash column for content
- # dedup queries
- try:
- await self._migrate_doc_status_add_content_hash()
- except Exception as e:
- logger.error(
- f"PostgreSQL, Failed to migrate LIGHTRAG_DOC_STATUS content_hash field: {e}"
- )
- # Migrate LIGHTRAG_DOC_CHUNKS to add heading / sidecar JSONB columns
- try:
- await self._migrate_text_chunks_add_heading_sidecar()
- except Exception as e:
- logger.error(
- f"PostgreSQL, Failed to migrate LIGHTRAG_DOC_CHUNKS heading/sidecar fields: {e}"
- )
- async def _migrate_create_full_entities_relations_tables(self):
- """Create LIGHTRAG_FULL_ENTITIES and LIGHTRAG_FULL_RELATIONS tables if they don't exist"""
- tables_to_check = [
- {
- "name": "LIGHTRAG_FULL_ENTITIES",
- "ddl": TABLES["LIGHTRAG_FULL_ENTITIES"]["ddl"],
- "description": "Full entities storage table",
- },
- {
- "name": "LIGHTRAG_FULL_RELATIONS",
- "ddl": TABLES["LIGHTRAG_FULL_RELATIONS"]["ddl"],
- "description": "Full relations storage table",
- },
- ]
- for table_info in tables_to_check:
- table_name = table_info["name"]
- try:
- # Check if table exists
- check_table_sql = """
- SELECT table_name
- FROM information_schema.tables
- WHERE table_name = $1
- AND table_schema = 'public'
- """
- params = {"table_name": table_name.lower()}
- table_exists = await self.query(check_table_sql, list(params.values()))
- if not table_exists:
- logger.info(f"Creating table {table_name}")
- await self.execute(table_info["ddl"])
- logger.info(
- f"Successfully created {table_info['description']}: {table_name}"
- )
- # Create basic indexes for the new table
- try:
- # Create index for id column
- index_name = f"idx_{table_name.lower()}_id"
- create_index_sql = (
- f"CREATE INDEX {index_name} ON {table_name}(id)"
- )
- await self.execute(create_index_sql)
- logger.info(f"Created index {index_name} on table {table_name}")
- # Create composite index for (workspace, id) columns
- composite_index_name = f"idx_{table_name.lower()}_workspace_id"
- create_composite_index_sql = f"CREATE INDEX {composite_index_name} ON {table_name}(workspace, id)"
- await self.execute(create_composite_index_sql)
- logger.info(
- f"Created composite index {composite_index_name} on table {table_name}"
- )
- except Exception as e:
- logger.warning(
- f"Failed to create indexes for table {table_name}: {e}"
- )
- else:
- logger.debug(f"Table {table_name} already exists")
- except Exception as e:
- logger.error(f"Failed to create table {table_name}: {e}")
- async def _create_pagination_indexes(self):
- """Create indexes to optimize pagination queries for LIGHTRAG_DOC_STATUS"""
- indexes = [
- {
- "name": "idx_lightrag_doc_status_workspace_status_updated_at",
- "sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_lightrag_doc_status_workspace_status_updated_at ON LIGHTRAG_DOC_STATUS (workspace, status, updated_at DESC)",
- "description": "Composite index for workspace + status + updated_at pagination",
- },
- {
- "name": "idx_lightrag_doc_status_workspace_status_created_at",
- "sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_lightrag_doc_status_workspace_status_created_at ON LIGHTRAG_DOC_STATUS (workspace, status, created_at DESC)",
- "description": "Composite index for workspace + status + created_at pagination",
- },
- {
- "name": "idx_lightrag_doc_status_workspace_updated_at",
- "sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_lightrag_doc_status_workspace_updated_at ON LIGHTRAG_DOC_STATUS (workspace, updated_at DESC)",
- "description": "Index for workspace + updated_at pagination (all statuses)",
- },
- {
- "name": "idx_lightrag_doc_status_workspace_created_at",
- "sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_lightrag_doc_status_workspace_created_at ON LIGHTRAG_DOC_STATUS (workspace, created_at DESC)",
- "description": "Index for workspace + created_at pagination (all statuses)",
- },
- {
- "name": "idx_lightrag_doc_status_workspace_id",
- "sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_lightrag_doc_status_workspace_id ON LIGHTRAG_DOC_STATUS (workspace, id)",
- "description": "Index for workspace + id sorting",
- },
- {
- "name": "idx_lightrag_doc_status_workspace_file_path",
- "sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_lightrag_doc_status_workspace_file_path ON LIGHTRAG_DOC_STATUS (workspace, file_path)",
- "description": "Index for workspace + file_path sorting",
- },
- ]
- # Fetch all existing index names in one query instead of N separate checks.
- index_names = [idx["name"] for idx in indexes]
- check_sql = """
- SELECT indexname FROM pg_indexes
- WHERE tablename = 'lightrag_doc_status'
- AND indexname = ANY($1)
- """
- try:
- rows = await self.query(check_sql, [index_names], multirows=True)
- existing_names = {row["indexname"] for row in (rows or [])}
- except asyncpg.PostgresError as e:
- logger.warning(
- f"[{self.workspace}] Failed to query existing pagination indexes "
- f"({type(e).__name__}), will attempt to create all: {e}"
- )
- existing_names = set()
- for index in indexes:
- if index["name"] in existing_names:
- logger.debug(f"Index already exists: {index['name']}")
- continue
- try:
- logger.info(f"Creating pagination index: {index['description']}")
- await self.execute(index["sql"])
- logger.info(f"Successfully created index: {index['name']}")
- except asyncpg.PostgresError as e:
- logger.warning(
- f"Failed to create index {index['name']} ({type(e).__name__}): {e}"
- )
- async def _create_vector_index(self, table_name: str, embedding_dim: int):
- """
- Create vector index for a specific table.
- Args:
- table_name: Name of the table to create index on
- embedding_dim: Embedding dimension for the vector column
- """
- if not self.vector_index_type:
- return
- create_sql = {
- "HNSW": f"""
- CREATE INDEX {{vector_index_name}}
- ON {{table_name}} USING hnsw (content_vector vector_cosine_ops)
- WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef})
- """,
- "HNSW_HALFVEC": f"""
- CREATE INDEX {{vector_index_name}}
- ON {{table_name}} USING hnsw (content_vector halfvec_cosine_ops)
- WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef})
- """,
- "IVFFLAT": f"""
- CREATE INDEX {{vector_index_name}}
- ON {{table_name}} USING ivfflat (content_vector vector_cosine_ops)
- WITH (lists = {self.ivfflat_lists})
- """,
- "VCHORDRQ": f"""
- CREATE INDEX {{vector_index_name}}
- ON {{table_name}} USING vchordrq (content_vector vector_cosine_ops)
- {f"WITH (options = $${self.vchordrq_build_options}$$)" if self.vchordrq_build_options else ""}
- """,
- }
- if self.vector_index_type not in create_sql:
- logger.warning(
- f"Unsupported vector index type: {self.vector_index_type}. "
- "Supported types: HNSW, HNSW_HALFVEC, IVFFLAT, VCHORDRQ"
- )
- return
- k = table_name
- # Use _safe_index_name to avoid PostgreSQL's 63-byte identifier truncation
- index_suffix = f"{self.vector_index_type.lower()}_cosine"
- vector_index_name = _safe_index_name(k, index_suffix)
- check_vector_index_sql = f"""
- SELECT 1 FROM pg_indexes
- WHERE indexname = '{vector_index_name}' AND tablename = '{k.lower()}'
- """
- if self.vector_index_type == "HNSW_HALFVEC":
- column_type = "HALFVEC"
- else:
- column_type = "VECTOR"
- try:
- vector_index_exists = await self.query(check_vector_index_sql)
- if not vector_index_exists:
- for suffix in _VECTOR_INDEX_SUFFIXES:
- if suffix == index_suffix:
- continue
- old_name = _safe_index_name(k, suffix)
- await self.execute(f"DROP INDEX IF EXISTS {old_name}")
- alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE {column_type}({embedding_dim})"
- await self.execute(alter_sql)
- logger.debug(f"Ensured vector dimension for {k}")
- logger.info(
- f"Creating {self.vector_index_type} index {vector_index_name} on table {k}"
- )
- await self.execute(
- create_sql[self.vector_index_type].format(
- vector_index_name=vector_index_name, table_name=k
- )
- )
- logger.info(
- f"Successfully created vector index {vector_index_name} on table {k}"
- )
- else:
- logger.info(
- f"{self.vector_index_type} vector index {vector_index_name} already exists on table {k}"
- )
- except Exception as e:
- logger.error(f"Failed to create vector index on table {k}, Got: {e}")
- async def query(
- self,
- sql: str,
- params: list[Any] | None = None,
- multirows: bool = False,
- with_age: bool = False,
- graph_name: str | None = None,
- timing_label: str | None = None,
- ) -> dict[str, Any] | None | list[dict[str, Any]]:
- async def _operation(connection: asyncpg.Connection) -> Any:
- prepared_params = tuple(params) if params else ()
- fetch_start = time.perf_counter()
- if prepared_params:
- rows = await connection.fetch(sql, *prepared_params)
- else:
- rows = await connection.fetch(sql)
- fetch_elapsed = time.perf_counter() - fetch_start
- if timing_label:
- performance_timing_log(
- "[%s] connection.fetch completed in %.4fs row_count=%s",
- timing_label,
- fetch_elapsed,
- len(rows),
- )
- conversion_start = time.perf_counter()
- if multirows:
- if rows:
- columns = [col for col in rows[0].keys()]
- converted_rows = [dict(zip(columns, row)) for row in rows]
- else:
- converted_rows = []
- if timing_label:
- conversion_elapsed = time.perf_counter() - conversion_start
- performance_timing_log(
- "[%s] result conversion completed in %.4fs multirows=%s",
- timing_label,
- conversion_elapsed,
- True,
- )
- return converted_rows
- if rows:
- columns = rows[0].keys()
- converted_row = dict(zip(columns, rows[0]))
- else:
- converted_row = None
- if timing_label:
- conversion_elapsed = time.perf_counter() - conversion_start
- performance_timing_log(
- "[%s] result conversion completed in %.4fs multirows=%s",
- timing_label,
- conversion_elapsed,
- False,
- )
- if converted_row is not None:
- return converted_row
- return None
- try:
- return await self._run_with_retry(
- _operation,
- with_age=with_age,
- graph_name=graph_name,
- timing_label=timing_label,
- )
- except Exception as e:
- logger.error(f"PostgreSQL database, error:{e}")
- raise
- async def check_table_exists(self, table_name: str) -> bool:
- """Check if a table exists in PostgreSQL database
- Args:
- table_name: Name of the table to check
- Returns:
- bool: True if table exists, False otherwise
- """
- query = """
- SELECT EXISTS (
- SELECT FROM information_schema.tables
- WHERE table_name = $1
- )
- """
- result = await self.query(query, [table_name.lower()])
- return result.get("exists", False) if result else False
- async def execute(
- self,
- sql: str,
- data: dict[str, Any] | None = None,
- upsert: bool = False,
- ignore_if_exists: bool = False,
- with_age: bool = False,
- graph_name: str | None = None,
- timing_label: str | None = None,
- ):
- async def _operation(connection: asyncpg.Connection) -> Any:
- prepared_values = tuple(data.values()) if data else ()
- execute_start = time.perf_counter()
- try:
- if not data:
- result = await connection.execute(sql)
- else:
- result = await connection.execute(sql, *prepared_values)
- except (
- asyncpg.exceptions.UniqueViolationError,
- asyncpg.exceptions.DuplicateTableError,
- asyncpg.exceptions.DuplicateObjectError,
- asyncpg.exceptions.InvalidSchemaNameError,
- ) as e:
- if ignore_if_exists:
- logger.debug("PostgreSQL, ignoring duplicate during execute: %r", e)
- result = None
- elif upsert:
- logger.info(
- "PostgreSQL, duplicate detected but treated as upsert success: %r",
- e,
- )
- result = None
- else:
- raise
- except Exception:
- if timing_label:
- performance_timing_log(
- "[%s] connection.execute failed after %.4fs",
- timing_label,
- time.perf_counter() - execute_start,
- )
- raise
- if timing_label:
- performance_timing_log(
- "[%s] connection.execute completed in %.4fs result=%s",
- timing_label,
- time.perf_counter() - execute_start,
- result,
- )
- return result
- try:
- await self._run_with_retry(
- _operation,
- with_age=with_age,
- graph_name=graph_name,
- timing_label=timing_label,
- )
- except Exception as e:
- logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
- raise
- class ClientManager:
- """Manage the process-wide PostgreSQL client pool shared by PG storages.
- The first successful initialization defines the pool configuration for the
- lifetime of the shared client. Reusing the pool with a different vector
- storage setup is not supported and will raise a fail-fast error.
- """
- _instances: dict[str, Any] = {
- "db": None,
- "ref_count": 0,
- "vector_signature": None,
- }
- _lock = asyncio.Lock()
- @staticmethod
- def get_config(vector_storage: str | None = None) -> dict[str, Any]:
- config = configparser.ConfigParser()
- config.read("config.ini", "utf-8")
- return {
- "host": os.environ.get(
- "POSTGRES_HOST",
- config.get("postgres", "host", fallback="localhost"),
- ),
- "port": os.environ.get(
- "POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
- ),
- "user": os.environ.get(
- "POSTGRES_USER", config.get("postgres", "user", fallback="postgres")
- ),
- "password": os.environ.get(
- "POSTGRES_PASSWORD",
- config.get("postgres", "password", fallback=None),
- ),
- "database": os.environ.get(
- "POSTGRES_DATABASE",
- config.get("postgres", "database", fallback="postgres"),
- ),
- "workspace": os.environ.get(
- "POSTGRES_WORKSPACE",
- config.get("postgres", "workspace", fallback=None),
- ),
- "max_connections": os.environ.get(
- "POSTGRES_MAX_CONNECTIONS",
- config.get("postgres", "max_connections", fallback=50),
- ),
- # SSL configuration
- "ssl_mode": os.environ.get(
- "POSTGRES_SSL_MODE",
- config.get("postgres", "ssl_mode", fallback=None),
- ),
- "ssl_cert": os.environ.get(
- "POSTGRES_SSL_CERT",
- config.get("postgres", "ssl_cert", fallback=None),
- ),
- "ssl_key": os.environ.get(
- "POSTGRES_SSL_KEY",
- config.get("postgres", "ssl_key", fallback=None),
- ),
- "ssl_root_cert": os.environ.get(
- "POSTGRES_SSL_ROOT_CERT",
- config.get("postgres", "ssl_root_cert", fallback=None),
- ),
- "ssl_crl": os.environ.get(
- "POSTGRES_SSL_CRL",
- config.get("postgres", "ssl_crl", fallback=None),
- ),
- # Vector configuration: derived from the vector storage backend in use.
- # PGVectorStorage requires pgvector; all other backends do not.
- "enable_vector": vector_storage == "PGVectorStorage"
- if vector_storage is not None
- else True,
- "vector_index_type": os.environ.get(
- "POSTGRES_VECTOR_INDEX_TYPE",
- config.get("postgres", "vector_index_type", fallback="HNSW"),
- ),
- "hnsw_m": int(
- os.environ.get(
- "POSTGRES_HNSW_M",
- config.get("postgres", "hnsw_m", fallback="16"),
- )
- ),
- "hnsw_ef": int(
- os.environ.get(
- "POSTGRES_HNSW_EF",
- config.get("postgres", "hnsw_ef", fallback="64"),
- )
- ),
- "ivfflat_lists": int(
- os.environ.get(
- "POSTGRES_IVFFLAT_LISTS",
- config.get("postgres", "ivfflat_lists", fallback="100"),
- )
- ),
- "vchordrq_build_options": os.environ.get(
- "POSTGRES_VCHORDRQ_BUILD_OPTIONS",
- config.get("postgres", "vchordrq_build_options", fallback=""),
- ),
- "vchordrq_probes": os.environ.get(
- "POSTGRES_VCHORDRQ_PROBES",
- config.get("postgres", "vchordrq_probes", fallback=""),
- ),
- "vchordrq_epsilon": float(
- os.environ.get(
- "POSTGRES_VCHORDRQ_EPSILON",
- config.get("postgres", "vchordrq_epsilon", fallback="1.9"),
- )
- ),
- # Server settings for Supabase
- "server_settings": os.environ.get(
- "POSTGRES_SERVER_SETTINGS",
- config.get("postgres", "server_options", fallback=None),
- ),
- "statement_cache_size": os.environ.get(
- "POSTGRES_STATEMENT_CACHE_SIZE",
- config.get("postgres", "statement_cache_size", fallback=None),
- ),
- # Connection retry configuration
- "connection_retry_attempts": min(
- 100, # Increased from 10 to 100 for long-running operations
- int(
- os.environ.get(
- "POSTGRES_CONNECTION_RETRIES",
- config.get("postgres", "connection_retries", fallback=10),
- )
- ),
- ),
- "connection_retry_backoff": min(
- 300.0, # Increased from 5.0 to 300.0 (5 minutes) for PG switchover scenarios
- float(
- os.environ.get(
- "POSTGRES_CONNECTION_RETRY_BACKOFF",
- config.get(
- "postgres", "connection_retry_backoff", fallback=3.0
- ),
- )
- ),
- ),
- "connection_retry_backoff_max": min(
- 600.0, # Increased from 60.0 to 600.0 (10 minutes) for PG switchover scenarios
- float(
- os.environ.get(
- "POSTGRES_CONNECTION_RETRY_BACKOFF_MAX",
- config.get(
- "postgres",
- "connection_retry_backoff_max",
- fallback=30.0,
- ),
- )
- ),
- ),
- "pool_close_timeout": min(
- 30.0,
- float(
- os.environ.get(
- "POSTGRES_POOL_CLOSE_TIMEOUT",
- config.get("postgres", "pool_close_timeout", fallback=5.0),
- )
- ),
- ),
- }
- @classmethod
- def _build_vector_signature(
- cls, config: dict[str, Any], vector_storage: str | None
- ) -> dict[str, Any]:
- signature = {
- "vector_storage": vector_storage,
- "enable_vector": config["enable_vector"],
- }
- if config["enable_vector"]:
- signature.update(
- {
- "vector_index_type": config["vector_index_type"],
- "hnsw_m": config["hnsw_m"],
- "hnsw_ef": config["hnsw_ef"],
- "ivfflat_lists": config["ivfflat_lists"],
- "vchordrq_build_options": config["vchordrq_build_options"],
- "vchordrq_probes": config["vchordrq_probes"],
- "vchordrq_epsilon": config["vchordrq_epsilon"],
- }
- )
- return signature
- @classmethod
- def _assert_compatible_vector_signature(
- cls, requested_signature: dict[str, Any]
- ) -> None:
- active_signature = cls._instances["vector_signature"]
- if active_signature is None or active_signature == requested_signature:
- return
- raise RuntimeError(
- "PostgreSQL client pool is process-wide and already initialized with "
- f"vector settings {active_signature}. Received incompatible settings "
- f"{requested_signature}. Multiple LightRAG instances with different "
- "PostgreSQL/vector storage configurations are not supported in the "
- "same process."
- )
- @classmethod
- async def get_client(cls, vector_storage: str | None = None) -> PostgreSQLDB:
- """Return the shared PostgreSQL client for all PG storages in this process.
- The first caller fixes the vector-related pool configuration. Later calls
- must provide a compatible vector storage setup or a RuntimeError is raised.
- """
- async with cls._lock:
- config = ClientManager.get_config(vector_storage=vector_storage)
- requested_signature = cls._build_vector_signature(config, vector_storage)
- if cls._instances["db"] is None:
- db = PostgreSQLDB(config)
- await db.initdb()
- await db.check_tables()
- cls._instances["db"] = db
- cls._instances["ref_count"] = 0
- cls._instances["vector_signature"] = requested_signature
- else:
- cls._assert_compatible_vector_signature(requested_signature)
- cls._instances["ref_count"] += 1
- return cls._instances["db"]
- @classmethod
- async def release_client(cls, db: PostgreSQLDB):
- async with cls._lock:
- if db is not None:
- if db is cls._instances["db"]:
- cls._instances["ref_count"] -= 1
- if cls._instances["ref_count"] == 0:
- if db.pool is not None:
- await db.pool.close()
- logger.info("Closed PostgreSQL database connection pool")
- cls._instances["db"] = None
- cls._instances["vector_signature"] = None
- else:
- if db.pool is not None:
- await db.pool.close()
- @final
- @dataclass
- class PGKVStorage(BaseKVStorage):
- db: PostgreSQLDB = field(default=None)
- def __post_init__(self):
- self._max_batch_size = 200 # DB batch size, independent of embedding batch size
- async def initialize(self):
- async with get_data_init_lock():
- if self.db is None:
- self.db = await ClientManager.get_client(
- vector_storage=self.global_config.get("vector_storage")
- )
- # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
- if self.db.workspace:
- # Use PostgreSQLDB's workspace (highest priority)
- logger.info(
- f"Using PG_WORKSPACE environment variable: '{self.db.workspace}' (overriding '{self.workspace}/{self.namespace}')"
- )
- self.workspace = self.db.workspace
- elif hasattr(self, "workspace") and self.workspace:
- # Use storage class's workspace (medium priority)
- pass
- else:
- # Use "default" for compatibility (lowest priority)
- self.workspace = "default"
- async def finalize(self):
- if self.db is not None:
- await ClientManager.release_client(self.db)
- self.db = None
- ################ QUERY METHODS ################
- async def get_by_id(self, id: str) -> dict[str, Any] | None:
- """Get data by id."""
- sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
- params = {"workspace": self.workspace, "id": id}
- response = await self.db.query(sql, list(params.values()))
- if response and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
- # Parse llm_cache_list JSON string back to list
- llm_cache_list = response.get("llm_cache_list", [])
- if isinstance(llm_cache_list, str):
- try:
- llm_cache_list = json.loads(llm_cache_list)
- except json.JSONDecodeError:
- llm_cache_list = []
- response["llm_cache_list"] = llm_cache_list
- # Parse heading JSON string back to dict; normalize None/missing to {}
- heading = response.get("heading")
- if isinstance(heading, str):
- try:
- heading = json.loads(heading)
- except json.JSONDecodeError:
- heading = {}
- if not isinstance(heading, dict):
- heading = {}
- response["heading"] = heading
- # Parse sidecar JSON string back to dict; normalize None/missing to {}
- sidecar = response.get("sidecar")
- if isinstance(sidecar, str):
- try:
- sidecar = json.loads(sidecar)
- except json.JSONDecodeError:
- sidecar = {}
- if not isinstance(sidecar, dict):
- sidecar = {}
- response["sidecar"] = sidecar
- create_time = response.get("create_time", 0)
- update_time = response.get("update_time", 0)
- response["create_time"] = create_time
- response["update_time"] = create_time if update_time == 0 else update_time
- if response and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
- # Parse chunk_options JSON string back to dict; normalize None/missing to {}
- chunk_options = response.get("chunk_options")
- if isinstance(chunk_options, str):
- try:
- chunk_options = json.loads(chunk_options)
- except json.JSONDecodeError:
- chunk_options = {}
- if not isinstance(chunk_options, dict):
- chunk_options = {}
- response["chunk_options"] = chunk_options
- # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
- if response and is_namespace(
- self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
- ):
- create_time = response.get("create_time", 0)
- update_time = response.get("update_time", 0)
- # Parse queryparam JSON string back to dict
- queryparam = response.get("queryparam")
- if isinstance(queryparam, str):
- try:
- queryparam = json.loads(queryparam)
- except json.JSONDecodeError:
- queryparam = None
- # Map field names for compatibility (mode field removed)
- response = {
- **response,
- "return": response.get("return_value", ""),
- "cache_type": response.get("cache_type"),
- "original_prompt": response.get("original_prompt", ""),
- "chunk_id": response.get("chunk_id"),
- "queryparam": queryparam,
- "create_time": create_time,
- "update_time": create_time if update_time == 0 else update_time,
- }
- # Special handling for FULL_ENTITIES namespace
- if response and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
- # Parse entity_names JSON string back to list
- entity_names = response.get("entity_names", [])
- if isinstance(entity_names, str):
- try:
- entity_names = json.loads(entity_names)
- except json.JSONDecodeError:
- entity_names = []
- response["entity_names"] = entity_names
- create_time = response.get("create_time", 0)
- update_time = response.get("update_time", 0)
- response["create_time"] = create_time
- response["update_time"] = create_time if update_time == 0 else update_time
- # Special handling for FULL_RELATIONS namespace
- if response and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_RELATIONS):
- # Parse relation_pairs JSON string back to list
- relation_pairs = response.get("relation_pairs", [])
- if isinstance(relation_pairs, str):
- try:
- relation_pairs = json.loads(relation_pairs)
- except json.JSONDecodeError:
- relation_pairs = []
- response["relation_pairs"] = relation_pairs
- create_time = response.get("create_time", 0)
- update_time = response.get("update_time", 0)
- response["create_time"] = create_time
- response["update_time"] = create_time if update_time == 0 else update_time
- # Special handling for ENTITY_CHUNKS namespace
- if response and is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS):
- # Parse chunk_ids JSON string back to list
- chunk_ids = response.get("chunk_ids", [])
- if isinstance(chunk_ids, str):
- try:
- chunk_ids = json.loads(chunk_ids)
- except json.JSONDecodeError:
- chunk_ids = []
- response["chunk_ids"] = chunk_ids
- create_time = response.get("create_time", 0)
- update_time = response.get("update_time", 0)
- response["create_time"] = create_time
- response["update_time"] = create_time if update_time == 0 else update_time
- # Special handling for RELATION_CHUNKS namespace
- if response and is_namespace(
- self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS
- ):
- # Parse chunk_ids JSON string back to list
- chunk_ids = response.get("chunk_ids", [])
- if isinstance(chunk_ids, str):
- try:
- chunk_ids = json.loads(chunk_ids)
- except json.JSONDecodeError:
- chunk_ids = []
- response["chunk_ids"] = chunk_ids
- create_time = response.get("create_time", 0)
- update_time = response.get("update_time", 0)
- response["create_time"] = create_time
- response["update_time"] = create_time if update_time == 0 else update_time
- return response if response else None
- # Query by id
- async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
- """Get data by ids"""
- if not ids:
- return []
- sql = SQL_TEMPLATES["get_by_ids_" + self.namespace]
- params = {"workspace": self.workspace, "ids": ids}
- results = await self.db.query(sql, list(params.values()), multirows=True)
- def _order_results(
- rows: list[dict[str, Any]] | None,
- ) -> list[dict[str, Any] | None]:
- """Preserve the caller requested ordering for bulk id lookups."""
- if not rows:
- return [None for _ in ids]
- id_map: dict[str, dict[str, Any]] = {}
- for row in rows:
- if row is None:
- continue
- row_id = row.get("id")
- if row_id is not None:
- id_map[str(row_id)] = row
- ordered: list[dict[str, Any] | None] = []
- for requested_id in ids:
- ordered.append(id_map.get(str(requested_id)))
- return ordered
- if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
- # Parse llm_cache_list / heading / sidecar JSON strings for each result
- for result in results:
- llm_cache_list = result.get("llm_cache_list", [])
- if isinstance(llm_cache_list, str):
- try:
- llm_cache_list = json.loads(llm_cache_list)
- except json.JSONDecodeError:
- llm_cache_list = []
- result["llm_cache_list"] = llm_cache_list
- heading = result.get("heading")
- if isinstance(heading, str):
- try:
- heading = json.loads(heading)
- except json.JSONDecodeError:
- heading = {}
- if not isinstance(heading, dict):
- heading = {}
- result["heading"] = heading
- sidecar = result.get("sidecar")
- if isinstance(sidecar, str):
- try:
- sidecar = json.loads(sidecar)
- except json.JSONDecodeError:
- sidecar = {}
- if not isinstance(sidecar, dict):
- sidecar = {}
- result["sidecar"] = sidecar
- create_time = result.get("create_time", 0)
- update_time = result.get("update_time", 0)
- result["create_time"] = create_time
- result["update_time"] = create_time if update_time == 0 else update_time
- if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
- for result in results:
- chunk_options = result.get("chunk_options")
- if isinstance(chunk_options, str):
- try:
- chunk_options = json.loads(chunk_options)
- except json.JSONDecodeError:
- chunk_options = {}
- if not isinstance(chunk_options, dict):
- chunk_options = {}
- result["chunk_options"] = chunk_options
- # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
- if results and is_namespace(
- self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
- ):
- processed_results = []
- for row in results:
- create_time = row.get("create_time", 0)
- update_time = row.get("update_time", 0)
- # Parse queryparam JSON string back to dict
- queryparam = row.get("queryparam")
- if isinstance(queryparam, str):
- try:
- queryparam = json.loads(queryparam)
- except json.JSONDecodeError:
- queryparam = None
- # Map field names for compatibility (mode field removed)
- processed_row = {
- **row,
- "return": row.get("return_value", ""),
- "cache_type": row.get("cache_type"),
- "original_prompt": row.get("original_prompt", ""),
- "chunk_id": row.get("chunk_id"),
- "queryparam": queryparam,
- "create_time": create_time,
- "update_time": create_time if update_time == 0 else update_time,
- }
- processed_results.append(processed_row)
- return _order_results(processed_results)
- # Special handling for FULL_ENTITIES namespace
- if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
- for result in results:
- # Parse entity_names JSON string back to list
- entity_names = result.get("entity_names", [])
- if isinstance(entity_names, str):
- try:
- entity_names = json.loads(entity_names)
- except json.JSONDecodeError:
- entity_names = []
- result["entity_names"] = entity_names
- create_time = result.get("create_time", 0)
- update_time = result.get("update_time", 0)
- result["create_time"] = create_time
- result["update_time"] = create_time if update_time == 0 else update_time
- # Special handling for FULL_RELATIONS namespace
- if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_RELATIONS):
- for result in results:
- # Parse relation_pairs JSON string back to list
- relation_pairs = result.get("relation_pairs", [])
- if isinstance(relation_pairs, str):
- try:
- relation_pairs = json.loads(relation_pairs)
- except json.JSONDecodeError:
- relation_pairs = []
- result["relation_pairs"] = relation_pairs
- create_time = result.get("create_time", 0)
- update_time = result.get("update_time", 0)
- result["create_time"] = create_time
- result["update_time"] = create_time if update_time == 0 else update_time
- # Special handling for ENTITY_CHUNKS namespace
- if results and is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS):
- for result in results:
- # Parse chunk_ids JSON string back to list
- chunk_ids = result.get("chunk_ids", [])
- if isinstance(chunk_ids, str):
- try:
- chunk_ids = json.loads(chunk_ids)
- except json.JSONDecodeError:
- chunk_ids = []
- result["chunk_ids"] = chunk_ids
- create_time = result.get("create_time", 0)
- update_time = result.get("update_time", 0)
- result["create_time"] = create_time
- result["update_time"] = create_time if update_time == 0 else update_time
- # Special handling for RELATION_CHUNKS namespace
- if results and is_namespace(self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS):
- for result in results:
- # Parse chunk_ids JSON string back to list
- chunk_ids = result.get("chunk_ids", [])
- if isinstance(chunk_ids, str):
- try:
- chunk_ids = json.loads(chunk_ids)
- except json.JSONDecodeError:
- chunk_ids = []
- result["chunk_ids"] = chunk_ids
- create_time = result.get("create_time", 0)
- update_time = result.get("update_time", 0)
- result["create_time"] = create_time
- result["update_time"] = create_time if update_time == 0 else update_time
- return _order_results(results)
- async def filter_keys(self, keys: set[str]) -> set[str]:
- """Filter out duplicated content"""
- if not keys:
- return set()
- table_name = namespace_to_table_name(self.namespace)
- sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
- params = {"workspace": self.workspace, "ids": list(keys)}
- try:
- res = await self.db.query(sql, list(params.values()), multirows=True)
- if res:
- exist_keys = [key["id"] for key in res]
- else:
- exist_keys = []
- new_keys = set([s for s in keys if s not in exist_keys])
- return new_keys
- except Exception as e:
- logger.error(
- f"[{self.workspace}] PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
- )
- raise
- ################ INSERT METHODS ################
- async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
- logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
- if not data:
- return
- timing_label = f"{self.workspace} PGKVStorage.upsert[{self.namespace}]"
- total_start = time.perf_counter()
- performance_timing_log(
- "[%s] start records=%s max_batch_size=%s",
- timing_label,
- len(data),
- self._max_batch_size,
- )
- batch_values: list[tuple] = []
- upsert_sql = ""
- batch_values_build_start = time.perf_counter()
- if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
- upsert_sql = SQL_TEMPLATES["upsert_text_chunk"]
- # Get current UTC time and convert to naive datetime for database storage
- current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None)
- for i, (k, v) in enumerate(data.items(), start=1):
- # Tuple order must match SQL: (workspace, id, tokens, chunk_order_index,
- # full_doc_id, content, file_path, llm_cache_list, heading, sidecar,
- # create_time, update_time)
- batch_values.append(
- (
- self.workspace,
- k,
- v["tokens"],
- v["chunk_order_index"],
- v["full_doc_id"],
- v["content"],
- v["file_path"],
- json.dumps(v.get("llm_cache_list", [])),
- json.dumps(v.get("heading") or {}),
- json.dumps(v.get("sidecar") or {}),
- current_time,
- current_time,
- )
- )
- await _cooperative_yield(i)
- elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
- upsert_sql = SQL_TEMPLATES["upsert_doc_full"]
- for i, (k, v) in enumerate(data.items(), start=1):
- # Tuple order must match SQL: (id, content, doc_name, workspace,
- # sidecar_location, parse_format, content_hash, process_options,
- # chunk_options, parse_engine)
- #
- # All pipeline-derived fields pass through untouched so the
- # SQL-level COALESCE guard in upsert_doc_full can distinguish
- # "caller did not supply" (None/'') from "caller supplied a
- # real value". The 'raw' default for parse_format is provided
- # by the column DDL on initial insert; do NOT default it here
- # or the COALESCE guard never triggers on subsequent partial
- # writes.
- batch_values.append(
- (
- k,
- v["content"],
- v.get("file_path", ""),
- self.workspace,
- v.get("sidecar_location"),
- v.get("parse_format"),
- v.get("content_hash"),
- v.get("process_options"),
- json.dumps(v.get("chunk_options") or {}),
- v.get("parse_engine"),
- )
- )
- await _cooperative_yield(i)
- elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
- upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
- for i, (k, v) in enumerate(data.items(), start=1):
- # Tuple order must match SQL: (workspace, id, original_prompt, return_value,
- # chunk_id, cache_type, queryparam)
- batch_values.append(
- (
- self.workspace,
- k,
- v["original_prompt"],
- v["return"],
- v.get("chunk_id"),
- v.get("cache_type", "extract"),
- json.dumps(v.get("queryparam"))
- if v.get("queryparam")
- else None,
- )
- )
- await _cooperative_yield(i)
- elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
- upsert_sql = SQL_TEMPLATES["upsert_full_entities"]
- # Get current UTC time and convert to naive datetime for database storage
- current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None)
- for i, (k, v) in enumerate(data.items(), start=1):
- # Tuple order must match SQL: (workspace, id, entity_names, count,
- # create_time, update_time)
- batch_values.append(
- (
- self.workspace,
- k,
- json.dumps(v["entity_names"]),
- v["count"],
- current_time,
- current_time,
- )
- )
- await _cooperative_yield(i)
- elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_RELATIONS):
- upsert_sql = SQL_TEMPLATES["upsert_full_relations"]
- # Get current UTC time and convert to naive datetime for database storage
- current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None)
- for i, (k, v) in enumerate(data.items(), start=1):
- # Tuple order must match SQL: (workspace, id, relation_pairs, count,
- # create_time, update_time)
- batch_values.append(
- (
- self.workspace,
- k,
- json.dumps(v["relation_pairs"]),
- v["count"],
- current_time,
- current_time,
- )
- )
- await _cooperative_yield(i)
- elif is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS):
- upsert_sql = SQL_TEMPLATES["upsert_entity_chunks"]
- # Get current UTC time and convert to naive datetime for database storage
- current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None)
- for i, (k, v) in enumerate(data.items(), start=1):
- # Tuple order must match SQL: (workspace, id, chunk_ids, count,
- # create_time, update_time)
- batch_values.append(
- (
- self.workspace,
- k,
- json.dumps(v["chunk_ids"]),
- v["count"],
- current_time,
- current_time,
- )
- )
- await _cooperative_yield(i)
- elif is_namespace(self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS):
- upsert_sql = SQL_TEMPLATES["upsert_relation_chunks"]
- # Get current UTC time and convert to naive datetime for database storage
- current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None)
- for i, (k, v) in enumerate(data.items(), start=1):
- # Tuple order must match SQL: (workspace, id, chunk_ids, count,
- # create_time, update_time)
- batch_values.append(
- (
- self.workspace,
- k,
- json.dumps(v["chunk_ids"]),
- v["count"],
- current_time,
- current_time,
- )
- )
- await _cooperative_yield(i)
- else:
- logger.error(f"Unknown namespace: {self.namespace}")
- raise ValueError(f"Unknown namespace: {self.namespace}")
- # upsert_sql is always set here; unknown namespace raises ValueError above
- performance_timing_log(
- "[%s] batch_values build completed in %.4fs records=%s%s",
- timing_label,
- time.perf_counter() - batch_values_build_start,
- len(batch_values),
- _timing_details_suffix(namespace=self.namespace),
- )
- if batch_values:
- # Split into sub-batches to prevent database overload
- num_batches = (
- len(batch_values) + self._max_batch_size - 1
- ) // self._max_batch_size
- for batch_index, i in enumerate(
- range(0, len(batch_values), self._max_batch_size), start=1
- ):
- sub_batch = batch_values[i : i + self._max_batch_size]
- async def _batch_upsert(
- connection: asyncpg.Connection,
- _sql: str = upsert_sql,
- _data: list[tuple] = sub_batch,
- _batch_index: int = batch_index,
- _num_batches: int = num_batches,
- ) -> None:
- execute_start = time.perf_counter()
- await connection.executemany(_sql, _data)
- performance_timing_log(
- "[%s] sub-batch %s/%s executemany completed in %.4fs batch_size=%s",
- timing_label,
- _batch_index,
- _num_batches,
- time.perf_counter() - execute_start,
- len(_data),
- )
- await self.db._run_with_retry(_batch_upsert, timing_label=timing_label)
- logger.debug(
- f"[{self.workspace}] Batch upserted {len(batch_values)} records to {self.namespace} "
- f"in {num_batches} sub-batches"
- )
- performance_timing_log(
- "[%s] total complete in %.4fs records=%s",
- timing_label,
- time.perf_counter() - total_start,
- len(batch_values),
- )
- async def index_done_callback(self) -> None:
- # PG handles persistence automatically
- pass
- async def is_empty(self) -> bool:
- """Check if the storage is empty for the current workspace and namespace
- Returns:
- bool: True if storage is empty, False otherwise
- """
- table_name = namespace_to_table_name(self.namespace)
- if not table_name:
- logger.error(
- f"[{self.workspace}] Unknown namespace for is_empty check: {self.namespace}"
- )
- return True
- sql = f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE workspace=$1 LIMIT 1) as has_data"
- try:
- result = await self.db.query(sql, [self.workspace])
- return not result.get("has_data", False) if result else True
- except Exception as e:
- logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}")
- return True
- async def delete(self, ids: list[str]) -> None:
- """Delete specific records from storage by their IDs
- Args:
- ids (list[str]): List of document IDs to be deleted from storage
- Returns:
- None
- """
- if not ids:
- return
- table_name = namespace_to_table_name(self.namespace)
- if not table_name:
- logger.error(
- f"[{self.workspace}] Unknown namespace for deletion: {self.namespace}"
- )
- return
- delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
- try:
- await self.db.execute(delete_sql, {"workspace": self.workspace, "ids": ids})
- logger.debug(
- f"[{self.workspace}] Successfully deleted {len(ids)} records from {self.namespace}"
- )
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error while deleting records from {self.namespace}: {e}"
- )
- async def drop(self) -> dict[str, str]:
- """Drop the storage"""
- try:
- table_name = namespace_to_table_name(self.namespace)
- if not table_name:
- return {
- "status": "error",
- "message": f"Unknown namespace: {self.namespace}",
- }
- drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
- table_name=table_name
- )
- await self.db.execute(drop_sql, {"workspace": self.workspace})
- return {"status": "success", "message": "data dropped"}
- except Exception as e:
- return {"status": "error", "message": str(e)}
- @dataclass
- class _PendingPGVectorDoc:
- """Buffered PG vector upsert awaiting embedding and batched flush.
- ``vector`` is stored as a numpy ndarray (typically float32 from the
- embedding function) once embedded; pgvector's asyncpg codec accepts
- ndarray directly so no per-flush conversion is needed.
- """
- item: dict[str, Any]
- created_at: datetime.datetime
- vector: np.ndarray | None = None
- @final
- @dataclass
- class PGVectorStorage(BaseVectorStorage):
- db: PostgreSQLDB | None = field(default=None)
- def __post_init__(self):
- self._validate_embedding_func()
- self._max_batch_size = self.global_config["embedding_batch_num"]
- config = self.global_config.get("vector_db_storage_cls_kwargs", {})
- cosine_threshold = config.get("cosine_better_than_threshold")
- if cosine_threshold is None:
- raise ValueError(
- "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
- )
- self.cosine_better_than_threshold = cosine_threshold
- # Generate model suffix for table isolation
- self.model_suffix = self._generate_collection_suffix()
- # Get base table name
- base_table = namespace_to_table_name(self.namespace)
- if not base_table:
- raise ValueError(f"Unknown namespace: {self.namespace}")
- # New table name (with suffix)
- # Ensure model_suffix is not empty before appending
- if self.model_suffix:
- self.table_name = f"{base_table}_{self.model_suffix}"
- logger.info(f"PostgreSQL table: {self.table_name}")
- else:
- # Fallback: use base table name if model_suffix is unavailable
- self.table_name = base_table
- logger.warning(
- f"PostgreSQL table: {self.table_name} missing suffix. Pls add model_name to embedding_func for proper workspace data isolation."
- )
- # Legacy table name (without suffix, for migration)
- self.legacy_table_name = base_table
- # Validate table name length (PostgreSQL identifier limit is 63 characters)
- if len(self.table_name) > PG_MAX_IDENTIFIER_LENGTH:
- raise ValueError(
- f"PostgreSQL table name exceeds {PG_MAX_IDENTIFIER_LENGTH} character limit: '{self.table_name}' "
- f"(length: {len(self.table_name)}). "
- f"Consider using a shorter embedding model name or workspace name."
- )
- # Pending buffers: upsert() and delete() queue work here until
- # _flush_pending_vector_ops() runs from index_done_callback() /
- # finalize(). Mirrors OpenSearchVectorDBStorage / NanoVectorDBStorage.
- self._pending_vector_docs: dict[str, _PendingPGVectorDoc] = {}
- self._pending_vector_deletes: set[str] = set()
- # Namespace-keyed lock; created in initialize() after workspace is final.
- self._flush_lock = None
- @staticmethod
- async def _pg_create_table(
- db: PostgreSQLDB, table_name: str, base_table: str, embedding_dim: int
- ) -> None:
- """Create a new vector table by replacing the table name in DDL template,
- and create indexes on id and (workspace, id) columns.
- Args:
- db: PostgreSQLDB instance
- table_name: Name of the new table to create
- base_table: Base table name for DDL template lookup
- embedding_dim: Embedding dimension for vector column
- """
- if base_table not in TABLES:
- raise ValueError(f"No DDL template found for table: {base_table}")
- ddl_template = TABLES[base_table]["ddl"]
- # Determine vector column type based on configuration
- # HALFVEC is used when HNSW_HALFVEC is selected
- vector_type = "VECTOR"
- if getattr(db, "vector_index_type", None) == "HNSW_HALFVEC":
- vector_type = "HALFVEC"
- # Replace embedding dimension placeholder if exists
- ddl = ddl_template.replace(
- "VECTOR(dimension)", f"{vector_type}({embedding_dim})"
- )
- # Replace table name
- ddl = ddl.replace(base_table, table_name)
- # Make creation idempotent to handle restarts and race conditions
- ddl = ddl.replace("CREATE TABLE ", "CREATE TABLE IF NOT EXISTS ", 1)
- await db.execute(ddl)
- # Create indexes similar to check_tables() but with safe index names
- # Create index for id column
- id_index_name = _safe_index_name(table_name, "id")
- try:
- create_id_index_sql = (
- f"CREATE INDEX IF NOT EXISTS {id_index_name} ON {table_name}(id)"
- )
- logger.info(
- f"PostgreSQL, Creating index {id_index_name} on table {table_name}"
- )
- await db.execute(create_id_index_sql)
- except Exception as e:
- logger.error(
- f"PostgreSQL, Failed to create index {id_index_name}, Got: {e}"
- )
- # Create composite index for (workspace, id)
- workspace_id_index_name = _safe_index_name(table_name, "workspace_id")
- try:
- create_composite_index_sql = f"CREATE INDEX IF NOT EXISTS {workspace_id_index_name} ON {table_name}(workspace, id)"
- logger.info(
- f"PostgreSQL, Creating composite index {workspace_id_index_name} on table {table_name}"
- )
- await db.execute(create_composite_index_sql)
- except Exception as e:
- logger.error(
- f"PostgreSQL, Failed to create composite index {workspace_id_index_name}, Got: {e}"
- )
- @staticmethod
- async def _pg_migrate_workspace_data(
- db: PostgreSQLDB,
- legacy_table_name: str,
- new_table_name: str,
- workspace: str,
- expected_count: int,
- embedding_dim: int,
- ) -> int:
- """Migrate workspace data from legacy table to new table using batch insert.
- This function uses asyncpg's executemany for efficient batch insertion,
- reducing database round-trips from N to 1 per batch.
- Uses keyset pagination (cursor-based) with ORDER BY id for stable ordering.
- This ensures every legacy row is migrated exactly once, avoiding the
- non-deterministic row ordering issues with OFFSET/LIMIT without ORDER BY.
- Args:
- db: PostgreSQLDB instance
- legacy_table_name: Name of the legacy table to migrate from
- new_table_name: Name of the new table to migrate to
- workspace: Workspace to filter records for migration
- expected_count: Expected number of records to migrate
- embedding_dim: Embedding dimension for vector column
- Returns:
- Number of records migrated
- """
- migrated_count = 0
- last_id: str | None = None
- batch_size = 500
- while True:
- # Use keyset pagination with ORDER BY id for deterministic ordering
- # This avoids OFFSET/LIMIT without ORDER BY which can skip or duplicate rows
- if workspace:
- if last_id is not None:
- select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 AND id > $2 ORDER BY id LIMIT $3"
- rows = await db.query(
- select_query, [workspace, last_id, batch_size], multirows=True
- )
- else:
- select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 ORDER BY id LIMIT $2"
- rows = await db.query(
- select_query, [workspace, batch_size], multirows=True
- )
- else:
- if last_id is not None:
- select_query = f"SELECT * FROM {legacy_table_name} WHERE id > $1 ORDER BY id LIMIT $2"
- rows = await db.query(
- select_query, [last_id, batch_size], multirows=True
- )
- else:
- select_query = (
- f"SELECT * FROM {legacy_table_name} ORDER BY id LIMIT $1"
- )
- rows = await db.query(select_query, [batch_size], multirows=True)
- if not rows:
- break
- # Track the last ID for keyset pagination cursor
- last_id = rows[-1]["id"]
- # Batch insert optimization: use executemany instead of individual inserts
- # Get column names from the first row
- first_row = dict(rows[0])
- columns = list(first_row.keys())
- columns_str = ", ".join(columns)
- placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))])
- insert_query = f"""
- INSERT INTO {new_table_name} ({columns_str})
- VALUES ({placeholders})
- ON CONFLICT (workspace, id) DO NOTHING
- """
- # Prepare batch data: convert rows to list of tuples
- batch_values = []
- for row in rows:
- row_dict = dict(row)
- # FIX: Parse vector strings from connections without register_vector codec.
- # When pgvector codec is not registered on the read connection, vector
- # columns are returned as text strings like "[0.1,0.2,...]" instead of
- # lists/arrays. We need to convert these to numpy arrays before passing
- # to executemany, which uses a connection WITH register_vector codec
- # that expects list/tuple/ndarray types.
- if "content_vector" in row_dict:
- vec = row_dict["content_vector"]
- if isinstance(vec, str):
- # pgvector text format: "[0.1,0.2,0.3,...]"
- vec = vec.strip("[]")
- if vec:
- row_dict["content_vector"] = np.array(
- [float(x) for x in vec.split(",")], dtype=np.float32
- )
- else:
- row_dict["content_vector"] = None
- # Extract values in column order to match placeholders
- values_tuple = tuple(row_dict[col] for col in columns)
- batch_values.append(values_tuple)
- # Use executemany for batch execution - significantly reduces DB round-trips
- # Note: register_vector is already called on pool init, no need to call it again
- async def _batch_insert(connection: asyncpg.Connection) -> None:
- await connection.executemany(insert_query, batch_values)
- await db._run_with_retry(_batch_insert)
- migrated_count += len(rows)
- workspace_info = f" for workspace '{workspace}'" if workspace else ""
- logger.info(
- f"PostgreSQL: {migrated_count}/{expected_count} records migrated{workspace_info}"
- )
- return migrated_count
- @staticmethod
- async def setup_table(
- db: PostgreSQLDB,
- table_name: str,
- workspace: str,
- embedding_dim: int,
- legacy_table_name: str,
- base_table: str,
- ):
- """
- Setup PostgreSQL table with migration support from legacy tables.
- Ensure final table has workspace isolation index.
- Check vector dimension compatibility before new table creation.
- Drop legacy table if it exists and is empty.
- Only migrate data from legacy table to new table when new table first created and legacy table is not empty.
- This function must be call ClientManager.get_client() to legacy table is migrated to latest schema.
- Args:
- db: PostgreSQLDB instance
- table_name: Name of the new table
- workspace: Workspace to filter records for migration
- legacy_table_name: Name of the legacy table to check for migration
- base_table: Base table name for DDL template lookup
- embedding_dim: Embedding dimension for vector column
- """
- if not workspace:
- raise ValueError("workspace must be provided")
- new_table_exists = await db.check_table_exists(table_name)
- legacy_exists = legacy_table_name and await db.check_table_exists(
- legacy_table_name
- )
- # Case 1: Only new table exists or new table is the same as legacy table
- # No data migration needed, ensuring index is created then return
- if (new_table_exists and not legacy_exists) or (
- new_table_exists and (table_name.lower() == legacy_table_name.lower())
- ):
- await db._create_vector_index(table_name, embedding_dim)
- workspace_count_query = (
- f"SELECT COUNT(*) as count FROM {table_name} WHERE workspace = $1"
- )
- workspace_count_result = await db.query(workspace_count_query, [workspace])
- workspace_count = (
- workspace_count_result.get("count", 0) if workspace_count_result else 0
- )
- if workspace_count == 0 and not (
- table_name.lower() == legacy_table_name.lower()
- ):
- logger.warning(
- f"PostgreSQL: workspace data in table '{table_name}' is empty. "
- f"Ensure it is caused by new workspace setup and not an unexpected embedding model change."
- )
- return
- legacy_count = None
- if not new_table_exists:
- # Check vector dimension compatibility before creating new table
- if legacy_exists:
- count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1"
- count_result = await db.query(count_query, [workspace])
- legacy_count = count_result.get("count", 0) if count_result else 0
- if legacy_count > 0:
- legacy_dim = None
- try:
- sample_query = f"SELECT content_vector FROM {legacy_table_name} WHERE workspace = $1 LIMIT 1"
- sample_result = await db.query(sample_query, [workspace])
- # Fix: Use 'is not None' instead of truthiness check to avoid
- # NumPy array boolean ambiguity error
- if (
- sample_result
- and sample_result.get("content_vector") is not None
- ):
- vector_data = sample_result["content_vector"]
- # pgvector returns list directly, but may also return NumPy arrays
- # when register_vector codec is active on the connection
- if isinstance(vector_data, (list, tuple)):
- legacy_dim = len(vector_data)
- elif hasattr(vector_data, "__len__") and not isinstance(
- vector_data, str
- ):
- # Handle NumPy arrays and other array-like objects
- legacy_dim = len(vector_data)
- elif hasattr(vector_data, "dimensions") and callable(
- vector_data.dimensions
- ):
- # pgvector HalfVector / SparseVector expose dimensions()
- legacy_dim = vector_data.dimensions()
- elif isinstance(vector_data, str):
- import json
- vector_list = json.loads(vector_data)
- legacy_dim = len(vector_list)
- if legacy_dim and legacy_dim != embedding_dim:
- logger.error(
- f"PostgreSQL: Dimension mismatch detected! "
- f"Legacy table '{legacy_table_name}' has {legacy_dim}d vectors, "
- f"but new embedding model expects {embedding_dim}d."
- )
- raise DataMigrationError(
- f"Dimension mismatch between legacy table '{legacy_table_name}' "
- f"and new embedding model. Expected {embedding_dim}d but got {legacy_dim}d."
- )
- except DataMigrationError:
- # Re-raise DataMigrationError as-is to preserve specific error messages
- raise
- except Exception as e:
- raise DataMigrationError(
- f"Could not verify legacy table vector dimension: {e}. "
- f"Proceeding with caution..."
- )
- await PGVectorStorage._pg_create_table(
- db, table_name, base_table, embedding_dim
- )
- logger.info(f"PostgreSQL: New table '{table_name}' created successfully")
- if not legacy_exists:
- await db._create_vector_index(table_name, embedding_dim)
- logger.info(
- "Ensure this new table creation is caused by new workspace setup and not an unexpected embedding model change."
- )
- return
- # Ensure vector index is created
- await db._create_vector_index(table_name, embedding_dim)
- # Case 2: Legacy table exist
- if legacy_exists:
- workspace_info = f" for workspace '{workspace}'"
- # Only drop legacy table if entire table is empty
- total_count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
- total_count_result = await db.query(total_count_query, [])
- total_count = (
- total_count_result.get("count", 0) if total_count_result else 0
- )
- if total_count == 0:
- logger.info(
- f"PostgreSQL: Empty legacy table '{legacy_table_name}' deleted successfully"
- )
- drop_query = f"DROP TABLE {legacy_table_name}"
- await db.execute(drop_query, None)
- return
- # No data migration needed if legacy workspace is empty
- if legacy_count is None:
- count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1"
- count_result = await db.query(count_query, [workspace])
- legacy_count = count_result.get("count", 0) if count_result else 0
- if legacy_count == 0:
- logger.info(
- f"PostgreSQL: No records{workspace_info} found in legacy table. "
- f"No data migration needed."
- )
- return
- new_count_query = (
- f"SELECT COUNT(*) as count FROM {table_name} WHERE workspace = $1"
- )
- new_count_result = await db.query(new_count_query, [workspace])
- new_table_workspace_count = (
- new_count_result.get("count", 0) if new_count_result else 0
- )
- if new_table_workspace_count > 0:
- logger.warning(
- f"PostgreSQL: Both new and legacy collection have data. "
- f"{legacy_count} records in {legacy_table_name} require manual deletion after migration verification."
- )
- return
- # Case 3: Legacy has workspace data and new table is empty for workspace
- logger.info(
- f"PostgreSQL: Found legacy table '{legacy_table_name}' with {legacy_count} records{workspace_info}."
- )
- logger.info(
- f"PostgreSQL: Migrating data from legacy table '{legacy_table_name}' to new table '{table_name}'"
- )
- try:
- migrated_count = await PGVectorStorage._pg_migrate_workspace_data(
- db,
- legacy_table_name,
- table_name,
- workspace,
- legacy_count,
- embedding_dim,
- )
- if migrated_count != legacy_count:
- logger.warning(
- "PostgreSQL: Read %s legacy records%s during migration, expected %s.",
- migrated_count,
- workspace_info,
- legacy_count,
- )
- new_count_result = await db.query(new_count_query, [workspace])
- new_table_count_after = (
- new_count_result.get("count", 0) if new_count_result else 0
- )
- inserted_count = new_table_count_after - new_table_workspace_count
- if inserted_count != legacy_count:
- error_msg = (
- "PostgreSQL: Migration verification failed, "
- f"expected {legacy_count} inserted records, got {inserted_count}."
- )
- logger.error(error_msg)
- raise DataMigrationError(error_msg)
- except DataMigrationError:
- # Re-raise DataMigrationError as-is to preserve specific error messages
- raise
- except Exception as e:
- logger.error(
- f"PostgreSQL: Failed to migrate data from legacy table '{legacy_table_name}' to new table '{table_name}': {e}"
- )
- raise DataMigrationError(
- f"Failed to migrate data from legacy table '{legacy_table_name}' to new table '{table_name}'"
- ) from e
- logger.info(
- f"PostgreSQL: Migration from '{legacy_table_name}' to '{table_name}' completed successfully"
- )
- logger.warning(
- "PostgreSQL: Manual deletion is required after data migration verification."
- )
- async def initialize(self):
- async with get_data_init_lock():
- if self.db is None:
- self.db = await ClientManager.get_client(
- vector_storage=self.global_config.get("vector_storage")
- )
- # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
- if self.db.workspace:
- # Use PostgreSQLDB's workspace (highest priority)
- logger.info(
- f"Using PG_WORKSPACE environment variable: '{self.db.workspace}' (overriding '{self.workspace}/{self.namespace}')"
- )
- self.workspace = self.db.workspace
- elif hasattr(self, "workspace") and self.workspace:
- # Use storage class's workspace (medium priority)
- pass
- else:
- # Use "default" for compatibility (lowest priority)
- self.workspace = "default"
- # Setup table (create if not exists and handle migration)
- await PGVectorStorage.setup_table(
- self.db,
- self.table_name,
- self.workspace, # CRITICAL: Filter migration by workspace
- embedding_dim=self.embedding_func.embedding_dim,
- legacy_table_name=self.legacy_table_name,
- base_table=self.legacy_table_name, # base_table for DDL template lookup
- )
- if self._flush_lock is None:
- self._flush_lock = get_namespace_lock(
- self.namespace, workspace=self.workspace
- )
- async def finalize(self):
- """Flush pending vector ops then release the shared PG client.
- Captures regular ``Exception`` from the flush so it can be re-raised
- as a ``RuntimeError`` naming the unflushed buffer counts after the
- client is released. ``BaseException`` (``CancelledError``,
- ``KeyboardInterrupt``, ``SystemExit``) is intentionally NOT caught
- so it can propagate through ``finally`` — the buffer-count reframing
- below is skipped in that case (the propagating exception already
- signals shutdown; conflating it with "left N pending" would be
- misleading).
- Idempotency:
- Re-entry after a successful or failed first call is a no-op for
- the flush (client is already released), but still raises if
- buffers remain non-empty so the operator sees the data-loss
- signal again.
- """
- if self.db is None:
- pending_docs = len(self._pending_vector_docs)
- pending_deletes = len(self._pending_vector_deletes)
- if pending_docs or pending_deletes:
- raise RuntimeError(
- f"[{self.workspace}] PGVectorStorage.finalize() re-entry: "
- f"client already released; {pending_docs} pending upserts "
- f"and {pending_deletes} pending deletes cannot be flushed"
- )
- return
- flush_error: Exception | None = None
- try:
- try:
- await self._flush_pending_vector_ops()
- except Exception as e:
- flush_error = e
- finally:
- if self.db is not None:
- await ClientManager.release_client(self.db)
- self.db = None
- pending_docs = len(self._pending_vector_docs)
- pending_deletes = len(self._pending_vector_deletes)
- if flush_error is not None:
- raise RuntimeError(
- f"[{self.workspace}] PGVectorStorage.finalize() flush raised; "
- f"{pending_docs} pending upserts and {pending_deletes} pending "
- f"deletes were left buffered (client released, data lost)"
- ) from flush_error
- if pending_docs or pending_deletes:
- raise RuntimeError(
- f"[{self.workspace}] PGVectorStorage.finalize() left "
- f"{pending_docs} pending upserts and {pending_deletes} "
- f"pending deletes buffered after final flush attempt"
- )
- def _upsert_chunks(
- self, item: dict[str, Any], current_time: datetime.datetime
- ) -> tuple[str, tuple[Any, ...]]:
- """Prepare upsert data for chunks.
- Returns:
- Tuple of (SQL template, values tuple for executemany)
- """
- try:
- upsert_sql = SQL_TEMPLATES["upsert_chunk"].format(
- table_name=self.table_name
- )
- # Return tuple in the exact order of SQL parameters ($1, $2, ...)
- values: tuple[Any, ...] = (
- self.workspace, # $1
- item["__id__"], # $2
- item["tokens"], # $3
- item["chunk_order_index"], # $4
- item["full_doc_id"], # $5
- item["content"], # $6
- item["__vector__"], # $7 - numpy array, handled by pgvector codec
- item["file_path"], # $8
- current_time, # $9
- current_time, # $10
- )
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error to prepare upsert,\nerror: {e}\nitem: {item}"
- )
- raise
- return upsert_sql, values
- def _upsert_entities(
- self, item: dict[str, Any], current_time: datetime.datetime
- ) -> tuple[str, tuple[Any, ...]]:
- """Prepare upsert data for entities.
- Returns:
- Tuple of (SQL template, values tuple for executemany)
- """
- upsert_sql = SQL_TEMPLATES["upsert_entity"].format(table_name=self.table_name)
- source_id = item["source_id"]
- if isinstance(source_id, str) and "<SEP>" in source_id:
- chunk_ids = source_id.split("<SEP>")
- else:
- chunk_ids = [source_id]
- # Return tuple in the exact order of SQL parameters ($1, $2, ...)
- values: tuple[Any, ...] = (
- self.workspace, # $1
- item["__id__"], # $2
- item["entity_name"], # $3
- item["content"], # $4
- item["__vector__"], # $5 - numpy array, handled by pgvector codec
- chunk_ids, # $6
- item.get("file_path", None), # $7
- current_time, # $8
- current_time, # $9
- )
- return upsert_sql, values
- def _upsert_relationships(
- self, item: dict[str, Any], current_time: datetime.datetime
- ) -> tuple[str, tuple[Any, ...]]:
- """Prepare upsert data for relationships.
- Returns:
- Tuple of (SQL template, values tuple for executemany)
- """
- upsert_sql = SQL_TEMPLATES["upsert_relationship"].format(
- table_name=self.table_name
- )
- source_id = item["source_id"]
- if isinstance(source_id, str) and "<SEP>" in source_id:
- chunk_ids = source_id.split("<SEP>")
- else:
- chunk_ids = [source_id]
- # Return tuple in the exact order of SQL parameters ($1, $2, ...)
- values: tuple[Any, ...] = (
- self.workspace, # $1
- item["__id__"], # $2
- item["src_id"], # $3
- item["tgt_id"], # $4
- item["content"], # $5
- item["__vector__"], # $6 - numpy array, handled by pgvector codec
- chunk_ids, # $7
- item.get("file_path", None), # $8
- current_time, # $9
- current_time, # $10
- )
- return upsert_sql, values
- async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
- """Buffer vector docs for embedding and batched flush.
- Correctness premise:
- LightRAG's pipeline is the normal write path for graph/vector
- mutations and guarantees a single writer process per workspace.
- This storage follows the same deferred-embedding contract as
- OpenSearchVectorDBStorage: the pending buffer is process-local.
- Committed PG rows are immediately visible across workers, but
- *buffered* writes are not — readers in other workers will not
- see them until the writing worker calls index_done_callback().
- Non-pipeline writers must provide equivalent single-writer
- serialization and must flush explicitly before depending on
- reads from another worker.
- Memory expectation:
- Pending docs (raw ``content`` strings, plus cached float32
- vectors once embedded) accumulate in process memory until the
- next ``index_done_callback()`` / ``finalize()``. This matches
- the OpenSearch/Nano/Faiss contract. Callers performing very
- large ingests should flush periodically (every N upserts) to
- cap working-set size.
- """
- if not data:
- return
- logger.debug(
- f"[{self.workspace}] Buffering {len(data)} vectors for {self.namespace}"
- )
- # Build pending docs outside the lock; UTC naive datetime mirrors
- # the previous direct-write code path (the _upsert_* helpers feed
- # this straight into asyncpg as a timestamp).
- current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None)
- pending_docs: list[tuple[str, _PendingPGVectorDoc]] = []
- for i, (k, v) in enumerate(data.items(), start=1):
- pending_docs.append(
- (
- k,
- _PendingPGVectorDoc(
- item={"__id__": k, **v},
- created_at=current_time,
- ),
- )
- )
- await _cooperative_yield(i)
- async with self._flush_lock:
- for doc_id, pending_doc in pending_docs:
- # Invariant: a later upsert wins over an earlier delete; the
- # unconditional dict assignment also discards any cached
- # stale vector from a prior upsert of the same id.
- self._pending_vector_deletes.discard(doc_id)
- self._pending_vector_docs[doc_id] = pending_doc
- async def _flush_pending_vector_ops(self) -> None:
- """Flush buffered PG vector upserts and deletes in one transaction.
- Concurrency:
- All buffer reads/writes and destructive server mutations on
- this storage run under ``self._flush_lock``. Embedding stays
- inside that lock so a destructive operation cannot interleave
- between embedding and the PG write in the same process.
- Failure handling:
- PG cannot expose per-document statuses, so flush is
- all-or-nothing:
- * If embedding fails the buffers stay intact (next flush
- retries; cached vectors are reused).
- * If ``_run_with_retry`` raises the transaction rolls back
- and the buffers stay intact. Cached vectors stay attached
- to pending docs so the next flush does not re-embed.
- * On success both buffers are cleared.
- Post-finalize / pre-initialize:
- Calling this after ``finalize()`` (``self.db is None``) or
- before ``initialize()`` (``self._flush_lock is None``) with a
- non-empty buffer raises ``RuntimeError`` — silently dropping
- buffered writes would defeat the data-loss visibility that
- ``finalize()`` provides. An empty-buffer call is a no-op.
- """
- if self._flush_lock is None:
- pending_docs = len(self._pending_vector_docs)
- pending_deletes = len(self._pending_vector_deletes)
- if pending_docs or pending_deletes:
- raise RuntimeError(
- f"[{self.workspace}] PGVectorStorage._flush_pending_vector_ops "
- f"called before initialize(); {pending_docs} pending upserts "
- f"and {pending_deletes} pending deletes cannot be flushed"
- )
- return
- async with self._flush_lock:
- if not self._pending_vector_docs and not self._pending_vector_deletes:
- return
- if self.db is None:
- pending_docs = len(self._pending_vector_docs)
- pending_deletes = len(self._pending_vector_deletes)
- raise RuntimeError(
- f"[{self.workspace}] PGVectorStorage._flush_pending_vector_ops "
- f"called after client release; {pending_docs} pending upserts "
- f"and {pending_deletes} pending deletes cannot be flushed"
- )
- timing_label = f"{self.workspace} PGVectorStorage.flush[{self.namespace}]"
- total_start = time.perf_counter()
- performance_timing_log(
- "[%s] start upserts=%s deletes=%s max_batch_size=%s",
- timing_label,
- len(self._pending_vector_docs),
- len(self._pending_vector_deletes),
- self._max_batch_size,
- )
- # --- Embedding phase ---------------------------------------------
- docs_to_embed = [
- (doc_id, pending_doc)
- for doc_id, pending_doc in self._pending_vector_docs.items()
- if pending_doc.vector is None
- ]
- if docs_to_embed:
- contents = [
- pending_doc.item["content"] for _, pending_doc in docs_to_embed
- ]
- batches = [
- contents[i : i + self._max_batch_size]
- for i in range(0, len(contents), self._max_batch_size)
- ]
- logger.info(
- f"[{self.workspace}] {self.namespace} flush: embedding "
- f"{len(docs_to_embed)} vectors in {len(batches)} batch(es) "
- f"(batch_num={self._max_batch_size})"
- )
- embedding_start = time.perf_counter()
- try:
- embeddings_list = await asyncio.gather(
- *[
- self.embedding_func(batch, context="document")
- for batch in batches
- ]
- )
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error embedding pending vector ops "
- f"(upserts={len(docs_to_embed)}): {e}"
- )
- raise
- performance_timing_log(
- "[%s] embedding completed in %.4fs docs=%s batches=%s",
- timing_label,
- time.perf_counter() - embedding_start,
- len(docs_to_embed),
- len(batches),
- )
- embeddings = np.concatenate(embeddings_list)
- # Explicit check: a count mismatch under `python -O` would
- # silently truncate via zip(), mispairing vectors with docs.
- if len(embeddings) != len(docs_to_embed):
- raise RuntimeError(
- f"[{self.workspace}] Embedding count mismatch: "
- f"expected {len(docs_to_embed)}, got {len(embeddings)}"
- )
- for i, ((_, pending_doc), embedding) in enumerate(
- zip(docs_to_embed, embeddings), start=1
- ):
- pending_doc.vector = embedding
- await _cooperative_yield(i)
- # --- Build batch tuples ------------------------------------------
- if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS):
- build_tuple = self._upsert_chunks
- elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES):
- build_tuple = self._upsert_entities
- elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS):
- build_tuple = self._upsert_relationships
- else:
- raise ValueError(f"{self.namespace} is not supported")
- batch_values: list[tuple[Any, ...]] = []
- upsert_sql: str | None = None
- for i, (doc_id, pending_doc) in enumerate(
- self._pending_vector_docs.items(), start=1
- ):
- if pending_doc.vector is None:
- # Should not happen: every pending doc was embedded above
- # or had a cached vector from a previous lazy embed.
- raise RuntimeError(
- f"[{self.workspace}] Pending vector for id={doc_id} "
- f"missing after embedding phase"
- )
- # Coerce to float32 ndarray if not already (defensive; the
- # embedding func typically returns float32 but a custom
- # provider may return float64 — pgvector wants float32).
- item = dict(pending_doc.item)
- vector = pending_doc.vector
- if not isinstance(vector, np.ndarray) or vector.dtype != np.float32:
- vector = np.asarray(vector, dtype=np.float32)
- item["__vector__"] = vector
- upsert_sql, values = build_tuple(item, pending_doc.created_at)
- batch_values.append(values)
- await _cooperative_yield(i)
- pending_delete_ids = list(self._pending_vector_deletes)
- # --- Persistence -------------------------------------------------
- async def _flush_batch(connection: asyncpg.Connection) -> None:
- async with connection.transaction():
- if batch_values and upsert_sql:
- execute_start = time.perf_counter()
- await connection.executemany(upsert_sql, batch_values)
- performance_timing_log(
- "[%s] executemany completed in %.4fs batch_size=%s",
- timing_label,
- time.perf_counter() - execute_start,
- len(batch_values),
- )
- if pending_delete_ids:
- delete_sql = (
- f"DELETE FROM {self.table_name} "
- "WHERE workspace=$1 AND id = ANY($2)"
- )
- await connection.execute(
- delete_sql, self.workspace, pending_delete_ids
- )
- try:
- await self.db._run_with_retry(_flush_batch, timing_label=timing_label)
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error flushing vector ops "
- f"(upserts={len(batch_values)}, "
- f"deletes={len(pending_delete_ids)}): {e}"
- )
- raise
- # Success: clear committed buffers. Cached vectors live on
- # those records and are GC'd with them.
- self._pending_vector_docs.clear()
- self._pending_vector_deletes.clear()
- performance_timing_log(
- "[%s] total complete in %.4fs upserts=%s deletes=%s",
- timing_label,
- time.perf_counter() - total_start,
- len(batch_values),
- len(pending_delete_ids),
- )
- #################### query method ###############
- async def query(
- self, query: str, top_k: int, query_embedding: list[float] = None
- ) -> list[dict[str, Any]]:
- if query_embedding is not None:
- embedding = query_embedding
- else:
- embeddings = await self.embedding_func(
- [query], context="query", _priority=5
- ) # higher priority for query
- embedding = embeddings[0]
- # Use positional $4 parameter instead of string-interpolated literal.
- # asyncpg sends the embedding via register_vector binary codec, avoiding
- # per-query text serialization and PostgreSQL text-to-vector parsing.
- vector_cast = (
- "halfvec"
- if getattr(self.db, "vector_index_type", None) == "HNSW_HALFVEC"
- else "vector"
- )
- sql = SQL_TEMPLATES[self.namespace].format(
- table_name=self.table_name, vector_cast=vector_cast
- )
- params = {
- "workspace": self.workspace,
- "closer_than_threshold": 1 - self.cosine_better_than_threshold,
- "top_k": top_k,
- "embedding": embedding,
- }
- results = await self.db.query(sql, params=list(params.values()), multirows=True)
- return results
- async def index_done_callback(self) -> None:
- await self._flush_pending_vector_ops()
- async def delete(self, ids: list[str]) -> None:
- """Buffer vector deletes for batched flush.
- A delete cancels any pending upsert for the same id. The actual PG
- delete is performed by ``_flush_pending_vector_ops`` during the next
- ``index_done_callback`` / ``finalize`` call.
- """
- if not ids:
- return
- if isinstance(ids, set):
- ids = list(ids)
- async with self._flush_lock:
- for doc_id in ids:
- self._pending_vector_docs.pop(doc_id, None)
- self._pending_vector_deletes.add(doc_id)
- logger.debug(
- f"[{self.workspace}] Buffered delete for {len(ids)} vectors in {self.namespace}"
- )
- async def delete_entity(self, entity_name: str) -> None:
- """Delete an entity vector by entity name.
- Runs the SQL predicate delete (``WHERE entity_name=$2``) immediately
- under ``_flush_lock`` so it cannot interleave with a flush of the
- same namespace, and — only after the SQL succeeds — prunes the
- matching pending docs and any pending delete that would otherwise
- re-fire. If the SQL raises, the buffer is left untouched so a
- subsequent retry can still observe the pending state instead of
- silently losing it, and the exception is logged and re-raised so
- the caller (e.g. ``adelete_by_entity``) short-circuits before
- ``_persist_graph_updates()`` flushes those preserved pending
- upserts back into the table. Matches the cross-backend contract
- documented on the Qdrant / Milvus / Mongo implementations: "server-
- side failures are re-raised; the caller decides whether to retry."
- The SQL predicate is kept (rather than ``self.delete([ent_id])``) as
- a safety net for legacy rows whose ``id`` may not equal
- ``compute_mdhash_id(entity_name, prefix="ent-")``.
- Raises:
- RuntimeError: if called before ``initialize()`` (``_flush_lock``
- is still ``None``). Silently dropping a destructive intent
- would defeat the data-loss visibility that the rest of this
- storage enforces; the caller must initialize first.
- """
- if self._flush_lock is None:
- raise RuntimeError(
- f"[{self.workspace}] PGVectorStorage.delete_entity called before "
- f"initialize(); call initialize_storages() on the LightRAG instance "
- f"before issuing destructive operations"
- )
- entity_id = compute_mdhash_id(entity_name, prefix="ent-")
- def _prune_pending() -> None:
- # Drop any pending upsert keyed by hash id or matching
- # entity_name in the buffered payload (relationship docs
- # have no entity_name; the lookup is a harmless no-op).
- self._pending_vector_docs.pop(entity_id, None)
- for buffered_id in [
- k
- for k, v in self._pending_vector_docs.items()
- if v.item.get("entity_name") == entity_name
- ]:
- self._pending_vector_docs.pop(buffered_id, None)
- # Drop any redundant pending delete; the SQL above covered it.
- self._pending_vector_deletes.discard(entity_id)
- try:
- async with self._flush_lock:
- if self.db is None:
- # Storage already finalized; buffer is the only state
- # left, so apply the delete intent there.
- _prune_pending()
- return
- delete_sql = (
- f"DELETE FROM {self.table_name} "
- "WHERE workspace=$1 AND entity_name=$2"
- )
- await self.db.execute(
- delete_sql,
- {"workspace": self.workspace, "entity_name": entity_name},
- )
- # SQL succeeded — safe to prune buffer. If it had raised,
- # we'd skip this so the pending state remains for retry.
- _prune_pending()
- logger.debug(
- f"[{self.workspace}] Successfully deleted entity {entity_name}"
- )
- except Exception as e:
- # Re-raise so the caller can short-circuit and skip the
- # subsequent flush; otherwise the pending upsert we just
- # preserved would be persisted back, undoing the delete.
- logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}")
- raise
- async def delete_entity_relation(self, entity_name: str) -> None:
- """Delete all relation vectors where ``entity_name`` is src or tgt.
- Predicate-based; runs immediately. The whole method holds
- ``_flush_lock`` so it cannot interleave with a flush of buffered
- relation upserts.
- Buffer semantics — post-prune with caller short-circuit contract:
- Any pending relation upsert whose ``src_id`` or ``tgt_id``
- matches ``entity_name`` is pruned from ``_pending_vector_docs``
- **only after** the SQL predicate delete succeeds. On SQL
- failure the pending docs are left intact and the exception is
- re-raised. This avoids silently dropping buffered relation
- vectors that the user never told us to discard.
- Correctness relies on the caller short-circuiting before it
- can trigger ``index_done_callback`` and flush those preserved
- pending upserts back into the table (which would undo the
- delete intent on a partial server-side delete). The single
- in-tree caller ``adelete_by_entity`` in ``utils_graph.py``
- honors this: its ``except`` clause skips both ``delete_node``
- and ``_persist_graph_updates``, so on failure both the graph
- and the pending vector buffer stay consistent with the
- "delete never happened" state and the operation converges on
- the next retry. Callers that need to rename or re-link the
- entity must re-issue the relation upserts after a successful
- call.
- Raises:
- RuntimeError: if called before ``initialize()`` (``_flush_lock``
- is still ``None``). Silently dropping a destructive intent
- would defeat the data-loss visibility that the rest of this
- storage enforces; the caller must initialize first.
- """
- if self._flush_lock is None:
- raise RuntimeError(
- f"[{self.workspace}] PGVectorStorage.delete_entity_relation called "
- f"before initialize(); call initialize_storages() on the LightRAG "
- f"instance before issuing destructive operations"
- )
- def _prune_pending() -> None:
- for buffered_id in [
- k
- for k, v in self._pending_vector_docs.items()
- if v.item.get("src_id") == entity_name
- or v.item.get("tgt_id") == entity_name
- ]:
- self._pending_vector_docs.pop(buffered_id, None)
- try:
- async with self._flush_lock:
- if self.db is None:
- # Storage already finalized; buffer is the only state
- # left, so apply the delete intent there.
- _prune_pending()
- return
- delete_sql = (
- f"DELETE FROM {self.table_name} "
- "WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)"
- )
- await self.db.execute(
- delete_sql,
- {"workspace": self.workspace, "entity_name": entity_name},
- )
- # SQL succeeded — safe to prune pending relation docs. If
- # it had raised, we'd skip this so the pending state
- # remains for retry on the next call.
- _prune_pending()
- logger.debug(
- f"[{self.workspace}] Successfully deleted relations for entity {entity_name}"
- )
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error deleting relations for entity {entity_name}: {e}"
- )
- raise
- async def get_by_id(self, id: str) -> dict[str, Any] | None:
- """Get vector data by its ID with read-your-writes against the buffer.
- ``__vector__`` and ``__id__`` are stripped from buffered results to
- match the other vector backends; callers needing embeddings must use
- ``get_vectors_by_ids``.
- Response shape:
- Buffered hits return ``{"id", "content", <payload fields>,
- "created_at"}`` only — no embedding column. SQL-fallback hits
- return the full row including ``content_vector`` (and any
- namespace-specific columns such as ``entity_name`` or
- ``source_id``). Callers that only read documented payload
- fields (``content``, ``id``, ``created_at``) are unaffected;
- consumers iterating over all keys must tolerate both shapes.
- """
- async with self._flush_lock:
- if id in self._pending_vector_deletes:
- return None
- pending = self._pending_vector_docs.get(id)
- if pending is not None:
- doc = {
- k: v
- for k, v in pending.item.items()
- if k not in ("__id__", "__vector__")
- }
- doc["id"] = id
- doc["created_at"] = int(pending.created_at.timestamp())
- return doc
- query = (
- f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at "
- f"FROM {self.table_name} WHERE workspace=$1 AND id=$2"
- )
- try:
- result = await self.db.query(query, [self.workspace, id])
- if result:
- return dict(result)
- return None
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error retrieving vector data for ID {id}: {e}"
- )
- return None
- async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
- """Get multiple vector docs by ID, preserving caller order.
- Pending deletes return ``None`` in their slot. Pending upserts are
- served from the buffer; remaining ids fall through to a single
- parameterized ``id = ANY($2)`` SQL query (replacing the previous
- string-built ``IN (...)`` form).
- Response shape: same buffered-vs-SQL inconsistency as
- ``get_by_id`` — see that docstring for details.
- """
- if not ids:
- return []
- buffered: dict[str, dict[str, Any] | None] = {}
- remaining: list[str] = []
- async with self._flush_lock:
- for doc_id in ids:
- if doc_id in self._pending_vector_deletes:
- buffered[doc_id] = None
- continue
- pending = self._pending_vector_docs.get(doc_id)
- if pending is not None:
- doc = {
- k: v
- for k, v in pending.item.items()
- if k not in ("__id__", "__vector__")
- }
- doc["id"] = doc_id
- doc["created_at"] = int(pending.created_at.timestamp())
- buffered[doc_id] = doc
- continue
- remaining.append(doc_id)
- id_map: dict[str, dict[str, Any]] = {}
- if remaining:
- query = (
- f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at "
- f"FROM {self.table_name} WHERE workspace=$1 AND id = ANY($2)"
- )
- try:
- results = await self.db.query(
- query, [self.workspace, remaining], multirows=True
- )
- for record in results or []:
- if record is None:
- continue
- record_dict = dict(record)
- row_id = record_dict.get("id")
- if row_id is not None:
- id_map[str(row_id)] = record_dict
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
- )
- return []
- ordered_results: list[dict[str, Any] | None] = []
- for requested_id in ids:
- if requested_id in buffered:
- ordered_results.append(buffered[requested_id])
- else:
- ordered_results.append(id_map.get(str(requested_id)))
- return ordered_results
- async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
- """Get vector embeddings by ID, with read-your-writes against the buffer.
- Lazily embeds pending docs whose vector has not been computed yet,
- caches the result on the pending record (so the next flush reuses
- it), and falls through to a parameterized SQL query for ids not in
- the buffer.
- Embedding I/O runs *outside* ``_flush_lock`` so a slow embedding
- provider cannot block concurrent ``upsert`` / ``delete`` / read
- calls on this storage. The lock is re-acquired briefly to cache
- the result, and the pending record's identity is re-checked
- first: if a concurrent ``upsert`` / ``delete`` / ``drop`` replaced
- or removed the record during the embedding window, that ID is
- dropped from the response entirely — we neither cache the stale
- vector on the new/missing record nor return it to the caller, so
- callers cannot observe an embedding that no longer matches the
- current buffer state. Affected callers should treat the missing
- key the same as the existing "id was deleted before the call"
- case and retry if needed.
- """
- if not ids:
- return {}
- result: dict[str, list[float]] = {}
- remaining: list[str] = []
- docs_to_embed: list[tuple[str, _PendingPGVectorDoc]] = []
- async with self._flush_lock:
- for doc_id in ids:
- if doc_id in self._pending_vector_deletes:
- continue
- pending = self._pending_vector_docs.get(doc_id)
- if pending is not None:
- if pending.vector is None:
- docs_to_embed.append((doc_id, pending))
- else:
- result[doc_id] = pending.vector.tolist()
- continue
- remaining.append(doc_id)
- if docs_to_embed:
- contents = [pending_doc.item["content"] for _, pending_doc in docs_to_embed]
- batches = [
- contents[i : i + self._max_batch_size]
- for i in range(0, len(contents), self._max_batch_size)
- ]
- try:
- embeddings_list = await asyncio.gather(
- *[
- self.embedding_func(batch, context="document")
- for batch in batches
- ]
- )
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error lazily embedding pending vectors "
- f"(upserts={len(docs_to_embed)}): {e}"
- )
- raise
- embeddings = np.concatenate(embeddings_list)
- if len(embeddings) != len(docs_to_embed):
- raise RuntimeError(
- f"[{self.workspace}] Embedding count mismatch: "
- f"expected {len(docs_to_embed)}, got {len(embeddings)}"
- )
- # Re-acquire the lock just long enough to cache results on
- # the same record. The identity check gates BOTH the cache
- # write and the response entry: if the pending record was
- # swapped or removed during the embedding window (concurrent
- # upsert / delete / drop), the just-computed vector no longer
- # matches the current buffer state for this id, so we drop it
- # from the response rather than return a stale embedding.
- async with self._flush_lock:
- for i, ((doc_id, original_pending), embedding) in enumerate(
- zip(docs_to_embed, embeddings), start=1
- ):
- current = self._pending_vector_docs.get(doc_id)
- if current is original_pending:
- current.vector = embedding
- result[doc_id] = embedding.tolist()
- await _cooperative_yield(i)
- if not remaining:
- return result
- query = (
- f"SELECT id, content_vector FROM {self.table_name} "
- f"WHERE workspace=$1 AND id = ANY($2)"
- )
- try:
- results = await self.db.query(
- query, [self.workspace, remaining], multirows=True
- )
- for row in results or []:
- if not row or "content_vector" not in row or "id" not in row:
- continue
- vector_data = row["content_vector"]
- try:
- if isinstance(vector_data, (list, tuple)):
- result[row["id"]] = list(vector_data)
- elif isinstance(vector_data, str):
- parsed = json.loads(vector_data)
- if isinstance(parsed, list):
- result[row["id"]] = parsed
- elif hasattr(vector_data, "tolist"):
- result[row["id"]] = vector_data.tolist()
- elif hasattr(vector_data, "to_list") and callable(
- vector_data.to_list
- ):
- result[row["id"]] = vector_data.to_list()
- except (json.JSONDecodeError, TypeError) as e:
- logger.warning(
- f"[{self.workspace}] Failed to parse vector data for ID {row['id']}: {e}"
- )
- except Exception as e:
- logger.error(f"[{self.workspace}] Error getting vectors: {e}")
- return result
- async def drop(self) -> dict[str, str]:
- """Drop all rows scoped to this storage's workspace.
- The underlying table is shared across workspaces and is NOT
- dropped — this method issues ``DELETE FROM <table> WHERE
- workspace=$1`` and clears the pending buffers (queued
- upserts/deletes against rows that are about to disappear are
- meaningless).
- Concurrency contract:
- ``_flush_lock`` guards same-process flush / upsert / delete
- races only. Cross-worker buffered writes are NOT covered —
- another worker's pending buffer can flush stale rows back
- into the table immediately after this call returns. Callers
- running inside the LightRAG framework MUST hold
- ``pipeline_status["destructive_busy"] = True`` (acquired
- atomically via ``_acquire_destructive_busy``) for the entire
- duration of the drop; the ``/documents/clear`` endpoint
- already does this before invoking ``drop()`` on every
- storage. Direct callers (tests, ops scripts, debugging) are
- responsible for ensuring no other writer is touching this
- workspace.
- Returns:
- ``{"status": "success" | "error", "message": ...}``. Unlike
- ``delete()`` / ``delete_entity()`` / ``delete_entity_relation()``
- which re-raise on failure, ``drop()`` swallows the exception
- into the return dict — callers MUST inspect ``status`` to
- detect failure. The exception is also logged at ``error``
- level so a missed status check still leaves a trail.
- """
- try:
- async with self._flush_lock:
- self._pending_vector_docs.clear()
- self._pending_vector_deletes.clear()
- drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
- table_name=self.table_name
- )
- await self.db.execute(drop_sql, {"workspace": self.workspace})
- return {"status": "success", "message": "data dropped"}
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error dropping vector storage "
- f"{self.namespace}: {e}"
- )
- return {"status": "error", "message": str(e)}
- def _parse_doc_status_datetime(
- dt_str: Any,
- context: str = "",
- ) -> datetime.datetime | None:
- """Convert a datetime value to a naive UTC datetime for database storage.
- Accepts `datetime.datetime` objects, `datetime.date` objects, or ISO-format
- strings. Returns None on failure (which may trigger a NOT NULL constraint
- violation if the column does not allow nulls).
- The optional context string (e.g. "[workspace] doc <id> created_at") is
- included in the error log to help locate the offending record.
- """
- if dt_str is None:
- return None
- if isinstance(dt_str, datetime.datetime):
- if dt_str.tzinfo is None:
- dt_str = dt_str.replace(tzinfo=timezone.utc)
- return dt_str.astimezone(timezone.utc).replace(tzinfo=None)
- if isinstance(dt_str, datetime.date):
- return datetime.datetime(
- dt_str.year, dt_str.month, dt_str.day, tzinfo=timezone.utc
- ).replace(tzinfo=None)
- try:
- dt = datetime.datetime.fromisoformat(dt_str)
- if dt.tzinfo is None:
- dt = dt.replace(tzinfo=timezone.utc)
- return dt.astimezone(timezone.utc).replace(tzinfo=None)
- except (ValueError, TypeError):
- logger.error(
- f"Unable to parse doc status datetime string"
- f"{f' ({context})' if context else ''}: {dt_str!r}"
- )
- return None
- @final
- @dataclass
- class PGDocStatusStorage(DocStatusStorage):
- db: PostgreSQLDB = field(default=None)
- def _format_datetime_with_timezone(self, dt):
- """Convert datetime to ISO format string with timezone info"""
- if dt is None:
- return None
- # If no timezone info, assume it's UTC time (as stored in database)
- if dt.tzinfo is None:
- dt = dt.replace(tzinfo=timezone.utc)
- # If datetime already has timezone info, keep it as is
- return dt.isoformat()
- async def initialize(self):
- async with get_data_init_lock():
- if self.db is None:
- self.db = await ClientManager.get_client(
- vector_storage=self.global_config.get("vector_storage")
- )
- # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
- if self.db.workspace:
- # Use PostgreSQLDB's workspace (highest priority)
- logger.info(
- f"Using PG_WORKSPACE environment variable: '{self.db.workspace}' (overriding '{self.workspace}/{self.namespace}')"
- )
- self.workspace = self.db.workspace
- elif hasattr(self, "workspace") and self.workspace:
- # Use storage class's workspace (medium priority)
- pass
- else:
- # Use "default" for compatibility (lowest priority)
- self.workspace = "default"
- # NOTE: Table creation is handled by PostgreSQLDB.initdb() during initialization
- # No need to create table here as it's already created in the TABLES dict
- async def finalize(self):
- if self.db is not None:
- await ClientManager.release_client(self.db)
- self.db = None
- async def filter_keys(self, keys: set[str]) -> set[str]:
- """Filter out duplicated content"""
- if not keys:
- return set()
- table_name = namespace_to_table_name(self.namespace)
- sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
- params = {"workspace": self.workspace, "ids": list(keys)}
- try:
- res = await self.db.query(sql, list(params.values()), multirows=True)
- if res:
- exist_keys = [key["id"] for key in res]
- else:
- exist_keys = []
- new_keys = set([s for s in keys if s not in exist_keys])
- # print(f"keys: {keys}")
- # print(f"new_keys: {new_keys}")
- return new_keys
- except Exception as e:
- logger.error(
- f"[{self.workspace}] PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
- )
- raise
- async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
- sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
- params = {"workspace": self.workspace, "id": id}
- result = await self.db.query(sql, list(params.values()), True)
- if result is None or result == []:
- return None
- else:
- # Parse chunks_list JSON string back to list
- chunks_list = result[0].get("chunks_list", [])
- if isinstance(chunks_list, str):
- try:
- chunks_list = json.loads(chunks_list)
- except json.JSONDecodeError:
- chunks_list = []
- # Parse metadata JSON string back to dict
- metadata = result[0].get("metadata", {})
- if isinstance(metadata, str):
- try:
- metadata = json.loads(metadata)
- except json.JSONDecodeError:
- metadata = {}
- # Convert datetime objects to ISO format strings with timezone info
- created_at = self._format_datetime_with_timezone(result[0]["created_at"])
- updated_at = self._format_datetime_with_timezone(result[0]["updated_at"])
- return dict(
- content_length=result[0]["content_length"],
- content_summary=result[0]["content_summary"],
- status=result[0]["status"],
- chunks_count=result[0]["chunks_count"],
- created_at=created_at,
- updated_at=updated_at,
- file_path=result[0]["file_path"],
- chunks_list=chunks_list,
- metadata=metadata,
- error_msg=result[0].get("error_msg"),
- track_id=result[0].get("track_id"),
- content_hash=result[0].get("content_hash"),
- )
- async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
- """Get doc_chunks data by multiple IDs."""
- if not ids:
- return []
- sql = "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2)"
- params = {"workspace": self.workspace, "ids": ids}
- results = await self.db.query(sql, list(params.values()), True)
- if not results:
- return []
- processed_map: dict[str, dict[str, Any]] = {}
- for row in results:
- # Parse chunks_list JSON string back to list
- chunks_list = row.get("chunks_list", [])
- if isinstance(chunks_list, str):
- try:
- chunks_list = json.loads(chunks_list)
- except json.JSONDecodeError:
- chunks_list = []
- # Parse metadata JSON string back to dict
- metadata = row.get("metadata", {})
- if isinstance(metadata, str):
- try:
- metadata = json.loads(metadata)
- except json.JSONDecodeError:
- metadata = {}
- # Convert datetime objects to ISO format strings with timezone info
- created_at = self._format_datetime_with_timezone(row["created_at"])
- updated_at = self._format_datetime_with_timezone(row["updated_at"])
- processed_map[str(row.get("id"))] = {
- "content_length": row["content_length"],
- "content_summary": row["content_summary"],
- "status": row["status"],
- "chunks_count": row["chunks_count"],
- "created_at": created_at,
- "updated_at": updated_at,
- "file_path": row["file_path"],
- "chunks_list": chunks_list,
- "metadata": metadata,
- "error_msg": row.get("error_msg"),
- "track_id": row.get("track_id"),
- "content_hash": row.get("content_hash"),
- }
- ordered_results: list[dict[str, Any] | None] = []
- for requested_id in ids:
- ordered_results.append(processed_map.get(str(requested_id)))
- return ordered_results
- async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]:
- """Get document by file path
- Args:
- file_path: The file path to search for
- Returns:
- Union[dict[str, Any], None]: Document data if found, None otherwise
- Returns the same format as get_by_id method
- """
- sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and file_path=$2"
- params = {"workspace": self.workspace, "file_path": file_path}
- result = await self.db.query(sql, list(params.values()), True)
- if result is None or result == []:
- return None
- else:
- # Parse chunks_list JSON string back to list
- chunks_list = result[0].get("chunks_list", [])
- if isinstance(chunks_list, str):
- try:
- chunks_list = json.loads(chunks_list)
- except json.JSONDecodeError:
- chunks_list = []
- # Parse metadata JSON string back to dict
- metadata = result[0].get("metadata", {})
- if isinstance(metadata, str):
- try:
- metadata = json.loads(metadata)
- except json.JSONDecodeError:
- metadata = {}
- # Convert datetime objects to ISO format strings with timezone info
- created_at = self._format_datetime_with_timezone(result[0]["created_at"])
- updated_at = self._format_datetime_with_timezone(result[0]["updated_at"])
- return dict(
- content_length=result[0]["content_length"],
- content_summary=result[0]["content_summary"],
- status=result[0]["status"],
- chunks_count=result[0]["chunks_count"],
- created_at=created_at,
- updated_at=updated_at,
- file_path=result[0]["file_path"],
- chunks_list=chunks_list,
- metadata=metadata,
- error_msg=result[0].get("error_msg"),
- track_id=result[0].get("track_id"),
- content_hash=result[0].get("content_hash"),
- )
- async def get_doc_by_file_basename(
- self, basename: str
- ) -> tuple[str, dict[str, Any]] | None:
- """PG-native override of basename-based document lookup.
- Replaces the base-class full-table scan with a database-level query on
- the canonical ``file_path`` column. The caller is responsible for
- passing an already-canonical basename; storage performs an exact match
- only.
- """
- if not basename:
- return None
- if basename == "unknown_source":
- return None
- sql = (
- "SELECT * FROM LIGHTRAG_DOC_STATUS "
- "WHERE workspace=$1 AND file_path = $2 "
- "ORDER BY created_at ASC, id ASC LIMIT 1"
- )
- params = [self.workspace, basename]
- result = await self.db.query(sql, params, True)
- if not result:
- return None
- row = result[0]
- chunks_list = row.get("chunks_list", [])
- if isinstance(chunks_list, str):
- try:
- chunks_list = json.loads(chunks_list)
- except json.JSONDecodeError:
- chunks_list = []
- metadata = row.get("metadata", {})
- if isinstance(metadata, str):
- try:
- metadata = json.loads(metadata)
- except json.JSONDecodeError:
- metadata = {}
- created_at = self._format_datetime_with_timezone(row["created_at"])
- updated_at = self._format_datetime_with_timezone(row["updated_at"])
- doc = dict(
- content_length=row["content_length"],
- content_summary=row["content_summary"],
- status=row["status"],
- chunks_count=row["chunks_count"],
- created_at=created_at,
- updated_at=updated_at,
- file_path=row["file_path"],
- chunks_list=chunks_list,
- metadata=metadata,
- error_msg=row.get("error_msg"),
- track_id=row.get("track_id"),
- content_hash=row.get("content_hash"),
- )
- return str(row["id"]), doc
- async def get_doc_by_content_hash(
- self, content_hash: str
- ) -> tuple[str, dict[str, Any]] | None:
- """PG-native override of content-hash document lookup.
- Replaces the base-class full-table scan with an indexed query on
- ``workspace + content_hash``. Empty strings are treated as a miss
- to align with the partial-index predicate.
- """
- if not content_hash:
- return None
- sql = (
- "SELECT * FROM LIGHTRAG_DOC_STATUS "
- "WHERE workspace=$1 AND content_hash=$2 "
- "ORDER BY created_at ASC, id ASC LIMIT 1"
- )
- result = await self.db.query(sql, [self.workspace, content_hash], True)
- if not result:
- return None
- row = result[0]
- chunks_list = row.get("chunks_list", [])
- if isinstance(chunks_list, str):
- try:
- chunks_list = json.loads(chunks_list)
- except json.JSONDecodeError:
- chunks_list = []
- metadata = row.get("metadata", {})
- if isinstance(metadata, str):
- try:
- metadata = json.loads(metadata)
- except json.JSONDecodeError:
- metadata = {}
- created_at = self._format_datetime_with_timezone(row["created_at"])
- updated_at = self._format_datetime_with_timezone(row["updated_at"])
- doc = dict(
- content_length=row["content_length"],
- content_summary=row["content_summary"],
- status=row["status"],
- chunks_count=row["chunks_count"],
- created_at=created_at,
- updated_at=updated_at,
- file_path=row["file_path"],
- chunks_list=chunks_list,
- metadata=metadata,
- error_msg=row.get("error_msg"),
- track_id=row.get("track_id"),
- content_hash=row.get("content_hash"),
- )
- return str(row["id"]), doc
- async def get_status_counts(self) -> dict[str, int]:
- """Get counts of documents in each status"""
- sql = """SELECT status as "status", COUNT(1) as "count"
- FROM LIGHTRAG_DOC_STATUS
- where workspace=$1 GROUP BY STATUS
- """
- params = {"workspace": self.workspace}
- result = await self.db.query(sql, list(params.values()), True)
- counts = {}
- for doc in result:
- counts[doc["status"]] = doc["count"]
- return counts
- async def get_docs_by_status(
- self, status: DocStatus
- ) -> dict[str, DocProcessingStatus]:
- """all documents with a specific status"""
- sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
- params = {"workspace": self.workspace, "status": status.value}
- result = await self.db.query(sql, list(params.values()), True)
- docs_by_status = {}
- for element in result:
- # Parse chunks_list JSON string back to list
- chunks_list = element.get("chunks_list", [])
- if isinstance(chunks_list, str):
- try:
- chunks_list = json.loads(chunks_list)
- except json.JSONDecodeError:
- chunks_list = []
- # Parse metadata JSON string back to dict
- metadata = element.get("metadata", {})
- if isinstance(metadata, str):
- try:
- metadata = json.loads(metadata)
- except json.JSONDecodeError:
- metadata = {}
- # Ensure metadata is a dict
- if not isinstance(metadata, dict):
- metadata = {}
- # Safe handling for file_path
- file_path = element.get("file_path")
- if file_path is None:
- file_path = "no-file-path"
- # Convert datetime objects to ISO format strings with timezone info
- created_at = self._format_datetime_with_timezone(element["created_at"])
- updated_at = self._format_datetime_with_timezone(element["updated_at"])
- docs_by_status[element["id"]] = DocProcessingStatus(
- content_summary=element["content_summary"],
- content_length=element["content_length"],
- status=element["status"],
- created_at=created_at,
- updated_at=updated_at,
- chunks_count=element["chunks_count"],
- file_path=file_path,
- chunks_list=chunks_list,
- metadata=metadata,
- error_msg=element.get("error_msg"),
- track_id=element.get("track_id"),
- content_hash=element.get("content_hash"),
- )
- return docs_by_status
- async def get_docs_by_statuses(
- self, statuses: list[DocStatus]
- ) -> dict[str, DocProcessingStatus]:
- """Fetch documents matching any of the given statuses in a single query.
- Replaces multiple sequential/parallel get_docs_by_status() calls when the
- caller needs documents across several statuses (e.g. PROCESSING + FAILED + PENDING).
- Uses a single ANY($2) query instead of N separate round-trips.
- """
- if not statuses:
- return {}
- status_values = [s.value for s in statuses]
- sql = (
- "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND status = ANY($2)"
- )
- result = await self.db.query(
- sql, [self.workspace, status_values], multirows=True
- )
- docs: dict[str, DocProcessingStatus] = {}
- for element in result or []:
- try:
- chunks_list = element.get("chunks_list", [])
- if isinstance(chunks_list, str):
- try:
- chunks_list = json.loads(chunks_list)
- except json.JSONDecodeError:
- chunks_list = []
- metadata = element.get("metadata", {})
- if isinstance(metadata, str):
- try:
- metadata = json.loads(metadata)
- except json.JSONDecodeError:
- metadata = {}
- if not isinstance(metadata, dict):
- metadata = {}
- file_path = element.get("file_path") or "no-file-path"
- docs[element["id"]] = DocProcessingStatus(
- content_summary=element["content_summary"],
- content_length=element["content_length"],
- status=element["status"],
- created_at=self._format_datetime_with_timezone(
- element["created_at"]
- ),
- updated_at=self._format_datetime_with_timezone(
- element["updated_at"]
- ),
- chunks_count=element["chunks_count"],
- file_path=file_path,
- chunks_list=chunks_list,
- metadata=metadata,
- error_msg=element.get("error_msg"),
- track_id=element.get("track_id"),
- content_hash=element.get("content_hash"),
- )
- except (KeyError, TypeError) as e:
- doc_id_hint = element.get("id", "<unknown>") if element else "<unknown>"
- logger.error(
- f"[{self.workspace}] Skipping document '{doc_id_hint}' — "
- f"required field missing or wrong type while parsing DB row: {e!r}"
- )
- continue
- return docs
- async def get_docs_by_track_id(
- self, track_id: str
- ) -> dict[str, DocProcessingStatus]:
- """Get all documents with a specific track_id"""
- sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and track_id=$2"
- params = {"workspace": self.workspace, "track_id": track_id}
- result = await self.db.query(sql, list(params.values()), True)
- docs_by_track_id = {}
- for element in result:
- # Parse chunks_list JSON string back to list
- chunks_list = element.get("chunks_list", [])
- if isinstance(chunks_list, str):
- try:
- chunks_list = json.loads(chunks_list)
- except json.JSONDecodeError:
- chunks_list = []
- # Parse metadata JSON string back to dict
- metadata = element.get("metadata", {})
- if isinstance(metadata, str):
- try:
- metadata = json.loads(metadata)
- except json.JSONDecodeError:
- metadata = {}
- # Ensure metadata is a dict
- if not isinstance(metadata, dict):
- metadata = {}
- # Safe handling for file_path
- file_path = element.get("file_path")
- if file_path is None:
- file_path = "no-file-path"
- # Convert datetime objects to ISO format strings with timezone info
- created_at = self._format_datetime_with_timezone(element["created_at"])
- updated_at = self._format_datetime_with_timezone(element["updated_at"])
- docs_by_track_id[element["id"]] = DocProcessingStatus(
- content_summary=element["content_summary"],
- content_length=element["content_length"],
- status=element["status"],
- created_at=created_at,
- updated_at=updated_at,
- chunks_count=element["chunks_count"],
- file_path=file_path,
- chunks_list=chunks_list,
- track_id=element.get("track_id"),
- metadata=metadata,
- error_msg=element.get("error_msg"),
- content_hash=element.get("content_hash"),
- )
- return docs_by_track_id
- async def get_docs_paginated(
- self,
- status_filter: DocStatus | None = None,
- status_filters: list[DocStatus] | None = None,
- page: int = 1,
- page_size: int = 50,
- sort_field: str = "updated_at",
- sort_direction: str = "desc",
- ) -> tuple[list[tuple[str, DocProcessingStatus]], int]:
- """Get documents with pagination support
- Args:
- status_filter: Filter by document status, None for all statuses
- page: Page number (1-based)
- page_size: Number of documents per page (10-200)
- sort_field: Field to sort by ('created_at', 'updated_at', 'id')
- sort_direction: Sort direction ('asc' or 'desc')
- Returns:
- Tuple of (list of (doc_id, DocProcessingStatus) tuples, total_count)
- """
- start = time.perf_counter()
- status_filter_values = self.resolve_status_filter_values(
- status_filter=status_filter,
- status_filters=status_filters,
- )
- status_filter_value = status_filter.value if status_filter is not None else None
- performance_timing_log(
- "[%s] PGDocStatusStorage.get_docs_paginated start status_filter=%s page=%s page_size=%s sort_field=%s sort_direction=%s",
- self.workspace,
- status_filter_value,
- page,
- page_size,
- sort_field,
- sort_direction,
- )
- # Validate parameters
- if page < 1:
- page = 1
- if page_size < 10:
- page_size = 10
- elif page_size > 200:
- page_size = 200
- # Whitelist validation for sort_field to prevent SQL injection
- allowed_sort_fields = {"created_at", "updated_at", "id", "file_path"}
- if sort_field not in allowed_sort_fields:
- sort_field = "updated_at"
- # Whitelist validation for sort_direction to prevent SQL injection
- if sort_direction.lower() not in ["asc", "desc"]:
- sort_direction = "desc"
- else:
- sort_direction = sort_direction.lower()
- # Calculate offset
- offset = (page - 1) * page_size
- # Build parameterized query components
- params = {"workspace": self.workspace}
- param_count = 1
- # Build WHERE clause with parameterized query
- if status_filter_values is not None:
- param_count += 1
- where_clause = "WHERE workspace=$1 AND status = ANY($2)"
- params["status_filters"] = sorted(status_filter_values)
- else:
- where_clause = "WHERE workspace=$1"
- # Build ORDER BY clause using validated whitelist values.
- # NULLS LAST is applied in both the inner paged CTE and the outer query so
- # that the LIMIT/OFFSET slice boundary and the display order are identical.
- # Without it, DESC defaults to NULLS FIRST: nulls land on earlier pages but
- # are re-sorted to the end by the outer ORDER BY, dropping non-null rows.
- order_clause = f"ORDER BY {sort_field} {sort_direction.upper()} NULLS LAST"
- # Two-CTE query: total count + page data in a single round-trip.
- #
- # COUNT(*) OVER () was replaced because when the LIMIT/OFFSET clause yields
- # no rows (out-of-range page), there are no result rows to carry the window
- # function value — so total_count would not appear in the output at all,
- # making it impossible to distinguish "0 matching documents" from "non-empty
- # result set, page is past the end".
- #
- # The LEFT JOIN pattern fixes this: the `total` CTE always produces exactly
- # one row (the aggregate count over the full WHERE clause), and the outer
- # LEFT JOIN emits that one row even when `paged` is empty. Python then
- # skips rows where id IS NULL (the empty-page sentinel).
- #
- # chunks_list is intentionally excluded from the paged CTE SELECT list:
- # DocStatusResponse does not expose it, so transferring the full JSONB array
- # would be pure overhead. The chunks_list=[] in the constructor below is
- # intentional — see the paged CTE column list above.
- params["limit"] = page_size
- params["offset"] = offset
- cte_sql = f"""
- WITH total AS (
- SELECT COUNT(*) AS _total_count
- FROM LIGHTRAG_DOC_STATUS
- {where_clause}
- ),
- paged AS (
- SELECT id, workspace, content_summary, content_length, chunks_count,
- status, file_path, track_id, metadata, error_msg, content_hash,
- created_at, updated_at
- FROM LIGHTRAG_DOC_STATUS
- {where_clause}
- {order_clause}
- LIMIT ${param_count + 1} OFFSET ${param_count + 2}
- )
- SELECT p.*, t._total_count
- FROM total t
- LEFT JOIN paged p ON true
- ORDER BY p.{sort_field} {sort_direction.upper()} NULLS LAST
- """
- query_timing_label = f"{self.workspace} PGDocStatusStorage.get_docs_paginated"
- result = await self.db.query(
- cte_sql,
- list(params.values()),
- True,
- timing_label=query_timing_label,
- )
- total_count = result[0]["_total_count"] if result else 0
- # Convert to (doc_id, DocProcessingStatus) tuples
- documents = []
- for element in result:
- if element["id"] is None:
- # Empty-page sentinel row from LEFT JOIN when paged has no rows.
- continue
- doc_id = element["id"]
- # Parse metadata JSON string back to dict
- metadata = element.get("metadata", {})
- if isinstance(metadata, str):
- try:
- metadata = json.loads(metadata)
- except json.JSONDecodeError:
- metadata = {}
- # Convert datetime objects to ISO format strings with timezone info
- created_at = self._format_datetime_with_timezone(element["created_at"])
- updated_at = self._format_datetime_with_timezone(element["updated_at"])
- doc_status = DocProcessingStatus(
- content_summary=element["content_summary"],
- content_length=element["content_length"],
- status=element["status"],
- created_at=created_at,
- updated_at=updated_at,
- chunks_count=element["chunks_count"],
- file_path=element["file_path"],
- chunks_list=[], # not fetched: unused by pagination response
- track_id=element.get("track_id"),
- metadata=metadata,
- error_msg=element.get("error_msg"),
- content_hash=element.get("content_hash"),
- )
- documents.append((doc_id, doc_status))
- elapsed = time.perf_counter() - start
- performance_timing_log(
- "[%s] PGDocStatusStorage.get_docs_paginated completed in %.4fs returned_rows=%s total_count=%s status_filter=%s page=%s page_size=%s sort_field=%s sort_direction=%s",
- self.workspace,
- elapsed,
- len(documents),
- total_count,
- status_filter_value,
- page,
- page_size,
- sort_field,
- sort_direction,
- )
- return documents, total_count
- async def get_all_status_counts(self) -> dict[str, int]:
- """Get counts of documents in each status for all documents
- Returns:
- Dictionary mapping status names to counts, including 'all' field
- """
- start = time.perf_counter()
- performance_timing_log(
- "[%s] PGDocStatusStorage.get_all_status_counts start", self.workspace
- )
- sql = """
- SELECT status, COUNT(*) as count
- FROM LIGHTRAG_DOC_STATUS
- WHERE workspace=$1
- GROUP BY status
- """
- params = {"workspace": self.workspace}
- query_timing_label = (
- f"{self.workspace} PGDocStatusStorage.get_all_status_counts"
- )
- result = await self.db.query(
- sql,
- list(params.values()),
- True,
- timing_label=query_timing_label,
- )
- counts = {}
- total_count = 0
- for row in result:
- counts[row["status"]] = row["count"]
- total_count += row["count"]
- # Add 'all' field with total count
- counts["all"] = total_count
- elapsed = time.perf_counter() - start
- performance_timing_log(
- "[%s] PGDocStatusStorage.get_all_status_counts completed in %.4fs counts=%s",
- self.workspace,
- elapsed,
- counts,
- )
- return counts
- async def index_done_callback(self) -> None:
- # PG handles persistence automatically
- pass
- async def is_empty(self) -> bool:
- """Check if the storage is empty for the current workspace and namespace
- Returns:
- bool: True if storage is empty, False otherwise
- """
- table_name = namespace_to_table_name(self.namespace)
- if not table_name:
- logger.error(
- f"[{self.workspace}] Unknown namespace for is_empty check: {self.namespace}"
- )
- return True
- sql = f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE workspace=$1 LIMIT 1) as has_data"
- try:
- result = await self.db.query(sql, [self.workspace])
- return not result.get("has_data", False) if result else True
- except Exception as e:
- logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}")
- return True
- async def delete(self, ids: list[str]) -> None:
- """Delete specific records from storage by their IDs
- Args:
- ids (list[str]): List of document IDs to be deleted from storage
- Returns:
- None
- """
- if not ids:
- return
- table_name = namespace_to_table_name(self.namespace)
- if not table_name:
- logger.error(
- f"[{self.workspace}] Unknown namespace for deletion: {self.namespace}"
- )
- return
- delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
- try:
- await self.db.execute(delete_sql, {"workspace": self.workspace, "ids": ids})
- logger.debug(
- f"[{self.workspace}] Successfully deleted {len(ids)} records from {self.namespace}"
- )
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error while deleting records from {self.namespace}: {e}"
- )
- async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
- """Update or insert document status
- Args:
- data: dictionary of document IDs and their status data
- """
- logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
- if not data:
- return
- timing_label = f"{self.workspace} PGDocStatusStorage.upsert"
- total_start = time.perf_counter()
- performance_timing_log(
- "[%s] start records=%s",
- timing_label,
- len(data),
- )
- # NOTE: content_hash uses COALESCE(NULLIF(...,''), existing) rather than
- # a straight EXCLUDED overwrite. This gives write-once-after-set
- # semantics: once a non-empty content_hash is recorded, subsequent
- # upserts that omit it (or pass '' / NULL) will NOT clear it. Required
- # because pipeline state transitions (e.g. processing -> processed)
- # reuse the existing DocProcessingStatus payload without re-supplying
- # the hash, while _persist_parsed_full_docs patches the hash in a
- # separate upsert.
- #
- # This is a deliberate behavioral divergence from JsonDocStatusStorage,
- # which overwrites unconditionally. No caller today wants to clear a
- # content_hash, so the divergence is invisible — but if that ever
- # changes, this guard must be revisited.
- sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content_summary,content_length,chunks_count,status,file_path,chunks_list,track_id,metadata,error_msg,content_hash,created_at,updated_at)
- values($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14)
- on conflict(id,workspace) do update set
- content_summary = EXCLUDED.content_summary,
- content_length = EXCLUDED.content_length,
- chunks_count = EXCLUDED.chunks_count,
- status = EXCLUDED.status,
- file_path = EXCLUDED.file_path,
- chunks_list = EXCLUDED.chunks_list,
- track_id = EXCLUDED.track_id,
- metadata = EXCLUDED.metadata,
- error_msg = EXCLUDED.error_msg,
- content_hash = COALESCE(
- NULLIF(EXCLUDED.content_hash, ''),
- LIGHTRAG_DOC_STATUS.content_hash
- ),
- created_at = EXCLUDED.created_at,
- updated_at = EXCLUDED.updated_at"""
- # Tuple order must match SQL: (workspace, id, content_summary, content_length,
- # chunks_count, status, file_path, chunks_list, track_id, metadata,
- # error_msg, content_hash, created_at, updated_at)
- batch: list[tuple] = []
- skipped: list[str] = []
- batch_build_start = time.perf_counter()
- for i, (k, v) in enumerate(data.items(), start=1):
- try:
- batch.append(
- (
- self.workspace,
- k,
- v["content_summary"],
- v["content_length"],
- v.get("chunks_count", -1),
- v["status"],
- v["file_path"],
- json.dumps(v.get("chunks_list", [])),
- v.get("track_id"),
- json.dumps(v.get("metadata", {})),
- v.get("error_msg"),
- v.get("content_hash"),
- _parse_doc_status_datetime(
- v.get("created_at"),
- f"[{self.workspace}] doc {k} created_at",
- ),
- _parse_doc_status_datetime(
- v.get("updated_at"),
- f"[{self.workspace}] doc {k} updated_at",
- ),
- )
- )
- except (KeyError, TypeError, ValueError) as e:
- logger.error(
- f"[{self.workspace}] Skipping document '{k}' in batch upsert — "
- f"invalid or missing field: {e!r}"
- )
- skipped.append(k)
- await _cooperative_yield(i)
- if skipped:
- logger.warning(
- f"[{self.workspace}] {len(skipped)} document(s) skipped in batch upsert: {skipped}"
- )
- performance_timing_log(
- "[%s] batch validation/assembly completed in %.4fs valid_count=%s skipped_count=%s",
- timing_label,
- time.perf_counter() - batch_build_start,
- len(batch),
- len(skipped),
- )
- async def _batch_upsert(
- connection: asyncpg.Connection,
- _sql: str = sql,
- _data: list[tuple] = batch,
- ) -> None:
- execute_start = time.perf_counter()
- async with connection.transaction():
- await connection.executemany(_sql, _data)
- performance_timing_log(
- "[%s] transaction + executemany completed in %.4fs batch_size=%s",
- timing_label,
- time.perf_counter() - execute_start,
- len(_data),
- )
- await self.db._run_with_retry(_batch_upsert, timing_label=timing_label)
- logger.debug(
- f"[{self.workspace}] Batch upserted {len(batch)} records to {self.namespace}"
- )
- performance_timing_log(
- "[%s] total complete in %.4fs valid_count=%s skipped_count=%s",
- timing_label,
- time.perf_counter() - total_start,
- len(batch),
- len(skipped),
- )
- async def drop(self) -> dict[str, str]:
- """Drop the storage"""
- try:
- table_name = namespace_to_table_name(self.namespace)
- if not table_name:
- return {
- "status": "error",
- "message": f"Unknown namespace: {self.namespace}",
- }
- drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
- table_name=table_name
- )
- await self.db.execute(drop_sql, {"workspace": self.workspace})
- return {"status": "success", "message": "data dropped"}
- except Exception as e:
- return {"status": "error", "message": str(e)}
- class PGGraphQueryException(Exception):
- """Exception for the AGE queries."""
- def __init__(self, exception: Union[str, dict[str, Any]]) -> None:
- if isinstance(exception, dict):
- self.message = exception["message"] if "message" in exception else "unknown"
- self.details = exception["details"] if "details" in exception else "unknown"
- else:
- self.message = exception
- self.details = "unknown"
- def get_message(self) -> str:
- return self.message
- def get_details(self) -> Any:
- return self.details
- def _is_transient_graph_write_error(exc: BaseException) -> bool:
- """Return True when a PGGraphQueryException wraps a transient write-time error.
- The inner _run_with_retry already handles connection-level transient errors
- (pool reset, TCP failures, etc.). This predicate covers query-level transient
- errors that survive the connection layer and surface as PGGraphQueryException:
- deadlocks, serialization conflicts, and lock-acquisition timeouts that can
- occur under concurrent document ingestion.
- """
- if not isinstance(exc, PGGraphQueryException):
- return False
- cause = exc.__cause__
- if cause is None:
- return False
- return isinstance(
- cause,
- (
- asyncpg.exceptions.DeadlockDetectedError,
- asyncpg.exceptions.SerializationError,
- asyncpg.exceptions.LockNotAvailableError,
- asyncpg.exceptions.QueryCanceledError,
- ),
- )
- @final
- @dataclass
- class PGGraphStorage(BaseGraphStorage):
- def __post_init__(self):
- # Graph name will be dynamically generated in initialize() based on workspace
- self.db: PostgreSQLDB | None = None
- def _get_workspace_graph_name(self) -> str:
- """
- Generate graph name based on workspace and namespace for data isolation.
- Rules:
- - If workspace is empty or "default": graph_name = namespace
- - If workspace has other value: graph_name = workspace_namespace
- Args:
- None
- Returns:
- str: The graph name for the current workspace
- """
- workspace = self.workspace
- namespace = self.namespace
- if workspace and workspace.strip() and workspace.strip().lower() != "default":
- # Ensure names comply with PostgreSQL identifier specifications
- safe_workspace = re.sub(r"[^a-zA-Z0-9_]", "_", workspace.strip())
- safe_namespace = re.sub(r"[^a-zA-Z0-9_]", "_", namespace)
- return f"{safe_workspace}_{safe_namespace}"
- else:
- # When the workspace is "default", use the namespace directly (for backward compatibility with legacy implementations)
- return re.sub(r"[^a-zA-Z0-9_]", "_", namespace)
- @staticmethod
- def _normalize_node_id(node_id: str) -> str:
- """
- Normalize node ID to ensure special characters are properly handled in Cypher queries.
- Used by write paths that still embed entity IDs in Cypher strings
- (delete_node, remove_nodes, remove_edges). The upsert paths now use
- parameterized Cypher instead.
- Within a Cypher double-quoted string the only recognised escape
- sequences are ``\\"`` and ``\\\\``. We also strip null bytes which
- could truncate the string in some PostgreSQL/AGE code paths.
- Args:
- node_id: The original node ID
- Returns:
- Normalized node ID suitable for embedding in a Cypher double-quoted string
- """
- # Strip null bytes that could truncate the string
- normalized_id = node_id.replace("\x00", "")
- # Escape backslashes first (order matters)
- normalized_id = normalized_id.replace("\\", "\\\\")
- # Escape double quotes
- normalized_id = normalized_id.replace('"', '\\"')
- return normalized_id
- async def initialize(self):
- async with get_data_init_lock():
- if self.db is None:
- self.db = await ClientManager.get_client(
- vector_storage=self.global_config.get("vector_storage")
- )
- # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
- if self.db.workspace:
- # Use PostgreSQLDB's workspace (highest priority)
- logger.info(
- f"Using PG_WORKSPACE environment variable: '{self.db.workspace}' (overriding '{self.workspace}/{self.namespace}')"
- )
- self.workspace = self.db.workspace
- elif hasattr(self, "workspace") and self.workspace:
- # Use storage class's workspace (medium priority)
- pass
- else:
- # Use "default" for compatibility (lowest priority)
- self.workspace = "default"
- # Dynamically generate graph name based on workspace
- self.graph_name = self._get_workspace_graph_name()
- # Log the graph initialization for debugging
- logger.info(
- f"[{self.workspace}] PostgreSQL Graph initialized: graph_name='{self.graph_name}'"
- )
- # Create AGE extension and configure graph environment once at initialization
- # Use _run_with_retry so transient connection errors are retried and pool=None
- # is handled safely (unlike a bare pool.acquire() call).
- async def _do_configure_age_extension(
- connection: asyncpg.Connection,
- ) -> None:
- await PostgreSQLDB.configure_age_extension(connection)
- await self.db._run_with_retry(_do_configure_age_extension)
- # Execute each statement separately and ignore errors
- queries = [
- f"SELECT create_graph('{self.graph_name}')",
- f"SELECT create_vlabel('{self.graph_name}', 'base');",
- f"SELECT create_elabel('{self.graph_name}', 'DIRECTED');",
- # f'CREATE INDEX CONCURRENTLY vertex_p_idx ON {self.graph_name}."_ag_label_vertex" (id)',
- f'CREATE INDEX CONCURRENTLY vertex_idx_node_id ON {self.graph_name}."_ag_label_vertex" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))',
- # f'CREATE INDEX CONCURRENTLY edge_p_idx ON {self.graph_name}."_ag_label_edge" (id)',
- f'CREATE INDEX CONCURRENTLY edge_sid_idx ON {self.graph_name}."_ag_label_edge" (start_id)',
- f'CREATE INDEX CONCURRENTLY edge_eid_idx ON {self.graph_name}."_ag_label_edge" (end_id)',
- f'CREATE INDEX CONCURRENTLY edge_seid_idx ON {self.graph_name}."_ag_label_edge" (start_id,end_id)',
- f'CREATE INDEX CONCURRENTLY directed_p_idx ON {self.graph_name}."DIRECTED" (id)',
- f'CREATE INDEX CONCURRENTLY directed_eid_idx ON {self.graph_name}."DIRECTED" (end_id)',
- f'CREATE INDEX CONCURRENTLY directed_sid_idx ON {self.graph_name}."DIRECTED" (start_id)',
- f'CREATE INDEX CONCURRENTLY directed_seid_idx ON {self.graph_name}."DIRECTED" (start_id,end_id)',
- f'CREATE INDEX CONCURRENTLY entity_p_idx ON {self.graph_name}."base" (id)',
- f'CREATE INDEX CONCURRENTLY entity_idx_node_id ON {self.graph_name}."base" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))',
- f'CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON {self.graph_name}."base" using gin(properties)',
- f'ALTER TABLE {self.graph_name}."DIRECTED" CLUSTER ON directed_sid_idx',
- ]
- for query in queries:
- # Use the new flag to silently ignore "already exists" errors
- # at the source, preventing log spam.
- await self.db.execute(
- query,
- upsert=True,
- ignore_if_exists=True, # Pass the new flag
- with_age=True,
- graph_name=self.graph_name,
- )
- async def finalize(self):
- if self.db is not None:
- await ClientManager.release_client(self.db)
- self.db = None
- async def index_done_callback(self) -> None:
- # PG handles persistence automatically
- pass
- @staticmethod
- def _record_to_dict(record: asyncpg.Record) -> dict[str, Any]:
- """
- Convert a record returned from an age query to a dictionary
- Args:
- record (): a record from an age query result
- Returns:
- dict[str, Any]: a dictionary representation of the record where
- the dictionary key is the field name and the value is the
- value converted to a python type
- """
- @staticmethod
- def parse_agtype_string(agtype_str: str) -> tuple[str, str]:
- """
- Parse agtype string precisely, separating JSON content and type identifier
- Args:
- agtype_str: String like '{"json": "content"}::vertex'
- Returns:
- (json_content, type_identifier)
- """
- if not isinstance(agtype_str, str) or "::" not in agtype_str:
- return agtype_str, ""
- # Find the last :: from the right, which is the start of type identifier
- last_double_colon = agtype_str.rfind("::")
- if last_double_colon == -1:
- return agtype_str, ""
- # Separate JSON content and type identifier
- json_content = agtype_str[:last_double_colon]
- type_identifier = agtype_str[last_double_colon + 2 :]
- return json_content, type_identifier
- @staticmethod
- def safe_json_parse(json_str: str, context: str = "") -> dict:
- """
- Safe JSON parsing with simplified error logging
- """
- try:
- return json.loads(json_str)
- except json.JSONDecodeError as e:
- logger.error(f"JSON parsing failed ({context}): {e}")
- logger.error(f"Raw data (first 100 chars): {repr(json_str[:100])}")
- logger.error(f"Error position: line {e.lineno}, column {e.colno}")
- return None
- # result holder
- d = {}
- # prebuild a mapping of vertex_id to vertex mappings to be used
- # later to build edges
- vertices = {}
- # First pass: preprocess vertices
- for k in record.keys():
- v = record[k]
- if isinstance(v, str) and "::" in v:
- if v.startswith("[") and v.endswith("]"):
- # Handle vertex arrays
- json_content, type_id = parse_agtype_string(v)
- if type_id == "vertex":
- vertexes = safe_json_parse(
- json_content, f"vertices array for {k}"
- )
- if vertexes:
- for vertex in vertexes:
- vertices[vertex["id"]] = vertex.get("properties")
- else:
- # Handle single vertex
- json_content, type_id = parse_agtype_string(v)
- if type_id == "vertex":
- vertex = safe_json_parse(json_content, f"single vertex for {k}")
- if vertex:
- vertices[vertex["id"]] = vertex.get("properties")
- # Second pass: process all fields
- for k in record.keys():
- v = record[k]
- if isinstance(v, str) and "::" in v:
- if v.startswith("[") and v.endswith("]"):
- # Handle array types
- json_content, type_id = parse_agtype_string(v)
- if type_id in ["vertex", "edge"]:
- parsed_data = safe_json_parse(
- json_content, f"array {type_id} for field {k}"
- )
- d[k] = parsed_data if parsed_data is not None else None
- else:
- logger.warning(f"Unknown array type: {type_id}")
- d[k] = None
- else:
- # Handle single objects
- json_content, type_id = parse_agtype_string(v)
- if type_id in ["vertex", "edge"]:
- parsed_data = safe_json_parse(
- json_content, f"single {type_id} for field {k}"
- )
- d[k] = parsed_data if parsed_data is not None else None
- else:
- # May be other types of agtype data, keep as is
- d[k] = v
- else:
- d[k] = v # Keep as string
- return d
- @staticmethod
- def _format_properties(
- properties: dict[str, Any], _id: Union[str, None] = None
- ) -> str:
- """
- Convert a dictionary of properties to a string representation that
- can be used in a cypher query insert/merge statement.
- Args:
- properties (dict[str,str]): a dictionary containing node/edge properties
- _id (Union[str, None]): the id of the node or None if none exists
- Returns:
- str: the properties dictionary as a properly formatted string
- """
- props = []
- # Wrap property keys in backticks and escape embedded backticks to
- # preserve the Cypher structure when property names contain specials.
- for k, v in properties.items():
- safe_key = str(k).replace("`", "``")
- prop = f"`{safe_key}`: {json.dumps(v, ensure_ascii=False)}"
- props.append(prop)
- if _id is not None and "id" not in properties:
- props.append(
- f"id: {json.dumps(_id, ensure_ascii=False)}"
- if isinstance(_id, str)
- else f"id: {_id}"
- )
- return "{" + ", ".join(props) + "}"
- async def _query(
- self,
- query: str,
- readonly: bool = True,
- upsert: bool = False,
- params: dict[str, Any] | None = None,
- timing_label: str | None = None,
- ) -> list[dict[str, Any]]:
- """
- Query the graph by taking a cypher query, converting it to an
- age compatible query, executing it and converting the result
- Args:
- query (str): a cypher query to be executed
- readonly (bool): if True, uses db.query; if False, uses db.execute.
- Both paths support the ``params`` argument.
- upsert (bool): passed through to db.execute for write operations.
- params (dict | None): AGE agtype parameters for parameterized Cypher
- (e.g. ``{"params": json.dumps({"entity_id": "..."})}``).
- Honoured for both read and write paths.
- timing_label (str | None): optional label for performance logging.
- Returns:
- list[dict[str, Any]]: a list of dictionaries containing the result set
- """
- try:
- if readonly:
- data = await self.db.query(
- query,
- list(params.values()) if params else None,
- multirows=True,
- with_age=True,
- graph_name=self.graph_name,
- timing_label=timing_label,
- )
- else:
- age_execute_start = time.perf_counter()
- data = await self.db.execute(
- query,
- data=params,
- upsert=upsert,
- with_age=True,
- graph_name=self.graph_name,
- timing_label=timing_label,
- )
- if timing_label:
- performance_timing_log(
- "[%s] AGE execute completed in %.4fs",
- timing_label,
- time.perf_counter() - age_execute_start,
- )
- except Exception as e:
- if timing_label and not readonly:
- performance_timing_log(
- "[%s] AGE execute failed after %.4fs",
- timing_label,
- time.perf_counter() - age_execute_start,
- )
- raise PGGraphQueryException(
- {
- "message": f"Error executing graph query: {query}",
- "wrapped": query,
- "detail": repr(e),
- "error_type": e.__class__.__name__,
- }
- ) from e
- if data is None:
- result = []
- # decode records
- else:
- result = [self._record_to_dict(d) for d in data]
- return result
- async def has_node(self, node_id: str) -> bool:
- query = f"""
- SELECT EXISTS (
- SELECT 1
- FROM {self.graph_name}.base
- WHERE ag_catalog.agtype_access_operator(
- VARIADIC ARRAY[properties, '"entity_id"'::agtype]
- ) = (to_json($1::text)::text)::agtype
- LIMIT 1
- ) AS node_exists;
- """
- params = {"node_id": node_id}
- row = (await self._query(query, params=params))[0]
- return bool(row["node_exists"])
- async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
- query = f"""
- WITH a AS (
- SELECT id AS vid
- FROM {self.graph_name}.base
- WHERE ag_catalog.agtype_access_operator(
- VARIADIC ARRAY[properties, '"entity_id"'::agtype]
- ) = (to_json($1::text)::text)::agtype
- ),
- b AS (
- SELECT id AS vid
- FROM {self.graph_name}.base
- WHERE ag_catalog.agtype_access_operator(
- VARIADIC ARRAY[properties, '"entity_id"'::agtype]
- ) = (to_json($2::text)::text)::agtype
- )
- SELECT EXISTS (
- SELECT 1
- FROM {self.graph_name}."DIRECTED" d
- JOIN a ON d.start_id = a.vid
- JOIN b ON d.end_id = b.vid
- LIMIT 1
- )
- OR EXISTS (
- SELECT 1
- FROM {self.graph_name}."DIRECTED" d
- JOIN a ON d.end_id = a.vid
- JOIN b ON d.start_id = b.vid
- LIMIT 1
- ) AS edge_exists;
- """
- params = {
- "source_node_id": source_node_id,
- "target_node_id": target_node_id,
- }
- row = (await self._query(query, params=params))[0]
- return bool(row["edge_exists"])
- async def get_node(self, node_id: str) -> dict[str, str] | None:
- """Get node by its label identifier, return only node properties"""
- result = await self.get_nodes_batch(node_ids=[node_id])
- if result and node_id in result:
- return result[node_id]
- return None
- async def node_degree(self, node_id: str) -> int:
- result = await self.node_degrees_batch(node_ids=[node_id])
- if result and node_id in result:
- return result[node_id]
- return 0
- async def edge_degree(self, src_id: str, tgt_id: str) -> int:
- result = await self.edge_degrees_batch(edges=[(src_id, tgt_id)])
- if result and (src_id, tgt_id) in result:
- return result[(src_id, tgt_id)]
- return 0
- async def get_edge(
- self, source_node_id: str, target_node_id: str
- ) -> dict[str, str] | None:
- """Get edge properties between two nodes"""
- result = await self.get_edges_batch(
- [{"src": source_node_id, "tgt": target_node_id}]
- )
- if result and (source_node_id, target_node_id) in result:
- return result[(source_node_id, target_node_id)]
- return None
- async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
- """
- Retrieves all edges (relationships) for a particular node identified by its label.
- :return: list of dictionaries containing edge information
- """
- cypher_query = """MATCH (n:base {entity_id: $entity_id})
- OPTIONAL MATCH (n)-[]-(connected:base)
- RETURN n.entity_id AS source_id, connected.entity_id AS connected_id"""
- query = f"SELECT * FROM cypher({_dollar_quote(self.graph_name)}::name, {_dollar_quote(cypher_query)}::cstring, $1::agtype) AS (source_id text, connected_id text)"
- pg_params = {
- "params": json.dumps({"entity_id": source_node_id}, ensure_ascii=False)
- }
- results = await self._query(query, params=pg_params)
- edges = []
- for record in results:
- source_id = record["source_id"]
- connected_id = record["connected_id"]
- if source_id and connected_id:
- edges.append((source_id, connected_id))
- return edges
- @retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=10),
- retry=retry_if_exception(_is_transient_graph_write_error),
- reraise=True,
- )
- async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
- """
- Upsert a node in the Neo4j database.
- Args:
- node_id: The unique identifier for the node (used as label)
- node_data: Dictionary of node properties
- """
- if "entity_id" not in node_data:
- raise ValueError(
- "PostgreSQL: node properties must contain an 'entity_id' field"
- )
- # AGE supports binding scalar values in Cypher parameters here, but not
- # using a bound agtype object on ``SET n += $props`` (verified on AGE 1.5.0).
- # Keep the node ID parameterized and inline a safely escaped property map literal.
- node_props = {k: v for k, v in node_data.items() if k != "entity_id"}
- props_literal = self._format_properties(node_props)
- cypher_query = f"""MERGE (n:base {{entity_id: $entity_id}})
- SET n += {props_literal}
- RETURN n"""
- query = (
- f"SELECT * FROM cypher("
- f"{_dollar_quote(self.graph_name)}::name, "
- f"{_dollar_quote(cypher_query)}::cstring, "
- f"$1::agtype) AS (n agtype)"
- )
- pg_params = {
- "params": json.dumps(
- {"entity_id": node_id},
- ensure_ascii=False,
- )
- }
- timing_label = f"{self.workspace} PGGraphStorage.upsert_node"
- total_start = time.perf_counter()
- performance_timing_log(
- "[%s] start node_id=%s",
- timing_label,
- node_id,
- )
- try:
- await self._query(
- query,
- readonly=False,
- upsert=True,
- params=pg_params,
- timing_label=timing_label,
- )
- performance_timing_log(
- "[%s] total complete in %.4fs node_id=%s",
- timing_label,
- time.perf_counter() - total_start,
- node_id,
- )
- except Exception:
- performance_timing_log(
- "[%s] total failed after %.4fs node_id=%s",
- timing_label,
- time.perf_counter() - total_start,
- node_id,
- )
- logger.error(
- f"[{self.workspace}] POSTGRES, upsert_node error on node_id: `{node_id}`"
- )
- raise
- @retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=10),
- retry=retry_if_exception(_is_transient_graph_write_error),
- reraise=True,
- )
- async def upsert_edge(
- self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
- ) -> None:
- """
- Upsert an edge and its properties between two nodes identified by their labels.
- Args:
- source_node_id (str): Label of the source node (used as identifier)
- target_node_id (str): Label of the target node (used as identifier)
- edge_data (dict): dictionary of properties to set on the edge
- """
- # AGE does not support binding a full agtype map in ``SET r += $props``
- # (verified on AGE 1.5.0), and the inlined literal form ``SET r += {map}``
- # is also silently ignored for edges (though it works for nodes). Individual
- # ``SET r.key = value`` assignments run without error but also do not persist.
- # The only reliable way to write edge properties in AGE is to inline them
- # directly in a CREATE clause. We use OPTIONAL MATCH to delete any existing
- # edge first so the operation remains idempotent.
- #
- # Concurrency: OPTIONAL MATCH + DELETE + CREATE is not atomic against other
- # writers — two transactions upserting the same pair could both observe no
- # existing edge and both CREATE one, leaving duplicate DIRECTED rows that
- # inflate degree counts and duplicate relations. We serialise per logical
- # edge with a transaction-scoped advisory lock keyed on
- # (graph_name, ordered (src_id, tgt_id)) so:
- # - {A,B} and {B,A} collide on the same lock (the OPTIONAL MATCH is
- # undirected), and
- # - the same (A,B) pair in different AGE graphs / workspaces does NOT
- # collide. pg_advisory_xact_lock is database-wide, and we don't want
- # independent tenants to serialise each other's ingestion.
- # AGE refuses to plan a join against a cypher() call that contains a
- # CREATE clause ("cypher create clause cannot be rescanned"), so we cannot
- # use a CTE for the lock. Instead we open an explicit transaction and run
- # two statements on the same connection: the lock acquisition first, then
- # the cypher upsert. The lock is released when the transaction commits.
- props_literal = self._format_properties(edge_data) if edge_data else "{}"
- cypher_query = f"""MATCH (source:base {{entity_id: $src_id}})
- WITH source
- MATCH (target:base {{entity_id: $tgt_id}})
- WITH source, target
- OPTIONAL MATCH (source)-[old:DIRECTED]-(target)
- DELETE old
- WITH source, target
- CREATE (source)-[r:DIRECTED {props_literal}]->(target)
- RETURN r"""
- lock_sql = (
- "SELECT pg_advisory_xact_lock("
- " hashtextextended("
- " $1::text || E'\\x01' ||"
- " LEAST($2::text, $3::text) || E'\\x01' || GREATEST($2::text, $3::text),"
- " 0"
- " )"
- ")"
- )
- cypher_sql = (
- f"SELECT r FROM cypher("
- f"{_dollar_quote(self.graph_name)}::name, "
- f"{_dollar_quote(cypher_query)}::cstring, "
- f"$1::agtype) AS (r agtype)"
- )
- params_json = json.dumps(
- {"src_id": source_node_id, "tgt_id": target_node_id},
- ensure_ascii=False,
- )
- timing_label = f"{self.workspace} PGGraphStorage.upsert_edge"
- total_start = time.perf_counter()
- performance_timing_log(
- "[%s] start source_node_id=%s target_node_id=%s",
- timing_label,
- source_node_id,
- target_node_id,
- )
- async def _operation(connection: asyncpg.Connection) -> None:
- async with connection.transaction():
- await connection.execute(
- lock_sql, self.graph_name, source_node_id, target_node_id
- )
- await connection.execute(cypher_sql, params_json)
- try:
- await self.db._run_with_retry(
- _operation,
- with_age=True,
- graph_name=self.graph_name,
- timing_label=timing_label,
- )
- performance_timing_log(
- "[%s] total complete in %.4fs source_node_id=%s target_node_id=%s",
- timing_label,
- time.perf_counter() - total_start,
- source_node_id,
- target_node_id,
- )
- except Exception as e:
- performance_timing_log(
- "[%s] total failed after %.4fs source_node_id=%s target_node_id=%s",
- timing_label,
- time.perf_counter() - total_start,
- source_node_id,
- target_node_id,
- )
- logger.error(
- f"[{self.workspace}] POSTGRES, upsert_edge error on edge: `{source_node_id}`-`{target_node_id}`"
- )
- # Re-raise as PGGraphQueryException so the outer @retry's
- # _is_transient_graph_write_error predicate can inspect __cause__ and
- # retry on DeadlockDetectedError / SerializationError /
- # LockNotAvailableError / QueryCanceledError — mirrors what _query
- # does for upsert_node and the rest of the AGE write paths. Without
- # this wrapping, query-level transient errors from connection.execute
- # would surface as raw asyncpg exceptions, fail isinstance() in the
- # predicate, and skip retries.
- if isinstance(e, PGGraphQueryException):
- raise
- raise PGGraphQueryException(
- {
- "message": (
- f"Error executing graph upsert_edge: "
- f"`{source_node_id}`-`{target_node_id}`"
- ),
- "wrapped": cypher_sql,
- "detail": repr(e),
- "error_type": e.__class__.__name__,
- }
- ) from e
- async def upsert_nodes_batch(self, nodes: list[tuple[str, dict[str, str]]]) -> None:
- """Batch insert/update multiple nodes while preserving input-order semantics.
- PostgreSQL/AGE write paths embed properties directly in Cypher strings and do not
- yet support parameterized UNWIND. Deduplicating by node ID first preserves the
- last-write-wins behaviour of the historical serial fallback.
- Args:
- nodes: List of (node_id, node_data) tuples.
- """
- if not nodes:
- return
- deduped_nodes: dict[str, dict[str, str]] = {}
- for node_id, node_data in nodes:
- deduped_nodes.pop(node_id, None)
- deduped_nodes[node_id] = node_data
- for node_id, node_data in deduped_nodes.items():
- await self.upsert_node(node_id, node_data=node_data)
- async def has_nodes_batch(self, node_ids: list[str]) -> set[str]:
- """Check existence of multiple nodes using a single array-based SQL query.
- Args:
- node_ids: List of node IDs to check.
- Returns:
- Set of node_ids that exist in the graph.
- """
- if not node_ids:
- return set()
- result = await self.get_nodes_batch(node_ids)
- return set(result.keys())
- async def upsert_edges_batch(
- self, edges: list[tuple[str, str, dict[str, str]]]
- ) -> None:
- """Batch insert/update multiple edges while preserving input-order semantics.
- PostgreSQL/AGE relationships are undirected (`MERGE (source)-[r:DIRECTED]-(target)`),
- so batches containing reciprocal duplicates must retain the last update for each
- endpoint pair to match the historical serial fallback.
- Args:
- edges: List of (source_node_id, target_node_id, edge_data) tuples.
- """
- if not edges:
- return
- deduped_edges: dict[tuple[str, str], tuple[str, str, dict[str, str]]] = {}
- for src, tgt, edge_data in edges:
- edge_key = tuple(sorted((src, tgt)))
- deduped_edges.pop(edge_key, None)
- deduped_edges[edge_key] = (src, tgt, edge_data)
- # Iterate in canonical (LEAST, GREATEST) order rather than dict
- # insertion order. upsert_edge opens an independent transaction per
- # call and releases the advisory lock on commit, so this is not a
- # deadlock fix — but a deterministic iteration order makes logs and
- # replays reproducible across callers, and matches the dedup key
- # already used above.
- for edge_key in sorted(deduped_edges):
- src, tgt, edge_data = deduped_edges[edge_key]
- await self.upsert_edge(src, tgt, edge_data=edge_data)
- async def delete_node(self, node_id: str) -> None:
- """
- Delete a node from the graph.
- Args:
- node_id (str): The ID of the node to delete.
- """
- label = self._normalize_node_id(node_id)
- # Build Cypher query with dynamic dollar-quoting to handle entity_id containing $ sequences
- cypher_query = f"""MATCH (n:base {{entity_id: "{label}"}})
- DETACH DELETE n"""
- query = f"SELECT * FROM cypher({_dollar_quote(self.graph_name)}, {_dollar_quote(cypher_query)}) AS (n agtype)"
- try:
- await self._query(query, readonly=False)
- except Exception as e:
- logger.error(f"[{self.workspace}] Error during node deletion: {e}")
- raise
- async def remove_nodes(self, node_ids: list[str]) -> None:
- """
- Remove multiple nodes from the graph.
- Args:
- node_ids (list[str]): A list of node IDs to remove.
- """
- node_ids_normalized = [self._normalize_node_id(node_id) for node_id in node_ids]
- node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids_normalized])
- # Build Cypher query with dynamic dollar-quoting to handle entity_id containing $ sequences
- cypher_query = f"""MATCH (n:base)
- WHERE n.entity_id IN [{node_id_list}]
- DETACH DELETE n"""
- query = f"SELECT * FROM cypher({_dollar_quote(self.graph_name)}, {_dollar_quote(cypher_query)}) AS (n agtype)"
- try:
- await self._query(query, readonly=False)
- except Exception as e:
- logger.error(f"[{self.workspace}] Error during node removal: {e}")
- raise
- async def remove_edges(self, edges: list[tuple[str, str]]) -> None:
- """
- Remove multiple edges from the graph.
- Args:
- edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
- """
- for source, target in edges:
- src_label = self._normalize_node_id(source)
- tgt_label = self._normalize_node_id(target)
- # Build Cypher query with dynamic dollar-quoting to handle entity_id containing $ sequences
- cypher_query = f"""MATCH (a:base {{entity_id: "{src_label}"}})-[r]-(b:base {{entity_id: "{tgt_label}"}})
- DELETE r"""
- query = f"SELECT * FROM cypher({_dollar_quote(self.graph_name)}, {_dollar_quote(cypher_query)}) AS (r agtype)"
- try:
- await self._query(query, readonly=False)
- logger.debug(
- f"[{self.workspace}] Deleted edge from '{source}' to '{target}'"
- )
- except Exception as e:
- logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}")
- raise
- async def get_nodes_batch(
- self, node_ids: list[str], batch_size: int = 1000
- ) -> dict[str, dict]:
- """
- Retrieve multiple nodes in one query using UNWIND.
- Args:
- node_ids: List of node entity IDs to fetch.
- batch_size: Batch size for the query
- Returns:
- A dictionary mapping each node_id to its node data (or None if not found).
- """
- if not node_ids:
- return {}
- seen: set[str] = set()
- unique_ids: list[str] = []
- lookup: dict[str, str] = {}
- requested: set[str] = set()
- for nid in node_ids:
- if nid not in seen:
- seen.add(nid)
- unique_ids.append(nid)
- requested.add(nid)
- lookup[nid] = nid
- lookup[self._normalize_node_id(nid)] = nid
- # Build result dictionary
- nodes_dict = {}
- for i in range(0, len(unique_ids), batch_size):
- batch = unique_ids[i : i + batch_size]
- query = f"""
- WITH input(v, ord) AS (
- SELECT v, ord
- FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord)
- ),
- ids(node_id, ord) AS (
- SELECT (to_json(v)::text)::agtype AS node_id, ord
- FROM input
- )
- SELECT i.node_id::text AS node_id,
- b.properties
- FROM {self.graph_name}.base AS b
- JOIN ids i
- ON ag_catalog.agtype_access_operator(
- VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]
- ) = i.node_id
- ORDER BY i.ord;
- """
- results = await self._query(query, params={"ids": batch})
- for result in results:
- if result["node_id"] and result["properties"]:
- node_dict = result["properties"]
- # Process string result, parse it to JSON dictionary
- if isinstance(node_dict, str):
- try:
- node_dict = json.loads(node_dict)
- except json.JSONDecodeError:
- logger.warning(
- f"[{self.workspace}] Failed to parse node string in batch: {node_dict}"
- )
- node_key = result["node_id"]
- original_key = lookup.get(node_key)
- if original_key is None:
- logger.warning(
- f"[{self.workspace}] Node {node_key} not found in lookup map"
- )
- original_key = node_key
- if original_key in requested:
- nodes_dict[original_key] = node_dict
- return nodes_dict
- async def node_degrees_batch(
- self, node_ids: list[str], batch_size: int = 500
- ) -> dict[str, int]:
- """
- Retrieve the degree for multiple nodes in a single query using UNWIND.
- Calculates the total degree by counting distinct relationships.
- Uses separate queries for outgoing and incoming edges.
- Args:
- node_ids: List of node labels (entity_id values) to look up.
- batch_size: Batch size for the query
- Returns:
- A dictionary mapping each node_id to its degree (total number of relationships).
- If a node is not found, its degree will be set to 0.
- """
- if not node_ids:
- return {}
- seen: set[str] = set()
- unique_ids: list[str] = []
- lookup: dict[str, str] = {}
- requested: set[str] = set()
- for nid in node_ids:
- if nid not in seen:
- seen.add(nid)
- unique_ids.append(nid)
- requested.add(nid)
- lookup[nid] = nid
- lookup[self._normalize_node_id(nid)] = nid
- out_degrees = {}
- in_degrees = {}
- for i in range(0, len(unique_ids), batch_size):
- batch = unique_ids[i : i + batch_size]
- query = f"""
- WITH input(v, ord) AS (
- SELECT v, ord
- FROM unnest($1::text[]) WITH ORDINALITY AS t(v, ord)
- ),
- ids(node_id, ord) AS (
- SELECT (to_json(v)::text)::agtype AS node_id, ord
- FROM input
- ),
- vids AS (
- SELECT b.id AS vid, i.node_id, i.ord
- FROM {self.graph_name}.base AS b
- JOIN ids i
- ON ag_catalog.agtype_access_operator(
- VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]
- ) = i.node_id
- ),
- deg_out AS (
- SELECT d.start_id AS vid, COUNT(*)::bigint AS out_degree
- FROM {self.graph_name}."DIRECTED" AS d
- JOIN vids v ON v.vid = d.start_id
- GROUP BY d.start_id
- ),
- deg_in AS (
- SELECT d.end_id AS vid, COUNT(*)::bigint AS in_degree
- FROM {self.graph_name}."DIRECTED" AS d
- JOIN vids v ON v.vid = d.end_id
- GROUP BY d.end_id
- )
- SELECT v.node_id::text AS node_id,
- COALESCE(o.out_degree, 0) AS out_degree,
- COALESCE(n.in_degree, 0) AS in_degree
- FROM vids v
- LEFT JOIN deg_out o ON o.vid = v.vid
- LEFT JOIN deg_in n ON n.vid = v.vid
- ORDER BY v.ord;
- """
- combined_results = await self._query(query, params={"ids": batch})
- for row in combined_results:
- node_id = row["node_id"]
- if not node_id:
- continue
- node_key = node_id
- original_key = lookup.get(node_key)
- if original_key is None:
- logger.warning(
- f"[{self.workspace}] Node {node_key} not found in lookup map"
- )
- original_key = node_key
- if original_key in requested:
- out_degrees[original_key] = int(row.get("out_degree", 0) or 0)
- in_degrees[original_key] = int(row.get("in_degree", 0) or 0)
- degrees_dict = {}
- for node_id in node_ids:
- out_degree = out_degrees.get(node_id, 0)
- in_degree = in_degrees.get(node_id, 0)
- degrees_dict[node_id] = out_degree + in_degree
- return degrees_dict
- async def edge_degrees_batch(
- self, edges: list[tuple[str, str]]
- ) -> dict[tuple[str, str], int]:
- """
- Calculate the combined degree for each edge (sum of the source and target node degrees)
- in batch using the already implemented node_degrees_batch.
- Args:
- edges: List of (source_node_id, target_node_id) tuples
- Returns:
- Dictionary mapping edge tuples to their combined degrees
- """
- if not edges:
- return {}
- # Use node_degrees_batch to get all node degrees efficiently
- all_nodes = set()
- for src, tgt in edges:
- all_nodes.add(src)
- all_nodes.add(tgt)
- node_degrees = await self.node_degrees_batch(list(all_nodes))
- # Calculate edge degrees
- edge_degrees_dict = {}
- for src, tgt in edges:
- src_degree = node_degrees.get(src, 0)
- tgt_degree = node_degrees.get(tgt, 0)
- edge_degrees_dict[(src, tgt)] = src_degree + tgt_degree
- return edge_degrees_dict
- async def get_edges_batch(
- self, pairs: list[dict[str, str]], batch_size: int = 500
- ) -> dict[tuple[str, str], dict]:
- """
- Retrieve edge properties for multiple (src, tgt) pairs in one query.
- Get forward and backward edges separately and merge them before return
- Args:
- pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
- batch_size: Batch size for the query
- Returns:
- A dictionary mapping (src, tgt) tuples to their edge properties.
- """
- if not pairs:
- return {}
- seen = set()
- uniq_pairs: list[dict[str, str]] = []
- for p in pairs:
- s = self._normalize_node_id(p["src"])
- t = self._normalize_node_id(p["tgt"])
- key = (s, t)
- if s and t and key not in seen:
- seen.add(key)
- uniq_pairs.append(p)
- edges_dict: dict[tuple[str, str], dict] = {}
- for i in range(0, len(uniq_pairs), batch_size):
- batch = uniq_pairs[i : i + batch_size]
- pairs = [{"src": p["src"], "tgt": p["tgt"]} for p in batch]
- forward_cypher = """
- UNWIND $pairs AS p
- WITH p.src AS src_eid, p.tgt AS tgt_eid
- MATCH (a:base {entity_id: src_eid})
- MATCH (b:base {entity_id: tgt_eid})
- MATCH (a)-[r]->(b)
- RETURN src_eid AS source, tgt_eid AS target, properties(r) AS edge_properties"""
- backward_cypher = """
- UNWIND $pairs AS p
- WITH p.src AS src_eid, p.tgt AS tgt_eid
- MATCH (a:base {entity_id: src_eid})
- MATCH (b:base {entity_id: tgt_eid})
- MATCH (a)<-[r]-(b)
- RETURN src_eid AS source, tgt_eid AS target, properties(r) AS edge_properties"""
- sql_fwd = f"""
- SELECT * FROM cypher({_dollar_quote(self.graph_name)}::name,
- {_dollar_quote(forward_cypher)}::cstring,
- $1::agtype)
- AS (source text, target text, edge_properties agtype)
- """
- sql_bwd = f"""
- SELECT * FROM cypher({_dollar_quote(self.graph_name)}::name,
- {_dollar_quote(backward_cypher)}::cstring,
- $1::agtype)
- AS (source text, target text, edge_properties agtype)
- """
- pg_params = {"params": json.dumps({"pairs": pairs}, ensure_ascii=False)}
- forward_results = await self._query(sql_fwd, params=pg_params)
- backward_results = await self._query(sql_bwd, params=pg_params)
- for result in forward_results:
- if result["source"] and result["target"] and result["edge_properties"]:
- edge_props = result["edge_properties"]
- # Process string result, parse it to JSON dictionary
- if isinstance(edge_props, str):
- try:
- edge_props = json.loads(edge_props)
- except json.JSONDecodeError:
- logger.warning(
- f"[{self.workspace}]Failed to parse edge properties string: {edge_props}"
- )
- continue
- edges_dict[(result["source"], result["target"])] = edge_props
- for result in backward_results:
- if result["source"] and result["target"] and result["edge_properties"]:
- edge_props = result["edge_properties"]
- # Process string result, parse it to JSON dictionary
- if isinstance(edge_props, str):
- try:
- edge_props = json.loads(edge_props)
- except json.JSONDecodeError:
- logger.warning(
- f"[{self.workspace}] Failed to parse edge properties string: {edge_props}"
- )
- continue
- edges_dict[(result["source"], result["target"])] = edge_props
- return edges_dict
- async def get_nodes_edges_batch(
- self, node_ids: list[str], batch_size: int = 500
- ) -> dict[str, list[tuple[str, str]]]:
- """
- Get all edges (both outgoing and incoming) for multiple nodes in a single batch operation.
- Args:
- node_ids: List of node IDs to get edges for
- batch_size: Batch size for the query
- Returns:
- Dictionary mapping node IDs to lists of (source, target) edge tuples
- """
- if not node_ids:
- return {}
- seen = set()
- unique_ids: list[str] = []
- for nid in node_ids:
- if nid and nid not in seen:
- seen.add(nid)
- unique_ids.append(nid)
- edges_norm: dict[str, list[tuple[str, str]]] = {n: [] for n in unique_ids}
- for i in range(0, len(unique_ids), batch_size):
- batch = unique_ids[i : i + batch_size]
- pg_params = {"params": json.dumps({"node_ids": batch}, ensure_ascii=False)}
- outgoing_cypher = """UNWIND $node_ids AS node_id
- MATCH (n:base {entity_id: node_id})
- OPTIONAL MATCH (n:base)-[]->(connected:base)
- RETURN node_id, connected.entity_id AS connected_id"""
- incoming_cypher = """UNWIND $node_ids AS node_id
- MATCH (n:base {entity_id: node_id})
- OPTIONAL MATCH (n:base)<-[]-(connected:base)
- RETURN node_id, connected.entity_id AS connected_id"""
- outgoing_query = f"SELECT * FROM cypher({_dollar_quote(self.graph_name)}::name, {_dollar_quote(outgoing_cypher)}::cstring, $1::agtype) AS (node_id text, connected_id text)"
- incoming_query = f"SELECT * FROM cypher({_dollar_quote(self.graph_name)}::name, {_dollar_quote(incoming_cypher)}::cstring, $1::agtype) AS (node_id text, connected_id text)"
- outgoing_results = await self._query(outgoing_query, params=pg_params)
- incoming_results = await self._query(incoming_query, params=pg_params)
- for result in outgoing_results:
- if result["node_id"] and result["connected_id"]:
- edges_norm[result["node_id"]].append(
- (result["node_id"], result["connected_id"])
- )
- for result in incoming_results:
- if result["node_id"] and result["connected_id"]:
- edges_norm[result["node_id"]].append(
- (result["connected_id"], result["node_id"])
- )
- out: dict[str, list[tuple[str, str]]] = {}
- for orig in node_ids:
- out[orig] = edges_norm.get(orig, [])
- return out
- async def get_all_labels(self) -> list[str]:
- """
- Get all labels(node IDs, entity names) in the graph.
- Returns:
- list[str]: A list of all labels in the graph.
- """
- query = (
- """SELECT * FROM cypher('%s', $$
- MATCH (n:base)
- WHERE n.entity_id IS NOT NULL
- RETURN DISTINCT n.entity_id AS label
- ORDER BY n.entity_id
- $$) AS (label text)"""
- % self.graph_name
- )
- results = await self._query(query)
- labels = []
- for result in results:
- if result and isinstance(result, dict) and "label" in result:
- labels.append(result["label"])
- return labels
- async def _bfs_subgraph(
- self, node_label: str, max_depth: int, max_nodes: int
- ) -> KnowledgeGraph:
- """
- Implements a true breadth-first search algorithm for subgraph retrieval.
- This method is used as a fallback when the standard Cypher query is too slow
- or when we need to guarantee BFS ordering.
- Args:
- node_label: Label of the starting node
- max_depth: Maximum depth of the subgraph
- max_nodes: Maximum number of nodes to return
- Returns:
- KnowledgeGraph object containing nodes and edges
- """
- from collections import deque
- result = KnowledgeGraph()
- visited_nodes = set()
- visited_node_ids = set()
- visited_edges = set()
- visited_edge_pairs = set()
- # Get starting node data
- label = self._normalize_node_id(node_label)
- # Build Cypher query with dynamic dollar-quoting to handle entity_id containing $ sequences
- cypher_query = f"""MATCH (n:base {{entity_id: "{label}"}})
- RETURN id(n) as node_id, n"""
- query = f"SELECT * FROM cypher({_dollar_quote(self.graph_name)}, {_dollar_quote(cypher_query)}) AS (node_id bigint, n agtype)"
- node_result = await self._query(query)
- if not node_result or not node_result[0].get("n"):
- return result
- # Create initial KnowledgeGraphNode
- start_node_data = node_result[0]["n"]
- entity_id = start_node_data["properties"]["entity_id"]
- internal_id = str(start_node_data["id"])
- start_node = KnowledgeGraphNode(
- id=internal_id,
- labels=[entity_id],
- properties=start_node_data["properties"],
- )
- # Initialize BFS queue, each element is a tuple of (node, depth)
- queue = deque([(start_node, 0)])
- visited_nodes.add(entity_id)
- visited_node_ids.add(internal_id)
- result.nodes.append(start_node)
- result.is_truncated = False
- # BFS search main loop
- while queue:
- # Get all nodes at the current depth
- current_level_nodes = []
- current_depth = None
- # Determine current depth
- if queue:
- current_depth = queue[0][1]
- # Extract all nodes at current depth from the queue
- while queue and queue[0][1] == current_depth:
- node, depth = queue.popleft()
- if depth > max_depth:
- continue
- current_level_nodes.append(node)
- if not current_level_nodes:
- continue
- # Check depth limit
- if current_depth > max_depth:
- continue
- # Prepare node IDs list
- node_ids = [node.labels[0] for node in current_level_nodes]
- formatted_ids = ", ".join(
- [f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids]
- )
- # Build Cypher queries with dynamic dollar-quoting to handle entity_id containing $ sequences
- outgoing_cypher = f"""UNWIND [{formatted_ids}] AS node_id
- MATCH (n:base {{entity_id: node_id}})
- OPTIONAL MATCH (n)-[r]->(neighbor:base)
- RETURN node_id AS current_id,
- id(n) AS current_internal_id,
- id(neighbor) AS neighbor_internal_id,
- neighbor.entity_id AS neighbor_id,
- id(r) AS edge_id,
- r,
- neighbor,
- true AS is_outgoing"""
- incoming_cypher = f"""UNWIND [{formatted_ids}] AS node_id
- MATCH (n:base {{entity_id: node_id}})
- OPTIONAL MATCH (n)<-[r]-(neighbor:base)
- RETURN node_id AS current_id,
- id(n) AS current_internal_id,
- id(neighbor) AS neighbor_internal_id,
- neighbor.entity_id AS neighbor_id,
- id(r) AS edge_id,
- r,
- neighbor,
- false AS is_outgoing"""
- outgoing_query = f"SELECT * FROM cypher({_dollar_quote(self.graph_name)}, {_dollar_quote(outgoing_cypher)}) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"
- incoming_query = f"SELECT * FROM cypher({_dollar_quote(self.graph_name)}, {_dollar_quote(incoming_cypher)}) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"
- # Execute queries
- outgoing_results = await self._query(outgoing_query)
- incoming_results = await self._query(incoming_query)
- # Combine results
- neighbors = outgoing_results + incoming_results
- # Create mapping from node ID to node object
- node_map = {node.labels[0]: node for node in current_level_nodes}
- # Process all results in a single loop
- for record in neighbors:
- if not record.get("neighbor") or not record.get("r"):
- continue
- # Get current node information
- current_entity_id = record["current_id"]
- current_node = node_map[current_entity_id]
- # Get neighbor node information
- neighbor_entity_id = record["neighbor_id"]
- neighbor_internal_id = str(record["neighbor_internal_id"])
- is_outgoing = record["is_outgoing"]
- # Determine edge direction
- if is_outgoing:
- source_id = current_node.id
- target_id = neighbor_internal_id
- else:
- source_id = neighbor_internal_id
- target_id = current_node.id
- if not neighbor_entity_id:
- continue
- # Get edge and node information
- b_node = record["neighbor"]
- rel = record["r"]
- edge_id = str(record["edge_id"])
- # Create neighbor node object
- neighbor_node = KnowledgeGraphNode(
- id=neighbor_internal_id,
- labels=[neighbor_entity_id],
- properties=b_node["properties"],
- )
- # Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
- sorted_pair = tuple(sorted([current_entity_id, neighbor_entity_id]))
- # Create edge object
- edge = KnowledgeGraphEdge(
- id=edge_id,
- type=rel["label"],
- source=source_id,
- target=target_id,
- properties=rel["properties"],
- )
- if neighbor_internal_id in visited_node_ids:
- # Add backward edge if neighbor node is already visited
- if (
- edge_id not in visited_edges
- and sorted_pair not in visited_edge_pairs
- ):
- result.edges.append(edge)
- visited_edges.add(edge_id)
- visited_edge_pairs.add(sorted_pair)
- else:
- if len(visited_node_ids) < max_nodes and current_depth < max_depth:
- # Add new node to result and queue
- result.nodes.append(neighbor_node)
- visited_nodes.add(neighbor_entity_id)
- visited_node_ids.add(neighbor_internal_id)
- # Add node to queue with incremented depth
- queue.append((neighbor_node, current_depth + 1))
- # Add forward edge
- if (
- edge_id not in visited_edges
- and sorted_pair not in visited_edge_pairs
- ):
- result.edges.append(edge)
- visited_edges.add(edge_id)
- visited_edge_pairs.add(sorted_pair)
- else:
- if current_depth < max_depth:
- result.is_truncated = True
- return result
- async def get_knowledge_graph(
- self,
- node_label: str,
- max_depth: int = 3,
- max_nodes: int = None,
- ) -> KnowledgeGraph:
- """
- Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
- Args:
- node_label: Label of the starting node, * means all nodes
- max_depth: Maximum depth of the subgraph, Defaults to 3
- max_nodes: Maximum nodes to return, Defaults to global_config max_graph_nodes
- Returns:
- KnowledgeGraph object containing nodes and edges, with an is_truncated flag
- indicating whether the graph was truncated due to max_nodes limit
- """
- # Use global_config max_graph_nodes as default if max_nodes is None
- if max_nodes is None:
- max_nodes = self.global_config.get("max_graph_nodes", 1000)
- else:
- # Limit max_nodes to not exceed global_config max_graph_nodes
- max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
- kg = KnowledgeGraph()
- # Handle wildcard query - get all nodes
- if node_label == "*":
- # First check total node count to determine if graph should be truncated
- count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
- MATCH (n:base)
- RETURN count(distinct n) AS total_nodes
- $$) AS (total_nodes bigint)"""
- count_result = await self._query(count_query)
- total_nodes = count_result[0]["total_nodes"] if count_result else 0
- is_truncated = total_nodes > max_nodes
- # Get max_nodes with highest degrees
- query_nodes = f"""SELECT * FROM cypher('{self.graph_name}', $$
- MATCH (n:base)
- OPTIONAL MATCH (n)-[r]->()
- RETURN id(n) as node_id, count(r) as degree
- $$) AS (node_id BIGINT, degree BIGINT)
- ORDER BY degree DESC
- LIMIT {max_nodes}"""
- node_results = await self._query(query_nodes)
- node_ids = [str(result["node_id"]) for result in node_results]
- logger.info(
- f"[{self.workspace}] Total nodes: {total_nodes}, Selected nodes: {len(node_ids)}"
- )
- if node_ids:
- formatted_ids = ", ".join(node_ids)
- # Construct batch query for subgraph within max_nodes
- query = f"""SELECT * FROM cypher('{self.graph_name}', $$
- WITH [{formatted_ids}] AS node_ids
- MATCH (a)
- WHERE id(a) IN node_ids
- OPTIONAL MATCH (a)-[r]->(b)
- WHERE id(b) IN node_ids
- RETURN a, r, b
- $$) AS (a AGTYPE, r AGTYPE, b AGTYPE)"""
- results = await self._query(query)
- # Process query results, deduplicate nodes and edges
- nodes_dict = {}
- edges_dict = {}
- for result in results:
- # Process node a
- if result.get("a") and isinstance(result["a"], dict):
- node_a = result["a"]
- node_id = str(node_a["id"])
- if node_id not in nodes_dict and "properties" in node_a:
- nodes_dict[node_id] = KnowledgeGraphNode(
- id=node_id,
- labels=[node_a["properties"]["entity_id"]],
- properties=node_a["properties"],
- )
- # Process node b
- if result.get("b") and isinstance(result["b"], dict):
- node_b = result["b"]
- node_id = str(node_b["id"])
- if node_id not in nodes_dict and "properties" in node_b:
- nodes_dict[node_id] = KnowledgeGraphNode(
- id=node_id,
- labels=[node_b["properties"]["entity_id"]],
- properties=node_b["properties"],
- )
- # Process edge r
- if result.get("r") and isinstance(result["r"], dict):
- edge = result["r"]
- edge_id = str(edge["id"])
- if edge_id not in edges_dict:
- edges_dict[edge_id] = KnowledgeGraphEdge(
- id=edge_id,
- type=edge["label"],
- source=str(edge["start_id"]),
- target=str(edge["end_id"]),
- properties=edge["properties"],
- )
- kg = KnowledgeGraph(
- nodes=list(nodes_dict.values()),
- edges=list(edges_dict.values()),
- is_truncated=is_truncated,
- )
- else:
- # For single node query, use BFS algorithm
- kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
- logger.info(
- f"[{self.workspace}] Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
- )
- else:
- # For non-wildcard queries, use the BFS algorithm
- kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
- logger.info(
- f"[{self.workspace}] Subgraph query for '{node_label}' successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
- )
- return kg
- async def get_all_nodes(self) -> list[dict]:
- """Get all nodes in the graph.
- Returns:
- A list of all nodes, where each node is a dictionary of its properties
- """
- # Use native SQL to avoid Cypher wrapper overhead
- # Original: SELECT * FROM cypher(...) with MATCH (n:base)
- # Optimized: Direct table access for better performance
- query = f"""
- SELECT properties
- FROM {self.graph_name}.base
- """
- results = await self._query(query)
- nodes = []
- for result in results:
- if result.get("properties"):
- node_dict = result["properties"]
- # Process string result, parse it to JSON dictionary
- if isinstance(node_dict, str):
- try:
- node_dict = json.loads(node_dict)
- except json.JSONDecodeError:
- logger.warning(
- f"[{self.workspace}] Failed to parse node string: {node_dict}"
- )
- continue
- # Add node id (entity_id) to the dictionary for easier access
- node_dict["id"] = node_dict.get("entity_id")
- nodes.append(node_dict)
- return nodes
- async def get_all_edges(self) -> list[dict]:
- """Get all edges in the graph.
- Returns:
- A list of all edges, where each edge is a dictionary of its properties
- (If 2 directional edges exist between the same pair of nodes, deduplication must be handled by the caller)
- """
- # Use native SQL to avoid Cartesian product (N×N) in Cypher MATCH
- # Original Cypher: MATCH (a:base)-[r]-(b:base) creates ~50 billion row combinations
- # Optimized: Start from edges table, join to nodes only to get entity_id
- # Performance: O(E) instead of O(N²), ~50,000x faster for large graphs
- query = f"""
- SELECT DISTINCT
- (ag_catalog.agtype_access_operator(VARIADIC ARRAY[a.properties, '"entity_id"'::agtype]))::text AS source,
- (ag_catalog.agtype_access_operator(VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]))::text AS target,
- r.properties
- FROM {self.graph_name}."DIRECTED" r
- JOIN {self.graph_name}.base a ON r.start_id = a.id
- JOIN {self.graph_name}.base b ON r.end_id = b.id
- """
- results = await self._query(query)
- edges = []
- for result in results:
- edge_properties = result["properties"]
- # Process string result, parse it to JSON dictionary
- if isinstance(edge_properties, str):
- try:
- edge_properties = json.loads(edge_properties)
- except json.JSONDecodeError:
- logger.warning(
- f"[{self.workspace}] Failed to parse edge properties string: {edge_properties}"
- )
- edge_properties = {}
- edge_properties["source"] = result["source"]
- edge_properties["target"] = result["target"]
- edges.append(edge_properties)
- return edges
- async def get_popular_labels(self, limit: int = 300) -> list[str]:
- """Get popular labels by node degree (most connected entities) using native SQL for performance."""
- try:
- # Native SQL query to calculate node degrees directly from AGE's underlying tables
- # This is significantly faster than using the cypher() function wrapper
- query = f"""
- WITH node_degrees AS (
- SELECT
- node_id,
- COUNT(*) AS degree
- FROM (
- SELECT start_id AS node_id FROM {self.graph_name}._ag_label_edge
- UNION ALL
- SELECT end_id AS node_id FROM {self.graph_name}._ag_label_edge
- ) AS all_edges
- GROUP BY node_id
- )
- SELECT
- (ag_catalog.agtype_access_operator(VARIADIC ARRAY[v.properties, '"entity_id"'::agtype]))::text AS label
- FROM
- node_degrees d
- JOIN
- {self.graph_name}._ag_label_vertex v ON d.node_id = v.id
- WHERE
- ag_catalog.agtype_access_operator(VARIADIC ARRAY[v.properties, '"entity_id"'::agtype]) IS NOT NULL
- ORDER BY
- d.degree DESC,
- label ASC
- LIMIT $1;
- """
- results = await self._query(query, params={"limit": limit})
- labels = [
- result["label"] for result in results if result and "label" in result
- ]
- logger.debug(
- f"[{self.workspace}] Retrieved {len(labels)} popular labels (limit: {limit})"
- )
- return labels
- except Exception as e:
- logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}")
- return []
- async def search_labels(self, query: str, limit: int = 50) -> list[str]:
- """Search labels with fuzzy matching using native, parameterized SQL for performance and security."""
- query_lower = query.lower().strip()
- if not query_lower:
- return []
- try:
- # Re-implementing with the correct agtype access operator and full scoring logic.
- sql_query = f"""
- WITH ranked_labels AS (
- SELECT
- (ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '"entity_id"'::agtype]))::text AS label,
- LOWER((ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '"entity_id"'::agtype]))::text) AS label_lower
- FROM
- {self.graph_name}._ag_label_vertex
- WHERE
- ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '"entity_id"'::agtype]) IS NOT NULL
- AND LOWER((ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '"entity_id"'::agtype]))::text) ILIKE $1
- )
- SELECT
- label
- FROM (
- SELECT
- label,
- CASE
- WHEN label_lower = $2 THEN 1000
- WHEN label_lower LIKE $3 THEN 500
- ELSE (100 - LENGTH(label))
- END +
- CASE
- WHEN label_lower LIKE $4 OR label_lower LIKE $5 THEN 50
- ELSE 0
- END AS score
- FROM
- ranked_labels
- ) AS scored_labels
- ORDER BY
- score DESC,
- label ASC
- LIMIT $6;
- """
- params = (
- f"%{query_lower}%", # For the main ILIKE clause ($1)
- query_lower, # For exact match ($2)
- f"{query_lower}%", # For prefix match ($3)
- f"% {query_lower}%", # For word boundary (space) ($4)
- f"%_{query_lower}%", # For word boundary (underscore) ($5)
- limit, # For LIMIT ($6)
- )
- results = await self._query(sql_query, params=dict(enumerate(params, 1)))
- labels = [
- result["label"] for result in results if result and "label" in result
- ]
- logger.debug(
- f"[{self.workspace}] Search query '{query}' returned {len(labels)} results (limit: {limit})"
- )
- return labels
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error searching labels with query '{query}': {str(e)}"
- )
- return []
- async def drop(self) -> dict[str, str]:
- """Drop the storage"""
- try:
- drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
- MATCH (n)
- DETACH DELETE n
- $$) AS (result agtype)"""
- await self._query(drop_query, readonly=False)
- return {
- "status": "success",
- "message": f"workspace '{self.workspace}' graph data dropped",
- }
- except Exception as e:
- logger.error(f"[{self.workspace}] Error dropping graph: {e}")
- return {"status": "error", "message": str(e)}
- # Note: Order matters! More specific namespaces (e.g., "full_entities") must come before
- # more general ones (e.g., "entities") because is_namespace() uses endswith() matching
- NAMESPACE_TABLE_MAP = {
- NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
- NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
- NameSpace.KV_STORE_FULL_ENTITIES: "LIGHTRAG_FULL_ENTITIES",
- NameSpace.KV_STORE_FULL_RELATIONS: "LIGHTRAG_FULL_RELATIONS",
- NameSpace.KV_STORE_ENTITY_CHUNKS: "LIGHTRAG_ENTITY_CHUNKS",
- NameSpace.KV_STORE_RELATION_CHUNKS: "LIGHTRAG_RELATION_CHUNKS",
- NameSpace.KV_STORE_LLM_RESPONSE_CACHE: "LIGHTRAG_LLM_CACHE",
- NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_VDB_CHUNKS",
- NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY",
- NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_VDB_RELATION",
- NameSpace.DOC_STATUS: "LIGHTRAG_DOC_STATUS",
- }
- def namespace_to_table_name(namespace: str) -> str:
- for k, v in NAMESPACE_TABLE_MAP.items():
- if is_namespace(namespace, k):
- return v
- TABLES = {
- "LIGHTRAG_DOC_FULL": {
- "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
- id VARCHAR(255),
- workspace VARCHAR(255),
- doc_name VARCHAR(1024),
- content TEXT,
- meta JSONB,
- sidecar_location TEXT NULL,
- parse_format VARCHAR(32) NULL DEFAULT 'raw',
- -- content_hash is TEXT (not VARCHAR(N)) so the column is
- -- agnostic to the hash algorithm. Today's pipeline writes
- -- 64-char SHA-256 hex; future algos (SHA-512, base64) do
- -- not require a schema change.
- content_hash TEXT NULL,
- -- process_options is an opaque selector string emitted by
- -- sanitize_process_options() (e.g. "Fi").
- process_options TEXT NULL,
- chunk_options JSONB NULL DEFAULT '{}'::jsonb,
- parse_engine VARCHAR(32) NULL,
- create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
- update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
- CONSTRAINT LIGHTRAG_DOC_FULL_PK PRIMARY KEY (workspace, id)
- )"""
- },
- "LIGHTRAG_DOC_CHUNKS": {
- "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
- id VARCHAR(255),
- workspace VARCHAR(255),
- full_doc_id VARCHAR(256),
- chunk_order_index INTEGER,
- tokens INTEGER,
- content TEXT,
- file_path TEXT NULL,
- llm_cache_list JSONB NULL DEFAULT '[]'::jsonb,
- heading JSONB NULL DEFAULT '{}'::jsonb,
- sidecar JSONB NULL DEFAULT '{}'::jsonb,
- create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
- update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
- CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
- )"""
- },
- "LIGHTRAG_VDB_CHUNKS": {
- "ddl": """CREATE TABLE LIGHTRAG_VDB_CHUNKS (
- id VARCHAR(255),
- workspace VARCHAR(255),
- full_doc_id VARCHAR(256),
- chunk_order_index INTEGER,
- tokens INTEGER,
- content TEXT,
- content_vector VECTOR(dimension),
- file_path TEXT NULL,
- create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
- update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
- CONSTRAINT LIGHTRAG_VDB_CHUNKS_PK PRIMARY KEY (workspace, id)
- )"""
- },
- "LIGHTRAG_VDB_ENTITY": {
- "ddl": """CREATE TABLE LIGHTRAG_VDB_ENTITY (
- id VARCHAR(255),
- workspace VARCHAR(255),
- entity_name VARCHAR(512),
- content TEXT,
- content_vector VECTOR(dimension),
- create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
- update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
- chunk_ids VARCHAR(255)[] NULL,
- file_path TEXT NULL,
- CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
- )"""
- },
- "LIGHTRAG_VDB_RELATION": {
- "ddl": """CREATE TABLE LIGHTRAG_VDB_RELATION (
- id VARCHAR(255),
- workspace VARCHAR(255),
- source_id VARCHAR(512),
- target_id VARCHAR(512),
- content TEXT,
- content_vector VECTOR(dimension),
- create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
- update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
- chunk_ids VARCHAR(255)[] NULL,
- file_path TEXT NULL,
- CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
- )"""
- },
- "LIGHTRAG_LLM_CACHE": {
- "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
- workspace varchar(255) NOT NULL,
- id varchar(255) NOT NULL,
- original_prompt TEXT,
- return_value TEXT,
- chunk_id VARCHAR(255) NULL,
- cache_type VARCHAR(32),
- queryparam JSONB NULL,
- create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, id)
- )"""
- },
- "LIGHTRAG_DOC_STATUS": {
- "ddl": """CREATE TABLE LIGHTRAG_DOC_STATUS (
- workspace varchar(255) NOT NULL,
- id varchar(255) NOT NULL,
- content_summary varchar(255) NULL,
- content_length int4 NULL,
- chunks_count int4 NULL,
- status varchar(64) NULL,
- file_path TEXT NULL,
- chunks_list JSONB NULL DEFAULT '[]'::jsonb,
- track_id varchar(255) NULL,
- metadata JSONB NULL DEFAULT '{}'::jsonb,
- error_msg TEXT NULL,
- -- content_hash is TEXT (not VARCHAR(N)) so the column is
- -- agnostic to the hash algorithm. Today's pipeline writes
- -- 64-char SHA-256 hex; future algos (SHA-512, base64) do
- -- not require a schema change.
- content_hash TEXT NULL,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id)
- )"""
- },
- "LIGHTRAG_FULL_ENTITIES": {
- "ddl": """CREATE TABLE LIGHTRAG_FULL_ENTITIES (
- id VARCHAR(255),
- workspace VARCHAR(255),
- entity_names JSONB,
- count INTEGER,
- create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
- update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
- CONSTRAINT LIGHTRAG_FULL_ENTITIES_PK PRIMARY KEY (workspace, id)
- )"""
- },
- "LIGHTRAG_FULL_RELATIONS": {
- "ddl": """CREATE TABLE LIGHTRAG_FULL_RELATIONS (
- id VARCHAR(255),
- workspace VARCHAR(255),
- relation_pairs JSONB,
- count INTEGER,
- create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
- update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
- CONSTRAINT LIGHTRAG_FULL_RELATIONS_PK PRIMARY KEY (workspace, id)
- )"""
- },
- "LIGHTRAG_ENTITY_CHUNKS": {
- "ddl": """CREATE TABLE LIGHTRAG_ENTITY_CHUNKS (
- id VARCHAR(512),
- workspace VARCHAR(255),
- chunk_ids JSONB,
- count INTEGER,
- create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
- update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
- CONSTRAINT LIGHTRAG_ENTITY_CHUNKS_PK PRIMARY KEY (workspace, id)
- )"""
- },
- "LIGHTRAG_RELATION_CHUNKS": {
- "ddl": """CREATE TABLE LIGHTRAG_RELATION_CHUNKS (
- id VARCHAR(512),
- workspace VARCHAR(255),
- chunk_ids JSONB,
- count INTEGER,
- create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
- update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
- CONSTRAINT LIGHTRAG_RELATION_CHUNKS_PK PRIMARY KEY (workspace, id)
- )"""
- },
- }
- SQL_TEMPLATES = {
- # SQL for KVStorage
- "get_by_id_full_docs": """SELECT id, COALESCE(content, '') as content,
- COALESCE(doc_name, '') as file_path,
- sidecar_location,
- parse_format,
- content_hash,
- process_options,
- COALESCE(chunk_options, '{}'::jsonb) as chunk_options,
- parse_engine
- FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2
- """,
- "get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
- chunk_order_index, full_doc_id, file_path,
- COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
- COALESCE(heading, '{}'::jsonb) as heading,
- COALESCE(sidecar, '{}'::jsonb) as sidecar,
- EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
- EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
- FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
- """,
- "get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam,
- EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
- EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
- FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
- """,
- "get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content,
- COALESCE(doc_name, '') as file_path,
- sidecar_location,
- parse_format,
- content_hash,
- process_options,
- COALESCE(chunk_options, '{}'::jsonb) as chunk_options,
- parse_engine
- FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id = ANY($2)
- """,
- "get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
- chunk_order_index, full_doc_id, file_path,
- COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
- COALESCE(heading, '{}'::jsonb) as heading,
- COALESCE(sidecar, '{}'::jsonb) as sidecar,
- EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
- EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
- FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id = ANY($2)
- """,
- "get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam,
- EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
- EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
- FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id = ANY($2)
- """,
- "get_by_id_full_entities": """SELECT id, entity_names, count,
- EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
- EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
- FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id=$2
- """,
- "get_by_id_full_relations": """SELECT id, relation_pairs, count,
- EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
- EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
- FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id=$2
- """,
- "get_by_ids_full_entities": """SELECT id, entity_names, count,
- EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
- EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
- FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id = ANY($2)
- """,
- "get_by_ids_full_relations": """SELECT id, relation_pairs, count,
- EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
- EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
- FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id = ANY($2)
- """,
- "get_by_id_entity_chunks": """SELECT id, chunk_ids, count,
- EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
- EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
- FROM LIGHTRAG_ENTITY_CHUNKS WHERE workspace=$1 AND id=$2
- """,
- "get_by_id_relation_chunks": """SELECT id, chunk_ids, count,
- EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
- EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
- FROM LIGHTRAG_RELATION_CHUNKS WHERE workspace=$1 AND id=$2
- """,
- "get_by_ids_entity_chunks": """SELECT id, chunk_ids, count,
- EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
- EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
- FROM LIGHTRAG_ENTITY_CHUNKS WHERE workspace=$1 AND id = ANY($2)
- """,
- "get_by_ids_relation_chunks": """SELECT id, chunk_ids, count,
- EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
- EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
- FROM LIGHTRAG_RELATION_CHUNKS WHERE workspace=$1 AND id = ANY($2)
- """,
- "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
- # Pipeline-derived columns (sidecar_location / parse_format / content_hash /
- # process_options / chunk_options / parse_engine) are guarded with COALESCE
- # so a partial upsert (e.g. a caller writing only ``content`` + ``doc_name``)
- # does not silently overwrite metadata recorded by _persist_parsed_full_docs.
- # ``content`` and ``doc_name`` themselves are always overwritten — they are
- # the primary payload, never a candidate for preservation.
- # For the string columns we use NULLIF('', ...) so that an empty string from
- # a default-bearing caller is treated as "no value, preserve existing".
- # For chunk_options (JSONB) we treat NULL or the empty-object literal as
- # "no value, preserve existing".
- "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace,
- sidecar_location, parse_format, content_hash,
- process_options, chunk_options, parse_engine)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
- ON CONFLICT (workspace,id) DO UPDATE
- SET content = EXCLUDED.content,
- doc_name = EXCLUDED.doc_name,
- sidecar_location = COALESCE(
- NULLIF(EXCLUDED.sidecar_location, ''),
- LIGHTRAG_DOC_FULL.sidecar_location
- ),
- parse_format = COALESCE(
- NULLIF(EXCLUDED.parse_format, ''),
- LIGHTRAG_DOC_FULL.parse_format
- ),
- content_hash = COALESCE(
- NULLIF(EXCLUDED.content_hash, ''),
- LIGHTRAG_DOC_FULL.content_hash
- ),
- process_options = COALESCE(
- NULLIF(EXCLUDED.process_options, ''),
- LIGHTRAG_DOC_FULL.process_options
- ),
- chunk_options = CASE
- WHEN EXCLUDED.chunk_options IS NULL
- OR EXCLUDED.chunk_options = '{}'::jsonb
- THEN LIGHTRAG_DOC_FULL.chunk_options
- ELSE EXCLUDED.chunk_options
- END,
- parse_engine = COALESCE(
- NULLIF(EXCLUDED.parse_engine, ''),
- LIGHTRAG_DOC_FULL.parse_engine
- ),
- update_time = CURRENT_TIMESTAMP
- """,
- "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,chunk_id,cache_type,queryparam)
- VALUES ($1, $2, $3, $4, $5, $6, $7)
- ON CONFLICT (workspace,id) DO UPDATE
- SET original_prompt = EXCLUDED.original_prompt,
- return_value=EXCLUDED.return_value,
- chunk_id=EXCLUDED.chunk_id,
- cache_type=EXCLUDED.cache_type,
- queryparam=EXCLUDED.queryparam,
- update_time = CURRENT_TIMESTAMP
- """,
- "upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
- chunk_order_index, full_doc_id, content, file_path, llm_cache_list,
- heading, sidecar, create_time, update_time)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
- ON CONFLICT (workspace,id) DO UPDATE
- SET tokens=EXCLUDED.tokens,
- chunk_order_index=EXCLUDED.chunk_order_index,
- full_doc_id=EXCLUDED.full_doc_id,
- content = EXCLUDED.content,
- file_path=EXCLUDED.file_path,
- llm_cache_list=EXCLUDED.llm_cache_list,
- heading=EXCLUDED.heading,
- sidecar=EXCLUDED.sidecar,
- update_time = EXCLUDED.update_time
- """,
- "upsert_full_entities": """INSERT INTO LIGHTRAG_FULL_ENTITIES (workspace, id, entity_names, count,
- create_time, update_time)
- VALUES ($1, $2, $3, $4, $5, $6)
- ON CONFLICT (workspace,id) DO UPDATE
- SET entity_names=EXCLUDED.entity_names,
- count=EXCLUDED.count,
- update_time = EXCLUDED.update_time
- """,
- "upsert_full_relations": """INSERT INTO LIGHTRAG_FULL_RELATIONS (workspace, id, relation_pairs, count,
- create_time, update_time)
- VALUES ($1, $2, $3, $4, $5, $6)
- ON CONFLICT (workspace,id) DO UPDATE
- SET relation_pairs=EXCLUDED.relation_pairs,
- count=EXCLUDED.count,
- update_time = EXCLUDED.update_time
- """,
- "upsert_entity_chunks": """INSERT INTO LIGHTRAG_ENTITY_CHUNKS (workspace, id, chunk_ids, count,
- create_time, update_time)
- VALUES ($1, $2, $3, $4, $5, $6)
- ON CONFLICT (workspace,id) DO UPDATE
- SET chunk_ids=EXCLUDED.chunk_ids,
- count=EXCLUDED.count,
- update_time = EXCLUDED.update_time
- """,
- "upsert_relation_chunks": """INSERT INTO LIGHTRAG_RELATION_CHUNKS (workspace, id, chunk_ids, count,
- create_time, update_time)
- VALUES ($1, $2, $3, $4, $5, $6)
- ON CONFLICT (workspace,id) DO UPDATE
- SET chunk_ids=EXCLUDED.chunk_ids,
- count=EXCLUDED.count,
- update_time = EXCLUDED.update_time
- """,
- # SQL for VectorStorage
- "upsert_chunk": """INSERT INTO {table_name} (workspace, id, tokens,
- chunk_order_index, full_doc_id, content, content_vector, file_path,
- create_time, update_time)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
- ON CONFLICT (workspace,id) DO UPDATE
- SET tokens=EXCLUDED.tokens,
- chunk_order_index=EXCLUDED.chunk_order_index,
- full_doc_id=EXCLUDED.full_doc_id,
- content = EXCLUDED.content,
- content_vector=EXCLUDED.content_vector,
- file_path=EXCLUDED.file_path,
- update_time = EXCLUDED.update_time
- """,
- "upsert_entity": """INSERT INTO {table_name} (workspace, id, entity_name, content,
- content_vector, chunk_ids, file_path, create_time, update_time)
- VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9)
- ON CONFLICT (workspace,id) DO UPDATE
- SET entity_name=EXCLUDED.entity_name,
- content=EXCLUDED.content,
- content_vector=EXCLUDED.content_vector,
- chunk_ids=EXCLUDED.chunk_ids,
- file_path=EXCLUDED.file_path,
- update_time=EXCLUDED.update_time
- """,
- "upsert_relationship": """INSERT INTO {table_name} (workspace, id, source_id,
- target_id, content, content_vector, chunk_ids, file_path, create_time, update_time)
- VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8, $9, $10)
- ON CONFLICT (workspace,id) DO UPDATE
- SET source_id=EXCLUDED.source_id,
- target_id=EXCLUDED.target_id,
- content=EXCLUDED.content,
- content_vector=EXCLUDED.content_vector,
- chunk_ids=EXCLUDED.chunk_ids,
- file_path=EXCLUDED.file_path,
- update_time = EXCLUDED.update_time
- """,
- "relationships": """
- SELECT source_id AS src_id,
- target_id AS tgt_id,
- EXTRACT(EPOCH FROM create_time)::BIGINT AS created_at
- FROM {table_name}
- WHERE workspace = $1
- AND content_vector <=> $4::{vector_cast} < $2
- ORDER BY content_vector <=> $4::{vector_cast}
- LIMIT $3;
- """,
- "entities": """
- SELECT entity_name,
- EXTRACT(EPOCH FROM create_time)::BIGINT AS created_at
- FROM {table_name}
- WHERE workspace = $1
- AND content_vector <=> $4::{vector_cast} < $2
- ORDER BY content_vector <=> $4::{vector_cast}
- LIMIT $3;
- """,
- "chunks": """
- SELECT id,
- content,
- file_path,
- EXTRACT(EPOCH FROM create_time)::BIGINT AS created_at
- FROM {table_name}
- WHERE workspace = $1
- AND content_vector <=> $4::{vector_cast} < $2
- ORDER BY content_vector <=> $4::{vector_cast}
- LIMIT $3;
- """,
- # DROP tables
- "drop_specifiy_table_workspace": """
- DELETE FROM {table_name} WHERE workspace=$1
- """,
- }
|