test_opensearch_storage.py 179 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358
  1. """
  2. Unit tests for OpenSearch storage implementations.
  3. All tests use mocks — no running OpenSearch instance required.
  4. Run with: pytest tests/kg/opensearch_impl/test_opensearch_storage.py -v
  5. """
  6. import asyncio
  7. import math
  8. import pytest
  9. from contextlib import asynccontextmanager
  10. from unittest.mock import AsyncMock, patch
  11. import numpy as np
  12. pytest.importorskip(
  13. "opensearchpy",
  14. reason="opensearchpy is required for OpenSearch storage tests",
  15. )
  16. from opensearchpy.exceptions import NotFoundError, OpenSearchException # type: ignore
  17. from lightrag.kg.opensearch_impl import (
  18. OpenSearchKVStorage,
  19. OpenSearchDocStatusStorage,
  20. OpenSearchGraphStorage,
  21. OpenSearchVectorDBStorage,
  22. ClientManager,
  23. _build_index_name,
  24. _resolve_workspace,
  25. _sanitize_index_name,
  26. _verify_mirrored_id_mapping,
  27. )
  28. from lightrag.base import DocStatus, DocProcessingStatus
  29. pytestmark = pytest.mark.offline
  30. # ---------------------------------------------------------------------------
  31. # Mock the shared storage lock so tests don't need full LightRAG init
  32. # ---------------------------------------------------------------------------
  33. @asynccontextmanager
  34. async def _mock_lock():
  35. yield
  36. def _mock_lock_factory():
  37. return _mock_lock()
  38. def _missing_index_error() -> NotFoundError:
  39. return NotFoundError(404, "index_not_found_exception", "no such index")
  40. @pytest.fixture(autouse=True)
  41. def patch_data_init_lock():
  42. """Patch get_data_init_lock globally so initialize() works without shared storage."""
  43. with patch(
  44. "lightrag.kg.opensearch_impl.get_data_init_lock", side_effect=_mock_lock_factory
  45. ):
  46. yield
  47. @pytest.fixture(autouse=True)
  48. def patch_namespace_lock():
  49. """Patch get_namespace_lock to return real asyncio.Lock instances.
  50. Returning a real Lock (not a no-op) preserves the in-process blocking
  51. semantics the storage relies on, so concurrent flush / read / write
  52. tests can observe actual serialization. Locks are cached per
  53. (namespace, workspace) tuple so multiple calls from the same storage
  54. pick up the same Lock instance.
  55. """
  56. cache: dict[tuple[str, str | None], asyncio.Lock] = {}
  57. def factory(namespace, workspace=None, enable_logging=False):
  58. key = (namespace, workspace or "")
  59. lock = cache.get(key)
  60. if lock is None:
  61. lock = asyncio.Lock()
  62. cache[key] = lock
  63. return lock
  64. with patch("lightrag.kg.opensearch_impl.get_namespace_lock", side_effect=factory):
  65. yield
  66. @pytest.fixture(autouse=True)
  67. def patch_shard_doc_supported():
  68. """Default tests to OpenSearch >= 3.3.0 so the __mirrored_id verification is a no-op.
  69. Tests covering the < 3.3.0 fallback should override this with their own patch.
  70. """
  71. with patch("lightrag.kg.opensearch_impl._shard_doc_supported", True):
  72. yield
  73. # ---------------------------------------------------------------------------
  74. # Fixtures
  75. # ---------------------------------------------------------------------------
  76. class MockEmbeddingFunc:
  77. """Mock embedding function that returns random vectors."""
  78. def __init__(self, dim=128):
  79. self.embedding_dim = dim
  80. self.max_token_size = 512
  81. self.model_name = "mock-embed"
  82. async def __call__(self, texts, **kwargs):
  83. return np.random.rand(len(texts), self.embedding_dim).astype(np.float32)
  84. class CountingEmbeddingFunc(MockEmbeddingFunc):
  85. """Embedding test double that records calls and can fail a fixed number of times."""
  86. def __init__(self, dim=128, fail_times=0):
  87. super().__init__(dim=dim)
  88. self.fail_times = fail_times
  89. self.call_count = 0
  90. self.batches: list[list[str]] = []
  91. self.texts: list[str] = []
  92. async def __call__(self, texts, **kwargs):
  93. self.call_count += 1
  94. batch = list(texts)
  95. self.batches.append(batch)
  96. self.texts.extend(batch)
  97. if self.fail_times > 0:
  98. self.fail_times -= 1
  99. raise RuntimeError("embedding failed")
  100. return await super().__call__(texts, **kwargs)
  101. @pytest.fixture
  102. def global_config():
  103. """Standard global config fixture for all storage tests."""
  104. return {
  105. "embedding_batch_num": 10,
  106. "max_graph_nodes": 1000,
  107. "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2},
  108. }
  109. @pytest.fixture
  110. def embed_func():
  111. """Mock embedding function fixture."""
  112. return MockEmbeddingFunc()
  113. def _make_client():
  114. """Create a fully-mocked AsyncOpenSearch client with spec validation."""
  115. from opensearchpy import AsyncOpenSearch
  116. client = AsyncMock(spec=AsyncOpenSearch)
  117. # indices sub-client
  118. client.indices = AsyncMock()
  119. client.indices.exists = AsyncMock(return_value=False)
  120. client.indices.create = AsyncMock()
  121. client.indices.delete = AsyncMock()
  122. client.indices.refresh = AsyncMock()
  123. client.indices.get_mapping = AsyncMock(return_value={})
  124. # transport for PPL
  125. client.transport = AsyncMock()
  126. client.transport.perform_request = AsyncMock(
  127. side_effect=Exception("PPL not available")
  128. )
  129. # document operations
  130. client.exists = AsyncMock(return_value=False)
  131. client.index = AsyncMock()
  132. client.delete = AsyncMock()
  133. client.delete_by_query = AsyncMock()
  134. client.get = AsyncMock(
  135. return_value={
  136. "_id": "doc1",
  137. "_source": {"content": "hello", "create_time": 0, "update_time": 0},
  138. }
  139. )
  140. client.mget = AsyncMock(
  141. return_value={
  142. "docs": [
  143. {"_id": "id1", "found": True, "_source": {"content": "c1"}},
  144. {"_id": "id2", "found": True, "_source": {"content": "c2"}},
  145. ]
  146. }
  147. )
  148. client.count = AsyncMock(return_value={"count": 5})
  149. client.search = AsyncMock(
  150. return_value={
  151. "hits": {"hits": [], "total": {"value": 0}},
  152. "aggregations": {
  153. "status_counts": {"buckets": []},
  154. "src": {"buckets": []},
  155. "tgt": {"buckets": []},
  156. "source_degrees": {"buckets": []},
  157. "target_degrees": {"buckets": []},
  158. },
  159. }
  160. )
  161. # PIT operations
  162. client.create_pit = AsyncMock(return_value={"pit_id": "mock_pit_id_123"})
  163. client.delete_pit = AsyncMock()
  164. return client
  165. @pytest.fixture
  166. def mock_client():
  167. """Fully-mocked AsyncOpenSearch client fixture."""
  168. return _make_client()
  169. # ---------------------------------------------------------------------------
  170. # Helper utilities
  171. # ---------------------------------------------------------------------------
  172. class TestHelpers:
  173. """Tests for module-level helper functions (_build_index_name, _resolve_workspace, _sanitize_index_name)."""
  174. def test_build_index_name_with_workspace(self):
  175. ws, ns, idx = _build_index_name("myws", "text_chunks")
  176. assert ws == "myws"
  177. assert ns == "myws_text_chunks"
  178. assert idx == _sanitize_index_name("myws_text_chunks")
  179. def test_build_index_name_no_workspace(self):
  180. ws, ns, idx = _build_index_name("", "chunks")
  181. assert ws == ""
  182. assert idx == _sanitize_index_name("chunks")
  183. def test_resolve_workspace_env_override(self):
  184. with patch.dict("os.environ", {"OPENSEARCH_WORKSPACE": "forced"}):
  185. assert _resolve_workspace("original", "ns") == "forced"
  186. def test_resolve_workspace_fallback(self):
  187. with patch.dict("os.environ", {}, clear=True):
  188. assert _resolve_workspace("original", "ns") == "original"
  189. def test_sanitize_index_name(self):
  190. assert _sanitize_index_name("Hello_World") == "hello_world"
  191. assert _sanitize_index_name("-bad") == "x-bad"
  192. assert _sanitize_index_name("a.b/c") == "a_b_c"
  193. # ---------------------------------------------------------------------------
  194. # ClientManager
  195. # ---------------------------------------------------------------------------
  196. class TestClientManager:
  197. """Tests for ClientManager singleton pattern and reference counting."""
  198. @staticmethod
  199. def _stub_client(version: str = "3.3.0") -> AsyncMock:
  200. """Build an AsyncMock client with a concrete .info() payload.
  201. Without this stub, _detect_shard_doc_support's chained .get(...) calls
  202. on an AsyncMock would leak un-awaited coroutines.
  203. """
  204. client = AsyncMock()
  205. client.info = AsyncMock(return_value={"version": {"number": version}})
  206. return client
  207. @pytest.mark.asyncio
  208. async def test_singleton_and_refcount(self):
  209. ClientManager._instances = {"client": None, "ref_count": 0}
  210. with patch("lightrag.kg.opensearch_impl.AsyncOpenSearch") as mock_cls:
  211. mock_cls.return_value = self._stub_client()
  212. c1 = await ClientManager.get_client()
  213. c2 = await ClientManager.get_client()
  214. assert c1 is c2
  215. assert ClientManager._instances["ref_count"] == 2
  216. await ClientManager.release_client(c1)
  217. assert ClientManager._instances["ref_count"] == 1
  218. await ClientManager.release_client(c2)
  219. assert ClientManager._instances["ref_count"] == 0
  220. assert ClientManager._instances["client"] is None
  221. @pytest.mark.asyncio
  222. async def test_close_called_on_last_release(self):
  223. ClientManager._instances = {"client": None, "ref_count": 0}
  224. with patch("lightrag.kg.opensearch_impl.AsyncOpenSearch") as mock_cls:
  225. inner = self._stub_client()
  226. mock_cls.return_value = inner
  227. c = await ClientManager.get_client()
  228. await ClientManager.release_client(c)
  229. inner.close.assert_awaited_once()
  230. # ---------------------------------------------------------------------------
  231. # _verify_mirrored_id_mapping helper
  232. # ---------------------------------------------------------------------------
  233. class TestMirroredIdVerification:
  234. """Tests for the _verify_mirrored_id_mapping fail-fast helper."""
  235. @pytest.mark.asyncio
  236. async def test_skipped_on_modern_cluster(self, mock_client):
  237. """On OpenSearch >= 3.3.0 the mapping check is short-circuited."""
  238. # _shard_doc_supported is True via autouse fixture.
  239. await _verify_mirrored_id_mapping(mock_client, "any_index")
  240. mock_client.indices.get_mapping.assert_not_awaited()
  241. @pytest.mark.asyncio
  242. async def test_passes_when_mapping_present(self, mock_client):
  243. """On OpenSearch < 3.3.0 a mapping containing __mirrored_id is accepted."""
  244. mock_client.indices.get_mapping = AsyncMock(
  245. return_value={
  246. "my_index": {
  247. "mappings": {"properties": {"__mirrored_id": {"type": "keyword"}}}
  248. }
  249. }
  250. )
  251. with patch("lightrag.kg.opensearch_impl._shard_doc_supported", False):
  252. await _verify_mirrored_id_mapping(mock_client, "my_index")
  253. @pytest.mark.asyncio
  254. async def test_fails_fast_when_mapping_missing(self, mock_client):
  255. """On OpenSearch < 3.3.0 a legacy index without __mirrored_id raises."""
  256. mock_client.indices.get_mapping = AsyncMock(
  257. return_value={
  258. "my_index": {
  259. "mappings": {"properties": {"other_field": {"type": "text"}}}
  260. }
  261. }
  262. )
  263. with patch("lightrag.kg.opensearch_impl._shard_doc_supported", False):
  264. with pytest.raises(RuntimeError, match="__mirrored_id"):
  265. await _verify_mirrored_id_mapping(mock_client, "my_index")
  266. @pytest.mark.asyncio
  267. async def test_swallows_get_mapping_error(self, mock_client):
  268. """Mapping-fetch failures should not block initialization."""
  269. mock_client.indices.get_mapping = AsyncMock(
  270. side_effect=OpenSearchException("transport error")
  271. )
  272. with patch("lightrag.kg.opensearch_impl._shard_doc_supported", False):
  273. await _verify_mirrored_id_mapping(mock_client, "my_index")
  274. # ---------------------------------------------------------------------------
  275. # KV Storage
  276. # ---------------------------------------------------------------------------
  277. class TestKVStorage:
  278. """Tests for OpenSearchKVStorage CRUD operations, timestamps, refresh behavior."""
  279. def _make(self, global_config, embed_func, workspace="test"):
  280. return OpenSearchKVStorage(
  281. namespace="text_chunks",
  282. global_config=global_config,
  283. embedding_func=embed_func,
  284. workspace=workspace,
  285. )
  286. @pytest.mark.asyncio
  287. async def test_index_name(self, global_config, embed_func):
  288. s = self._make(global_config, embed_func, workspace="proj_a")
  289. assert s._index_name == "proj_a_text_chunks"
  290. @pytest.mark.asyncio
  291. async def test_initialize_creates_index(
  292. self, global_config, embed_func, mock_client
  293. ):
  294. with patch.object(ClientManager, "get_client", return_value=mock_client):
  295. s = self._make(global_config, embed_func)
  296. await s.initialize()
  297. mock_client.indices.exists.assert_awaited_once()
  298. mock_client.indices.create.assert_awaited_once()
  299. @pytest.mark.asyncio
  300. async def test_initialize_skips_existing_index(
  301. self, global_config, embed_func, mock_client
  302. ):
  303. mock_client.indices.exists = AsyncMock(return_value=True)
  304. with patch.object(ClientManager, "get_client", return_value=mock_client):
  305. s = self._make(global_config, embed_func)
  306. await s.initialize()
  307. mock_client.indices.create.assert_not_awaited()
  308. @pytest.mark.asyncio
  309. async def test_initialize_fails_on_legacy_index_without_mirrored_id(
  310. self, global_config, embed_func, mock_client
  311. ):
  312. """On OpenSearch < 3.3.0, an existing index lacking __mirrored_id must fail-fast."""
  313. mock_client.indices.exists = AsyncMock(return_value=True)
  314. mock_client.indices.get_mapping = AsyncMock(
  315. return_value={
  316. "test_text_chunks": {
  317. "mappings": {"properties": {"content": {"type": "text"}}}
  318. }
  319. }
  320. )
  321. with (
  322. patch.object(ClientManager, "get_client", return_value=mock_client),
  323. patch("lightrag.kg.opensearch_impl._shard_doc_supported", False),
  324. ):
  325. s = self._make(global_config, embed_func)
  326. with pytest.raises(RuntimeError, match="__mirrored_id"):
  327. await s.initialize()
  328. mock_client.indices.create.assert_not_awaited()
  329. @pytest.mark.asyncio
  330. async def test_get_by_id(self, global_config, embed_func, mock_client):
  331. mock_client.mget = AsyncMock(
  332. return_value={
  333. "docs": [
  334. {
  335. "_id": "doc1",
  336. "found": True,
  337. "_source": {
  338. "content": "hello",
  339. "create_time": 0,
  340. "update_time": 0,
  341. },
  342. }
  343. ]
  344. }
  345. )
  346. with patch.object(ClientManager, "get_client", return_value=mock_client):
  347. s = self._make(global_config, embed_func)
  348. await s.initialize()
  349. doc = await s.get_by_id("doc1")
  350. assert doc is not None
  351. assert doc["content"] == "hello"
  352. assert doc["_id"] == "doc1"
  353. mock_client.mget.assert_awaited_once_with(
  354. index=s._index_name, body={"ids": ["doc1"]}
  355. )
  356. @pytest.mark.asyncio
  357. async def test_get_by_id_not_found(self, global_config, embed_func, mock_client):
  358. mock_client.mget = AsyncMock(
  359. return_value={"docs": [{"_id": "missing", "found": False}]}
  360. )
  361. with patch.object(ClientManager, "get_client", return_value=mock_client):
  362. s = self._make(global_config, embed_func)
  363. await s.initialize()
  364. assert await s.get_by_id("missing") is None
  365. mock_client.get.assert_not_awaited()
  366. @pytest.mark.asyncio
  367. async def test_get_by_ids_preserves_order(
  368. self, global_config, embed_func, mock_client
  369. ):
  370. with patch.object(ClientManager, "get_client", return_value=mock_client):
  371. s = self._make(global_config, embed_func)
  372. await s.initialize()
  373. docs = await s.get_by_ids(["id1", "id2"])
  374. assert docs[0]["content"] == "c1"
  375. assert docs[1]["content"] == "c2"
  376. @pytest.mark.asyncio
  377. async def test_filter_keys(self, global_config, embed_func, mock_client):
  378. mock_client.mget = AsyncMock(
  379. return_value={
  380. "docs": [
  381. {"_id": "a", "found": True},
  382. {"_id": "b", "found": False},
  383. ]
  384. }
  385. )
  386. with patch.object(ClientManager, "get_client", return_value=mock_client):
  387. s = self._make(global_config, embed_func)
  388. await s.initialize()
  389. result = await s.filter_keys({"a", "b"})
  390. assert result == {"b"}
  391. @pytest.mark.asyncio
  392. async def test_upsert_no_per_operation_refresh(
  393. self, global_config, embed_func, mock_client
  394. ):
  395. """The flush (during index_done_callback) must not request per-op refresh."""
  396. with patch.object(ClientManager, "get_client", return_value=mock_client):
  397. with patch(
  398. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  399. ) as mock_bulk:
  400. mock_bulk.return_value = (1, [])
  401. s = self._make(global_config, embed_func)
  402. await s.initialize()
  403. await s.upsert({"k1": {"content": "v1"}})
  404. # upsert buffers; bulk fires on flush.
  405. mock_bulk.assert_not_awaited()
  406. await s.index_done_callback()
  407. _, kwargs = mock_bulk.call_args
  408. assert "refresh" not in kwargs
  409. @pytest.mark.asyncio
  410. async def test_upsert_sets_timestamps(self, global_config, embed_func, mock_client):
  411. """Buffered docs carry create_time / update_time set eagerly during upsert."""
  412. with patch.object(ClientManager, "get_client", return_value=mock_client):
  413. with patch(
  414. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  415. ) as mock_bulk:
  416. mock_bulk.return_value = (1, [])
  417. s = self._make(global_config, embed_func)
  418. await s.initialize()
  419. await s.upsert({"k1": {"content": "v1"}})
  420. # Timestamps are visible in the pending buffer immediately.
  421. assert "create_time" in s._pending_upserts["k1"]
  422. assert "update_time" in s._pending_upserts["k1"]
  423. await s.index_done_callback()
  424. actions = mock_bulk.call_args[0][1]
  425. src = actions[0]["_source"]
  426. assert "create_time" in src
  427. assert "update_time" in src
  428. @pytest.mark.asyncio
  429. async def test_is_empty(self, global_config, embed_func, mock_client):
  430. mock_client.count = AsyncMock(return_value={"count": 0})
  431. with patch.object(ClientManager, "get_client", return_value=mock_client):
  432. s = self._make(global_config, embed_func)
  433. await s.initialize()
  434. assert await s.is_empty() is True
  435. @pytest.mark.asyncio
  436. async def test_delete(self, global_config, embed_func, mock_client):
  437. """delete() buffers tombstones; the bulk delete fires on flush."""
  438. with patch.object(ClientManager, "get_client", return_value=mock_client):
  439. with patch(
  440. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  441. ) as mock_bulk:
  442. mock_bulk.return_value = (2, [])
  443. s = self._make(global_config, embed_func)
  444. await s.initialize()
  445. await s.delete(["a", "b"])
  446. mock_bulk.assert_not_awaited()
  447. assert s._pending_kv_deletes == {"a", "b"}
  448. await s.index_done_callback()
  449. actions = mock_bulk.call_args[0][1]
  450. assert len(actions) == 2
  451. assert all(a["_op_type"] == "delete" for a in actions)
  452. @pytest.mark.asyncio
  453. async def test_drop(self, global_config, embed_func, mock_client):
  454. with patch.object(ClientManager, "get_client", return_value=mock_client):
  455. s = self._make(global_config, embed_func)
  456. await s.initialize()
  457. result = await s.drop()
  458. assert result["status"] == "success"
  459. mock_client.indices.delete.assert_awaited_once()
  460. @pytest.mark.asyncio
  461. async def test_drop_error_marks_index_not_ready_and_next_upsert_recreates_index(
  462. self, global_config, embed_func, mock_client
  463. ):
  464. mock_client.indices.delete = AsyncMock(
  465. side_effect=OpenSearchException("drop failed")
  466. )
  467. with patch.object(ClientManager, "get_client", return_value=mock_client):
  468. with patch(
  469. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  470. ) as mock_bulk:
  471. mock_bulk.return_value = (1, [])
  472. s = self._make(global_config, embed_func)
  473. await s.initialize()
  474. with patch.object(
  475. s, "_create_index_if_not_exists", new_callable=AsyncMock
  476. ) as mock_create:
  477. result = await s.drop()
  478. assert result["status"] == "error"
  479. assert s._index_ready is False
  480. await s.upsert({"k1": {"content": "v1"}})
  481. mock_create.assert_awaited_once()
  482. @pytest.mark.asyncio
  483. async def test_upsert_after_drop_recreates_index(
  484. self, global_config, embed_func, mock_client
  485. ):
  486. with patch.object(ClientManager, "get_client", return_value=mock_client):
  487. with patch(
  488. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  489. ) as mock_bulk:
  490. mock_bulk.return_value = (1, [])
  491. s = self._make(global_config, embed_func)
  492. await s.initialize()
  493. with patch.object(
  494. s, "_create_index_if_not_exists", new_callable=AsyncMock
  495. ) as mock_create:
  496. await s.drop()
  497. await s.upsert({"k1": {"content": "v1"}})
  498. mock_create.assert_awaited_once()
  499. @pytest.mark.asyncio
  500. async def test_reads_short_circuit_after_drop(
  501. self, global_config, embed_func, mock_client
  502. ):
  503. with patch.object(ClientManager, "get_client", return_value=mock_client):
  504. s = self._make(global_config, embed_func)
  505. await s.initialize()
  506. await s.drop()
  507. assert await s.get_by_id("doc1") is None
  508. assert await s.get_by_ids(["doc1", "doc2"]) == [None, None]
  509. assert await s.is_empty() is True
  510. mock_client.mget.assert_not_awaited()
  511. mock_client.count.assert_not_awaited()
  512. @pytest.mark.asyncio
  513. async def test_read_missing_index_demotes_readiness(
  514. self, global_config, embed_func, mock_client
  515. ):
  516. mock_client.mget = AsyncMock(side_effect=_missing_index_error())
  517. with patch.object(ClientManager, "get_client", return_value=mock_client):
  518. s = self._make(global_config, embed_func)
  519. await s.initialize()
  520. assert await s.get_by_id("doc1") is None
  521. assert await s.get_by_id("doc1") is None
  522. assert s._index_ready is False
  523. assert mock_client.mget.await_count == 1
  524. @pytest.mark.asyncio
  525. async def test_iter_raw_docs_uses_pit_and_search_after(
  526. self, global_config, embed_func, mock_client
  527. ):
  528. mock_client.search = AsyncMock(
  529. side_effect=[
  530. {
  531. "hits": {
  532. "hits": [
  533. {"_id": "d1", "_source": {"content": "a"}, "sort": [1]},
  534. {"_id": "d2", "_source": {"content": "b"}, "sort": [2]},
  535. ]
  536. }
  537. },
  538. {
  539. "hits": {
  540. "hits": [
  541. {"_id": "d3", "_source": {"content": "c"}, "sort": [3]}
  542. ]
  543. }
  544. },
  545. ]
  546. )
  547. with patch.object(ClientManager, "get_client", return_value=mock_client):
  548. s = self._make(global_config, embed_func)
  549. await s.initialize()
  550. batches = [batch async for batch in s._iter_raw_docs(batch_size=2)]
  551. assert [[doc["_id"] for doc in batch] for batch in batches] == [
  552. ["d1", "d2"],
  553. ["d3"],
  554. ]
  555. assert (
  556. "search_after"
  557. not in mock_client.search.await_args_list[0].kwargs["body"]
  558. )
  559. assert mock_client.search.await_args_list[1].kwargs["body"][
  560. "search_after"
  561. ] == [2]
  562. mock_client.create_pit.assert_awaited_once()
  563. mock_client.delete_pit.assert_awaited_once()
  564. @pytest.mark.asyncio
  565. async def test_iter_raw_docs_missing_index_demotes_readiness(
  566. self, global_config, embed_func, mock_client
  567. ):
  568. mock_client.search = AsyncMock(side_effect=_missing_index_error())
  569. with patch.object(ClientManager, "get_client", return_value=mock_client):
  570. s = self._make(global_config, embed_func)
  571. await s.initialize()
  572. batches = [batch async for batch in s._iter_raw_docs(batch_size=2)]
  573. assert batches == []
  574. assert s._index_ready is False
  575. mock_client.create_pit.assert_awaited_once()
  576. mock_client.delete_pit.assert_awaited_once()
  577. @pytest.mark.asyncio
  578. async def test_finalize(self, global_config, embed_func, mock_client):
  579. with patch.object(ClientManager, "get_client", return_value=mock_client):
  580. with patch.object(
  581. ClientManager, "release_client", new_callable=AsyncMock
  582. ) as mock_release:
  583. s = self._make(global_config, embed_func)
  584. await s.initialize()
  585. await s.finalize()
  586. mock_release.assert_awaited_once()
  587. assert s.client is None
  588. # ---------------------------------------------------------------------------
  589. # KV storage write batching (derived from issue #2785 / PR #2822)
  590. # ---------------------------------------------------------------------------
  591. class TestKVStorageBatching:
  592. """Tests for the buffered upsert/delete + flush behaviour."""
  593. def _make(self, global_config, embed_func, workspace="test"):
  594. return OpenSearchKVStorage(
  595. namespace="text_chunks",
  596. global_config=global_config,
  597. embedding_func=embed_func,
  598. workspace=workspace,
  599. )
  600. @pytest.mark.asyncio
  601. async def test_repeated_kv_upserts_flush_in_single_bulk_call(
  602. self, global_config, embed_func, mock_client
  603. ):
  604. """Many small upsert() calls collapse to one async_bulk on flush."""
  605. with patch.object(ClientManager, "get_client", return_value=mock_client):
  606. with patch(
  607. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  608. ) as mock_bulk:
  609. mock_bulk.return_value = (5, [])
  610. s = self._make(global_config, embed_func)
  611. await s.initialize()
  612. for i in range(5):
  613. await s.upsert({f"k{i}": {"content": f"doc {i}"}})
  614. mock_bulk.assert_not_awaited()
  615. await s.index_done_callback()
  616. mock_bulk.assert_awaited_once()
  617. actions = mock_bulk.call_args[0][1]
  618. assert len(actions) == 5
  619. assert {a["_id"] for a in actions} == {f"k{i}" for i in range(5)}
  620. @pytest.mark.asyncio
  621. async def test_kv_upsert_overwrites_pending_doc_for_same_id(
  622. self, global_config, embed_func, mock_client
  623. ):
  624. """Upserting the same id twice keeps only the latest payload."""
  625. with patch.object(ClientManager, "get_client", return_value=mock_client):
  626. with patch(
  627. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  628. ) as mock_bulk:
  629. mock_bulk.return_value = (1, [])
  630. s = self._make(global_config, embed_func)
  631. await s.initialize()
  632. await s.upsert({"k1": {"content": "first"}})
  633. await s.upsert({"k1": {"content": "second"}})
  634. await s.index_done_callback()
  635. actions = mock_bulk.call_args[0][1]
  636. assert len(actions) == 1
  637. assert actions[0]["_source"]["content"] == "second"
  638. @pytest.mark.asyncio
  639. async def test_kv_delete_cancels_pending_upsert(
  640. self, global_config, embed_func, mock_client
  641. ):
  642. """A delete after a buffered upsert removes the upsert from the buffer.
  643. Without this, the flush would re-index the doc and silently
  644. resurrect a logically-deleted key.
  645. """
  646. with patch.object(ClientManager, "get_client", return_value=mock_client):
  647. with patch(
  648. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  649. ) as mock_bulk:
  650. mock_bulk.return_value = (1, [])
  651. s = self._make(global_config, embed_func)
  652. await s.initialize()
  653. await s.upsert({"k1": {"content": "doomed"}})
  654. await s.delete(["k1"])
  655. assert "k1" not in s._pending_upserts
  656. assert "k1" in s._pending_kv_deletes
  657. await s.index_done_callback()
  658. actions = mock_bulk.call_args[0][1]
  659. assert len(actions) == 1
  660. assert actions[0]["_op_type"] == "delete"
  661. @pytest.mark.asyncio
  662. async def test_kv_upsert_cancels_pending_delete(
  663. self, global_config, embed_func, mock_client
  664. ):
  665. """An upsert after a buffered delete removes the tombstone."""
  666. with patch.object(ClientManager, "get_client", return_value=mock_client):
  667. with patch(
  668. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  669. ) as mock_bulk:
  670. mock_bulk.return_value = (1, [])
  671. s = self._make(global_config, embed_func)
  672. await s.initialize()
  673. await s.delete(["k1"])
  674. await s.upsert({"k1": {"content": "resurrected"}})
  675. assert "k1" not in s._pending_kv_deletes
  676. assert "k1" in s._pending_upserts
  677. await s.index_done_callback()
  678. actions = mock_bulk.call_args[0][1]
  679. assert len(actions) == 1
  680. assert actions[0]["_op_type"] == "index"
  681. @pytest.mark.asyncio
  682. async def test_kv_delete_works_when_index_not_ready(
  683. self, global_config, embed_func, mock_client
  684. ):
  685. """delete() must invalidate pending upserts even if the index has
  686. been marked missing -- otherwise the next flush would resurrect
  687. the logically-deleted key.
  688. """
  689. with patch.object(ClientManager, "get_client", return_value=mock_client):
  690. with patch(
  691. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  692. ) as mock_bulk:
  693. mock_bulk.return_value = (1, [])
  694. s = self._make(global_config, embed_func)
  695. await s.initialize()
  696. await s.upsert({"k1": {"content": "x"}})
  697. s._mark_index_missing()
  698. await s.delete(["k1"])
  699. # Buffer invariants hold regardless of _index_ready.
  700. assert "k1" not in s._pending_upserts
  701. assert "k1" in s._pending_kv_deletes
  702. @pytest.mark.asyncio
  703. async def test_kv_get_by_id_reads_pending_buffer(
  704. self, global_config, embed_func, mock_client
  705. ):
  706. """Buffered upserts are visible to get_by_id without hitting OpenSearch."""
  707. with patch.object(ClientManager, "get_client", return_value=mock_client):
  708. s = self._make(global_config, embed_func)
  709. await s.initialize()
  710. await s.upsert({"k1": {"content": "buffered"}})
  711. doc = await s.get_by_id("k1")
  712. assert doc is not None
  713. assert doc["_id"] == "k1"
  714. assert doc["content"] == "buffered"
  715. mock_client.mget.assert_not_awaited()
  716. @pytest.mark.asyncio
  717. async def test_kv_get_by_id_returns_none_for_pending_delete(
  718. self, global_config, embed_func, mock_client
  719. ):
  720. """A pending tombstone shadows any persisted doc, without mget RTT."""
  721. mock_client.mget = AsyncMock()
  722. with patch.object(ClientManager, "get_client", return_value=mock_client):
  723. s = self._make(global_config, embed_func)
  724. await s.initialize()
  725. await s.delete(["k1"])
  726. assert await s.get_by_id("k1") is None
  727. mock_client.mget.assert_not_awaited()
  728. @pytest.mark.asyncio
  729. async def test_kv_get_by_id_strips_mirrored_id_from_buffer_path(
  730. self, global_config, embed_func, mock_client
  731. ):
  732. """Buffered docs internally carry __mirrored_id (used for PIT sort);
  733. the returned dict must NOT expose it, matching the mget read path."""
  734. with patch.object(ClientManager, "get_client", return_value=mock_client):
  735. s = self._make(global_config, embed_func)
  736. await s.initialize()
  737. await s.upsert({"k1": {"content": "x"}})
  738. # Sanity: the buffer entry itself carries __mirrored_id.
  739. assert s._pending_upserts["k1"]["__mirrored_id"] == "k1"
  740. doc = await s.get_by_id("k1")
  741. assert doc is not None
  742. assert "__mirrored_id" not in doc
  743. assert doc["_id"] == "k1"
  744. @pytest.mark.asyncio
  745. async def test_kv_get_by_ids_merges_buffer_and_mget(
  746. self, global_config, embed_func, mock_client
  747. ):
  748. """get_by_ids returns buffered docs and falls back to mget for the rest."""
  749. mock_client.mget = AsyncMock(
  750. return_value={
  751. "docs": [
  752. {
  753. "_id": "k2",
  754. "found": True,
  755. "_source": {"content": "from_index"},
  756. },
  757. ]
  758. }
  759. )
  760. with patch.object(ClientManager, "get_client", return_value=mock_client):
  761. s = self._make(global_config, embed_func)
  762. await s.initialize()
  763. await s.upsert({"k1": {"content": "buffered"}})
  764. docs = await s.get_by_ids(["k1", "k2"])
  765. assert docs[0]["content"] == "buffered"
  766. assert "__mirrored_id" not in docs[0]
  767. assert docs[1]["content"] == "from_index"
  768. mock_client.mget.assert_awaited_once_with(
  769. index=s._index_name, body={"ids": ["k2"]}
  770. )
  771. @pytest.mark.asyncio
  772. async def test_kv_filter_keys_excludes_buffered_upserts(
  773. self, global_config, embed_func, mock_client
  774. ):
  775. """Buffered upserts shadow OpenSearch: filter_keys treats them as
  776. existing and never queries them via mget."""
  777. mock_client.mget = AsyncMock(
  778. return_value={"docs": [{"_id": "k2", "found": False}]}
  779. )
  780. with patch.object(ClientManager, "get_client", return_value=mock_client):
  781. s = self._make(global_config, embed_func)
  782. await s.initialize()
  783. await s.upsert({"k1": {"content": "x"}})
  784. missing = await s.filter_keys({"k1", "k2"})
  785. assert missing == {"k2"}
  786. # Only the unbuffered id is queried server-side.
  787. ((_, kwargs),) = mock_client.mget.await_args_list[0:1]
  788. assert kwargs["body"] == {"ids": ["k2"]}
  789. @pytest.mark.asyncio
  790. async def test_kv_filter_keys_treats_buffered_deletes_as_missing(
  791. self, global_config, embed_func, mock_client
  792. ):
  793. """A persisted-but-pending-delete key must be reported as missing
  794. AND must NOT be looked up via mget (otherwise the still-persisted
  795. row would be misclassified as existing)."""
  796. mock_client.mget = AsyncMock(
  797. return_value={"docs": [{"_id": "k3", "found": True}]}
  798. )
  799. with patch.object(ClientManager, "get_client", return_value=mock_client):
  800. s = self._make(global_config, embed_func)
  801. await s.initialize()
  802. await s.delete(["k1"]) # tombstone
  803. missing = await s.filter_keys({"k1", "k3"})
  804. assert "k1" in missing # tombstoned key counts as missing
  805. assert "k3" not in missing # exists on server
  806. # The tombstone id was NOT sent to mget.
  807. mget_kwargs = mock_client.mget.await_args_list[0].kwargs
  808. assert mget_kwargs["body"] == {"ids": ["k3"]}
  809. @pytest.mark.asyncio
  810. async def test_kv_is_empty_returns_false_with_pending_upsert(
  811. self, global_config, embed_func, mock_client
  812. ):
  813. """is_empty short-circuits to False when the buffer has pending
  814. upserts -- avoiding the counterintuitive "I just upserted but
  815. is_empty returned True" outcome."""
  816. with patch.object(ClientManager, "get_client", return_value=mock_client):
  817. s = self._make(global_config, embed_func)
  818. await s.initialize()
  819. await s.upsert({"k1": {"content": "x"}})
  820. assert await s.is_empty() is False
  821. mock_client.count.assert_not_awaited()
  822. @pytest.mark.asyncio
  823. async def test_kv_finalize_flushes_pending(
  824. self, global_config, embed_func, mock_client
  825. ):
  826. """finalize() flushes the buffer before releasing the client."""
  827. with patch.object(ClientManager, "get_client", return_value=mock_client):
  828. with patch(
  829. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  830. ) as mock_bulk:
  831. mock_bulk.return_value = (1, [])
  832. s = self._make(global_config, embed_func)
  833. await s.initialize()
  834. await s.upsert({"k1": {"content": "to flush"}})
  835. await s.finalize()
  836. mock_bulk.assert_awaited_once()
  837. assert s.client is None
  838. @pytest.mark.asyncio
  839. async def test_kv_finalize_raises_when_retryable_buffer_remains(
  840. self, global_config, embed_func, mock_client
  841. ):
  842. """finalize() must surface a RuntimeError when retryable bulk
  843. failures left rows buffered, otherwise the upstream
  844. finalize_storages() call would log the storage as successfully
  845. finalized while writes are silently lost.
  846. The client is still released so we don't leak a connection on
  847. shutdown.
  848. """
  849. with patch.object(ClientManager, "get_client", return_value=mock_client):
  850. with patch.object(
  851. ClientManager, "release_client", new_callable=AsyncMock
  852. ) as mock_release:
  853. with patch(
  854. "lightrag.kg.opensearch_impl.helpers.async_bulk",
  855. new_callable=AsyncMock,
  856. ) as mock_bulk:
  857. # 503 is retryable; flush keeps it in the buffer.
  858. mock_bulk.return_value = (
  859. 0,
  860. [{"index": {"_id": "k1", "status": 503, "error": "down"}}],
  861. )
  862. s = self._make(global_config, embed_func)
  863. await s.initialize()
  864. await s.upsert({"k1": {"content": "stuck"}})
  865. with pytest.raises(RuntimeError, match="pending upserts"):
  866. await s.finalize()
  867. # Client released regardless of the failure.
  868. mock_release.assert_awaited_once()
  869. assert s.client is None
  870. @pytest.mark.asyncio
  871. async def test_kv_finalize_propagates_flush_exception(
  872. self, global_config, embed_func, mock_client
  873. ):
  874. """If async_bulk itself raises, finalize() still releases the
  875. client and wraps the original error in a RuntimeError that
  876. names the unflushed buffer counts.
  877. """
  878. with patch.object(ClientManager, "get_client", return_value=mock_client):
  879. with patch.object(
  880. ClientManager, "release_client", new_callable=AsyncMock
  881. ) as mock_release:
  882. with patch(
  883. "lightrag.kg.opensearch_impl.helpers.async_bulk",
  884. new_callable=AsyncMock,
  885. ) as mock_bulk:
  886. mock_bulk.side_effect = OpenSearchException("connection reset")
  887. s = self._make(global_config, embed_func)
  888. await s.initialize()
  889. await s.upsert({"k1": {"content": "stuck"}})
  890. with pytest.raises(RuntimeError) as exc_info:
  891. await s.finalize()
  892. # Wrapped: cause is the original OpenSearchException.
  893. assert isinstance(exc_info.value.__cause__, OpenSearchException)
  894. mock_release.assert_awaited_once()
  895. assert s.client is None
  896. @pytest.mark.asyncio
  897. async def test_kv_finalize_propagates_cancellation(
  898. self, global_config, embed_func, mock_client
  899. ):
  900. """asyncio.CancelledError raised during the final flush must
  901. propagate UN-wrapped so the shutdown sequence honours the
  902. cancellation signal. The client is still released (finally
  903. block) before the cancellation continues.
  904. """
  905. with patch.object(ClientManager, "get_client", return_value=mock_client):
  906. with patch.object(
  907. ClientManager, "release_client", new_callable=AsyncMock
  908. ) as mock_release:
  909. with patch(
  910. "lightrag.kg.opensearch_impl.helpers.async_bulk",
  911. new_callable=AsyncMock,
  912. ) as mock_bulk:
  913. mock_bulk.side_effect = asyncio.CancelledError()
  914. s = self._make(global_config, embed_func)
  915. await s.initialize()
  916. await s.upsert({"k1": {"content": "stuck"}})
  917. with pytest.raises(asyncio.CancelledError):
  918. await s.finalize()
  919. # finally block still released the client.
  920. mock_release.assert_awaited_once()
  921. assert s.client is None
  922. @pytest.mark.asyncio
  923. async def test_kv_drop_discards_buffers_and_serialises_with_flush(
  924. self, global_config, embed_func, mock_client
  925. ):
  926. """drop() drops both buffers and is serialised with any in-flight
  927. flush so indices.delete cannot land mid-bulk."""
  928. flush_started = asyncio.Event()
  929. flush_can_finish = asyncio.Event()
  930. drop_delete_started = asyncio.Event()
  931. async def slow_bulk(client, actions, raise_on_error=False, **kwargs):
  932. flush_started.set()
  933. await flush_can_finish.wait()
  934. return (len(actions), [])
  935. async def watch_indices_delete(**kwargs):
  936. drop_delete_started.set()
  937. mock_client.indices.delete = AsyncMock(side_effect=watch_indices_delete)
  938. with patch.object(ClientManager, "get_client", return_value=mock_client):
  939. with patch("lightrag.kg.opensearch_impl.helpers.async_bulk", new=slow_bulk):
  940. s = self._make(global_config, embed_func)
  941. await s.initialize()
  942. await s.upsert({"k1": {"content": "x"}})
  943. await s.delete(["k2"])
  944. flush_task = asyncio.create_task(s.index_done_callback())
  945. await flush_started.wait()
  946. drop_task = asyncio.create_task(s.drop())
  947. for _ in range(5):
  948. await asyncio.sleep(0)
  949. assert (
  950. not drop_delete_started.is_set()
  951. ), "indices.delete should be blocked behind the flush lock"
  952. assert not drop_task.done()
  953. flush_can_finish.set()
  954. await flush_task
  955. await drop_task
  956. assert drop_delete_started.is_set()
  957. # Even though flush flushed k1/k2, drop() then cleared the
  958. # buffer state (no-op here because flush already drained
  959. # them, but the assertion confirms drop() does not crash
  960. # against the now-empty buffer).
  961. assert s._pending_upserts == {}
  962. assert s._pending_kv_deletes == set()
  963. @pytest.mark.asyncio
  964. async def test_kv_failed_flush_retains_retryable(
  965. self, global_config, embed_func, mock_client
  966. ):
  967. """Transient (5xx) per-doc failures stay buffered for the next flush."""
  968. with patch.object(ClientManager, "get_client", return_value=mock_client):
  969. with patch(
  970. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  971. ) as mock_bulk:
  972. mock_bulk.return_value = (
  973. 1,
  974. [{"index": {"_id": "k2", "status": 503, "error": "down"}}],
  975. )
  976. s = self._make(global_config, embed_func)
  977. await s.initialize()
  978. await s.upsert({"k1": {"content": "ok"}, "k2": {"content": "boom"}})
  979. await s.index_done_callback()
  980. assert "k1" not in s._pending_upserts
  981. assert "k2" in s._pending_upserts
  982. @pytest.mark.asyncio
  983. async def test_kv_failed_flush_drops_non_retryable(
  984. self, global_config, embed_func, mock_client
  985. ):
  986. """Permanent (4xx, e.g. mapping error) failures are cleared from
  987. the buffer rather than retried forever."""
  988. with patch.object(ClientManager, "get_client", return_value=mock_client):
  989. with patch(
  990. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  991. ) as mock_bulk:
  992. mock_bulk.return_value = (
  993. 0,
  994. [
  995. {
  996. "index": {
  997. "_id": "k1",
  998. "status": 400,
  999. "error": {
  1000. "type": "mapper_parsing_exception",
  1001. "reason": "bad",
  1002. },
  1003. }
  1004. },
  1005. {"index": {"_id": "k2", "status": 503, "error": "down"}},
  1006. ],
  1007. )
  1008. s = self._make(global_config, embed_func)
  1009. await s.initialize()
  1010. await s.upsert({"k1": {"content": "x"}, "k2": {"content": "y"}})
  1011. await s.index_done_callback()
  1012. assert "k1" not in s._pending_upserts
  1013. assert "k2" in s._pending_upserts
  1014. @pytest.mark.asyncio
  1015. async def test_kv_concurrent_upsert_during_flush_blocked(
  1016. self, global_config, embed_func, mock_client
  1017. ):
  1018. """A concurrent upsert that lands while async_bulk is in flight is
  1019. blocked by the namespace lock and lands in the buffer only after
  1020. the flush completes."""
  1021. flush_started = asyncio.Event()
  1022. flush_can_finish = asyncio.Event()
  1023. async def slow_bulk(client, actions, raise_on_error=False, **kwargs):
  1024. flush_started.set()
  1025. await flush_can_finish.wait()
  1026. return (len(actions), [])
  1027. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1028. with patch("lightrag.kg.opensearch_impl.helpers.async_bulk", new=slow_bulk):
  1029. s = self._make(global_config, embed_func)
  1030. await s.initialize()
  1031. await s.upsert({"k1": {"content": "first"}})
  1032. flush_task = asyncio.create_task(s.index_done_callback())
  1033. await flush_started.wait()
  1034. concurrent_task = asyncio.create_task(
  1035. s.upsert({"k2": {"content": "concurrent"}})
  1036. )
  1037. for _ in range(5):
  1038. await asyncio.sleep(0)
  1039. assert (
  1040. not concurrent_task.done()
  1041. ), "concurrent upsert should be blocked by the flush lock"
  1042. assert "k2" not in s._pending_upserts
  1043. flush_can_finish.set()
  1044. await flush_task
  1045. await concurrent_task
  1046. # k1 flushed and cleared; k2 added after flush released.
  1047. assert "k1" not in s._pending_upserts
  1048. assert "k2" in s._pending_upserts
  1049. # ---------------------------------------------------------------------------
  1050. # DocStatus Storage
  1051. # ---------------------------------------------------------------------------
  1052. class TestDocStatusStorage:
  1053. """Tests for OpenSearchDocStatusStorage including aggregations, pagination, and data normalization."""
  1054. def _make(self, global_config, embed_func, workspace="test"):
  1055. return OpenSearchDocStatusStorage(
  1056. namespace="doc_status",
  1057. global_config=global_config,
  1058. embedding_func=embed_func,
  1059. workspace=workspace,
  1060. )
  1061. @pytest.mark.asyncio
  1062. async def test_index_name(self, global_config, embed_func):
  1063. s = self._make(global_config, embed_func)
  1064. assert s._index_name == "test_doc_status"
  1065. @pytest.mark.asyncio
  1066. async def test_initialize_creates_index(
  1067. self, global_config, embed_func, mock_client
  1068. ):
  1069. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1070. s = self._make(global_config, embed_func)
  1071. await s.initialize()
  1072. mock_client.indices.create.assert_awaited_once()
  1073. @pytest.mark.asyncio
  1074. async def test_get_by_id(self, global_config, embed_func, mock_client):
  1075. mock_client.mget = AsyncMock(
  1076. return_value={
  1077. "docs": [
  1078. {
  1079. "_id": "doc-abc",
  1080. "found": True,
  1081. "_source": {"status": "processed", "file_path": "/a.txt"},
  1082. }
  1083. ]
  1084. }
  1085. )
  1086. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1087. s = self._make(global_config, embed_func)
  1088. await s.initialize()
  1089. doc = await s.get_by_id("doc-abc")
  1090. assert doc["status"] == "processed"
  1091. assert doc["_id"] == "doc-abc"
  1092. mock_client.mget.assert_awaited_once_with(
  1093. index=s._index_name, body={"ids": ["doc-abc"]}
  1094. )
  1095. @pytest.mark.asyncio
  1096. async def test_get_by_id_not_found(self, global_config, embed_func, mock_client):
  1097. mock_client.mget = AsyncMock(
  1098. return_value={"docs": [{"_id": "missing", "found": False}]}
  1099. )
  1100. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1101. s = self._make(global_config, embed_func)
  1102. await s.initialize()
  1103. assert await s.get_by_id("missing") is None
  1104. mock_client.get.assert_not_awaited()
  1105. @pytest.mark.asyncio
  1106. async def test_upsert_sets_chunks_list_default(
  1107. self, global_config, embed_func, mock_client
  1108. ):
  1109. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1110. with patch(
  1111. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  1112. ) as mock_bulk:
  1113. mock_bulk.return_value = (1, [])
  1114. s = self._make(global_config, embed_func)
  1115. await s.initialize()
  1116. await s.upsert({"d1": {"status": "pending"}})
  1117. actions = mock_bulk.call_args[0][1]
  1118. assert actions[0]["_source"]["chunks_list"] == []
  1119. @pytest.mark.asyncio
  1120. async def test_get_status_counts(self, global_config, embed_func, mock_client):
  1121. mock_client.search = AsyncMock(
  1122. return_value={
  1123. "hits": {"hits": [], "total": {"value": 0}},
  1124. "aggregations": {
  1125. "status_counts": {
  1126. "buckets": [
  1127. {"key": "processed", "doc_count": 3},
  1128. {"key": "pending", "doc_count": 1},
  1129. ]
  1130. }
  1131. },
  1132. }
  1133. )
  1134. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1135. s = self._make(global_config, embed_func)
  1136. await s.initialize()
  1137. counts = await s.get_status_counts()
  1138. assert counts == {"processed": 3, "pending": 1}
  1139. @pytest.mark.asyncio
  1140. async def test_get_all_status_counts_includes_all(
  1141. self, global_config, embed_func, mock_client
  1142. ):
  1143. mock_client.search = AsyncMock(
  1144. return_value={
  1145. "hits": {"hits": [], "total": {"value": 0}},
  1146. "aggregations": {
  1147. "status_counts": {
  1148. "buckets": [
  1149. {"key": "processed", "doc_count": 5},
  1150. {"key": "failed", "doc_count": 2},
  1151. ]
  1152. }
  1153. },
  1154. }
  1155. )
  1156. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1157. s = self._make(global_config, embed_func)
  1158. await s.initialize()
  1159. counts = await s.get_all_status_counts()
  1160. assert counts["all"] == 7
  1161. assert counts["processed"] == 5
  1162. @pytest.mark.asyncio
  1163. async def test_get_docs_by_status(self, global_config, embed_func, mock_client):
  1164. mock_client.search = AsyncMock(
  1165. return_value={
  1166. "hits": {
  1167. "hits": [
  1168. {
  1169. "_id": "d1",
  1170. "_source": {
  1171. "status": "processed",
  1172. "file_path": "/a.txt",
  1173. "content_summary": "s",
  1174. "content_length": 10,
  1175. "chunks_count": 1,
  1176. "created_at": 100,
  1177. "updated_at": 200,
  1178. },
  1179. "sort": ["d1"],
  1180. },
  1181. ],
  1182. "total": {"value": 1},
  1183. },
  1184. }
  1185. )
  1186. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1187. s = self._make(global_config, embed_func)
  1188. await s.initialize()
  1189. result = await s.get_docs_by_status(DocStatus.PROCESSED)
  1190. assert "d1" in result
  1191. assert isinstance(result["d1"], DocProcessingStatus)
  1192. @pytest.mark.asyncio
  1193. async def test_get_docs_paginated(self, global_config, embed_func, mock_client):
  1194. """Page 1 returns results directly without search_after."""
  1195. mock_client.count = AsyncMock(return_value={"count": 50})
  1196. mock_client.search = AsyncMock(
  1197. return_value={
  1198. "hits": {
  1199. "hits": [
  1200. {
  1201. "_id": "d1",
  1202. "_source": {
  1203. "status": "processed",
  1204. "file_path": "/a.txt",
  1205. "content_summary": "s",
  1206. "content_length": 10,
  1207. "chunks_count": 1,
  1208. "created_at": 100,
  1209. "updated_at": 200,
  1210. },
  1211. "sort": [200, "d1"],
  1212. },
  1213. ],
  1214. "total": {"value": 50},
  1215. },
  1216. }
  1217. )
  1218. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1219. s = self._make(global_config, embed_func)
  1220. await s.initialize()
  1221. docs, total = await s.get_docs_paginated(page=1, page_size=10)
  1222. assert total == 50
  1223. assert len(docs) == 1
  1224. assert docs[0][0] == "d1"
  1225. # Page 1: no search_after needed, single search call
  1226. assert mock_client.search.await_count == 1
  1227. body = mock_client.search.call_args.kwargs.get(
  1228. "body"
  1229. ) or mock_client.search.call_args[1].get("body", {})
  1230. assert "search_after" not in body
  1231. @pytest.mark.asyncio
  1232. async def test_get_docs_paginated_page2_uses_search_after(
  1233. self, global_config, embed_func, mock_client
  1234. ):
  1235. """Page 2 skips page 1 results via search_after."""
  1236. mock_client.count = AsyncMock(return_value={"count": 50})
  1237. call_count = {"n": 0}
  1238. async def search_side_effect(*args, **kwargs):
  1239. call_count["n"] += 1
  1240. body = kwargs.get("body", {})
  1241. if "search_after" not in body:
  1242. # First call: skip batch
  1243. return {
  1244. "hits": {
  1245. "hits": [
  1246. {
  1247. "_id": f"skip{i}",
  1248. "_source": {
  1249. "status": "processed",
  1250. "file_path": f"/{i}.txt",
  1251. "content_summary": "s",
  1252. "content_length": 1,
  1253. "chunks_count": 1,
  1254. "created_at": 100,
  1255. "updated_at": 100 + i,
  1256. },
  1257. "sort": [100 + i, f"skip{i}"],
  1258. }
  1259. for i in range(10)
  1260. ],
  1261. "total": {"value": 50},
  1262. }
  1263. }
  1264. else:
  1265. # Second call: actual page
  1266. return {
  1267. "hits": {
  1268. "hits": [
  1269. {
  1270. "_id": "page2_doc",
  1271. "_source": {
  1272. "status": "pending",
  1273. "file_path": "/p2.txt",
  1274. "content_summary": "s",
  1275. "content_length": 1,
  1276. "chunks_count": 1,
  1277. "created_at": 200,
  1278. "updated_at": 300,
  1279. },
  1280. "sort": [300, "page2_doc"],
  1281. }
  1282. ],
  1283. "total": {"value": 50},
  1284. }
  1285. }
  1286. mock_client.search = AsyncMock(side_effect=search_side_effect)
  1287. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1288. s = self._make(global_config, embed_func)
  1289. await s.initialize()
  1290. docs, total = await s.get_docs_paginated(page=2, page_size=10)
  1291. assert total == 50
  1292. assert len(docs) == 1
  1293. assert docs[0][0] == "page2_doc"
  1294. # 2 search calls: 1 skip + 1 fetch
  1295. assert mock_client.search.await_count == 2
  1296. @pytest.mark.asyncio
  1297. async def test_get_docs_paginated_empty_index(
  1298. self, global_config, embed_func, mock_client
  1299. ):
  1300. """Empty index returns empty list with total 0."""
  1301. mock_client.count = AsyncMock(return_value={"count": 0})
  1302. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1303. s = self._make(global_config, embed_func)
  1304. await s.initialize()
  1305. docs, total = await s.get_docs_paginated(page=1, page_size=10)
  1306. assert total == 0
  1307. assert docs == []
  1308. mock_client.search.assert_not_awaited()
  1309. @pytest.mark.asyncio
  1310. async def test_get_docs_paginated_page_beyond_total(
  1311. self, global_config, embed_func, mock_client
  1312. ):
  1313. """Requesting a page beyond total docs returns empty list."""
  1314. mock_client.count = AsyncMock(return_value={"count": 5})
  1315. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1316. s = self._make(global_config, embed_func)
  1317. await s.initialize()
  1318. docs, total = await s.get_docs_paginated(page=100, page_size=10)
  1319. assert total == 5
  1320. assert docs == []
  1321. @pytest.mark.asyncio
  1322. async def test_get_docs_paginated_with_status_filter(
  1323. self, global_config, embed_func, mock_client
  1324. ):
  1325. """Status filter is passed as term query."""
  1326. mock_client.count = AsyncMock(return_value={"count": 3})
  1327. mock_client.search = AsyncMock(
  1328. return_value={
  1329. "hits": {"hits": [], "total": {"value": 3}},
  1330. }
  1331. )
  1332. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1333. s = self._make(global_config, embed_func)
  1334. await s.initialize()
  1335. docs, total = await s.get_docs_paginated(
  1336. status_filter=DocStatus.PROCESSED, page=1, page_size=10
  1337. )
  1338. assert total == 3
  1339. # Verify count query used the status filter
  1340. count_body = mock_client.count.call_args.kwargs.get("body", {})
  1341. assert count_body["query"] == {"term": {"status": "processed"}}
  1342. @pytest.mark.asyncio
  1343. async def test_get_docs_paginated_with_status_filters(
  1344. self, global_config, embed_func, mock_client
  1345. ):
  1346. """Multi-status filters are passed as terms query and override status_filter."""
  1347. mock_client.count = AsyncMock(return_value={"count": 2})
  1348. mock_client.search = AsyncMock(
  1349. return_value={
  1350. "hits": {"hits": [], "total": {"value": 2}},
  1351. }
  1352. )
  1353. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1354. s = self._make(global_config, embed_func)
  1355. await s.initialize()
  1356. docs, total = await s.get_docs_paginated(
  1357. status_filter=DocStatus.PROCESSED,
  1358. status_filters=[DocStatus.PARSING, DocStatus.ANALYZING],
  1359. page=1,
  1360. page_size=10,
  1361. )
  1362. assert total == 2
  1363. assert docs == []
  1364. count_body = mock_client.count.call_args.kwargs.get("body", {})
  1365. assert count_body["query"] == {
  1366. "terms": {"status": ["analyzing", "parsing"]}
  1367. }
  1368. @pytest.mark.asyncio
  1369. async def test_get_doc_by_file_path(self, global_config, embed_func, mock_client):
  1370. mock_client.search = AsyncMock(
  1371. return_value={
  1372. "hits": {
  1373. "hits": [
  1374. {
  1375. "_id": "d1",
  1376. "_source": {
  1377. "file_path": "/test.txt",
  1378. "status": "processed",
  1379. },
  1380. },
  1381. ],
  1382. "total": {"value": 1},
  1383. },
  1384. }
  1385. )
  1386. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1387. s = self._make(global_config, embed_func)
  1388. await s.initialize()
  1389. doc = await s.get_doc_by_file_path("/test.txt")
  1390. assert doc is not None
  1391. assert doc["_id"] == "d1"
  1392. @pytest.mark.asyncio
  1393. async def test_get_doc_by_file_path_not_found(
  1394. self, global_config, embed_func, mock_client
  1395. ):
  1396. mock_client.search = AsyncMock(
  1397. return_value={
  1398. "hits": {"hits": [], "total": {"value": 0}},
  1399. }
  1400. )
  1401. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1402. s = self._make(global_config, embed_func)
  1403. await s.initialize()
  1404. assert await s.get_doc_by_file_path("/nope.txt") is None
  1405. @pytest.mark.asyncio
  1406. async def test_get_doc_by_file_basename_returns_tuple_on_hit(
  1407. self, global_config, embed_func, mock_client
  1408. ):
  1409. mock_client.search = AsyncMock(
  1410. return_value={
  1411. "hits": {
  1412. "hits": [
  1413. {
  1414. "_id": "doc-1",
  1415. "_source": {
  1416. "file_path": "report.pdf",
  1417. "status": "processed",
  1418. },
  1419. },
  1420. ],
  1421. "total": {"value": 1},
  1422. },
  1423. }
  1424. )
  1425. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1426. s = self._make(global_config, embed_func)
  1427. await s.initialize()
  1428. result = await s.get_doc_by_file_basename("report.pdf")
  1429. assert result is not None
  1430. doc_id, doc = result
  1431. assert doc_id == "doc-1"
  1432. assert doc["file_path"] == "report.pdf"
  1433. body = mock_client.search.call_args.kwargs.get(
  1434. "body"
  1435. ) or mock_client.search.call_args[1].get("body", {})
  1436. assert body["query"] == {"term": {"file_path": "report.pdf"}}
  1437. @pytest.mark.asyncio
  1438. async def test_get_doc_by_file_basename_empty_short_circuits(
  1439. self, global_config, embed_func, mock_client
  1440. ):
  1441. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1442. s = self._make(global_config, embed_func)
  1443. await s.initialize()
  1444. mock_client.search.reset_mock()
  1445. assert await s.get_doc_by_file_basename("") is None
  1446. mock_client.search.assert_not_awaited()
  1447. @pytest.mark.asyncio
  1448. async def test_get_doc_by_file_basename_unknown_source_sentinel(
  1449. self, global_config, embed_func, mock_client
  1450. ):
  1451. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1452. s = self._make(global_config, embed_func)
  1453. await s.initialize()
  1454. mock_client.search.reset_mock()
  1455. assert await s.get_doc_by_file_basename("unknown_source") is None
  1456. mock_client.search.assert_not_awaited()
  1457. @pytest.mark.asyncio
  1458. async def test_get_doc_by_file_basename_miss_returns_none(
  1459. self, global_config, embed_func, mock_client
  1460. ):
  1461. mock_client.search = AsyncMock(
  1462. return_value={"hits": {"hits": [], "total": {"value": 0}}}
  1463. )
  1464. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1465. s = self._make(global_config, embed_func)
  1466. await s.initialize()
  1467. assert await s.get_doc_by_file_basename("missing.pdf") is None
  1468. @pytest.mark.asyncio
  1469. async def test_get_doc_by_content_hash_returns_tuple_on_hit(
  1470. self, global_config, embed_func, mock_client
  1471. ):
  1472. mock_client.search = AsyncMock(
  1473. return_value={
  1474. "hits": {
  1475. "hits": [
  1476. {
  1477. "_id": "doc-1",
  1478. "_source": {
  1479. "file_path": "report.pdf",
  1480. "content_hash": "abc123",
  1481. "status": "processed",
  1482. },
  1483. },
  1484. ],
  1485. "total": {"value": 1},
  1486. },
  1487. }
  1488. )
  1489. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1490. s = self._make(global_config, embed_func)
  1491. await s.initialize()
  1492. result = await s.get_doc_by_content_hash("abc123")
  1493. assert result is not None
  1494. doc_id, doc = result
  1495. assert doc_id == "doc-1"
  1496. assert doc["content_hash"] == "abc123"
  1497. body = mock_client.search.call_args.kwargs.get(
  1498. "body"
  1499. ) or mock_client.search.call_args[1].get("body", {})
  1500. assert body["query"] == {"term": {"content_hash": "abc123"}}
  1501. @pytest.mark.asyncio
  1502. async def test_get_doc_by_content_hash_empty_short_circuits(
  1503. self, global_config, embed_func, mock_client
  1504. ):
  1505. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1506. s = self._make(global_config, embed_func)
  1507. await s.initialize()
  1508. mock_client.search.reset_mock()
  1509. assert await s.get_doc_by_content_hash("") is None
  1510. mock_client.search.assert_not_awaited()
  1511. @pytest.mark.asyncio
  1512. async def test_get_doc_by_content_hash_miss_returns_none(
  1513. self, global_config, embed_func, mock_client
  1514. ):
  1515. mock_client.search = AsyncMock(
  1516. return_value={"hits": {"hits": [], "total": {"value": 0}}}
  1517. )
  1518. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1519. s = self._make(global_config, embed_func)
  1520. await s.initialize()
  1521. assert await s.get_doc_by_content_hash("zzz999") is None
  1522. @pytest.mark.asyncio
  1523. async def test_ensure_content_hash_mapping_added_when_missing(
  1524. self, global_config, embed_func, mock_client
  1525. ):
  1526. """Pre-existing indices without content_hash mapping should get one added."""
  1527. mock_client.indices.exists = AsyncMock(return_value=True)
  1528. mock_client.indices.get_mapping = AsyncMock(
  1529. return_value={
  1530. "test_doc_status": {
  1531. "mappings": {
  1532. "properties": {
  1533. "__mirrored_id": {"type": "keyword"},
  1534. "status": {"type": "keyword"},
  1535. "file_path": {"type": "keyword"},
  1536. }
  1537. }
  1538. }
  1539. }
  1540. )
  1541. mock_client.indices.put_mapping = AsyncMock()
  1542. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1543. s = self._make(global_config, embed_func)
  1544. await s.initialize()
  1545. mock_client.indices.put_mapping.assert_awaited_once()
  1546. kwargs = mock_client.indices.put_mapping.call_args.kwargs
  1547. assert kwargs["body"] == {
  1548. "properties": {"content_hash": {"type": "keyword"}}
  1549. }
  1550. @pytest.mark.asyncio
  1551. async def test_ensure_content_hash_mapping_skipped_when_present(
  1552. self, global_config, embed_func, mock_client
  1553. ):
  1554. """Indices that already have content_hash mapping should not be touched."""
  1555. mock_client.indices.exists = AsyncMock(return_value=True)
  1556. mock_client.indices.get_mapping = AsyncMock(
  1557. return_value={
  1558. "test_doc_status": {
  1559. "mappings": {
  1560. "properties": {
  1561. "__mirrored_id": {"type": "keyword"},
  1562. "content_hash": {"type": "keyword"},
  1563. }
  1564. }
  1565. }
  1566. }
  1567. )
  1568. mock_client.indices.put_mapping = AsyncMock()
  1569. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1570. s = self._make(global_config, embed_func)
  1571. await s.initialize()
  1572. mock_client.indices.put_mapping.assert_not_awaited()
  1573. @pytest.mark.asyncio
  1574. async def test_prepare_doc_status_data(self, global_config, embed_func):
  1575. s = self._make(global_config, embed_func)
  1576. raw = {"_id": "x", "status": "processed", "error": "oops"}
  1577. data = s._prepare_doc_status_data(raw)
  1578. assert "_id" not in data
  1579. assert data["error_msg"] == "oops"
  1580. assert "error" not in data
  1581. assert data["file_path"] == "no-file-path"
  1582. assert data["metadata"] == {}
  1583. @pytest.mark.asyncio
  1584. async def test_drop_error_marks_index_not_ready_and_next_upsert_recreates_index(
  1585. self, global_config, embed_func, mock_client
  1586. ):
  1587. mock_client.indices.delete = AsyncMock(
  1588. side_effect=OpenSearchException("drop failed")
  1589. )
  1590. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1591. with patch(
  1592. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  1593. ) as mock_bulk:
  1594. mock_bulk.return_value = (1, [])
  1595. s = self._make(global_config, embed_func)
  1596. await s.initialize()
  1597. with patch.object(
  1598. s, "_create_index_if_not_exists", new_callable=AsyncMock
  1599. ) as mock_create:
  1600. result = await s.drop()
  1601. assert result["status"] == "error"
  1602. assert s._index_ready is False
  1603. await s.upsert({"d1": {"status": "pending"}})
  1604. mock_create.assert_awaited_once()
  1605. @pytest.mark.asyncio
  1606. async def test_upsert_after_drop_recreates_index(
  1607. self, global_config, embed_func, mock_client
  1608. ):
  1609. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1610. with patch(
  1611. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  1612. ) as mock_bulk:
  1613. mock_bulk.return_value = (1, [])
  1614. s = self._make(global_config, embed_func)
  1615. await s.initialize()
  1616. with patch.object(
  1617. s, "_create_index_if_not_exists", new_callable=AsyncMock
  1618. ) as mock_create:
  1619. await s.drop()
  1620. await s.upsert({"d1": {"status": "pending"}})
  1621. mock_create.assert_awaited_once()
  1622. @pytest.mark.asyncio
  1623. async def test_reads_short_circuit_after_drop(
  1624. self, global_config, embed_func, mock_client
  1625. ):
  1626. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1627. s = self._make(global_config, embed_func)
  1628. await s.initialize()
  1629. await s.drop()
  1630. assert await s.get_all_status_counts() == {}
  1631. assert await s.get_docs_paginated(page=1, page_size=10) == ([], 0)
  1632. assert await s.get_doc_by_file_path("/a.txt") is None
  1633. assert await s.get_docs_by_status(DocStatus.PROCESSED) == {}
  1634. mock_client.count.assert_not_awaited()
  1635. mock_client.search.assert_not_awaited()
  1636. mock_client.create_pit.assert_not_awaited()
  1637. @pytest.mark.asyncio
  1638. async def test_read_missing_index_demotes_readiness(
  1639. self, global_config, embed_func, mock_client
  1640. ):
  1641. mock_client.search = AsyncMock(side_effect=_missing_index_error())
  1642. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1643. s = self._make(global_config, embed_func)
  1644. await s.initialize()
  1645. assert await s.get_all_status_counts() == {}
  1646. assert await s.get_all_status_counts() == {}
  1647. assert s._index_ready is False
  1648. assert mock_client.search.await_count == 1
  1649. # ---------------------------------------------------------------------------
  1650. # Graph Storage
  1651. # ---------------------------------------------------------------------------
  1652. class TestGraphStorage:
  1653. """Tests for OpenSearchGraphStorage node/edge CRUD, batch ops, BFS, and label queries."""
  1654. def _make(self, global_config, embed_func, workspace="test"):
  1655. return OpenSearchGraphStorage(
  1656. namespace="chunk_entity_relation",
  1657. global_config=global_config,
  1658. embedding_func=embed_func,
  1659. workspace=workspace,
  1660. )
  1661. @pytest.mark.asyncio
  1662. async def test_index_names(self, global_config, embed_func):
  1663. s = self._make(global_config, embed_func)
  1664. assert s._nodes_index == "test_chunk_entity_relation-nodes"
  1665. assert s._edges_index == "test_chunk_entity_relation-edges"
  1666. @pytest.mark.asyncio
  1667. async def test_initialize_creates_both_indices(
  1668. self, global_config, embed_func, mock_client
  1669. ):
  1670. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1671. s = self._make(global_config, embed_func)
  1672. await s.initialize()
  1673. assert mock_client.indices.create.await_count == 2
  1674. @pytest.mark.asyncio
  1675. async def test_has_node_true(self, global_config, embed_func, mock_client):
  1676. mock_client.exists = AsyncMock(return_value=True)
  1677. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1678. s = self._make(global_config, embed_func)
  1679. await s.initialize()
  1680. assert await s.has_node("Alice") is True
  1681. @pytest.mark.asyncio
  1682. async def test_has_node_false(self, global_config, embed_func, mock_client):
  1683. mock_client.exists = AsyncMock(return_value=False)
  1684. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1685. s = self._make(global_config, embed_func)
  1686. await s.initialize()
  1687. assert await s.has_node("Nobody") is False
  1688. @pytest.mark.asyncio
  1689. async def test_has_edge(self, global_config, embed_func, mock_client):
  1690. mock_client.search = AsyncMock(
  1691. return_value={
  1692. "hits": {"hits": [], "total": {"value": 1}},
  1693. "aggregations": {
  1694. "status_counts": {"buckets": []},
  1695. "src": {"buckets": []},
  1696. "tgt": {"buckets": []},
  1697. "source_degrees": {"buckets": []},
  1698. "target_degrees": {"buckets": []},
  1699. },
  1700. }
  1701. )
  1702. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1703. s = self._make(global_config, embed_func)
  1704. await s.initialize()
  1705. assert await s.has_edge("A", "B") is True
  1706. @pytest.mark.asyncio
  1707. async def test_node_degree(self, global_config, embed_func, mock_client):
  1708. mock_client.count = AsyncMock(return_value={"count": 3})
  1709. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1710. s = self._make(global_config, embed_func)
  1711. await s.initialize()
  1712. assert await s.node_degree("A") == 3
  1713. @pytest.mark.asyncio
  1714. async def test_get_node(self, global_config, embed_func, mock_client):
  1715. mock_client.mget = AsyncMock(
  1716. return_value={
  1717. "docs": [
  1718. {
  1719. "_id": "Alice",
  1720. "found": True,
  1721. "_source": {
  1722. "entity_type": "person",
  1723. "description": "A researcher",
  1724. },
  1725. }
  1726. ]
  1727. }
  1728. )
  1729. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1730. s = self._make(global_config, embed_func)
  1731. await s.initialize()
  1732. node = await s.get_node("Alice")
  1733. assert node["entity_type"] == "person"
  1734. assert node["_id"] == "Alice"
  1735. mock_client.mget.assert_awaited_once_with(
  1736. index=s._nodes_index, body={"ids": ["Alice"]}
  1737. )
  1738. @pytest.mark.asyncio
  1739. async def test_get_node_not_found(self, global_config, embed_func, mock_client):
  1740. mock_client.mget = AsyncMock(
  1741. return_value={"docs": [{"_id": "Nobody", "found": False}]}
  1742. )
  1743. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1744. s = self._make(global_config, embed_func)
  1745. await s.initialize()
  1746. assert await s.get_node("Nobody") is None
  1747. mock_client.get.assert_not_awaited()
  1748. @pytest.mark.asyncio
  1749. async def test_get_edge(self, global_config, embed_func, mock_client):
  1750. # get_edge now uses mget (translog real-time) instead of search.
  1751. mock_client.mget = AsyncMock(
  1752. return_value={
  1753. "docs": [
  1754. {
  1755. "_id": "e1",
  1756. "found": True,
  1757. "_source": {
  1758. "source_node_id": "A",
  1759. "target_node_id": "B",
  1760. "weight": 1.0,
  1761. },
  1762. },
  1763. {
  1764. "_id": "e2",
  1765. "found": False,
  1766. },
  1767. ]
  1768. }
  1769. )
  1770. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1771. s = self._make(global_config, embed_func)
  1772. await s.initialize()
  1773. edge = await s.get_edge("A", "B")
  1774. assert edge is not None
  1775. assert edge["weight"] == 1.0
  1776. @pytest.mark.asyncio
  1777. async def test_get_node_edges(self, global_config, embed_func, mock_client):
  1778. mock_client.search = AsyncMock(
  1779. return_value={
  1780. "hits": {
  1781. "hits": [
  1782. {
  1783. "_id": "e1",
  1784. "_source": {"source_node_id": "A", "target_node_id": "B"},
  1785. "sort": [1],
  1786. },
  1787. {
  1788. "_id": "e2",
  1789. "_source": {"source_node_id": "C", "target_node_id": "A"},
  1790. "sort": [2],
  1791. },
  1792. ],
  1793. "total": {"value": 2},
  1794. },
  1795. }
  1796. )
  1797. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1798. s = self._make(global_config, embed_func)
  1799. await s.initialize()
  1800. edges = await s.get_node_edges("A")
  1801. assert len(edges) == 2
  1802. assert ("A", "B") in edges
  1803. @pytest.mark.asyncio
  1804. async def test_get_nodes_batch(self, global_config, embed_func, mock_client):
  1805. mock_client.mget = AsyncMock(
  1806. return_value={
  1807. "docs": [
  1808. {"_id": "A", "found": True, "_source": {"entity_type": "person"}},
  1809. {"_id": "B", "found": False},
  1810. ]
  1811. }
  1812. )
  1813. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1814. s = self._make(global_config, embed_func)
  1815. await s.initialize()
  1816. result = await s.get_nodes_batch(["A", "B"])
  1817. assert "A" in result
  1818. assert "B" not in result
  1819. @pytest.mark.asyncio
  1820. async def test_node_degrees_batch(self, global_config, embed_func, mock_client):
  1821. mock_client.search = AsyncMock(
  1822. return_value={
  1823. "hits": {"hits": [], "total": {"value": 0}},
  1824. "aggregations": {
  1825. "source_degrees": {"buckets": [{"key": "A", "doc_count": 2}]},
  1826. "target_degrees": {
  1827. "buckets": [
  1828. {"key": "A", "doc_count": 1},
  1829. {"key": "B", "doc_count": 3},
  1830. ]
  1831. },
  1832. "status_counts": {"buckets": []},
  1833. "src": {"buckets": []},
  1834. "tgt": {"buckets": []},
  1835. },
  1836. }
  1837. )
  1838. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1839. s = self._make(global_config, embed_func)
  1840. await s.initialize()
  1841. degrees = await s.node_degrees_batch(["A", "B"])
  1842. assert degrees["A"] == 3 # 2 + 1
  1843. assert degrees["B"] == 3
  1844. @pytest.mark.asyncio
  1845. async def test_upsert_node(self, global_config, embed_func, mock_client):
  1846. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1847. s = self._make(global_config, embed_func)
  1848. await s.initialize()
  1849. await s.upsert_node(
  1850. "Alice", {"entity_type": "person", "source_id": "c1<SEP>c2"}
  1851. )
  1852. mock_client.index.assert_awaited()
  1853. call_kwargs = mock_client.index.call_args
  1854. assert call_kwargs.kwargs["id"] == "Alice"
  1855. body = call_kwargs.kwargs["body"]
  1856. assert body["source_ids"] == ["c1", "c2"]
  1857. assert body["entity_id"] == "Alice"
  1858. @pytest.mark.asyncio
  1859. async def test_upsert_edge(self, global_config, embed_func, mock_client):
  1860. mock_client.exists = AsyncMock(return_value=False)
  1861. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1862. s = self._make(global_config, embed_func)
  1863. await s.initialize()
  1864. await s.upsert_edge("A", "B", {"weight": "1.0", "description": "knows"})
  1865. # Should call index twice: once for ensuring source node, once for edge
  1866. assert mock_client.index.await_count == 2
  1867. @pytest.mark.asyncio
  1868. async def test_upsert_edges_batch_reuses_id_for_reciprocal_edges(
  1869. self, global_config, embed_func, mock_client
  1870. ):
  1871. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1872. s = self._make(global_config, embed_func)
  1873. await s.initialize()
  1874. bulk_calls = []
  1875. async def capture_bulk(_client, actions, *args, **kwargs):
  1876. bulk_calls.append(list(actions))
  1877. return (len(bulk_calls[-1]), [])
  1878. mock_client.mget = AsyncMock(
  1879. side_effect=[
  1880. {"docs": []},
  1881. {"docs": [{"_id": "edge-ba", "found": False}] * 2},
  1882. ]
  1883. )
  1884. with patch(
  1885. "lightrag.kg.opensearch_impl.helpers.async_bulk",
  1886. new=AsyncMock(side_effect=capture_bulk),
  1887. ):
  1888. await s.upsert_edges_batch(
  1889. [
  1890. ("A", "B", {"weight": "1.0"}),
  1891. ("B", "A", {"weight": "2.0"}),
  1892. ]
  1893. )
  1894. edge_actions = bulk_calls[-1]
  1895. assert len(edge_actions) == 2
  1896. assert edge_actions[0]["_id"] == edge_actions[1]["_id"]
  1897. @pytest.mark.asyncio
  1898. async def test_upsert_after_drop_recreates_indices(
  1899. self, global_config, embed_func, mock_client
  1900. ):
  1901. mock_client.exists = AsyncMock(return_value=False)
  1902. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1903. s = self._make(global_config, embed_func)
  1904. with patch.object(
  1905. s, "_create_indices_if_not_exist", new_callable=AsyncMock
  1906. ) as mock_create:
  1907. await s.initialize()
  1908. mock_create.reset_mock()
  1909. await s.drop()
  1910. await s.upsert_edge("A", "B", {"weight": "1.0"})
  1911. mock_create.assert_awaited_once()
  1912. assert mock_client.index.await_count == 2
  1913. @pytest.mark.asyncio
  1914. async def test_reads_short_circuit_after_drop(
  1915. self, global_config, embed_func, mock_client
  1916. ):
  1917. mock_client.transport = AsyncMock()
  1918. mock_client.transport.perform_request = AsyncMock(
  1919. side_effect=Exception("PPL not available")
  1920. )
  1921. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1922. s = self._make(global_config, embed_func)
  1923. await s.initialize()
  1924. await s.drop()
  1925. graph = await s.get_knowledge_graph("A", max_depth=2)
  1926. assert await s.get_node("A") is None
  1927. assert await s.get_all_labels() == []
  1928. assert await s.has_edge("A", "B") is False
  1929. assert await s.node_degree("A") == 0
  1930. assert len(graph.nodes) == 0
  1931. assert len(graph.edges) == 0
  1932. mock_client.mget.assert_not_awaited()
  1933. mock_client.search.assert_not_awaited()
  1934. mock_client.create_pit.assert_not_awaited()
  1935. mock_client.count.assert_not_awaited()
  1936. @pytest.mark.asyncio
  1937. async def test_read_missing_index_demotes_readiness(
  1938. self, global_config, embed_func, mock_client
  1939. ):
  1940. mock_client.transport = AsyncMock()
  1941. mock_client.transport.perform_request = AsyncMock(
  1942. side_effect=Exception("PPL not available")
  1943. )
  1944. mock_client.mget = AsyncMock(side_effect=_missing_index_error())
  1945. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1946. s = self._make(global_config, embed_func)
  1947. await s.initialize()
  1948. assert await s.get_node("A") is None
  1949. assert await s.get_node("A") is None
  1950. assert s._indices_ready is False
  1951. assert mock_client.mget.await_count == 1
  1952. @pytest.mark.asyncio
  1953. async def test_delete_node(self, global_config, embed_func, mock_client):
  1954. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1955. s = self._make(global_config, embed_func)
  1956. await s.initialize()
  1957. await s.delete_node("Alice")
  1958. mock_client.delete_by_query.assert_awaited_once()
  1959. mock_client.delete.assert_awaited_once()
  1960. @pytest.mark.asyncio
  1961. async def test_remove_nodes(self, global_config, embed_func, mock_client):
  1962. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1963. with patch(
  1964. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  1965. ) as mock_bulk:
  1966. mock_bulk.return_value = (2, [])
  1967. s = self._make(global_config, embed_func)
  1968. await s.initialize()
  1969. await s.remove_nodes(["A", "B"])
  1970. mock_client.delete_by_query.assert_awaited_once()
  1971. mock_bulk.assert_awaited_once()
  1972. @pytest.mark.asyncio
  1973. async def test_remove_edges(self, global_config, embed_func, mock_client):
  1974. # remove_edges now uses bulk delete with deterministic IDs instead of
  1975. # delete_by_query, so mock bulk as AsyncMock.
  1976. mock_client.bulk = AsyncMock(return_value={"errors": False, "items": []})
  1977. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1978. s = self._make(global_config, embed_func)
  1979. await s.initialize()
  1980. await s.remove_edges([("A", "B"), ("C", "D")])
  1981. # 2 edges × 2 candidate directions = 4 delete actions in one bulk call
  1982. mock_client.bulk.assert_awaited_once()
  1983. call_body = mock_client.bulk.call_args.kwargs["body"]
  1984. assert len(call_body) == 4
  1985. @pytest.mark.asyncio
  1986. async def test_get_all_labels(self, global_config, embed_func, mock_client):
  1987. mock_client.search = AsyncMock(
  1988. return_value={
  1989. "hits": {
  1990. "hits": [
  1991. {"_id": "Alice", "sort": ["Alice"]},
  1992. {"_id": "Bob", "sort": ["Bob"]},
  1993. ],
  1994. "total": {"value": 2},
  1995. },
  1996. }
  1997. )
  1998. with patch.object(ClientManager, "get_client", return_value=mock_client):
  1999. s = self._make(global_config, embed_func)
  2000. await s.initialize()
  2001. labels = await s.get_all_labels()
  2002. assert labels == ["Alice", "Bob"]
  2003. @pytest.mark.asyncio
  2004. async def test_get_popular_labels(self, global_config, embed_func, mock_client):
  2005. mock_client.search = AsyncMock(
  2006. return_value={
  2007. "hits": {"hits": [], "total": {"value": 0}},
  2008. "aggregations": {
  2009. "src": {
  2010. "buckets": [
  2011. {"key": "A", "doc_count": 5},
  2012. {"key": "B", "doc_count": 2},
  2013. ]
  2014. },
  2015. "tgt": {"buckets": [{"key": "A", "doc_count": 3}]},
  2016. "status_counts": {"buckets": []},
  2017. },
  2018. }
  2019. )
  2020. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2021. s = self._make(global_config, embed_func)
  2022. await s.initialize()
  2023. labels = await s.get_popular_labels(limit=10)
  2024. assert labels[0] == "A" # degree 8 > B degree 2
  2025. @pytest.mark.asyncio
  2026. async def test_get_knowledge_graph_all_backfills_isolated_nodes_when_truncated(
  2027. self, global_config, embed_func, mock_client
  2028. ):
  2029. mock_client.count = AsyncMock(return_value={"count": 5})
  2030. mock_client.search = AsyncMock(
  2031. side_effect=[
  2032. {
  2033. "hits": {"hits": [], "total": {"value": 1}},
  2034. "aggregations": {
  2035. "src": {"buckets": [{"key": "A", "doc_count": 1}]},
  2036. "tgt": {"buckets": [{"key": "B", "doc_count": 1}]},
  2037. "status_counts": {"buckets": []},
  2038. },
  2039. },
  2040. {
  2041. "hits": {
  2042. "hits": [
  2043. {"_id": "A", "sort": [1]},
  2044. {"_id": "B", "sort": [2]},
  2045. {"_id": "C", "sort": [3]},
  2046. {"_id": "D", "sort": [4]},
  2047. {"_id": "E", "sort": [5]},
  2048. ],
  2049. "total": {"value": 5},
  2050. }
  2051. },
  2052. {
  2053. "hits": {
  2054. "hits": [
  2055. {
  2056. "_id": "edge-ab",
  2057. "_source": {
  2058. "source_node_id": "A",
  2059. "target_node_id": "B",
  2060. "relationship": "knows",
  2061. },
  2062. }
  2063. ],
  2064. "total": {"value": 1},
  2065. }
  2066. },
  2067. ]
  2068. )
  2069. mock_client.mget = AsyncMock(
  2070. return_value={
  2071. "docs": [
  2072. {"_id": "A", "found": True, "_source": {"entity_type": "person"}},
  2073. {"_id": "B", "found": True, "_source": {"entity_type": "person"}},
  2074. {"_id": "C", "found": True, "_source": {"entity_type": "person"}},
  2075. {"_id": "D", "found": True, "_source": {"entity_type": "person"}},
  2076. ]
  2077. }
  2078. )
  2079. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2080. s = self._make(global_config, embed_func)
  2081. await s.initialize()
  2082. result = await s.get_knowledge_graph("*", max_nodes=4)
  2083. assert result.is_truncated is True
  2084. assert [node.id for node in result.nodes] == ["A", "B", "C", "D"]
  2085. assert len(result.edges) == 1
  2086. assert result.edges[0].source == "A"
  2087. assert result.edges[0].target == "B"
  2088. assert mock_client.create_pit.await_count == 2
  2089. @pytest.mark.asyncio
  2090. async def test_get_knowledge_graph_all_paginates_edges_between_selected_nodes(
  2091. self, global_config, embed_func, mock_client
  2092. ):
  2093. mock_client.count = AsyncMock(return_value={"count": 2})
  2094. first_edge_page = [
  2095. {
  2096. "_id": f"edge-{i}",
  2097. "_source": {
  2098. "source_node_id": "A",
  2099. "target_node_id": "B",
  2100. "relationship": "knows",
  2101. },
  2102. "sort": [i],
  2103. }
  2104. for i in range(10000)
  2105. ]
  2106. mock_client.search = AsyncMock(
  2107. side_effect=[
  2108. {
  2109. "hits": {
  2110. "hits": [
  2111. {"_id": "A"},
  2112. {"_id": "B"},
  2113. ],
  2114. "total": {"value": 2},
  2115. }
  2116. },
  2117. {"hits": {"hits": first_edge_page, "total": {"value": 10001}}},
  2118. {
  2119. "hits": {
  2120. "hits": [
  2121. {
  2122. "_id": "edge-last",
  2123. "_source": {
  2124. "source_node_id": "B",
  2125. "target_node_id": "A",
  2126. "relationship": "knows",
  2127. },
  2128. "sort": [10000],
  2129. }
  2130. ],
  2131. "total": {"value": 10001},
  2132. }
  2133. },
  2134. ]
  2135. )
  2136. mock_client.mget = AsyncMock(
  2137. return_value={
  2138. "docs": [
  2139. {"_id": "A", "found": True, "_source": {"entity_type": "person"}},
  2140. {"_id": "B", "found": True, "_source": {"entity_type": "person"}},
  2141. ]
  2142. }
  2143. )
  2144. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2145. s = self._make(global_config, embed_func)
  2146. await s.initialize()
  2147. result = await s.get_knowledge_graph("*", max_nodes=2)
  2148. assert len(result.nodes) == 2
  2149. assert len(result.edges) == 2
  2150. assert {(edge.source, edge.target) for edge in result.edges} == {
  2151. ("A", "B"),
  2152. ("B", "A"),
  2153. }
  2154. assert mock_client.search.await_count == 3
  2155. @pytest.mark.asyncio
  2156. async def test_search_labels_empty_query(
  2157. self, global_config, embed_func, mock_client
  2158. ):
  2159. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2160. s = self._make(global_config, embed_func)
  2161. await s.initialize()
  2162. assert await s.search_labels("") == []
  2163. @pytest.mark.asyncio
  2164. async def test_drop(self, global_config, embed_func, mock_client):
  2165. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2166. s = self._make(global_config, embed_func)
  2167. await s.initialize()
  2168. result = await s.drop()
  2169. assert result["status"] == "success"
  2170. assert mock_client.indices.delete.await_count == 2
  2171. @pytest.mark.asyncio
  2172. async def test_drop_partial_error_marks_indices_not_ready_and_next_upsert_recreates_indices(
  2173. self, global_config, embed_func, mock_client
  2174. ):
  2175. mock_client.exists = AsyncMock(return_value=False)
  2176. mock_client.indices.delete = AsyncMock(
  2177. side_effect=[None, OpenSearchException("edges drop failed")]
  2178. )
  2179. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2180. s = self._make(global_config, embed_func)
  2181. await s.initialize()
  2182. with patch.object(
  2183. s, "_create_indices_if_not_exist", new_callable=AsyncMock
  2184. ) as mock_create:
  2185. result = await s.drop()
  2186. assert result["status"] == "error"
  2187. assert "edges drop failed" in result["message"]
  2188. assert s._indices_ready is False
  2189. await s.upsert_edge("A", "B", {"weight": "1.0"})
  2190. mock_create.assert_awaited_once()
  2191. @pytest.mark.asyncio
  2192. async def test_drop_treats_missing_graph_indices_as_success(
  2193. self, global_config, embed_func, mock_client
  2194. ):
  2195. mock_client.indices.delete = AsyncMock(
  2196. side_effect=[_missing_index_error(), _missing_index_error()]
  2197. )
  2198. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2199. s = self._make(global_config, embed_func)
  2200. await s.initialize()
  2201. result = await s.drop()
  2202. assert result["status"] == "success"
  2203. assert s._indices_ready is False
  2204. @pytest.mark.asyncio
  2205. async def test_construct_graph_node(self, global_config, embed_func):
  2206. s = self._make(global_config, embed_func)
  2207. node = s._construct_graph_node(
  2208. "Alice",
  2209. {
  2210. "entity_type": "person",
  2211. "description": "A researcher",
  2212. "_id": "Alice",
  2213. "entity_id": "Alice",
  2214. },
  2215. )
  2216. assert node.id == "Alice"
  2217. assert "entity_type" in node.properties
  2218. assert "_id" not in node.properties
  2219. assert "entity_id" not in node.properties
  2220. @pytest.mark.asyncio
  2221. async def test_construct_graph_edge(self, global_config, embed_func):
  2222. s = self._make(global_config, embed_func)
  2223. edge = s._construct_graph_edge(
  2224. "e1",
  2225. {
  2226. "source_node_id": "A",
  2227. "target_node_id": "B",
  2228. "relationship": "knows",
  2229. "weight": 1.0,
  2230. },
  2231. )
  2232. assert edge.source == "A"
  2233. assert edge.target == "B"
  2234. assert edge.type == "knows"
  2235. assert "source_node_id" not in edge.properties
  2236. @pytest.mark.asyncio
  2237. async def test_bfs_subgraph_start_not_found(
  2238. self, global_config, embed_func, mock_client
  2239. ):
  2240. mock_client.mget = AsyncMock(
  2241. return_value={"docs": [{"_id": "NonExistent", "found": False}]}
  2242. )
  2243. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2244. s = self._make(global_config, embed_func)
  2245. await s.initialize()
  2246. result = await s.get_knowledge_graph("NonExistent", max_depth=2)
  2247. assert len(result.nodes) == 0
  2248. assert len(result.edges) == 0
  2249. class TestGraphPPLDetection:
  2250. """Tests for PPL graphlookup detection and server-side BFS."""
  2251. def _make(self, global_config, embed_func, workspace="test"):
  2252. return OpenSearchGraphStorage(
  2253. namespace="chunk_entity_relation",
  2254. global_config=global_config,
  2255. embedding_func=embed_func,
  2256. workspace=workspace,
  2257. )
  2258. @pytest.mark.asyncio
  2259. async def test_ppl_detected_when_available(
  2260. self, global_config, embed_func, mock_client
  2261. ):
  2262. """When PPL endpoint responds successfully, graphlookup should be detected."""
  2263. mock_client.transport = AsyncMock()
  2264. mock_client.transport.perform_request = AsyncMock(
  2265. return_value={"datarows": [], "schema": []}
  2266. )
  2267. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2268. s = self._make(global_config, embed_func)
  2269. await s.initialize()
  2270. assert s._ppl_graphlookup_available is True
  2271. @pytest.mark.asyncio
  2272. async def test_ppl_not_detected_when_endpoint_fails(
  2273. self, global_config, embed_func, mock_client
  2274. ):
  2275. """When PPL endpoint fails, should fall back to client-side BFS."""
  2276. mock_client.transport = AsyncMock()
  2277. mock_client.transport.perform_request = AsyncMock(
  2278. side_effect=Exception("PPL not supported")
  2279. )
  2280. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2281. s = self._make(global_config, embed_func)
  2282. await s.initialize()
  2283. assert s._ppl_graphlookup_available is False
  2284. @pytest.mark.asyncio
  2285. async def test_env_override_true(self, global_config, embed_func, mock_client):
  2286. with patch.dict("os.environ", {"OPENSEARCH_USE_PPL_GRAPHLOOKUP": "true"}):
  2287. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2288. s = self._make(global_config, embed_func)
  2289. await s.initialize()
  2290. assert s._ppl_graphlookup_available is True
  2291. # Should NOT have called transport.perform_request for detection
  2292. mock_client.transport.perform_request.assert_not_awaited()
  2293. @pytest.mark.asyncio
  2294. async def test_env_override_false(self, global_config, embed_func, mock_client):
  2295. mock_client.transport = AsyncMock()
  2296. mock_client.transport.perform_request = AsyncMock(
  2297. return_value={"datarows": [], "schema": []}
  2298. )
  2299. with patch.dict("os.environ", {"OPENSEARCH_USE_PPL_GRAPHLOOKUP": "false"}):
  2300. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2301. s = self._make(global_config, embed_func)
  2302. await s.initialize()
  2303. assert s._ppl_graphlookup_available is False
  2304. @pytest.mark.asyncio
  2305. async def test_ppl_bfs_calls_ppl_endpoint(
  2306. self, global_config, embed_func, mock_client
  2307. ):
  2308. """When PPL is available, get_knowledge_graph should use PPL endpoint."""
  2309. mock_client.transport = AsyncMock()
  2310. # PPL response: connected_edges contains dicts with source_node_id/target_node_id
  2311. ppl_response = {
  2312. "schema": [
  2313. {"name": "entity_id", "type": "string"},
  2314. {"name": "connected_edges", "type": "struct"},
  2315. ],
  2316. "datarows": [
  2317. [
  2318. "A",
  2319. [ # connected_edges array
  2320. {
  2321. "source_node_id": "A",
  2322. "target_node_id": "B",
  2323. "weight": 1.0,
  2324. "_depth": 0,
  2325. },
  2326. {
  2327. "source_node_id": "B",
  2328. "target_node_id": "C",
  2329. "weight": 0.5,
  2330. "_depth": 1,
  2331. },
  2332. ],
  2333. ]
  2334. ],
  2335. }
  2336. mock_client.transport.perform_request = AsyncMock(return_value=ppl_response)
  2337. # get_node for start node verification
  2338. mock_client.get = AsyncMock(
  2339. return_value={
  2340. "_id": "A",
  2341. "_source": {"entity_type": "person", "description": "Node A"},
  2342. }
  2343. )
  2344. # mget for batch node fetch (only B and C, A is already added)
  2345. mock_client.mget = AsyncMock(
  2346. return_value={
  2347. "docs": [
  2348. {"_id": "B", "found": True, "_source": {"entity_type": "person"}},
  2349. {"_id": "C", "found": True, "_source": {"entity_type": "person"}},
  2350. ]
  2351. }
  2352. )
  2353. # search for final edge fetch
  2354. mock_client.search = AsyncMock(
  2355. return_value={
  2356. "hits": {
  2357. "hits": [
  2358. {
  2359. "_id": "e1",
  2360. "_source": {
  2361. "source_node_id": "A",
  2362. "target_node_id": "B",
  2363. "relationship": "knows",
  2364. },
  2365. },
  2366. {
  2367. "_id": "e2",
  2368. "_source": {
  2369. "source_node_id": "B",
  2370. "target_node_id": "C",
  2371. "relationship": "knows",
  2372. },
  2373. },
  2374. ],
  2375. "total": {"value": 2},
  2376. },
  2377. "aggregations": {
  2378. "status_counts": {"buckets": []},
  2379. "src": {"buckets": []},
  2380. "tgt": {"buckets": []},
  2381. },
  2382. }
  2383. )
  2384. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2385. s = self._make(global_config, embed_func)
  2386. await s.initialize()
  2387. assert s._ppl_graphlookup_available is True
  2388. result = await s.get_knowledge_graph("A", max_depth=2)
  2389. assert len(result.nodes) == 3
  2390. assert len(result.edges) == 2
  2391. # Verify PPL was called (2 for detection + 1 for actual query)
  2392. assert mock_client.transport.perform_request.await_count == 3
  2393. # Verify the PPL query uses nodes index as source
  2394. actual_query = mock_client.transport.perform_request.call_args_list[2]
  2395. ppl_body = actual_query.kwargs.get("body") or actual_query[1].get(
  2396. "body", {}
  2397. )
  2398. if isinstance(ppl_body, dict):
  2399. assert s._nodes_index in ppl_body.get("query", "")
  2400. @pytest.mark.asyncio
  2401. async def test_ppl_bfs_falls_back_on_query_failure(
  2402. self, global_config, embed_func, mock_client
  2403. ):
  2404. """If PPL query fails at runtime, should fall back to client-side BFS."""
  2405. call_count = {"n": 0}
  2406. async def ppl_side_effect(*args, **kwargs):
  2407. call_count["n"] += 1
  2408. if call_count["n"] <= 2:
  2409. # Detection calls succeed
  2410. return {"datarows": [], "schema": []}
  2411. # Actual query fails
  2412. raise Exception("PPL query timeout")
  2413. mock_client.transport = AsyncMock()
  2414. mock_client.transport.perform_request = AsyncMock(side_effect=ppl_side_effect)
  2415. mock_client.mget = AsyncMock(
  2416. return_value={"docs": [{"_id": "A", "found": False}]}
  2417. )
  2418. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2419. s = self._make(global_config, embed_func)
  2420. await s.initialize()
  2421. assert s._ppl_graphlookup_available is True
  2422. # Should fall back to _bfs_subgraph, which returns empty (node not found)
  2423. result = await s.get_knowledge_graph("A", max_depth=2)
  2424. assert len(result.nodes) == 0
  2425. @pytest.mark.asyncio
  2426. async def test_escape_ppl(self, global_config, embed_func):
  2427. s = self._make(global_config, embed_func)
  2428. assert s._escape_ppl("it's") == "it\\'s"
  2429. assert s._escape_ppl("normal") == "normal"
  2430. assert s._escape_ppl("back\\slash") == "back\\\\slash"
  2431. assert s._escape_ppl("both\\and'quote") == "both\\\\and\\'quote"
  2432. @pytest.mark.asyncio
  2433. async def test_ppl_bfs_depth_zero_returns_start_only(
  2434. self, global_config, embed_func, mock_client
  2435. ):
  2436. """max_depth=0 should return only the start node without PPL query."""
  2437. mock_client.transport = AsyncMock()
  2438. mock_client.transport.perform_request = AsyncMock(
  2439. return_value={"datarows": [], "schema": []}
  2440. )
  2441. mock_client.get = AsyncMock(
  2442. return_value={"_id": "A", "_source": {"entity_type": "person"}}
  2443. )
  2444. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2445. s = self._make(global_config, embed_func)
  2446. await s.initialize()
  2447. assert s._ppl_graphlookup_available is True
  2448. result = await s.get_knowledge_graph("A", max_depth=0)
  2449. assert len(result.nodes) == 1
  2450. assert result.nodes[0].id == "A"
  2451. assert len(result.edges) == 0
  2452. # PPL query should NOT have been called for the actual traversal (only 2 detection calls)
  2453. assert mock_client.transport.perform_request.await_count == 2
  2454. @pytest.mark.asyncio
  2455. async def test_ppl_bfs_empty_connected_edges(
  2456. self, global_config, embed_func, mock_client
  2457. ):
  2458. """PPL returns no connected edges — should return only start node."""
  2459. mock_client.transport = AsyncMock()
  2460. ppl_response = {
  2461. "schema": [
  2462. {"name": "entity_id", "type": "string"},
  2463. {"name": "connected_edges", "type": "struct"},
  2464. ],
  2465. "datarows": [["A", []]],
  2466. }
  2467. mock_client.transport.perform_request = AsyncMock(return_value=ppl_response)
  2468. mock_client.get = AsyncMock(
  2469. return_value={"_id": "A", "_source": {"entity_type": "person"}}
  2470. )
  2471. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2472. s = self._make(global_config, embed_func)
  2473. await s.initialize()
  2474. result = await s.get_knowledge_graph("A", max_depth=2)
  2475. assert len(result.nodes) == 1
  2476. assert result.nodes[0].id == "A"
  2477. @pytest.mark.asyncio
  2478. async def test_ppl_bfs_truncates_nodes_by_depth_then_weight(
  2479. self, global_config, embed_func, mock_client
  2480. ):
  2481. mock_client.transport = AsyncMock()
  2482. ppl_response = {
  2483. "schema": [
  2484. {"name": "entity_id", "type": "string"},
  2485. {"name": "connected_edges", "type": "struct"},
  2486. ],
  2487. "datarows": [
  2488. [
  2489. "A",
  2490. [
  2491. {
  2492. "source_node_id": "A",
  2493. "target_node_id": "C",
  2494. "weight": 1.0,
  2495. "_depth": 1,
  2496. },
  2497. {
  2498. "source_node_id": "B",
  2499. "target_node_id": "D",
  2500. "weight": 10.0,
  2501. "_depth": 1,
  2502. },
  2503. {
  2504. "source_node_id": "A",
  2505. "target_node_id": "B",
  2506. "weight": 1.0,
  2507. "_depth": 0,
  2508. },
  2509. ],
  2510. ]
  2511. ],
  2512. }
  2513. mock_client.transport.perform_request = AsyncMock(return_value=ppl_response)
  2514. mock_client.mget = AsyncMock(
  2515. side_effect=[
  2516. {
  2517. "docs": [
  2518. {
  2519. "_id": "A",
  2520. "found": True,
  2521. "_source": {"entity_type": "person"},
  2522. }
  2523. ]
  2524. },
  2525. {
  2526. "docs": [
  2527. {
  2528. "_id": "B",
  2529. "found": True,
  2530. "_source": {"entity_type": "person"},
  2531. },
  2532. {
  2533. "_id": "D",
  2534. "found": True,
  2535. "_source": {"entity_type": "person"},
  2536. },
  2537. ]
  2538. },
  2539. ]
  2540. )
  2541. mock_client.search = AsyncMock(
  2542. return_value={
  2543. "hits": {
  2544. "hits": [
  2545. {
  2546. "_id": "e1",
  2547. "_source": {
  2548. "source_node_id": "A",
  2549. "target_node_id": "B",
  2550. "relationship": "knows",
  2551. },
  2552. "sort": [1],
  2553. },
  2554. {
  2555. "_id": "e2",
  2556. "_source": {
  2557. "source_node_id": "B",
  2558. "target_node_id": "D",
  2559. "relationship": "knows",
  2560. },
  2561. "sort": [2],
  2562. },
  2563. ],
  2564. "total": {"value": 2},
  2565. }
  2566. }
  2567. )
  2568. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2569. s = self._make(global_config, embed_func)
  2570. await s.initialize()
  2571. result = await s.get_knowledge_graph("A", max_depth=2, max_nodes=3)
  2572. assert [node.id for node in result.nodes] == ["A", "B", "D"]
  2573. assert result.is_truncated is True
  2574. assert {(edge.source, edge.target) for edge in result.edges} == {
  2575. ("A", "B"),
  2576. ("B", "D"),
  2577. }
  2578. @pytest.mark.asyncio
  2579. async def test_upsert_node_adds_entity_id(
  2580. self, global_config, embed_func, mock_client
  2581. ):
  2582. """upsert_node should always include entity_id field for PPL compatibility."""
  2583. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2584. s = self._make(global_config, embed_func)
  2585. await s.initialize()
  2586. await s.upsert_node("TestNode", {"description": "test"})
  2587. body = mock_client.index.call_args.kwargs["body"]
  2588. assert body["entity_id"] == "TestNode"
  2589. assert body["description"] == "test"
  2590. @pytest.mark.asyncio
  2591. async def test_node_degree_uses_count_api(
  2592. self, global_config, embed_func, mock_client
  2593. ):
  2594. """node_degree should use the count API, not search."""
  2595. mock_client.count = AsyncMock(return_value={"count": 7})
  2596. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2597. s = self._make(global_config, embed_func)
  2598. await s.initialize()
  2599. degree = await s.node_degree("X")
  2600. assert degree == 7
  2601. # Verify count was called on the edges index
  2602. mock_client.count.assert_awaited()
  2603. call_kwargs = mock_client.count.call_args
  2604. assert s._edges_index in str(call_kwargs)
  2605. # ---------------------------------------------------------------------------
  2606. # Vector Storage
  2607. # ---------------------------------------------------------------------------
  2608. class TestVectorStorage:
  2609. """Tests for OpenSearchVectorDBStorage k-NN index, embeddings, cosine conversion, and entity deletion."""
  2610. def _make(self, global_config, embed_func, workspace="test"):
  2611. return OpenSearchVectorDBStorage(
  2612. namespace="entities",
  2613. global_config=global_config,
  2614. embedding_func=embed_func,
  2615. workspace=workspace,
  2616. meta_fields={"content", "entity_name", "src_id", "tgt_id"},
  2617. )
  2618. @pytest.mark.asyncio
  2619. async def test_index_name(self, global_config, embed_func):
  2620. s = self._make(global_config, embed_func)
  2621. assert s._index_name == "test_entities"
  2622. @pytest.mark.asyncio
  2623. async def test_cosine_threshold_required(self, embed_func):
  2624. with pytest.raises(ValueError, match="cosine_better_than_threshold"):
  2625. OpenSearchVectorDBStorage(
  2626. namespace="v",
  2627. global_config={
  2628. "embedding_batch_num": 10,
  2629. "vector_db_storage_cls_kwargs": {},
  2630. },
  2631. embedding_func=embed_func,
  2632. )
  2633. @pytest.mark.asyncio
  2634. async def test_initialize_creates_knn_index(
  2635. self, global_config, embed_func, mock_client
  2636. ):
  2637. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2638. s = self._make(global_config, embed_func)
  2639. await s.initialize()
  2640. mock_client.indices.create.assert_awaited_once()
  2641. body = mock_client.indices.create.call_args.kwargs["body"]
  2642. assert body["settings"]["index"]["knn"] is True
  2643. assert body["mappings"]["properties"]["vector"]["dimension"] == 128
  2644. assert (
  2645. body["mappings"]["properties"]["vector"]["method"]["engine"] == "lucene"
  2646. )
  2647. @pytest.mark.asyncio
  2648. async def test_upsert_generates_embeddings(
  2649. self, global_config, embed_func, mock_client
  2650. ):
  2651. """Embeddings are deferred until flush; upsert only buffers payloads."""
  2652. embed_func = CountingEmbeddingFunc()
  2653. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2654. with patch(
  2655. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  2656. ) as mock_bulk:
  2657. mock_bulk.return_value = (2, [])
  2658. s = self._make(global_config, embed_func)
  2659. await s.initialize()
  2660. await s.upsert(
  2661. {
  2662. "v1": {"content": "hello"},
  2663. "v2": {"content": "world"},
  2664. }
  2665. )
  2666. # Upsert buffers; no bulk write yet.
  2667. mock_bulk.assert_not_awaited()
  2668. assert embed_func.call_count == 0
  2669. assert set(s._pending_vector_docs.keys()) == {"v1", "v2"}
  2670. assert s._pending_vector_docs["v1"].vector is None
  2671. # Flush embeds and triggers a single bulk call with both docs.
  2672. await s.index_done_callback()
  2673. assert embed_func.call_count == 1
  2674. mock_bulk.assert_awaited_once()
  2675. actions = mock_bulk.call_args[0][1]
  2676. assert len(actions) == 2
  2677. assert all(a["_op_type"] == "index" for a in actions)
  2678. assert all("vector" in a["_source"] for a in actions)
  2679. @pytest.mark.asyncio
  2680. async def test_query_cosine_score_conversion(
  2681. self, global_config, embed_func, mock_client
  2682. ):
  2683. """Test that scores are used directly and threshold filtering works."""
  2684. mock_client.search = AsyncMock(
  2685. return_value={
  2686. "hits": {
  2687. "hits": [
  2688. {
  2689. "_id": "v1",
  2690. "_score": 0.85,
  2691. "_source": {"content": "match", "entity_name": "E1"},
  2692. },
  2693. ],
  2694. "total": {"value": 1},
  2695. },
  2696. "aggregations": {
  2697. "status_counts": {"buckets": []},
  2698. "src": {"buckets": []},
  2699. "tgt": {"buckets": []},
  2700. },
  2701. }
  2702. )
  2703. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2704. s = self._make(global_config, embed_func)
  2705. await s.initialize()
  2706. results = await s.query("test", top_k=5)
  2707. assert len(results) == 1
  2708. assert results[0]["distance"] == 0.85
  2709. @pytest.mark.asyncio
  2710. async def test_query_filters_below_threshold(
  2711. self, global_config, embed_func, mock_client
  2712. ):
  2713. """Low scores should be filtered out."""
  2714. # score 0.15 < threshold 0.2
  2715. mock_client.search = AsyncMock(
  2716. return_value={
  2717. "hits": {
  2718. "hits": [
  2719. {
  2720. "_id": "v1",
  2721. "_score": 0.15,
  2722. "_source": {"content": "weak match"},
  2723. },
  2724. ],
  2725. "total": {"value": 1},
  2726. },
  2727. "aggregations": {
  2728. "status_counts": {"buckets": []},
  2729. "src": {"buckets": []},
  2730. "tgt": {"buckets": []},
  2731. },
  2732. }
  2733. )
  2734. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2735. s = self._make(global_config, embed_func)
  2736. await s.initialize()
  2737. results = await s.query("test", top_k=5)
  2738. assert len(results) == 0
  2739. @pytest.mark.asyncio
  2740. async def test_query_with_provided_embedding(
  2741. self, global_config, embed_func, mock_client
  2742. ):
  2743. mock_client.search = AsyncMock(
  2744. return_value={
  2745. "hits": {
  2746. "hits": [
  2747. {"_id": "v1", "_score": 1.0, "_source": {"content": "exact"}},
  2748. ],
  2749. "total": {"value": 1},
  2750. },
  2751. "aggregations": {
  2752. "status_counts": {"buckets": []},
  2753. "src": {"buckets": []},
  2754. "tgt": {"buckets": []},
  2755. },
  2756. }
  2757. )
  2758. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2759. s = self._make(global_config, embed_func)
  2760. await s.initialize()
  2761. vec = np.random.rand(128).astype(np.float32)
  2762. results = await s.query("test", top_k=5, query_embedding=vec)
  2763. assert len(results) == 1
  2764. assert results[0]["distance"] == 1.0
  2765. @pytest.mark.asyncio
  2766. async def test_get_by_id(self, global_config, embed_func, mock_client):
  2767. mock_client.mget = AsyncMock(
  2768. return_value={
  2769. "docs": [
  2770. {
  2771. "_id": "v1",
  2772. "found": True,
  2773. "_source": {"content": "hello", "vector": [0.1] * 128},
  2774. }
  2775. ]
  2776. }
  2777. )
  2778. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2779. s = self._make(global_config, embed_func)
  2780. await s.initialize()
  2781. doc = await s.get_by_id("v1")
  2782. assert doc["id"] == "v1"
  2783. assert doc["content"] == "hello"
  2784. # vector field is stripped on the mget path to match NanoVectorDB
  2785. assert "vector" not in doc
  2786. mock_client.mget.assert_awaited_once_with(
  2787. index=s._index_name,
  2788. body={"ids": ["v1"]},
  2789. _source_excludes=["vector"],
  2790. )
  2791. @pytest.mark.asyncio
  2792. async def test_get_by_id_not_found(self, global_config, embed_func, mock_client):
  2793. mock_client.mget = AsyncMock(
  2794. return_value={"docs": [{"_id": "missing", "found": False}]}
  2795. )
  2796. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2797. s = self._make(global_config, embed_func)
  2798. await s.initialize()
  2799. assert await s.get_by_id("missing") is None
  2800. mock_client.get.assert_not_awaited()
  2801. @pytest.mark.asyncio
  2802. async def test_get_by_ids(self, global_config, embed_func, mock_client):
  2803. mock_client.mget = AsyncMock(
  2804. return_value={
  2805. "docs": [
  2806. {"_id": "v1", "found": True, "_source": {"content": "a"}},
  2807. {"_id": "v2", "found": True, "_source": {"content": "b"}},
  2808. ]
  2809. }
  2810. )
  2811. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2812. s = self._make(global_config, embed_func)
  2813. await s.initialize()
  2814. docs = await s.get_by_ids(["v1", "v2"])
  2815. assert docs[0]["id"] == "v1"
  2816. assert docs[1]["id"] == "v2"
  2817. @pytest.mark.asyncio
  2818. async def test_get_vectors_by_ids(self, global_config, embed_func, mock_client):
  2819. vec = [0.1] * 128
  2820. mock_client.mget = AsyncMock(
  2821. return_value={
  2822. "docs": [
  2823. {"_id": "v1", "found": True, "_source": {"vector": vec}},
  2824. {"_id": "v2", "found": False},
  2825. ]
  2826. }
  2827. )
  2828. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2829. s = self._make(global_config, embed_func)
  2830. await s.initialize()
  2831. result = await s.get_vectors_by_ids(["v1", "v2"])
  2832. assert "v1" in result
  2833. assert "v2" not in result
  2834. assert result["v1"] == vec
  2835. @pytest.mark.asyncio
  2836. async def test_delete(self, global_config, embed_func, mock_client):
  2837. """delete() buffers ids; the actual bulk delete fires on flush."""
  2838. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2839. with patch(
  2840. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  2841. ) as mock_bulk:
  2842. mock_bulk.return_value = (2, [])
  2843. s = self._make(global_config, embed_func)
  2844. await s.initialize()
  2845. await s.delete(["v1", "v2"])
  2846. mock_bulk.assert_not_awaited()
  2847. assert s._pending_vector_deletes == {"v1", "v2"}
  2848. await s.index_done_callback()
  2849. mock_bulk.assert_awaited_once()
  2850. actions = mock_bulk.call_args[0][1]
  2851. assert len(actions) == 2
  2852. assert all(a["_op_type"] == "delete" for a in actions)
  2853. @pytest.mark.asyncio
  2854. async def test_delete_entity(self, global_config, embed_func, mock_client):
  2855. """delete_entity buffers a tombstone for the computed mdhash id."""
  2856. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2857. s = self._make(global_config, embed_func)
  2858. await s.initialize()
  2859. await s.delete_entity("Alice")
  2860. # No direct client.delete call -- delete is buffered for batched flush.
  2861. mock_client.delete.assert_not_awaited()
  2862. assert len(s._pending_vector_deletes) == 1
  2863. @pytest.mark.asyncio
  2864. async def test_delete_entity_relation(self, global_config, embed_func, mock_client):
  2865. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2866. s = self._make(global_config, embed_func)
  2867. await s.initialize()
  2868. await s.delete_entity_relation("Alice")
  2869. mock_client.delete_by_query.assert_awaited_once()
  2870. @pytest.mark.asyncio
  2871. async def test_drop_recreates_index(self, global_config, embed_func, mock_client):
  2872. # After drop, _create_knn_index_if_not_exists is called again.
  2873. # First call (init): exists=False -> create. Second call (after drop): exists=False -> create again.
  2874. mock_client.indices.exists = AsyncMock(return_value=False)
  2875. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2876. s = self._make(global_config, embed_func)
  2877. await s.initialize()
  2878. result = await s.drop()
  2879. assert result["status"] == "success"
  2880. mock_client.indices.delete.assert_awaited_once()
  2881. # create called twice: once during init, once during drop recreate
  2882. assert mock_client.indices.create.await_count == 2
  2883. @pytest.mark.asyncio
  2884. async def test_drop_delete_error_marks_index_not_ready(
  2885. self, global_config, embed_func, mock_client
  2886. ):
  2887. mock_client.indices.delete = AsyncMock(
  2888. side_effect=OpenSearchException("delete failed")
  2889. )
  2890. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2891. s = self._make(global_config, embed_func)
  2892. await s.initialize()
  2893. result = await s.drop()
  2894. assert result["status"] == "error"
  2895. assert s._index_ready is False
  2896. @pytest.mark.asyncio
  2897. async def test_drop_recreate_error_marks_index_not_ready(
  2898. self, global_config, embed_func, mock_client
  2899. ):
  2900. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2901. s = self._make(global_config, embed_func)
  2902. await s.initialize()
  2903. with patch.object(
  2904. s,
  2905. "_create_knn_index_if_not_exists",
  2906. new=AsyncMock(side_effect=OpenSearchException("recreate failed")),
  2907. ):
  2908. result = await s.drop()
  2909. assert result["status"] == "error"
  2910. assert s._index_ready is False
  2911. @pytest.mark.asyncio
  2912. async def test_drop_recreates_index_when_missing(
  2913. self, global_config, embed_func, mock_client
  2914. ):
  2915. mock_client.indices.exists = AsyncMock(return_value=False)
  2916. mock_client.indices.delete = AsyncMock(
  2917. side_effect=NotFoundError(404, "not found")
  2918. )
  2919. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2920. s = self._make(global_config, embed_func)
  2921. await s.initialize()
  2922. result = await s.drop()
  2923. assert result["status"] == "success"
  2924. assert mock_client.indices.create.await_count == 2
  2925. @pytest.mark.asyncio
  2926. async def test_reads_short_circuit_when_index_not_ready(
  2927. self, global_config, embed_func, mock_client
  2928. ):
  2929. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2930. s = self._make(global_config, embed_func)
  2931. await s.initialize()
  2932. s._index_ready = False
  2933. assert await s.query("test", top_k=5) == []
  2934. assert await s.get_by_id("v1") is None
  2935. assert await s.get_vectors_by_ids(["v1"]) == {}
  2936. mock_client.search.assert_not_awaited()
  2937. mock_client.mget.assert_not_awaited()
  2938. @pytest.mark.asyncio
  2939. async def test_read_missing_index_demotes_readiness(
  2940. self, global_config, embed_func, mock_client
  2941. ):
  2942. mock_client.search = AsyncMock(side_effect=_missing_index_error())
  2943. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2944. s = self._make(global_config, embed_func)
  2945. await s.initialize()
  2946. assert await s.query("test", top_k=5) == []
  2947. assert await s.query("test", top_k=5) == []
  2948. assert s._index_ready is False
  2949. assert mock_client.search.await_count == 1
  2950. # ---------------------------------------------------------------------------
  2951. # Vector storage write batching (issue #2785)
  2952. # ---------------------------------------------------------------------------
  2953. class TestVectorStorageBatching:
  2954. """Tests for the buffered upsert/delete + flush behaviour added for #2785."""
  2955. def _make(self, global_config, embed_func, workspace="test"):
  2956. return OpenSearchVectorDBStorage(
  2957. namespace="entities",
  2958. global_config=global_config,
  2959. embedding_func=embed_func,
  2960. workspace=workspace,
  2961. meta_fields={"content", "entity_name", "src_id", "tgt_id"},
  2962. )
  2963. @pytest.mark.asyncio
  2964. async def test_repeated_upserts_flush_in_single_bulk_call(
  2965. self, global_config, embed_func, mock_client
  2966. ):
  2967. """Many small upsert() calls collapse to one async_bulk on flush."""
  2968. embed_func = CountingEmbeddingFunc()
  2969. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2970. with patch(
  2971. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  2972. ) as mock_bulk:
  2973. mock_bulk.return_value = (5, [])
  2974. s = self._make(global_config, embed_func)
  2975. await s.initialize()
  2976. for i in range(5):
  2977. await s.upsert({f"v{i}": {"content": f"doc {i}"}})
  2978. mock_bulk.assert_not_awaited()
  2979. assert embed_func.call_count == 0
  2980. await s.index_done_callback()
  2981. assert embed_func.call_count == 1
  2982. assert embed_func.batches == [[f"doc {i}" for i in range(5)]]
  2983. mock_bulk.assert_awaited_once()
  2984. actions = mock_bulk.call_args[0][1]
  2985. assert len(actions) == 5
  2986. assert {a["_id"] for a in actions} == {f"v{i}" for i in range(5)}
  2987. @pytest.mark.asyncio
  2988. async def test_deferred_embeddings_respect_batch_size(
  2989. self, global_config, embed_func, mock_client
  2990. ):
  2991. """Flush batches deferred embeddings by embedding_batch_num."""
  2992. embed_func = CountingEmbeddingFunc()
  2993. config = {**global_config, "embedding_batch_num": 2}
  2994. with patch.object(ClientManager, "get_client", return_value=mock_client):
  2995. with patch(
  2996. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  2997. ) as mock_bulk:
  2998. mock_bulk.return_value = (5, [])
  2999. s = self._make(config, embed_func)
  3000. await s.initialize()
  3001. for i in range(5):
  3002. await s.upsert({f"v{i}": {"content": f"doc {i}"}})
  3003. await s.index_done_callback()
  3004. assert embed_func.batches == [
  3005. ["doc 0", "doc 1"],
  3006. ["doc 2", "doc 3"],
  3007. ["doc 4"],
  3008. ]
  3009. mock_bulk.assert_awaited_once()
  3010. @pytest.mark.asyncio
  3011. async def test_upsert_overwrites_pending_doc_for_same_id(
  3012. self, global_config, embed_func, mock_client
  3013. ):
  3014. """Upserting the same id twice keeps only the latest payload."""
  3015. embed_func = CountingEmbeddingFunc()
  3016. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3017. with patch(
  3018. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  3019. ) as mock_bulk:
  3020. mock_bulk.return_value = (1, [])
  3021. s = self._make(global_config, embed_func)
  3022. await s.initialize()
  3023. await s.upsert({"v1": {"content": "first"}})
  3024. await s.upsert({"v1": {"content": "second"}})
  3025. await s.index_done_callback()
  3026. assert embed_func.call_count == 1
  3027. assert embed_func.texts == ["second"]
  3028. actions = mock_bulk.call_args[0][1]
  3029. assert len(actions) == 1
  3030. assert actions[0]["_source"]["content"] == "second"
  3031. @pytest.mark.asyncio
  3032. async def test_delete_cancels_pending_upsert(
  3033. self, global_config, embed_func, mock_client
  3034. ):
  3035. """A delete after a buffered upsert removes the upsert from the buffer."""
  3036. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3037. with patch(
  3038. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  3039. ) as mock_bulk:
  3040. mock_bulk.return_value = (1, [])
  3041. s = self._make(global_config, embed_func)
  3042. await s.initialize()
  3043. await s.upsert({"v1": {"content": "doomed"}})
  3044. await s.delete(["v1"])
  3045. assert "v1" not in s._pending_vector_docs
  3046. assert "v1" in s._pending_vector_deletes
  3047. await s.index_done_callback()
  3048. actions = mock_bulk.call_args[0][1]
  3049. assert len(actions) == 1
  3050. assert actions[0]["_op_type"] == "delete"
  3051. @pytest.mark.asyncio
  3052. async def test_upsert_cancels_pending_delete(
  3053. self, global_config, embed_func, mock_client
  3054. ):
  3055. """An upsert after a buffered delete removes the tombstone."""
  3056. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3057. with patch(
  3058. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  3059. ) as mock_bulk:
  3060. mock_bulk.return_value = (1, [])
  3061. s = self._make(global_config, embed_func)
  3062. await s.initialize()
  3063. await s.delete(["v1"])
  3064. await s.upsert({"v1": {"content": "resurrected"}})
  3065. assert "v1" not in s._pending_vector_deletes
  3066. assert "v1" in s._pending_vector_docs
  3067. await s.index_done_callback()
  3068. actions = mock_bulk.call_args[0][1]
  3069. assert len(actions) == 1
  3070. assert actions[0]["_op_type"] == "index"
  3071. @pytest.mark.asyncio
  3072. async def test_get_by_id_reads_pending_buffer(
  3073. self, global_config, embed_func, mock_client
  3074. ):
  3075. """Buffered upserts are visible to get_by_id without hitting OpenSearch."""
  3076. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3077. s = self._make(global_config, embed_func)
  3078. await s.initialize()
  3079. await s.upsert({"v1": {"content": "buffered"}})
  3080. doc = await s.get_by_id("v1")
  3081. assert doc is not None
  3082. assert doc["id"] == "v1"
  3083. assert doc["content"] == "buffered"
  3084. # Vector field is hidden from get_by_id results, mirroring the
  3085. # _source excludes used by query().
  3086. assert "vector" not in doc
  3087. mock_client.mget.assert_not_awaited()
  3088. @pytest.mark.asyncio
  3089. async def test_get_by_id_returns_none_for_pending_delete(
  3090. self, global_config, embed_func, mock_client
  3091. ):
  3092. """A pending tombstone shadows any persisted doc."""
  3093. mock_client.mget = AsyncMock() # would be wrong to invoke
  3094. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3095. s = self._make(global_config, embed_func)
  3096. await s.initialize()
  3097. await s.delete(["v1"])
  3098. assert await s.get_by_id("v1") is None
  3099. mock_client.mget.assert_not_awaited()
  3100. @pytest.mark.asyncio
  3101. async def test_get_by_ids_merges_buffer_and_index(
  3102. self, global_config, embed_func, mock_client
  3103. ):
  3104. """get_by_ids returns buffered docs and falls back to mget for the rest."""
  3105. mock_client.mget = AsyncMock(
  3106. return_value={
  3107. "docs": [
  3108. {"_id": "v2", "found": True, "_source": {"content": "from_index"}},
  3109. ]
  3110. }
  3111. )
  3112. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3113. s = self._make(global_config, embed_func)
  3114. await s.initialize()
  3115. await s.upsert({"v1": {"content": "buffered"}})
  3116. docs = await s.get_by_ids(["v1", "v2"])
  3117. assert docs[0]["content"] == "buffered"
  3118. assert docs[1]["content"] == "from_index"
  3119. # Only the unbuffered id is requested from OpenSearch,
  3120. # and vector is excluded server-side.
  3121. mock_client.mget.assert_awaited_once_with(
  3122. index=s._index_name,
  3123. body={"ids": ["v2"]},
  3124. _source_excludes=["vector"],
  3125. )
  3126. @pytest.mark.asyncio
  3127. async def test_get_vectors_by_ids_uses_buffer(
  3128. self, global_config, embed_func, mock_client
  3129. ):
  3130. """get_vectors_by_ids returns buffered embeddings without an mget roundtrip."""
  3131. embed_func = CountingEmbeddingFunc()
  3132. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3133. s = self._make(global_config, embed_func)
  3134. await s.initialize()
  3135. await s.upsert({"v1": {"content": "x"}})
  3136. assert embed_func.call_count == 0
  3137. vecs = await s.get_vectors_by_ids(["v1"])
  3138. assert "v1" in vecs
  3139. assert len(vecs["v1"]) == 128
  3140. assert embed_func.call_count == 1
  3141. assert s._pending_vector_docs["v1"].vector == vecs["v1"]
  3142. mock_client.mget.assert_not_awaited()
  3143. @pytest.mark.asyncio
  3144. async def test_lazy_get_vectors_cache_is_reused_by_flush(
  3145. self, global_config, embed_func, mock_client
  3146. ):
  3147. """A lazy pending-vector read should not force a second embedding during flush."""
  3148. embed_func = CountingEmbeddingFunc()
  3149. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3150. with patch(
  3151. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  3152. ) as mock_bulk:
  3153. mock_bulk.return_value = (1, [])
  3154. s = self._make(global_config, embed_func)
  3155. await s.initialize()
  3156. await s.upsert({"v1": {"content": "x"}})
  3157. vecs = await s.get_vectors_by_ids(["v1"])
  3158. await s.index_done_callback()
  3159. assert embed_func.call_count == 1
  3160. actions = mock_bulk.call_args[0][1]
  3161. assert actions[0]["_source"]["vector"] == vecs["v1"]
  3162. @pytest.mark.asyncio
  3163. async def test_finalize_flushes_pending_ops(
  3164. self, global_config, embed_func, mock_client
  3165. ):
  3166. """finalize() flushes buffered writes before releasing the client."""
  3167. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3168. with patch(
  3169. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  3170. ) as mock_bulk:
  3171. mock_bulk.return_value = (1, [])
  3172. s = self._make(global_config, embed_func)
  3173. await s.initialize()
  3174. await s.upsert({"v1": {"content": "to flush"}})
  3175. await s.finalize()
  3176. mock_bulk.assert_awaited_once()
  3177. assert s.client is None
  3178. @pytest.mark.asyncio
  3179. async def test_vector_finalize_raises_when_retryable_buffer_remains(
  3180. self, global_config, embed_func, mock_client
  3181. ):
  3182. """finalize() must surface a RuntimeError when retryable bulk
  3183. failures left vector rows buffered, otherwise the upstream
  3184. finalize_storages() call would log the storage as successfully
  3185. finalized while writes are silently lost.
  3186. The client is still released regardless to avoid connection leak.
  3187. """
  3188. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3189. with patch.object(
  3190. ClientManager, "release_client", new_callable=AsyncMock
  3191. ) as mock_release:
  3192. with patch(
  3193. "lightrag.kg.opensearch_impl.helpers.async_bulk",
  3194. new_callable=AsyncMock,
  3195. ) as mock_bulk:
  3196. mock_bulk.return_value = (
  3197. 0,
  3198. [{"index": {"_id": "v1", "status": 503, "error": "down"}}],
  3199. )
  3200. s = self._make(global_config, embed_func)
  3201. await s.initialize()
  3202. await s.upsert({"v1": {"content": "stuck"}})
  3203. with pytest.raises(RuntimeError, match="pending upserts"):
  3204. await s.finalize()
  3205. mock_release.assert_awaited_once()
  3206. assert s.client is None
  3207. @pytest.mark.asyncio
  3208. async def test_vector_finalize_propagates_flush_exception(
  3209. self, global_config, embed_func, mock_client
  3210. ):
  3211. """If async_bulk raises during the final flush, finalize() still
  3212. releases the client and wraps the original error in a RuntimeError
  3213. that names the unflushed buffer counts.
  3214. """
  3215. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3216. with patch.object(
  3217. ClientManager, "release_client", new_callable=AsyncMock
  3218. ) as mock_release:
  3219. with patch(
  3220. "lightrag.kg.opensearch_impl.helpers.async_bulk",
  3221. new_callable=AsyncMock,
  3222. ) as mock_bulk:
  3223. mock_bulk.side_effect = OpenSearchException("connection reset")
  3224. s = self._make(global_config, embed_func)
  3225. await s.initialize()
  3226. await s.upsert({"v1": {"content": "stuck"}})
  3227. with pytest.raises(RuntimeError) as exc_info:
  3228. await s.finalize()
  3229. assert isinstance(exc_info.value.__cause__, OpenSearchException)
  3230. mock_release.assert_awaited_once()
  3231. assert s.client is None
  3232. @pytest.mark.asyncio
  3233. async def test_vector_finalize_propagates_cancellation(
  3234. self, global_config, embed_func, mock_client
  3235. ):
  3236. """asyncio.CancelledError raised during the final flush must
  3237. propagate UN-wrapped so the shutdown sequence honours the
  3238. cancellation signal. The client is still released (finally
  3239. block) before the cancellation continues.
  3240. """
  3241. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3242. with patch.object(
  3243. ClientManager, "release_client", new_callable=AsyncMock
  3244. ) as mock_release:
  3245. with patch(
  3246. "lightrag.kg.opensearch_impl.helpers.async_bulk",
  3247. new_callable=AsyncMock,
  3248. ) as mock_bulk:
  3249. mock_bulk.side_effect = asyncio.CancelledError()
  3250. s = self._make(global_config, embed_func)
  3251. await s.initialize()
  3252. await s.upsert({"v1": {"content": "stuck"}})
  3253. with pytest.raises(asyncio.CancelledError):
  3254. await s.finalize()
  3255. mock_release.assert_awaited_once()
  3256. assert s.client is None
  3257. @pytest.mark.asyncio
  3258. async def test_drop_discards_pending_buffers(
  3259. self, global_config, embed_func, mock_client
  3260. ):
  3261. """drop() throws away pending writes; nothing is flushed to a deleted index."""
  3262. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3263. with patch(
  3264. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  3265. ) as mock_bulk:
  3266. s = self._make(global_config, embed_func)
  3267. await s.initialize()
  3268. await s.upsert({"v1": {"content": "doomed"}})
  3269. await s.delete(["v2"])
  3270. await s.drop()
  3271. assert s._pending_vector_docs == {}
  3272. assert s._pending_vector_deletes == set()
  3273. mock_bulk.assert_not_awaited()
  3274. @pytest.mark.asyncio
  3275. async def test_failed_flush_entries_retained_for_retry(
  3276. self, global_config, embed_func, mock_client
  3277. ):
  3278. """Transient (5xx) per-doc failures stay buffered for the next flush."""
  3279. embed_func = CountingEmbeddingFunc()
  3280. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3281. with patch(
  3282. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  3283. ) as mock_bulk:
  3284. # First flush: v1 succeeds, v2 fails with 503 (retryable).
  3285. mock_bulk.side_effect = [
  3286. (
  3287. 1,
  3288. [{"index": {"_id": "v2", "status": 503, "error": "down"}}],
  3289. ),
  3290. (1, []),
  3291. ]
  3292. s = self._make(global_config, embed_func)
  3293. await s.initialize()
  3294. await s.upsert(
  3295. {
  3296. "v1": {"content": "ok"},
  3297. "v2": {"content": "boom"},
  3298. }
  3299. )
  3300. await s.index_done_callback()
  3301. # v1 cleared, v2 retained for retry.
  3302. assert "v1" not in s._pending_vector_docs
  3303. assert "v2" in s._pending_vector_docs
  3304. assert s._pending_vector_docs["v2"].vector is not None
  3305. assert embed_func.call_count == 1
  3306. await s.index_done_callback()
  3307. assert "v2" not in s._pending_vector_docs
  3308. assert embed_func.call_count == 1
  3309. assert mock_bulk.await_count == 2
  3310. @pytest.mark.asyncio
  3311. async def test_embedding_failure_leaves_pending_for_retry(
  3312. self, global_config, embed_func, mock_client
  3313. ):
  3314. """Embedding failures behave like flush failures: buffers stay intact."""
  3315. embed_func = CountingEmbeddingFunc(fail_times=1)
  3316. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3317. with patch(
  3318. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  3319. ) as mock_bulk:
  3320. mock_bulk.return_value = (1, [])
  3321. s = self._make(global_config, embed_func)
  3322. await s.initialize()
  3323. await s.upsert({"v1": {"content": "retry me"}})
  3324. with pytest.raises(RuntimeError, match="embedding failed"):
  3325. await s.index_done_callback()
  3326. mock_bulk.assert_not_awaited()
  3327. assert "v1" in s._pending_vector_docs
  3328. assert s._pending_vector_docs["v1"].vector is None
  3329. await s.index_done_callback()
  3330. mock_bulk.assert_awaited_once()
  3331. assert "v1" not in s._pending_vector_docs
  3332. assert embed_func.call_count == 2
  3333. @pytest.mark.asyncio
  3334. async def test_finalize_wraps_embedding_failure(
  3335. self, global_config, embed_func, mock_client
  3336. ):
  3337. """finalize() reports pending buffers when deferred embedding fails."""
  3338. embed_func = CountingEmbeddingFunc(fail_times=1)
  3339. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3340. with patch.object(
  3341. ClientManager, "release_client", new_callable=AsyncMock
  3342. ) as mock_release:
  3343. with patch(
  3344. "lightrag.kg.opensearch_impl.helpers.async_bulk",
  3345. new_callable=AsyncMock,
  3346. ) as mock_bulk:
  3347. s = self._make(global_config, embed_func)
  3348. await s.initialize()
  3349. await s.upsert({"v1": {"content": "stuck"}})
  3350. with pytest.raises(RuntimeError, match="pending upserts"):
  3351. await s.finalize()
  3352. mock_bulk.assert_not_awaited()
  3353. mock_release.assert_awaited_once()
  3354. assert s.client is None
  3355. assert "v1" in s._pending_vector_docs
  3356. assert s._pending_vector_docs["v1"].vector is None
  3357. @pytest.mark.asyncio
  3358. async def test_delete_entity_relation_prunes_pending_buffer(
  3359. self, global_config, embed_func, mock_client
  3360. ):
  3361. """Pending docs whose src_id/tgt_id match the entity are dropped before delete_by_query."""
  3362. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3363. with patch(
  3364. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  3365. ) as mock_bulk:
  3366. mock_bulk.return_value = (1, [])
  3367. s = self._make(global_config, embed_func)
  3368. await s.initialize()
  3369. await s.upsert(
  3370. {
  3371. "rel-1": {
  3372. "content": "Alice -> Bob",
  3373. "src_id": "Alice",
  3374. "tgt_id": "Bob",
  3375. },
  3376. "rel-2": {
  3377. "content": "Carol -> Dave",
  3378. "src_id": "Carol",
  3379. "tgt_id": "Dave",
  3380. },
  3381. }
  3382. )
  3383. await s.delete_entity_relation("Alice")
  3384. assert "rel-1" not in s._pending_vector_docs
  3385. assert "rel-2" in s._pending_vector_docs
  3386. mock_client.delete_by_query.assert_awaited_once()
  3387. def test_extract_bulk_failed_ids_classifies_by_status(self):
  3388. from lightrag.kg.opensearch_impl import _extract_bulk_failed_ids
  3389. # No failures -> empty containers.
  3390. retryable, non_retryable = _extract_bulk_failed_ids(None)
  3391. assert retryable == set()
  3392. assert non_retryable == []
  3393. retryable, non_retryable = _extract_bulk_failed_ids([])
  3394. assert retryable == set()
  3395. assert non_retryable == []
  3396. retryable, non_retryable = _extract_bulk_failed_ids(
  3397. [
  3398. # Retryable: 5xx server error.
  3399. {"index": {"_id": "r-500", "status": 500}},
  3400. # Retryable: rate-limited.
  3401. {"index": {"_id": "r-429", "status": 429}},
  3402. # Retryable: missing status (network / parse failure).
  3403. {"create": {"_id": "r-none"}},
  3404. # Non-retryable: bad request with dict-shape error.
  3405. {
  3406. "index": {
  3407. "_id": "n-400",
  3408. "status": 400,
  3409. "error": {
  3410. "type": "mapper_parsing_exception",
  3411. "reason": "vector must be array",
  3412. },
  3413. }
  3414. },
  3415. # Non-retryable: not found on update (doc disappeared).
  3416. {"update": {"_id": "n-404", "status": 404, "error": "not found"}},
  3417. # Special case: delete of missing doc -> dropped from BOTH
  3418. # sets, since the row is already gone.
  3419. {"delete": {"_id": "drop-404", "status": 404}},
  3420. # Malformed entries are skipped silently.
  3421. "garbage",
  3422. {"update": {}},
  3423. ]
  3424. )
  3425. assert retryable == {"r-500", "r-429", "r-none"}
  3426. non_retryable_ids = {op.doc_id for op in non_retryable}
  3427. assert non_retryable_ids == {"n-400", "n-404"}
  3428. by_id = {op.doc_id: op for op in non_retryable}
  3429. # dict-shape error is summarised via "reason"
  3430. assert by_id["n-400"].op == "index"
  3431. assert by_id["n-400"].status == 400
  3432. assert "vector must be array" in by_id["n-400"].error
  3433. # string-shape error is passed through
  3434. assert by_id["n-404"].op == "update"
  3435. assert by_id["n-404"].status == 404
  3436. assert by_id["n-404"].error == "not found"
  3437. def test_extract_bulk_failed_ids_truncates_long_errors(self):
  3438. from lightrag.kg.opensearch_impl import (
  3439. _extract_bulk_failed_ids,
  3440. _BULK_ERROR_SUMMARY_MAX_LEN,
  3441. )
  3442. long_reason = "x" * 1000
  3443. _, non_retryable = _extract_bulk_failed_ids(
  3444. [
  3445. {
  3446. "index": {
  3447. "_id": "n-400",
  3448. "status": 400,
  3449. "error": {"reason": long_reason},
  3450. }
  3451. }
  3452. ]
  3453. )
  3454. assert len(non_retryable) == 1
  3455. assert len(non_retryable[0].error) <= _BULK_ERROR_SUMMARY_MAX_LEN
  3456. assert non_retryable[0].error.endswith("...")
  3457. @pytest.mark.asyncio
  3458. async def test_failed_flush_drops_non_retryable_entries(
  3459. self, global_config, embed_func, mock_client
  3460. ):
  3461. """4xx (non-429) failures are dropped, not perpetually retried."""
  3462. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3463. with patch(
  3464. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  3465. ) as mock_bulk:
  3466. # v1 fails permanently (400 mapping error); v2 fails
  3467. # transiently (503).
  3468. mock_bulk.return_value = (
  3469. 0,
  3470. [
  3471. {"index": {"_id": "v1", "status": 400, "error": "bad mapping"}},
  3472. {"index": {"_id": "v2", "status": 503, "error": "down"}},
  3473. ],
  3474. )
  3475. s = self._make(global_config, embed_func)
  3476. await s.initialize()
  3477. await s.upsert(
  3478. {"v1": {"content": "bad"}, "v2": {"content": "transient"}}
  3479. )
  3480. await s.index_done_callback()
  3481. # v1 is dropped (non-retryable), v2 is retained (retryable).
  3482. assert "v1" not in s._pending_vector_docs
  3483. assert "v2" in s._pending_vector_docs
  3484. @pytest.mark.asyncio
  3485. async def test_concurrent_writes_during_flush_are_serialised(
  3486. self, global_config, embed_func, mock_client
  3487. ):
  3488. """All buffer writes acquire the namespace lock, so an upsert issued
  3489. while a flush is in flight is blocked until the flush completes and
  3490. then lands in the live buffer for the next flush.
  3491. """
  3492. flush_started = asyncio.Event()
  3493. flush_can_finish = asyncio.Event()
  3494. async def slow_bulk(client, actions, raise_on_error=False, **kwargs):
  3495. flush_started.set()
  3496. await flush_can_finish.wait()
  3497. return (len(actions), [])
  3498. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3499. with patch("lightrag.kg.opensearch_impl.helpers.async_bulk", new=slow_bulk):
  3500. s = self._make(global_config, embed_func)
  3501. await s.initialize()
  3502. await s.upsert({"v1": {"content": "first"}})
  3503. flush_task = asyncio.create_task(s.index_done_callback())
  3504. await flush_started.wait()
  3505. # The flush is holding the lock and awaiting async_bulk.
  3506. # Issue a concurrent upsert via create_task so we can
  3507. # assert it is blocked (a direct await would deadlock the
  3508. # single-threaded event loop on the lock acquisition).
  3509. concurrent_task = asyncio.create_task(
  3510. s.upsert({"v2": {"content": "concurrent"}})
  3511. )
  3512. # Yield so the concurrent task gets a chance to start its
  3513. # embedding computation and arrive at the lock.
  3514. for _ in range(5):
  3515. await asyncio.sleep(0)
  3516. assert (
  3517. not concurrent_task.done()
  3518. ), "concurrent upsert should be blocked by the flush lock"
  3519. # v2 must not be visible in the buffer yet.
  3520. assert "v2" not in s._pending_vector_docs
  3521. # Release the bulk call; flush completes and the concurrent
  3522. # upsert then finally writes v2 into the (now-empty) buffer.
  3523. flush_can_finish.set()
  3524. await flush_task
  3525. await concurrent_task
  3526. assert "v1" not in s._pending_vector_docs
  3527. assert "v2" in s._pending_vector_docs
  3528. @pytest.mark.asyncio
  3529. async def test_concurrent_delete_during_flush_supersedes_retried_upsert(
  3530. self, global_config, embed_func, mock_client
  3531. ):
  3532. """A delete that lands after a flush retains a transient failure
  3533. wins over the retried upsert for the same id.
  3534. Under the lock-everywhere model the delete runs strictly after the
  3535. flush; the merge-back of the retryable v1 upsert is then cancelled
  3536. by the delete in a single, sequential pass.
  3537. """
  3538. flush_started = asyncio.Event()
  3539. flush_can_finish = asyncio.Event()
  3540. async def slow_bulk(client, actions, raise_on_error=False, **kwargs):
  3541. flush_started.set()
  3542. await flush_can_finish.wait()
  3543. # Report v1's upsert as a transient failure so the flush
  3544. # leaves it in the buffer for retry.
  3545. return (
  3546. 0,
  3547. [{"index": {"_id": "v1", "status": 503, "error": "down"}}],
  3548. )
  3549. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3550. with patch("lightrag.kg.opensearch_impl.helpers.async_bulk", new=slow_bulk):
  3551. s = self._make(global_config, embed_func)
  3552. await s.initialize()
  3553. await s.upsert({"v1": {"content": "first"}})
  3554. flush_task = asyncio.create_task(s.index_done_callback())
  3555. await flush_started.wait()
  3556. # Issue the concurrent delete; it queues behind the lock.
  3557. delete_task = asyncio.create_task(s.delete(["v1"]))
  3558. for _ in range(5):
  3559. await asyncio.sleep(0)
  3560. assert (
  3561. not delete_task.done()
  3562. ), "concurrent delete should be blocked by the flush lock"
  3563. flush_can_finish.set()
  3564. await flush_task
  3565. await delete_task
  3566. # The retry left v1 in the docs buffer; the subsequent
  3567. # delete then cancelled that upsert and replaced it with a
  3568. # tombstone.
  3569. assert "v1" not in s._pending_vector_docs
  3570. assert "v1" in s._pending_vector_deletes
  3571. @pytest.mark.asyncio
  3572. async def test_get_by_id_strips_vector_from_mget_path(
  3573. self, global_config, embed_func, mock_client
  3574. ):
  3575. """The mget fallback path returns the same shape as NanoVectorDB:
  3576. no ``vector`` key, and the server-side _source_excludes is set so the
  3577. embedding never crosses the wire in the first place.
  3578. """
  3579. mock_client.mget = AsyncMock(
  3580. return_value={
  3581. "docs": [
  3582. {
  3583. "_id": "v1",
  3584. "found": True,
  3585. # defensive: server-side excludes might be ignored
  3586. # in misconfigured indices; we still pop client-side.
  3587. "_source": {"content": "from_index", "vector": [0.1] * 128},
  3588. }
  3589. ]
  3590. }
  3591. )
  3592. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3593. s = self._make(global_config, embed_func)
  3594. await s.initialize()
  3595. # No upsert: buffer empty, falls through to mget.
  3596. doc = await s.get_by_id("v1")
  3597. assert doc is not None
  3598. assert doc["id"] == "v1"
  3599. assert doc["content"] == "from_index"
  3600. assert "vector" not in doc
  3601. mock_client.mget.assert_awaited_once_with(
  3602. index=s._index_name,
  3603. body={"ids": ["v1"]},
  3604. _source_excludes=["vector"],
  3605. )
  3606. @pytest.mark.asyncio
  3607. async def test_get_by_ids_strips_vector_from_mget_path(
  3608. self, global_config, embed_func, mock_client
  3609. ):
  3610. """get_by_ids strips vector on the fallback path and forwards
  3611. _source_excludes to mget."""
  3612. mock_client.mget = AsyncMock(
  3613. return_value={
  3614. "docs": [
  3615. {
  3616. "_id": "v1",
  3617. "found": True,
  3618. "_source": {"content": "a", "vector": [0.1] * 128},
  3619. },
  3620. {
  3621. "_id": "v2",
  3622. "found": True,
  3623. "_source": {"content": "b", "vector": [0.2] * 128},
  3624. },
  3625. ]
  3626. }
  3627. )
  3628. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3629. s = self._make(global_config, embed_func)
  3630. await s.initialize()
  3631. docs = await s.get_by_ids(["v1", "v2"])
  3632. assert all(d is not None for d in docs)
  3633. assert all("vector" not in d for d in docs)
  3634. assert docs[0]["content"] == "a"
  3635. assert docs[1]["content"] == "b"
  3636. mock_client.mget.assert_awaited_once_with(
  3637. index=s._index_name,
  3638. body={"ids": ["v1", "v2"]},
  3639. _source_excludes=["vector"],
  3640. )
  3641. @pytest.mark.asyncio
  3642. async def test_non_retryable_logs_sample_ids(
  3643. self, global_config, embed_func, mock_client, caplog
  3644. ):
  3645. """Non-retryable bulk failures log a sample with id/status/error."""
  3646. import logging as _logging
  3647. failed = [
  3648. {
  3649. "index": {
  3650. "_id": f"v{i}",
  3651. "status": 400,
  3652. "error": {
  3653. "type": "mapper_parsing_exception",
  3654. "reason": f"bad field {i}",
  3655. },
  3656. }
  3657. }
  3658. for i in range(6)
  3659. ]
  3660. # lightrag logger has propagate=False, so caplog's root handler
  3661. # would miss these records. Re-enable propagation just for this
  3662. # test so caplog can capture the warning we emit.
  3663. lightrag_logger = _logging.getLogger("lightrag")
  3664. original_propagate = lightrag_logger.propagate
  3665. lightrag_logger.propagate = True
  3666. try:
  3667. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3668. with patch(
  3669. "lightrag.kg.opensearch_impl.helpers.async_bulk",
  3670. new_callable=AsyncMock,
  3671. ) as mock_bulk:
  3672. mock_bulk.return_value = (0, failed)
  3673. s = self._make(global_config, embed_func)
  3674. await s.initialize()
  3675. await s.upsert({f"v{i}": {"content": f"d{i}"} for i in range(6)})
  3676. with caplog.at_level("WARNING", logger="lightrag"):
  3677. await s.index_done_callback()
  3678. finally:
  3679. lightrag_logger.propagate = original_propagate
  3680. warning_text = "\n".join(
  3681. rec.message for rec in caplog.records if rec.levelname == "WARNING"
  3682. )
  3683. # Sample contains the first 5 ids with op/status/reason text.
  3684. for i in range(5):
  3685. assert f"v{i}" in warning_text
  3686. assert "status=400" in warning_text
  3687. assert "bad field" in warning_text
  3688. # 6 permanent failures reported in aggregate.
  3689. assert "6 vector ops" in warning_text
  3690. @pytest.mark.asyncio
  3691. async def test_index_done_callback_flushes_when_index_recreated(
  3692. self, global_config, embed_func, mock_client
  3693. ):
  3694. """If the index was marked missing after writes were buffered, the
  3695. callback must still flush — _flush_pending_vector_ops recreates the
  3696. index via _ensure_index_ready before issuing the bulk call.
  3697. """
  3698. # Sequence the indices.exists results so the second _create
  3699. # invocation actually creates the index again.
  3700. exists_responses = [False, False]
  3701. mock_client.indices.exists = AsyncMock(
  3702. side_effect=lambda **kw: exists_responses.pop(0)
  3703. if exists_responses
  3704. else False
  3705. )
  3706. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3707. with patch(
  3708. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  3709. ) as mock_bulk:
  3710. mock_bulk.return_value = (1, [])
  3711. s = self._make(global_config, embed_func)
  3712. await s.initialize()
  3713. await s.upsert({"v1": {"content": "ok"}})
  3714. # Simulate the index disappearing (e.g. via a read 404)
  3715. # AFTER the write was buffered.
  3716. s._mark_index_missing()
  3717. await s.index_done_callback()
  3718. # The buffer was flushed, even though _index_ready was
  3719. # False at callback entry.
  3720. mock_bulk.assert_awaited_once()
  3721. assert s._pending_vector_docs == {}
  3722. # The index was recreated as part of flush.
  3723. assert mock_client.indices.create.await_count >= 2
  3724. @pytest.mark.asyncio
  3725. async def test_delete_entity_relation_serialised_with_flush(
  3726. self, global_config, embed_func, mock_client
  3727. ):
  3728. """delete_entity_relation runs entirely under the flush lock, so it
  3729. cannot race with an in-flight bulk indexing operation."""
  3730. flush_started = asyncio.Event()
  3731. flush_can_finish = asyncio.Event()
  3732. delete_started = asyncio.Event()
  3733. async def slow_bulk(client, actions, raise_on_error=False, **kwargs):
  3734. flush_started.set()
  3735. await flush_can_finish.wait()
  3736. return (len(actions), [])
  3737. async def watch_delete_by_query(**kwargs):
  3738. delete_started.set()
  3739. return {"deleted": 0}
  3740. mock_client.delete_by_query = AsyncMock(side_effect=watch_delete_by_query)
  3741. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3742. with patch("lightrag.kg.opensearch_impl.helpers.async_bulk", new=slow_bulk):
  3743. s = self._make(global_config, embed_func)
  3744. await s.initialize()
  3745. await s.upsert(
  3746. {
  3747. "rel-1": {
  3748. "content": "X",
  3749. "src_id": "Alice",
  3750. "tgt_id": "Bob",
  3751. }
  3752. }
  3753. )
  3754. flush_task = asyncio.create_task(s.index_done_callback())
  3755. await flush_started.wait()
  3756. # delete_by_query must NOT fire while bulk is still in flight.
  3757. rel_task = asyncio.create_task(s.delete_entity_relation("Alice"))
  3758. for _ in range(5):
  3759. await asyncio.sleep(0)
  3760. assert (
  3761. not delete_started.is_set()
  3762. ), "delete_by_query should be blocked behind the flush lock"
  3763. assert not rel_task.done()
  3764. flush_can_finish.set()
  3765. await flush_task
  3766. await rel_task
  3767. assert delete_started.is_set()
  3768. @pytest.mark.asyncio
  3769. async def test_drop_serialised_with_flush(
  3770. self, global_config, embed_func, mock_client
  3771. ):
  3772. """drop must serialise with an in-flight flush; the index delete
  3773. cannot land while bulk indexing is mid-request.
  3774. """
  3775. flush_started = asyncio.Event()
  3776. flush_can_finish = asyncio.Event()
  3777. drop_delete_started = asyncio.Event()
  3778. async def slow_bulk(client, actions, raise_on_error=False, **kwargs):
  3779. flush_started.set()
  3780. await flush_can_finish.wait()
  3781. return (len(actions), [])
  3782. async def watch_indices_delete(**kwargs):
  3783. drop_delete_started.set()
  3784. mock_client.indices.delete = AsyncMock(side_effect=watch_indices_delete)
  3785. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3786. with patch("lightrag.kg.opensearch_impl.helpers.async_bulk", new=slow_bulk):
  3787. s = self._make(global_config, embed_func)
  3788. await s.initialize()
  3789. await s.upsert({"v1": {"content": "x"}})
  3790. flush_task = asyncio.create_task(s.index_done_callback())
  3791. await flush_started.wait()
  3792. drop_task = asyncio.create_task(s.drop())
  3793. for _ in range(5):
  3794. await asyncio.sleep(0)
  3795. assert (
  3796. not drop_delete_started.is_set()
  3797. ), "indices.delete should be blocked behind the flush lock"
  3798. assert not drop_task.done()
  3799. flush_can_finish.set()
  3800. await flush_task
  3801. await drop_task
  3802. assert drop_delete_started.is_set()
  3803. @pytest.mark.asyncio
  3804. async def test_drop_serialised_with_flush_embedding_phase(
  3805. self, global_config, mock_client
  3806. ):
  3807. """drop must also wait while deferred embedding runs under the flush lock."""
  3808. embedding_started = asyncio.Event()
  3809. embedding_can_finish = asyncio.Event()
  3810. drop_delete_started = asyncio.Event()
  3811. class GatedEmbeddingFunc(MockEmbeddingFunc):
  3812. async def __call__(self, texts, **kwargs):
  3813. embedding_started.set()
  3814. await embedding_can_finish.wait()
  3815. return await super().__call__(texts, **kwargs)
  3816. async def watch_indices_delete(**kwargs):
  3817. drop_delete_started.set()
  3818. mock_client.indices.delete = AsyncMock(side_effect=watch_indices_delete)
  3819. embed_func = GatedEmbeddingFunc()
  3820. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3821. with patch(
  3822. "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
  3823. ) as mock_bulk:
  3824. mock_bulk.return_value = (1, [])
  3825. s = self._make(global_config, embed_func)
  3826. await s.initialize()
  3827. await s.upsert({"v1": {"content": "x"}})
  3828. flush_task = asyncio.create_task(s.index_done_callback())
  3829. await embedding_started.wait()
  3830. drop_task = asyncio.create_task(s.drop())
  3831. for _ in range(5):
  3832. await asyncio.sleep(0)
  3833. assert (
  3834. not drop_delete_started.is_set()
  3835. ), "indices.delete should be blocked during deferred embedding"
  3836. assert not drop_task.done()
  3837. embedding_can_finish.set()
  3838. await flush_task
  3839. await drop_task
  3840. assert drop_delete_started.is_set()
  3841. # ---------------------------------------------------------------------------
  3842. # Cosine score edge cases
  3843. # ---------------------------------------------------------------------------
  3844. class TestScoreThreshold:
  3845. """Verify that raw OpenSearch scores are compared directly against threshold."""
  3846. def test_above_threshold(self):
  3847. assert 0.85 >= 0.2
  3848. def test_below_threshold(self):
  3849. assert 0.15 < 0.2
  3850. def test_exact_threshold(self):
  3851. assert 0.2 >= 0.2
  3852. # ---------------------------------------------------------------------------
  3853. # Why raising EMBEDDING_BATCH_NUM does not lower the embedding call count
  3854. # ---------------------------------------------------------------------------
  3855. class TestEmbeddingBatchNumDiagnosis:
  3856. """Pin down why bumping EMBEDDING_BATCH_NUM leaves the embedding call
  3857. count (get_embedding_queue_status -> submitted_total) unchanged for
  3858. entities/relations.
  3859. ``merge_nodes_and_edges`` upserts entities/relations ONE id at a time:
  3860. ``_merge_nodes_then_upsert`` calls ``entity_vdb.upsert({single})`` and
  3861. ``_merge_edges_then_upsert`` calls ``relationships_vdb.upsert({single})``
  3862. (lightrag/operate.py). ``EMBEDDING_BATCH_NUM`` only slices the items
  3863. *within one embedding pass* (``contents[i:i+batch]``). So the call count
  3864. is governed by how many items reach a single embedding pass, not by the
  3865. batch size -- raising the batch size only helps once >= 2 items are
  3866. embedded together.
  3867. """
  3868. def _make(self, batch_num, embed_func, workspace="diag"):
  3869. config = {
  3870. "embedding_batch_num": batch_num,
  3871. "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2},
  3872. }
  3873. return OpenSearchVectorDBStorage(
  3874. namespace="entities",
  3875. global_config=config,
  3876. embedding_func=embed_func,
  3877. workspace=workspace,
  3878. meta_fields={"content", "entity_name"},
  3879. )
  3880. @staticmethod
  3881. def _fake_bulk(_client, actions, *_args, **_kwargs):
  3882. # async_bulk(raise_on_error=False) -> (success_count, failed_list).
  3883. # Empty failed list = every buffered action persisted.
  3884. return (len(actions), [])
  3885. async def _run_per_item(self, batch_num, *, flush_each, n=100):
  3886. """Upsert ``n`` entities one-at-a-time, mirroring the merge path.
  3887. flush_each=True -> embed right after each single-item upsert, so every
  3888. embedding pass sees exactly 1 item. This is the
  3889. pre-defer / eager behaviour where ``upsert`` embeds
  3890. inline.
  3891. flush_each=False -> buffer every single-item upsert and flush once, i.e.
  3892. the deferred-embedding design on this branch.
  3893. """
  3894. embed = CountingEmbeddingFunc()
  3895. mock_client = _make_client()
  3896. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3897. with patch(
  3898. "lightrag.kg.opensearch_impl.helpers.async_bulk",
  3899. new_callable=AsyncMock,
  3900. ) as mock_bulk:
  3901. mock_bulk.side_effect = self._fake_bulk
  3902. s = self._make(batch_num, embed)
  3903. await s.initialize()
  3904. for i in range(n):
  3905. await s.upsert(
  3906. {f"ent-{i}": {"content": f"entity {i}", "entity_name": f"E{i}"}}
  3907. )
  3908. if flush_each:
  3909. await s.index_done_callback()
  3910. if not flush_each:
  3911. await s.index_done_callback()
  3912. return embed
  3913. @pytest.mark.asyncio
  3914. async def test_per_item_embedding_makes_batch_num_a_noop(self):
  3915. """Eager pattern: embedding happens once per single-item upsert.
  3916. Reproduces the billing observation -- every embedding call carries
  3917. exactly ONE item (~one entity's tokens) -- and bumping
  3918. EMBEDDING_BATCH_NUM from 16 to 32 changes nothing.
  3919. """
  3920. embed16 = await self._run_per_item(16, flush_each=True)
  3921. embed32 = await self._run_per_item(32, flush_each=True)
  3922. assert embed16.call_count == 100
  3923. assert embed32.call_count == 100
  3924. # Each embedding pass saw exactly one item, regardless of batch size.
  3925. assert all(len(b) == 1 for b in embed16.batches)
  3926. assert all(len(b) == 1 for b in embed32.batches)
  3927. # The crux: raising the batch size did not reduce the call count.
  3928. assert embed16.call_count == embed32.call_count
  3929. @pytest.mark.asyncio
  3930. async def test_deferred_flush_makes_batch_num_effective(self):
  3931. """Deferred pattern: buffer all single-item upserts, flush once.
  3932. Now EMBEDDING_BATCH_NUM finally governs the count:
  3933. ceil(100/16)=7 vs ceil(100/32)=4.
  3934. """
  3935. embed16 = await self._run_per_item(16, flush_each=False)
  3936. embed32 = await self._run_per_item(32, flush_each=False)
  3937. assert embed16.call_count == math.ceil(100 / 16) == 7
  3938. assert embed32.call_count == math.ceil(100 / 32) == 4
  3939. assert embed16.call_count != embed32.call_count
  3940. # Every flushed batch respects the configured cap, and nothing is lost.
  3941. assert all(len(b) <= 16 for b in embed16.batches)
  3942. assert all(len(b) <= 32 for b in embed32.batches)
  3943. assert len(embed16.texts) == 100
  3944. assert len(embed32.texts) == 100
  3945. @pytest.mark.asyncio
  3946. async def test_single_multiitem_upsert_is_batched_like_chunks_vdb(self):
  3947. """Contrast: chunks_vdb upserts a whole document's chunks in ONE call.
  3948. When many items arrive in a single upsert/embedding pass,
  3949. EMBEDDING_BATCH_NUM works as expected even with an immediate flush --
  3950. proving the determining factor is items-per-embedding-pass, not the
  3951. storage backend. This is why batch_num visibly affects chunks but not
  3952. per-id entity/relation upserts.
  3953. """
  3954. embed16 = CountingEmbeddingFunc()
  3955. embed32 = CountingEmbeddingFunc()
  3956. for batch_num, embed in ((16, embed16), (32, embed32)):
  3957. mock_client = _make_client()
  3958. with patch.object(ClientManager, "get_client", return_value=mock_client):
  3959. with patch(
  3960. "lightrag.kg.opensearch_impl.helpers.async_bulk",
  3961. new_callable=AsyncMock,
  3962. ) as mock_bulk:
  3963. mock_bulk.side_effect = self._fake_bulk
  3964. s = self._make(batch_num, embed)
  3965. await s.initialize()
  3966. # chunks_vdb.upsert(chunks): one call carrying 100 items.
  3967. await s.upsert(
  3968. {
  3969. f"chunk-{i}": {"content": f"chunk {i}", "entity_name": ""}
  3970. for i in range(100)
  3971. }
  3972. )
  3973. await s.index_done_callback()
  3974. assert embed16.call_count == math.ceil(100 / 16) == 7
  3975. assert embed32.call_count == math.ceil(100 / 32) == 4
  3976. assert embed16.call_count != embed32.call_count