| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358 |
- """
- Unit tests for OpenSearch storage implementations.
- All tests use mocks — no running OpenSearch instance required.
- Run with: pytest tests/kg/opensearch_impl/test_opensearch_storage.py -v
- """
- import asyncio
- import math
- import pytest
- from contextlib import asynccontextmanager
- from unittest.mock import AsyncMock, patch
- import numpy as np
- pytest.importorskip(
- "opensearchpy",
- reason="opensearchpy is required for OpenSearch storage tests",
- )
- from opensearchpy.exceptions import NotFoundError, OpenSearchException # type: ignore
- from lightrag.kg.opensearch_impl import (
- OpenSearchKVStorage,
- OpenSearchDocStatusStorage,
- OpenSearchGraphStorage,
- OpenSearchVectorDBStorage,
- ClientManager,
- _build_index_name,
- _resolve_workspace,
- _sanitize_index_name,
- _verify_mirrored_id_mapping,
- )
- from lightrag.base import DocStatus, DocProcessingStatus
- pytestmark = pytest.mark.offline
- # ---------------------------------------------------------------------------
- # Mock the shared storage lock so tests don't need full LightRAG init
- # ---------------------------------------------------------------------------
- @asynccontextmanager
- async def _mock_lock():
- yield
- def _mock_lock_factory():
- return _mock_lock()
- def _missing_index_error() -> NotFoundError:
- return NotFoundError(404, "index_not_found_exception", "no such index")
- @pytest.fixture(autouse=True)
- def patch_data_init_lock():
- """Patch get_data_init_lock globally so initialize() works without shared storage."""
- with patch(
- "lightrag.kg.opensearch_impl.get_data_init_lock", side_effect=_mock_lock_factory
- ):
- yield
- @pytest.fixture(autouse=True)
- def patch_namespace_lock():
- """Patch get_namespace_lock to return real asyncio.Lock instances.
- Returning a real Lock (not a no-op) preserves the in-process blocking
- semantics the storage relies on, so concurrent flush / read / write
- tests can observe actual serialization. Locks are cached per
- (namespace, workspace) tuple so multiple calls from the same storage
- pick up the same Lock instance.
- """
- cache: dict[tuple[str, str | None], asyncio.Lock] = {}
- def factory(namespace, workspace=None, enable_logging=False):
- key = (namespace, workspace or "")
- lock = cache.get(key)
- if lock is None:
- lock = asyncio.Lock()
- cache[key] = lock
- return lock
- with patch("lightrag.kg.opensearch_impl.get_namespace_lock", side_effect=factory):
- yield
- @pytest.fixture(autouse=True)
- def patch_shard_doc_supported():
- """Default tests to OpenSearch >= 3.3.0 so the __mirrored_id verification is a no-op.
- Tests covering the < 3.3.0 fallback should override this with their own patch.
- """
- with patch("lightrag.kg.opensearch_impl._shard_doc_supported", True):
- yield
- # ---------------------------------------------------------------------------
- # Fixtures
- # ---------------------------------------------------------------------------
- class MockEmbeddingFunc:
- """Mock embedding function that returns random vectors."""
- def __init__(self, dim=128):
- self.embedding_dim = dim
- self.max_token_size = 512
- self.model_name = "mock-embed"
- async def __call__(self, texts, **kwargs):
- return np.random.rand(len(texts), self.embedding_dim).astype(np.float32)
- class CountingEmbeddingFunc(MockEmbeddingFunc):
- """Embedding test double that records calls and can fail a fixed number of times."""
- def __init__(self, dim=128, fail_times=0):
- super().__init__(dim=dim)
- self.fail_times = fail_times
- self.call_count = 0
- self.batches: list[list[str]] = []
- self.texts: list[str] = []
- async def __call__(self, texts, **kwargs):
- self.call_count += 1
- batch = list(texts)
- self.batches.append(batch)
- self.texts.extend(batch)
- if self.fail_times > 0:
- self.fail_times -= 1
- raise RuntimeError("embedding failed")
- return await super().__call__(texts, **kwargs)
- @pytest.fixture
- def global_config():
- """Standard global config fixture for all storage tests."""
- return {
- "embedding_batch_num": 10,
- "max_graph_nodes": 1000,
- "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2},
- }
- @pytest.fixture
- def embed_func():
- """Mock embedding function fixture."""
- return MockEmbeddingFunc()
- def _make_client():
- """Create a fully-mocked AsyncOpenSearch client with spec validation."""
- from opensearchpy import AsyncOpenSearch
- client = AsyncMock(spec=AsyncOpenSearch)
- # indices sub-client
- client.indices = AsyncMock()
- client.indices.exists = AsyncMock(return_value=False)
- client.indices.create = AsyncMock()
- client.indices.delete = AsyncMock()
- client.indices.refresh = AsyncMock()
- client.indices.get_mapping = AsyncMock(return_value={})
- # transport for PPL
- client.transport = AsyncMock()
- client.transport.perform_request = AsyncMock(
- side_effect=Exception("PPL not available")
- )
- # document operations
- client.exists = AsyncMock(return_value=False)
- client.index = AsyncMock()
- client.delete = AsyncMock()
- client.delete_by_query = AsyncMock()
- client.get = AsyncMock(
- return_value={
- "_id": "doc1",
- "_source": {"content": "hello", "create_time": 0, "update_time": 0},
- }
- )
- client.mget = AsyncMock(
- return_value={
- "docs": [
- {"_id": "id1", "found": True, "_source": {"content": "c1"}},
- {"_id": "id2", "found": True, "_source": {"content": "c2"}},
- ]
- }
- )
- client.count = AsyncMock(return_value={"count": 5})
- client.search = AsyncMock(
- return_value={
- "hits": {"hits": [], "total": {"value": 0}},
- "aggregations": {
- "status_counts": {"buckets": []},
- "src": {"buckets": []},
- "tgt": {"buckets": []},
- "source_degrees": {"buckets": []},
- "target_degrees": {"buckets": []},
- },
- }
- )
- # PIT operations
- client.create_pit = AsyncMock(return_value={"pit_id": "mock_pit_id_123"})
- client.delete_pit = AsyncMock()
- return client
- @pytest.fixture
- def mock_client():
- """Fully-mocked AsyncOpenSearch client fixture."""
- return _make_client()
- # ---------------------------------------------------------------------------
- # Helper utilities
- # ---------------------------------------------------------------------------
- class TestHelpers:
- """Tests for module-level helper functions (_build_index_name, _resolve_workspace, _sanitize_index_name)."""
- def test_build_index_name_with_workspace(self):
- ws, ns, idx = _build_index_name("myws", "text_chunks")
- assert ws == "myws"
- assert ns == "myws_text_chunks"
- assert idx == _sanitize_index_name("myws_text_chunks")
- def test_build_index_name_no_workspace(self):
- ws, ns, idx = _build_index_name("", "chunks")
- assert ws == ""
- assert idx == _sanitize_index_name("chunks")
- def test_resolve_workspace_env_override(self):
- with patch.dict("os.environ", {"OPENSEARCH_WORKSPACE": "forced"}):
- assert _resolve_workspace("original", "ns") == "forced"
- def test_resolve_workspace_fallback(self):
- with patch.dict("os.environ", {}, clear=True):
- assert _resolve_workspace("original", "ns") == "original"
- def test_sanitize_index_name(self):
- assert _sanitize_index_name("Hello_World") == "hello_world"
- assert _sanitize_index_name("-bad") == "x-bad"
- assert _sanitize_index_name("a.b/c") == "a_b_c"
- # ---------------------------------------------------------------------------
- # ClientManager
- # ---------------------------------------------------------------------------
- class TestClientManager:
- """Tests for ClientManager singleton pattern and reference counting."""
- @staticmethod
- def _stub_client(version: str = "3.3.0") -> AsyncMock:
- """Build an AsyncMock client with a concrete .info() payload.
- Without this stub, _detect_shard_doc_support's chained .get(...) calls
- on an AsyncMock would leak un-awaited coroutines.
- """
- client = AsyncMock()
- client.info = AsyncMock(return_value={"version": {"number": version}})
- return client
- @pytest.mark.asyncio
- async def test_singleton_and_refcount(self):
- ClientManager._instances = {"client": None, "ref_count": 0}
- with patch("lightrag.kg.opensearch_impl.AsyncOpenSearch") as mock_cls:
- mock_cls.return_value = self._stub_client()
- c1 = await ClientManager.get_client()
- c2 = await ClientManager.get_client()
- assert c1 is c2
- assert ClientManager._instances["ref_count"] == 2
- await ClientManager.release_client(c1)
- assert ClientManager._instances["ref_count"] == 1
- await ClientManager.release_client(c2)
- assert ClientManager._instances["ref_count"] == 0
- assert ClientManager._instances["client"] is None
- @pytest.mark.asyncio
- async def test_close_called_on_last_release(self):
- ClientManager._instances = {"client": None, "ref_count": 0}
- with patch("lightrag.kg.opensearch_impl.AsyncOpenSearch") as mock_cls:
- inner = self._stub_client()
- mock_cls.return_value = inner
- c = await ClientManager.get_client()
- await ClientManager.release_client(c)
- inner.close.assert_awaited_once()
- # ---------------------------------------------------------------------------
- # _verify_mirrored_id_mapping helper
- # ---------------------------------------------------------------------------
- class TestMirroredIdVerification:
- """Tests for the _verify_mirrored_id_mapping fail-fast helper."""
- @pytest.mark.asyncio
- async def test_skipped_on_modern_cluster(self, mock_client):
- """On OpenSearch >= 3.3.0 the mapping check is short-circuited."""
- # _shard_doc_supported is True via autouse fixture.
- await _verify_mirrored_id_mapping(mock_client, "any_index")
- mock_client.indices.get_mapping.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_passes_when_mapping_present(self, mock_client):
- """On OpenSearch < 3.3.0 a mapping containing __mirrored_id is accepted."""
- mock_client.indices.get_mapping = AsyncMock(
- return_value={
- "my_index": {
- "mappings": {"properties": {"__mirrored_id": {"type": "keyword"}}}
- }
- }
- )
- with patch("lightrag.kg.opensearch_impl._shard_doc_supported", False):
- await _verify_mirrored_id_mapping(mock_client, "my_index")
- @pytest.mark.asyncio
- async def test_fails_fast_when_mapping_missing(self, mock_client):
- """On OpenSearch < 3.3.0 a legacy index without __mirrored_id raises."""
- mock_client.indices.get_mapping = AsyncMock(
- return_value={
- "my_index": {
- "mappings": {"properties": {"other_field": {"type": "text"}}}
- }
- }
- )
- with patch("lightrag.kg.opensearch_impl._shard_doc_supported", False):
- with pytest.raises(RuntimeError, match="__mirrored_id"):
- await _verify_mirrored_id_mapping(mock_client, "my_index")
- @pytest.mark.asyncio
- async def test_swallows_get_mapping_error(self, mock_client):
- """Mapping-fetch failures should not block initialization."""
- mock_client.indices.get_mapping = AsyncMock(
- side_effect=OpenSearchException("transport error")
- )
- with patch("lightrag.kg.opensearch_impl._shard_doc_supported", False):
- await _verify_mirrored_id_mapping(mock_client, "my_index")
- # ---------------------------------------------------------------------------
- # KV Storage
- # ---------------------------------------------------------------------------
- class TestKVStorage:
- """Tests for OpenSearchKVStorage CRUD operations, timestamps, refresh behavior."""
- def _make(self, global_config, embed_func, workspace="test"):
- return OpenSearchKVStorage(
- namespace="text_chunks",
- global_config=global_config,
- embedding_func=embed_func,
- workspace=workspace,
- )
- @pytest.mark.asyncio
- async def test_index_name(self, global_config, embed_func):
- s = self._make(global_config, embed_func, workspace="proj_a")
- assert s._index_name == "proj_a_text_chunks"
- @pytest.mark.asyncio
- async def test_initialize_creates_index(
- self, global_config, embed_func, mock_client
- ):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- mock_client.indices.exists.assert_awaited_once()
- mock_client.indices.create.assert_awaited_once()
- @pytest.mark.asyncio
- async def test_initialize_skips_existing_index(
- self, global_config, embed_func, mock_client
- ):
- mock_client.indices.exists = AsyncMock(return_value=True)
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- mock_client.indices.create.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_initialize_fails_on_legacy_index_without_mirrored_id(
- self, global_config, embed_func, mock_client
- ):
- """On OpenSearch < 3.3.0, an existing index lacking __mirrored_id must fail-fast."""
- mock_client.indices.exists = AsyncMock(return_value=True)
- mock_client.indices.get_mapping = AsyncMock(
- return_value={
- "test_text_chunks": {
- "mappings": {"properties": {"content": {"type": "text"}}}
- }
- }
- )
- with (
- patch.object(ClientManager, "get_client", return_value=mock_client),
- patch("lightrag.kg.opensearch_impl._shard_doc_supported", False),
- ):
- s = self._make(global_config, embed_func)
- with pytest.raises(RuntimeError, match="__mirrored_id"):
- await s.initialize()
- mock_client.indices.create.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_get_by_id(self, global_config, embed_func, mock_client):
- mock_client.mget = AsyncMock(
- return_value={
- "docs": [
- {
- "_id": "doc1",
- "found": True,
- "_source": {
- "content": "hello",
- "create_time": 0,
- "update_time": 0,
- },
- }
- ]
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- doc = await s.get_by_id("doc1")
- assert doc is not None
- assert doc["content"] == "hello"
- assert doc["_id"] == "doc1"
- mock_client.mget.assert_awaited_once_with(
- index=s._index_name, body={"ids": ["doc1"]}
- )
- @pytest.mark.asyncio
- async def test_get_by_id_not_found(self, global_config, embed_func, mock_client):
- mock_client.mget = AsyncMock(
- return_value={"docs": [{"_id": "missing", "found": False}]}
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert await s.get_by_id("missing") is None
- mock_client.get.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_get_by_ids_preserves_order(
- self, global_config, embed_func, mock_client
- ):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- docs = await s.get_by_ids(["id1", "id2"])
- assert docs[0]["content"] == "c1"
- assert docs[1]["content"] == "c2"
- @pytest.mark.asyncio
- async def test_filter_keys(self, global_config, embed_func, mock_client):
- mock_client.mget = AsyncMock(
- return_value={
- "docs": [
- {"_id": "a", "found": True},
- {"_id": "b", "found": False},
- ]
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- result = await s.filter_keys({"a", "b"})
- assert result == {"b"}
- @pytest.mark.asyncio
- async def test_upsert_no_per_operation_refresh(
- self, global_config, embed_func, mock_client
- ):
- """The flush (during index_done_callback) must not request per-op refresh."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"k1": {"content": "v1"}})
- # upsert buffers; bulk fires on flush.
- mock_bulk.assert_not_awaited()
- await s.index_done_callback()
- _, kwargs = mock_bulk.call_args
- assert "refresh" not in kwargs
- @pytest.mark.asyncio
- async def test_upsert_sets_timestamps(self, global_config, embed_func, mock_client):
- """Buffered docs carry create_time / update_time set eagerly during upsert."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"k1": {"content": "v1"}})
- # Timestamps are visible in the pending buffer immediately.
- assert "create_time" in s._pending_upserts["k1"]
- assert "update_time" in s._pending_upserts["k1"]
- await s.index_done_callback()
- actions = mock_bulk.call_args[0][1]
- src = actions[0]["_source"]
- assert "create_time" in src
- assert "update_time" in src
- @pytest.mark.asyncio
- async def test_is_empty(self, global_config, embed_func, mock_client):
- mock_client.count = AsyncMock(return_value={"count": 0})
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert await s.is_empty() is True
- @pytest.mark.asyncio
- async def test_delete(self, global_config, embed_func, mock_client):
- """delete() buffers tombstones; the bulk delete fires on flush."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (2, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.delete(["a", "b"])
- mock_bulk.assert_not_awaited()
- assert s._pending_kv_deletes == {"a", "b"}
- await s.index_done_callback()
- actions = mock_bulk.call_args[0][1]
- assert len(actions) == 2
- assert all(a["_op_type"] == "delete" for a in actions)
- @pytest.mark.asyncio
- async def test_drop(self, global_config, embed_func, mock_client):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- result = await s.drop()
- assert result["status"] == "success"
- mock_client.indices.delete.assert_awaited_once()
- @pytest.mark.asyncio
- async def test_drop_error_marks_index_not_ready_and_next_upsert_recreates_index(
- self, global_config, embed_func, mock_client
- ):
- mock_client.indices.delete = AsyncMock(
- side_effect=OpenSearchException("drop failed")
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- with patch.object(
- s, "_create_index_if_not_exists", new_callable=AsyncMock
- ) as mock_create:
- result = await s.drop()
- assert result["status"] == "error"
- assert s._index_ready is False
- await s.upsert({"k1": {"content": "v1"}})
- mock_create.assert_awaited_once()
- @pytest.mark.asyncio
- async def test_upsert_after_drop_recreates_index(
- self, global_config, embed_func, mock_client
- ):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- with patch.object(
- s, "_create_index_if_not_exists", new_callable=AsyncMock
- ) as mock_create:
- await s.drop()
- await s.upsert({"k1": {"content": "v1"}})
- mock_create.assert_awaited_once()
- @pytest.mark.asyncio
- async def test_reads_short_circuit_after_drop(
- self, global_config, embed_func, mock_client
- ):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.drop()
- assert await s.get_by_id("doc1") is None
- assert await s.get_by_ids(["doc1", "doc2"]) == [None, None]
- assert await s.is_empty() is True
- mock_client.mget.assert_not_awaited()
- mock_client.count.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_read_missing_index_demotes_readiness(
- self, global_config, embed_func, mock_client
- ):
- mock_client.mget = AsyncMock(side_effect=_missing_index_error())
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert await s.get_by_id("doc1") is None
- assert await s.get_by_id("doc1") is None
- assert s._index_ready is False
- assert mock_client.mget.await_count == 1
- @pytest.mark.asyncio
- async def test_iter_raw_docs_uses_pit_and_search_after(
- self, global_config, embed_func, mock_client
- ):
- mock_client.search = AsyncMock(
- side_effect=[
- {
- "hits": {
- "hits": [
- {"_id": "d1", "_source": {"content": "a"}, "sort": [1]},
- {"_id": "d2", "_source": {"content": "b"}, "sort": [2]},
- ]
- }
- },
- {
- "hits": {
- "hits": [
- {"_id": "d3", "_source": {"content": "c"}, "sort": [3]}
- ]
- }
- },
- ]
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- batches = [batch async for batch in s._iter_raw_docs(batch_size=2)]
- assert [[doc["_id"] for doc in batch] for batch in batches] == [
- ["d1", "d2"],
- ["d3"],
- ]
- assert (
- "search_after"
- not in mock_client.search.await_args_list[0].kwargs["body"]
- )
- assert mock_client.search.await_args_list[1].kwargs["body"][
- "search_after"
- ] == [2]
- mock_client.create_pit.assert_awaited_once()
- mock_client.delete_pit.assert_awaited_once()
- @pytest.mark.asyncio
- async def test_iter_raw_docs_missing_index_demotes_readiness(
- self, global_config, embed_func, mock_client
- ):
- mock_client.search = AsyncMock(side_effect=_missing_index_error())
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- batches = [batch async for batch in s._iter_raw_docs(batch_size=2)]
- assert batches == []
- assert s._index_ready is False
- mock_client.create_pit.assert_awaited_once()
- mock_client.delete_pit.assert_awaited_once()
- @pytest.mark.asyncio
- async def test_finalize(self, global_config, embed_func, mock_client):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch.object(
- ClientManager, "release_client", new_callable=AsyncMock
- ) as mock_release:
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.finalize()
- mock_release.assert_awaited_once()
- assert s.client is None
- # ---------------------------------------------------------------------------
- # KV storage write batching (derived from issue #2785 / PR #2822)
- # ---------------------------------------------------------------------------
- class TestKVStorageBatching:
- """Tests for the buffered upsert/delete + flush behaviour."""
- def _make(self, global_config, embed_func, workspace="test"):
- return OpenSearchKVStorage(
- namespace="text_chunks",
- global_config=global_config,
- embedding_func=embed_func,
- workspace=workspace,
- )
- @pytest.mark.asyncio
- async def test_repeated_kv_upserts_flush_in_single_bulk_call(
- self, global_config, embed_func, mock_client
- ):
- """Many small upsert() calls collapse to one async_bulk on flush."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (5, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- for i in range(5):
- await s.upsert({f"k{i}": {"content": f"doc {i}"}})
- mock_bulk.assert_not_awaited()
- await s.index_done_callback()
- mock_bulk.assert_awaited_once()
- actions = mock_bulk.call_args[0][1]
- assert len(actions) == 5
- assert {a["_id"] for a in actions} == {f"k{i}" for i in range(5)}
- @pytest.mark.asyncio
- async def test_kv_upsert_overwrites_pending_doc_for_same_id(
- self, global_config, embed_func, mock_client
- ):
- """Upserting the same id twice keeps only the latest payload."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"k1": {"content": "first"}})
- await s.upsert({"k1": {"content": "second"}})
- await s.index_done_callback()
- actions = mock_bulk.call_args[0][1]
- assert len(actions) == 1
- assert actions[0]["_source"]["content"] == "second"
- @pytest.mark.asyncio
- async def test_kv_delete_cancels_pending_upsert(
- self, global_config, embed_func, mock_client
- ):
- """A delete after a buffered upsert removes the upsert from the buffer.
- Without this, the flush would re-index the doc and silently
- resurrect a logically-deleted key.
- """
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"k1": {"content": "doomed"}})
- await s.delete(["k1"])
- assert "k1" not in s._pending_upserts
- assert "k1" in s._pending_kv_deletes
- await s.index_done_callback()
- actions = mock_bulk.call_args[0][1]
- assert len(actions) == 1
- assert actions[0]["_op_type"] == "delete"
- @pytest.mark.asyncio
- async def test_kv_upsert_cancels_pending_delete(
- self, global_config, embed_func, mock_client
- ):
- """An upsert after a buffered delete removes the tombstone."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.delete(["k1"])
- await s.upsert({"k1": {"content": "resurrected"}})
- assert "k1" not in s._pending_kv_deletes
- assert "k1" in s._pending_upserts
- await s.index_done_callback()
- actions = mock_bulk.call_args[0][1]
- assert len(actions) == 1
- assert actions[0]["_op_type"] == "index"
- @pytest.mark.asyncio
- async def test_kv_delete_works_when_index_not_ready(
- self, global_config, embed_func, mock_client
- ):
- """delete() must invalidate pending upserts even if the index has
- been marked missing -- otherwise the next flush would resurrect
- the logically-deleted key.
- """
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"k1": {"content": "x"}})
- s._mark_index_missing()
- await s.delete(["k1"])
- # Buffer invariants hold regardless of _index_ready.
- assert "k1" not in s._pending_upserts
- assert "k1" in s._pending_kv_deletes
- @pytest.mark.asyncio
- async def test_kv_get_by_id_reads_pending_buffer(
- self, global_config, embed_func, mock_client
- ):
- """Buffered upserts are visible to get_by_id without hitting OpenSearch."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"k1": {"content": "buffered"}})
- doc = await s.get_by_id("k1")
- assert doc is not None
- assert doc["_id"] == "k1"
- assert doc["content"] == "buffered"
- mock_client.mget.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_kv_get_by_id_returns_none_for_pending_delete(
- self, global_config, embed_func, mock_client
- ):
- """A pending tombstone shadows any persisted doc, without mget RTT."""
- mock_client.mget = AsyncMock()
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.delete(["k1"])
- assert await s.get_by_id("k1") is None
- mock_client.mget.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_kv_get_by_id_strips_mirrored_id_from_buffer_path(
- self, global_config, embed_func, mock_client
- ):
- """Buffered docs internally carry __mirrored_id (used for PIT sort);
- the returned dict must NOT expose it, matching the mget read path."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"k1": {"content": "x"}})
- # Sanity: the buffer entry itself carries __mirrored_id.
- assert s._pending_upserts["k1"]["__mirrored_id"] == "k1"
- doc = await s.get_by_id("k1")
- assert doc is not None
- assert "__mirrored_id" not in doc
- assert doc["_id"] == "k1"
- @pytest.mark.asyncio
- async def test_kv_get_by_ids_merges_buffer_and_mget(
- self, global_config, embed_func, mock_client
- ):
- """get_by_ids returns buffered docs and falls back to mget for the rest."""
- mock_client.mget = AsyncMock(
- return_value={
- "docs": [
- {
- "_id": "k2",
- "found": True,
- "_source": {"content": "from_index"},
- },
- ]
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"k1": {"content": "buffered"}})
- docs = await s.get_by_ids(["k1", "k2"])
- assert docs[0]["content"] == "buffered"
- assert "__mirrored_id" not in docs[0]
- assert docs[1]["content"] == "from_index"
- mock_client.mget.assert_awaited_once_with(
- index=s._index_name, body={"ids": ["k2"]}
- )
- @pytest.mark.asyncio
- async def test_kv_filter_keys_excludes_buffered_upserts(
- self, global_config, embed_func, mock_client
- ):
- """Buffered upserts shadow OpenSearch: filter_keys treats them as
- existing and never queries them via mget."""
- mock_client.mget = AsyncMock(
- return_value={"docs": [{"_id": "k2", "found": False}]}
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"k1": {"content": "x"}})
- missing = await s.filter_keys({"k1", "k2"})
- assert missing == {"k2"}
- # Only the unbuffered id is queried server-side.
- ((_, kwargs),) = mock_client.mget.await_args_list[0:1]
- assert kwargs["body"] == {"ids": ["k2"]}
- @pytest.mark.asyncio
- async def test_kv_filter_keys_treats_buffered_deletes_as_missing(
- self, global_config, embed_func, mock_client
- ):
- """A persisted-but-pending-delete key must be reported as missing
- AND must NOT be looked up via mget (otherwise the still-persisted
- row would be misclassified as existing)."""
- mock_client.mget = AsyncMock(
- return_value={"docs": [{"_id": "k3", "found": True}]}
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.delete(["k1"]) # tombstone
- missing = await s.filter_keys({"k1", "k3"})
- assert "k1" in missing # tombstoned key counts as missing
- assert "k3" not in missing # exists on server
- # The tombstone id was NOT sent to mget.
- mget_kwargs = mock_client.mget.await_args_list[0].kwargs
- assert mget_kwargs["body"] == {"ids": ["k3"]}
- @pytest.mark.asyncio
- async def test_kv_is_empty_returns_false_with_pending_upsert(
- self, global_config, embed_func, mock_client
- ):
- """is_empty short-circuits to False when the buffer has pending
- upserts -- avoiding the counterintuitive "I just upserted but
- is_empty returned True" outcome."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"k1": {"content": "x"}})
- assert await s.is_empty() is False
- mock_client.count.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_kv_finalize_flushes_pending(
- self, global_config, embed_func, mock_client
- ):
- """finalize() flushes the buffer before releasing the client."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"k1": {"content": "to flush"}})
- await s.finalize()
- mock_bulk.assert_awaited_once()
- assert s.client is None
- @pytest.mark.asyncio
- async def test_kv_finalize_raises_when_retryable_buffer_remains(
- self, global_config, embed_func, mock_client
- ):
- """finalize() must surface a RuntimeError when retryable bulk
- failures left rows buffered, otherwise the upstream
- finalize_storages() call would log the storage as successfully
- finalized while writes are silently lost.
- The client is still released so we don't leak a connection on
- shutdown.
- """
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch.object(
- ClientManager, "release_client", new_callable=AsyncMock
- ) as mock_release:
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk",
- new_callable=AsyncMock,
- ) as mock_bulk:
- # 503 is retryable; flush keeps it in the buffer.
- mock_bulk.return_value = (
- 0,
- [{"index": {"_id": "k1", "status": 503, "error": "down"}}],
- )
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"k1": {"content": "stuck"}})
- with pytest.raises(RuntimeError, match="pending upserts"):
- await s.finalize()
- # Client released regardless of the failure.
- mock_release.assert_awaited_once()
- assert s.client is None
- @pytest.mark.asyncio
- async def test_kv_finalize_propagates_flush_exception(
- self, global_config, embed_func, mock_client
- ):
- """If async_bulk itself raises, finalize() still releases the
- client and wraps the original error in a RuntimeError that
- names the unflushed buffer counts.
- """
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch.object(
- ClientManager, "release_client", new_callable=AsyncMock
- ) as mock_release:
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk",
- new_callable=AsyncMock,
- ) as mock_bulk:
- mock_bulk.side_effect = OpenSearchException("connection reset")
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"k1": {"content": "stuck"}})
- with pytest.raises(RuntimeError) as exc_info:
- await s.finalize()
- # Wrapped: cause is the original OpenSearchException.
- assert isinstance(exc_info.value.__cause__, OpenSearchException)
- mock_release.assert_awaited_once()
- assert s.client is None
- @pytest.mark.asyncio
- async def test_kv_finalize_propagates_cancellation(
- self, global_config, embed_func, mock_client
- ):
- """asyncio.CancelledError raised during the final flush must
- propagate UN-wrapped so the shutdown sequence honours the
- cancellation signal. The client is still released (finally
- block) before the cancellation continues.
- """
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch.object(
- ClientManager, "release_client", new_callable=AsyncMock
- ) as mock_release:
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk",
- new_callable=AsyncMock,
- ) as mock_bulk:
- mock_bulk.side_effect = asyncio.CancelledError()
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"k1": {"content": "stuck"}})
- with pytest.raises(asyncio.CancelledError):
- await s.finalize()
- # finally block still released the client.
- mock_release.assert_awaited_once()
- assert s.client is None
- @pytest.mark.asyncio
- async def test_kv_drop_discards_buffers_and_serialises_with_flush(
- self, global_config, embed_func, mock_client
- ):
- """drop() drops both buffers and is serialised with any in-flight
- flush so indices.delete cannot land mid-bulk."""
- flush_started = asyncio.Event()
- flush_can_finish = asyncio.Event()
- drop_delete_started = asyncio.Event()
- async def slow_bulk(client, actions, raise_on_error=False, **kwargs):
- flush_started.set()
- await flush_can_finish.wait()
- return (len(actions), [])
- async def watch_indices_delete(**kwargs):
- drop_delete_started.set()
- mock_client.indices.delete = AsyncMock(side_effect=watch_indices_delete)
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch("lightrag.kg.opensearch_impl.helpers.async_bulk", new=slow_bulk):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"k1": {"content": "x"}})
- await s.delete(["k2"])
- flush_task = asyncio.create_task(s.index_done_callback())
- await flush_started.wait()
- drop_task = asyncio.create_task(s.drop())
- for _ in range(5):
- await asyncio.sleep(0)
- assert (
- not drop_delete_started.is_set()
- ), "indices.delete should be blocked behind the flush lock"
- assert not drop_task.done()
- flush_can_finish.set()
- await flush_task
- await drop_task
- assert drop_delete_started.is_set()
- # Even though flush flushed k1/k2, drop() then cleared the
- # buffer state (no-op here because flush already drained
- # them, but the assertion confirms drop() does not crash
- # against the now-empty buffer).
- assert s._pending_upserts == {}
- assert s._pending_kv_deletes == set()
- @pytest.mark.asyncio
- async def test_kv_failed_flush_retains_retryable(
- self, global_config, embed_func, mock_client
- ):
- """Transient (5xx) per-doc failures stay buffered for the next flush."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (
- 1,
- [{"index": {"_id": "k2", "status": 503, "error": "down"}}],
- )
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"k1": {"content": "ok"}, "k2": {"content": "boom"}})
- await s.index_done_callback()
- assert "k1" not in s._pending_upserts
- assert "k2" in s._pending_upserts
- @pytest.mark.asyncio
- async def test_kv_failed_flush_drops_non_retryable(
- self, global_config, embed_func, mock_client
- ):
- """Permanent (4xx, e.g. mapping error) failures are cleared from
- the buffer rather than retried forever."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (
- 0,
- [
- {
- "index": {
- "_id": "k1",
- "status": 400,
- "error": {
- "type": "mapper_parsing_exception",
- "reason": "bad",
- },
- }
- },
- {"index": {"_id": "k2", "status": 503, "error": "down"}},
- ],
- )
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"k1": {"content": "x"}, "k2": {"content": "y"}})
- await s.index_done_callback()
- assert "k1" not in s._pending_upserts
- assert "k2" in s._pending_upserts
- @pytest.mark.asyncio
- async def test_kv_concurrent_upsert_during_flush_blocked(
- self, global_config, embed_func, mock_client
- ):
- """A concurrent upsert that lands while async_bulk is in flight is
- blocked by the namespace lock and lands in the buffer only after
- the flush completes."""
- flush_started = asyncio.Event()
- flush_can_finish = asyncio.Event()
- async def slow_bulk(client, actions, raise_on_error=False, **kwargs):
- flush_started.set()
- await flush_can_finish.wait()
- return (len(actions), [])
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch("lightrag.kg.opensearch_impl.helpers.async_bulk", new=slow_bulk):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"k1": {"content": "first"}})
- flush_task = asyncio.create_task(s.index_done_callback())
- await flush_started.wait()
- concurrent_task = asyncio.create_task(
- s.upsert({"k2": {"content": "concurrent"}})
- )
- for _ in range(5):
- await asyncio.sleep(0)
- assert (
- not concurrent_task.done()
- ), "concurrent upsert should be blocked by the flush lock"
- assert "k2" not in s._pending_upserts
- flush_can_finish.set()
- await flush_task
- await concurrent_task
- # k1 flushed and cleared; k2 added after flush released.
- assert "k1" not in s._pending_upserts
- assert "k2" in s._pending_upserts
- # ---------------------------------------------------------------------------
- # DocStatus Storage
- # ---------------------------------------------------------------------------
- class TestDocStatusStorage:
- """Tests for OpenSearchDocStatusStorage including aggregations, pagination, and data normalization."""
- def _make(self, global_config, embed_func, workspace="test"):
- return OpenSearchDocStatusStorage(
- namespace="doc_status",
- global_config=global_config,
- embedding_func=embed_func,
- workspace=workspace,
- )
- @pytest.mark.asyncio
- async def test_index_name(self, global_config, embed_func):
- s = self._make(global_config, embed_func)
- assert s._index_name == "test_doc_status"
- @pytest.mark.asyncio
- async def test_initialize_creates_index(
- self, global_config, embed_func, mock_client
- ):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- mock_client.indices.create.assert_awaited_once()
- @pytest.mark.asyncio
- async def test_get_by_id(self, global_config, embed_func, mock_client):
- mock_client.mget = AsyncMock(
- return_value={
- "docs": [
- {
- "_id": "doc-abc",
- "found": True,
- "_source": {"status": "processed", "file_path": "/a.txt"},
- }
- ]
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- doc = await s.get_by_id("doc-abc")
- assert doc["status"] == "processed"
- assert doc["_id"] == "doc-abc"
- mock_client.mget.assert_awaited_once_with(
- index=s._index_name, body={"ids": ["doc-abc"]}
- )
- @pytest.mark.asyncio
- async def test_get_by_id_not_found(self, global_config, embed_func, mock_client):
- mock_client.mget = AsyncMock(
- return_value={"docs": [{"_id": "missing", "found": False}]}
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert await s.get_by_id("missing") is None
- mock_client.get.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_upsert_sets_chunks_list_default(
- self, global_config, embed_func, mock_client
- ):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"d1": {"status": "pending"}})
- actions = mock_bulk.call_args[0][1]
- assert actions[0]["_source"]["chunks_list"] == []
- @pytest.mark.asyncio
- async def test_get_status_counts(self, global_config, embed_func, mock_client):
- mock_client.search = AsyncMock(
- return_value={
- "hits": {"hits": [], "total": {"value": 0}},
- "aggregations": {
- "status_counts": {
- "buckets": [
- {"key": "processed", "doc_count": 3},
- {"key": "pending", "doc_count": 1},
- ]
- }
- },
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- counts = await s.get_status_counts()
- assert counts == {"processed": 3, "pending": 1}
- @pytest.mark.asyncio
- async def test_get_all_status_counts_includes_all(
- self, global_config, embed_func, mock_client
- ):
- mock_client.search = AsyncMock(
- return_value={
- "hits": {"hits": [], "total": {"value": 0}},
- "aggregations": {
- "status_counts": {
- "buckets": [
- {"key": "processed", "doc_count": 5},
- {"key": "failed", "doc_count": 2},
- ]
- }
- },
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- counts = await s.get_all_status_counts()
- assert counts["all"] == 7
- assert counts["processed"] == 5
- @pytest.mark.asyncio
- async def test_get_docs_by_status(self, global_config, embed_func, mock_client):
- mock_client.search = AsyncMock(
- return_value={
- "hits": {
- "hits": [
- {
- "_id": "d1",
- "_source": {
- "status": "processed",
- "file_path": "/a.txt",
- "content_summary": "s",
- "content_length": 10,
- "chunks_count": 1,
- "created_at": 100,
- "updated_at": 200,
- },
- "sort": ["d1"],
- },
- ],
- "total": {"value": 1},
- },
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- result = await s.get_docs_by_status(DocStatus.PROCESSED)
- assert "d1" in result
- assert isinstance(result["d1"], DocProcessingStatus)
- @pytest.mark.asyncio
- async def test_get_docs_paginated(self, global_config, embed_func, mock_client):
- """Page 1 returns results directly without search_after."""
- mock_client.count = AsyncMock(return_value={"count": 50})
- mock_client.search = AsyncMock(
- return_value={
- "hits": {
- "hits": [
- {
- "_id": "d1",
- "_source": {
- "status": "processed",
- "file_path": "/a.txt",
- "content_summary": "s",
- "content_length": 10,
- "chunks_count": 1,
- "created_at": 100,
- "updated_at": 200,
- },
- "sort": [200, "d1"],
- },
- ],
- "total": {"value": 50},
- },
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- docs, total = await s.get_docs_paginated(page=1, page_size=10)
- assert total == 50
- assert len(docs) == 1
- assert docs[0][0] == "d1"
- # Page 1: no search_after needed, single search call
- assert mock_client.search.await_count == 1
- body = mock_client.search.call_args.kwargs.get(
- "body"
- ) or mock_client.search.call_args[1].get("body", {})
- assert "search_after" not in body
- @pytest.mark.asyncio
- async def test_get_docs_paginated_page2_uses_search_after(
- self, global_config, embed_func, mock_client
- ):
- """Page 2 skips page 1 results via search_after."""
- mock_client.count = AsyncMock(return_value={"count": 50})
- call_count = {"n": 0}
- async def search_side_effect(*args, **kwargs):
- call_count["n"] += 1
- body = kwargs.get("body", {})
- if "search_after" not in body:
- # First call: skip batch
- return {
- "hits": {
- "hits": [
- {
- "_id": f"skip{i}",
- "_source": {
- "status": "processed",
- "file_path": f"/{i}.txt",
- "content_summary": "s",
- "content_length": 1,
- "chunks_count": 1,
- "created_at": 100,
- "updated_at": 100 + i,
- },
- "sort": [100 + i, f"skip{i}"],
- }
- for i in range(10)
- ],
- "total": {"value": 50},
- }
- }
- else:
- # Second call: actual page
- return {
- "hits": {
- "hits": [
- {
- "_id": "page2_doc",
- "_source": {
- "status": "pending",
- "file_path": "/p2.txt",
- "content_summary": "s",
- "content_length": 1,
- "chunks_count": 1,
- "created_at": 200,
- "updated_at": 300,
- },
- "sort": [300, "page2_doc"],
- }
- ],
- "total": {"value": 50},
- }
- }
- mock_client.search = AsyncMock(side_effect=search_side_effect)
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- docs, total = await s.get_docs_paginated(page=2, page_size=10)
- assert total == 50
- assert len(docs) == 1
- assert docs[0][0] == "page2_doc"
- # 2 search calls: 1 skip + 1 fetch
- assert mock_client.search.await_count == 2
- @pytest.mark.asyncio
- async def test_get_docs_paginated_empty_index(
- self, global_config, embed_func, mock_client
- ):
- """Empty index returns empty list with total 0."""
- mock_client.count = AsyncMock(return_value={"count": 0})
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- docs, total = await s.get_docs_paginated(page=1, page_size=10)
- assert total == 0
- assert docs == []
- mock_client.search.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_get_docs_paginated_page_beyond_total(
- self, global_config, embed_func, mock_client
- ):
- """Requesting a page beyond total docs returns empty list."""
- mock_client.count = AsyncMock(return_value={"count": 5})
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- docs, total = await s.get_docs_paginated(page=100, page_size=10)
- assert total == 5
- assert docs == []
- @pytest.mark.asyncio
- async def test_get_docs_paginated_with_status_filter(
- self, global_config, embed_func, mock_client
- ):
- """Status filter is passed as term query."""
- mock_client.count = AsyncMock(return_value={"count": 3})
- mock_client.search = AsyncMock(
- return_value={
- "hits": {"hits": [], "total": {"value": 3}},
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- docs, total = await s.get_docs_paginated(
- status_filter=DocStatus.PROCESSED, page=1, page_size=10
- )
- assert total == 3
- # Verify count query used the status filter
- count_body = mock_client.count.call_args.kwargs.get("body", {})
- assert count_body["query"] == {"term": {"status": "processed"}}
- @pytest.mark.asyncio
- async def test_get_docs_paginated_with_status_filters(
- self, global_config, embed_func, mock_client
- ):
- """Multi-status filters are passed as terms query and override status_filter."""
- mock_client.count = AsyncMock(return_value={"count": 2})
- mock_client.search = AsyncMock(
- return_value={
- "hits": {"hits": [], "total": {"value": 2}},
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- docs, total = await s.get_docs_paginated(
- status_filter=DocStatus.PROCESSED,
- status_filters=[DocStatus.PARSING, DocStatus.ANALYZING],
- page=1,
- page_size=10,
- )
- assert total == 2
- assert docs == []
- count_body = mock_client.count.call_args.kwargs.get("body", {})
- assert count_body["query"] == {
- "terms": {"status": ["analyzing", "parsing"]}
- }
- @pytest.mark.asyncio
- async def test_get_doc_by_file_path(self, global_config, embed_func, mock_client):
- mock_client.search = AsyncMock(
- return_value={
- "hits": {
- "hits": [
- {
- "_id": "d1",
- "_source": {
- "file_path": "/test.txt",
- "status": "processed",
- },
- },
- ],
- "total": {"value": 1},
- },
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- doc = await s.get_doc_by_file_path("/test.txt")
- assert doc is not None
- assert doc["_id"] == "d1"
- @pytest.mark.asyncio
- async def test_get_doc_by_file_path_not_found(
- self, global_config, embed_func, mock_client
- ):
- mock_client.search = AsyncMock(
- return_value={
- "hits": {"hits": [], "total": {"value": 0}},
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert await s.get_doc_by_file_path("/nope.txt") is None
- @pytest.mark.asyncio
- async def test_get_doc_by_file_basename_returns_tuple_on_hit(
- self, global_config, embed_func, mock_client
- ):
- mock_client.search = AsyncMock(
- return_value={
- "hits": {
- "hits": [
- {
- "_id": "doc-1",
- "_source": {
- "file_path": "report.pdf",
- "status": "processed",
- },
- },
- ],
- "total": {"value": 1},
- },
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- result = await s.get_doc_by_file_basename("report.pdf")
- assert result is not None
- doc_id, doc = result
- assert doc_id == "doc-1"
- assert doc["file_path"] == "report.pdf"
- body = mock_client.search.call_args.kwargs.get(
- "body"
- ) or mock_client.search.call_args[1].get("body", {})
- assert body["query"] == {"term": {"file_path": "report.pdf"}}
- @pytest.mark.asyncio
- async def test_get_doc_by_file_basename_empty_short_circuits(
- self, global_config, embed_func, mock_client
- ):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- mock_client.search.reset_mock()
- assert await s.get_doc_by_file_basename("") is None
- mock_client.search.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_get_doc_by_file_basename_unknown_source_sentinel(
- self, global_config, embed_func, mock_client
- ):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- mock_client.search.reset_mock()
- assert await s.get_doc_by_file_basename("unknown_source") is None
- mock_client.search.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_get_doc_by_file_basename_miss_returns_none(
- self, global_config, embed_func, mock_client
- ):
- mock_client.search = AsyncMock(
- return_value={"hits": {"hits": [], "total": {"value": 0}}}
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert await s.get_doc_by_file_basename("missing.pdf") is None
- @pytest.mark.asyncio
- async def test_get_doc_by_content_hash_returns_tuple_on_hit(
- self, global_config, embed_func, mock_client
- ):
- mock_client.search = AsyncMock(
- return_value={
- "hits": {
- "hits": [
- {
- "_id": "doc-1",
- "_source": {
- "file_path": "report.pdf",
- "content_hash": "abc123",
- "status": "processed",
- },
- },
- ],
- "total": {"value": 1},
- },
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- result = await s.get_doc_by_content_hash("abc123")
- assert result is not None
- doc_id, doc = result
- assert doc_id == "doc-1"
- assert doc["content_hash"] == "abc123"
- body = mock_client.search.call_args.kwargs.get(
- "body"
- ) or mock_client.search.call_args[1].get("body", {})
- assert body["query"] == {"term": {"content_hash": "abc123"}}
- @pytest.mark.asyncio
- async def test_get_doc_by_content_hash_empty_short_circuits(
- self, global_config, embed_func, mock_client
- ):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- mock_client.search.reset_mock()
- assert await s.get_doc_by_content_hash("") is None
- mock_client.search.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_get_doc_by_content_hash_miss_returns_none(
- self, global_config, embed_func, mock_client
- ):
- mock_client.search = AsyncMock(
- return_value={"hits": {"hits": [], "total": {"value": 0}}}
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert await s.get_doc_by_content_hash("zzz999") is None
- @pytest.mark.asyncio
- async def test_ensure_content_hash_mapping_added_when_missing(
- self, global_config, embed_func, mock_client
- ):
- """Pre-existing indices without content_hash mapping should get one added."""
- mock_client.indices.exists = AsyncMock(return_value=True)
- mock_client.indices.get_mapping = AsyncMock(
- return_value={
- "test_doc_status": {
- "mappings": {
- "properties": {
- "__mirrored_id": {"type": "keyword"},
- "status": {"type": "keyword"},
- "file_path": {"type": "keyword"},
- }
- }
- }
- }
- )
- mock_client.indices.put_mapping = AsyncMock()
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- mock_client.indices.put_mapping.assert_awaited_once()
- kwargs = mock_client.indices.put_mapping.call_args.kwargs
- assert kwargs["body"] == {
- "properties": {"content_hash": {"type": "keyword"}}
- }
- @pytest.mark.asyncio
- async def test_ensure_content_hash_mapping_skipped_when_present(
- self, global_config, embed_func, mock_client
- ):
- """Indices that already have content_hash mapping should not be touched."""
- mock_client.indices.exists = AsyncMock(return_value=True)
- mock_client.indices.get_mapping = AsyncMock(
- return_value={
- "test_doc_status": {
- "mappings": {
- "properties": {
- "__mirrored_id": {"type": "keyword"},
- "content_hash": {"type": "keyword"},
- }
- }
- }
- }
- )
- mock_client.indices.put_mapping = AsyncMock()
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- mock_client.indices.put_mapping.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_prepare_doc_status_data(self, global_config, embed_func):
- s = self._make(global_config, embed_func)
- raw = {"_id": "x", "status": "processed", "error": "oops"}
- data = s._prepare_doc_status_data(raw)
- assert "_id" not in data
- assert data["error_msg"] == "oops"
- assert "error" not in data
- assert data["file_path"] == "no-file-path"
- assert data["metadata"] == {}
- @pytest.mark.asyncio
- async def test_drop_error_marks_index_not_ready_and_next_upsert_recreates_index(
- self, global_config, embed_func, mock_client
- ):
- mock_client.indices.delete = AsyncMock(
- side_effect=OpenSearchException("drop failed")
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- with patch.object(
- s, "_create_index_if_not_exists", new_callable=AsyncMock
- ) as mock_create:
- result = await s.drop()
- assert result["status"] == "error"
- assert s._index_ready is False
- await s.upsert({"d1": {"status": "pending"}})
- mock_create.assert_awaited_once()
- @pytest.mark.asyncio
- async def test_upsert_after_drop_recreates_index(
- self, global_config, embed_func, mock_client
- ):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- with patch.object(
- s, "_create_index_if_not_exists", new_callable=AsyncMock
- ) as mock_create:
- await s.drop()
- await s.upsert({"d1": {"status": "pending"}})
- mock_create.assert_awaited_once()
- @pytest.mark.asyncio
- async def test_reads_short_circuit_after_drop(
- self, global_config, embed_func, mock_client
- ):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.drop()
- assert await s.get_all_status_counts() == {}
- assert await s.get_docs_paginated(page=1, page_size=10) == ([], 0)
- assert await s.get_doc_by_file_path("/a.txt") is None
- assert await s.get_docs_by_status(DocStatus.PROCESSED) == {}
- mock_client.count.assert_not_awaited()
- mock_client.search.assert_not_awaited()
- mock_client.create_pit.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_read_missing_index_demotes_readiness(
- self, global_config, embed_func, mock_client
- ):
- mock_client.search = AsyncMock(side_effect=_missing_index_error())
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert await s.get_all_status_counts() == {}
- assert await s.get_all_status_counts() == {}
- assert s._index_ready is False
- assert mock_client.search.await_count == 1
- # ---------------------------------------------------------------------------
- # Graph Storage
- # ---------------------------------------------------------------------------
- class TestGraphStorage:
- """Tests for OpenSearchGraphStorage node/edge CRUD, batch ops, BFS, and label queries."""
- def _make(self, global_config, embed_func, workspace="test"):
- return OpenSearchGraphStorage(
- namespace="chunk_entity_relation",
- global_config=global_config,
- embedding_func=embed_func,
- workspace=workspace,
- )
- @pytest.mark.asyncio
- async def test_index_names(self, global_config, embed_func):
- s = self._make(global_config, embed_func)
- assert s._nodes_index == "test_chunk_entity_relation-nodes"
- assert s._edges_index == "test_chunk_entity_relation-edges"
- @pytest.mark.asyncio
- async def test_initialize_creates_both_indices(
- self, global_config, embed_func, mock_client
- ):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert mock_client.indices.create.await_count == 2
- @pytest.mark.asyncio
- async def test_has_node_true(self, global_config, embed_func, mock_client):
- mock_client.exists = AsyncMock(return_value=True)
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert await s.has_node("Alice") is True
- @pytest.mark.asyncio
- async def test_has_node_false(self, global_config, embed_func, mock_client):
- mock_client.exists = AsyncMock(return_value=False)
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert await s.has_node("Nobody") is False
- @pytest.mark.asyncio
- async def test_has_edge(self, global_config, embed_func, mock_client):
- mock_client.search = AsyncMock(
- return_value={
- "hits": {"hits": [], "total": {"value": 1}},
- "aggregations": {
- "status_counts": {"buckets": []},
- "src": {"buckets": []},
- "tgt": {"buckets": []},
- "source_degrees": {"buckets": []},
- "target_degrees": {"buckets": []},
- },
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert await s.has_edge("A", "B") is True
- @pytest.mark.asyncio
- async def test_node_degree(self, global_config, embed_func, mock_client):
- mock_client.count = AsyncMock(return_value={"count": 3})
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert await s.node_degree("A") == 3
- @pytest.mark.asyncio
- async def test_get_node(self, global_config, embed_func, mock_client):
- mock_client.mget = AsyncMock(
- return_value={
- "docs": [
- {
- "_id": "Alice",
- "found": True,
- "_source": {
- "entity_type": "person",
- "description": "A researcher",
- },
- }
- ]
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- node = await s.get_node("Alice")
- assert node["entity_type"] == "person"
- assert node["_id"] == "Alice"
- mock_client.mget.assert_awaited_once_with(
- index=s._nodes_index, body={"ids": ["Alice"]}
- )
- @pytest.mark.asyncio
- async def test_get_node_not_found(self, global_config, embed_func, mock_client):
- mock_client.mget = AsyncMock(
- return_value={"docs": [{"_id": "Nobody", "found": False}]}
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert await s.get_node("Nobody") is None
- mock_client.get.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_get_edge(self, global_config, embed_func, mock_client):
- # get_edge now uses mget (translog real-time) instead of search.
- mock_client.mget = AsyncMock(
- return_value={
- "docs": [
- {
- "_id": "e1",
- "found": True,
- "_source": {
- "source_node_id": "A",
- "target_node_id": "B",
- "weight": 1.0,
- },
- },
- {
- "_id": "e2",
- "found": False,
- },
- ]
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- edge = await s.get_edge("A", "B")
- assert edge is not None
- assert edge["weight"] == 1.0
- @pytest.mark.asyncio
- async def test_get_node_edges(self, global_config, embed_func, mock_client):
- mock_client.search = AsyncMock(
- return_value={
- "hits": {
- "hits": [
- {
- "_id": "e1",
- "_source": {"source_node_id": "A", "target_node_id": "B"},
- "sort": [1],
- },
- {
- "_id": "e2",
- "_source": {"source_node_id": "C", "target_node_id": "A"},
- "sort": [2],
- },
- ],
- "total": {"value": 2},
- },
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- edges = await s.get_node_edges("A")
- assert len(edges) == 2
- assert ("A", "B") in edges
- @pytest.mark.asyncio
- async def test_get_nodes_batch(self, global_config, embed_func, mock_client):
- mock_client.mget = AsyncMock(
- return_value={
- "docs": [
- {"_id": "A", "found": True, "_source": {"entity_type": "person"}},
- {"_id": "B", "found": False},
- ]
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- result = await s.get_nodes_batch(["A", "B"])
- assert "A" in result
- assert "B" not in result
- @pytest.mark.asyncio
- async def test_node_degrees_batch(self, global_config, embed_func, mock_client):
- mock_client.search = AsyncMock(
- return_value={
- "hits": {"hits": [], "total": {"value": 0}},
- "aggregations": {
- "source_degrees": {"buckets": [{"key": "A", "doc_count": 2}]},
- "target_degrees": {
- "buckets": [
- {"key": "A", "doc_count": 1},
- {"key": "B", "doc_count": 3},
- ]
- },
- "status_counts": {"buckets": []},
- "src": {"buckets": []},
- "tgt": {"buckets": []},
- },
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- degrees = await s.node_degrees_batch(["A", "B"])
- assert degrees["A"] == 3 # 2 + 1
- assert degrees["B"] == 3
- @pytest.mark.asyncio
- async def test_upsert_node(self, global_config, embed_func, mock_client):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert_node(
- "Alice", {"entity_type": "person", "source_id": "c1<SEP>c2"}
- )
- mock_client.index.assert_awaited()
- call_kwargs = mock_client.index.call_args
- assert call_kwargs.kwargs["id"] == "Alice"
- body = call_kwargs.kwargs["body"]
- assert body["source_ids"] == ["c1", "c2"]
- assert body["entity_id"] == "Alice"
- @pytest.mark.asyncio
- async def test_upsert_edge(self, global_config, embed_func, mock_client):
- mock_client.exists = AsyncMock(return_value=False)
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert_edge("A", "B", {"weight": "1.0", "description": "knows"})
- # Should call index twice: once for ensuring source node, once for edge
- assert mock_client.index.await_count == 2
- @pytest.mark.asyncio
- async def test_upsert_edges_batch_reuses_id_for_reciprocal_edges(
- self, global_config, embed_func, mock_client
- ):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- bulk_calls = []
- async def capture_bulk(_client, actions, *args, **kwargs):
- bulk_calls.append(list(actions))
- return (len(bulk_calls[-1]), [])
- mock_client.mget = AsyncMock(
- side_effect=[
- {"docs": []},
- {"docs": [{"_id": "edge-ba", "found": False}] * 2},
- ]
- )
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk",
- new=AsyncMock(side_effect=capture_bulk),
- ):
- await s.upsert_edges_batch(
- [
- ("A", "B", {"weight": "1.0"}),
- ("B", "A", {"weight": "2.0"}),
- ]
- )
- edge_actions = bulk_calls[-1]
- assert len(edge_actions) == 2
- assert edge_actions[0]["_id"] == edge_actions[1]["_id"]
- @pytest.mark.asyncio
- async def test_upsert_after_drop_recreates_indices(
- self, global_config, embed_func, mock_client
- ):
- mock_client.exists = AsyncMock(return_value=False)
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- with patch.object(
- s, "_create_indices_if_not_exist", new_callable=AsyncMock
- ) as mock_create:
- await s.initialize()
- mock_create.reset_mock()
- await s.drop()
- await s.upsert_edge("A", "B", {"weight": "1.0"})
- mock_create.assert_awaited_once()
- assert mock_client.index.await_count == 2
- @pytest.mark.asyncio
- async def test_reads_short_circuit_after_drop(
- self, global_config, embed_func, mock_client
- ):
- mock_client.transport = AsyncMock()
- mock_client.transport.perform_request = AsyncMock(
- side_effect=Exception("PPL not available")
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.drop()
- graph = await s.get_knowledge_graph("A", max_depth=2)
- assert await s.get_node("A") is None
- assert await s.get_all_labels() == []
- assert await s.has_edge("A", "B") is False
- assert await s.node_degree("A") == 0
- assert len(graph.nodes) == 0
- assert len(graph.edges) == 0
- mock_client.mget.assert_not_awaited()
- mock_client.search.assert_not_awaited()
- mock_client.create_pit.assert_not_awaited()
- mock_client.count.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_read_missing_index_demotes_readiness(
- self, global_config, embed_func, mock_client
- ):
- mock_client.transport = AsyncMock()
- mock_client.transport.perform_request = AsyncMock(
- side_effect=Exception("PPL not available")
- )
- mock_client.mget = AsyncMock(side_effect=_missing_index_error())
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert await s.get_node("A") is None
- assert await s.get_node("A") is None
- assert s._indices_ready is False
- assert mock_client.mget.await_count == 1
- @pytest.mark.asyncio
- async def test_delete_node(self, global_config, embed_func, mock_client):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.delete_node("Alice")
- mock_client.delete_by_query.assert_awaited_once()
- mock_client.delete.assert_awaited_once()
- @pytest.mark.asyncio
- async def test_remove_nodes(self, global_config, embed_func, mock_client):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (2, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.remove_nodes(["A", "B"])
- mock_client.delete_by_query.assert_awaited_once()
- mock_bulk.assert_awaited_once()
- @pytest.mark.asyncio
- async def test_remove_edges(self, global_config, embed_func, mock_client):
- # remove_edges now uses bulk delete with deterministic IDs instead of
- # delete_by_query, so mock bulk as AsyncMock.
- mock_client.bulk = AsyncMock(return_value={"errors": False, "items": []})
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.remove_edges([("A", "B"), ("C", "D")])
- # 2 edges × 2 candidate directions = 4 delete actions in one bulk call
- mock_client.bulk.assert_awaited_once()
- call_body = mock_client.bulk.call_args.kwargs["body"]
- assert len(call_body) == 4
- @pytest.mark.asyncio
- async def test_get_all_labels(self, global_config, embed_func, mock_client):
- mock_client.search = AsyncMock(
- return_value={
- "hits": {
- "hits": [
- {"_id": "Alice", "sort": ["Alice"]},
- {"_id": "Bob", "sort": ["Bob"]},
- ],
- "total": {"value": 2},
- },
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- labels = await s.get_all_labels()
- assert labels == ["Alice", "Bob"]
- @pytest.mark.asyncio
- async def test_get_popular_labels(self, global_config, embed_func, mock_client):
- mock_client.search = AsyncMock(
- return_value={
- "hits": {"hits": [], "total": {"value": 0}},
- "aggregations": {
- "src": {
- "buckets": [
- {"key": "A", "doc_count": 5},
- {"key": "B", "doc_count": 2},
- ]
- },
- "tgt": {"buckets": [{"key": "A", "doc_count": 3}]},
- "status_counts": {"buckets": []},
- },
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- labels = await s.get_popular_labels(limit=10)
- assert labels[0] == "A" # degree 8 > B degree 2
- @pytest.mark.asyncio
- async def test_get_knowledge_graph_all_backfills_isolated_nodes_when_truncated(
- self, global_config, embed_func, mock_client
- ):
- mock_client.count = AsyncMock(return_value={"count": 5})
- mock_client.search = AsyncMock(
- side_effect=[
- {
- "hits": {"hits": [], "total": {"value": 1}},
- "aggregations": {
- "src": {"buckets": [{"key": "A", "doc_count": 1}]},
- "tgt": {"buckets": [{"key": "B", "doc_count": 1}]},
- "status_counts": {"buckets": []},
- },
- },
- {
- "hits": {
- "hits": [
- {"_id": "A", "sort": [1]},
- {"_id": "B", "sort": [2]},
- {"_id": "C", "sort": [3]},
- {"_id": "D", "sort": [4]},
- {"_id": "E", "sort": [5]},
- ],
- "total": {"value": 5},
- }
- },
- {
- "hits": {
- "hits": [
- {
- "_id": "edge-ab",
- "_source": {
- "source_node_id": "A",
- "target_node_id": "B",
- "relationship": "knows",
- },
- }
- ],
- "total": {"value": 1},
- }
- },
- ]
- )
- mock_client.mget = AsyncMock(
- return_value={
- "docs": [
- {"_id": "A", "found": True, "_source": {"entity_type": "person"}},
- {"_id": "B", "found": True, "_source": {"entity_type": "person"}},
- {"_id": "C", "found": True, "_source": {"entity_type": "person"}},
- {"_id": "D", "found": True, "_source": {"entity_type": "person"}},
- ]
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- result = await s.get_knowledge_graph("*", max_nodes=4)
- assert result.is_truncated is True
- assert [node.id for node in result.nodes] == ["A", "B", "C", "D"]
- assert len(result.edges) == 1
- assert result.edges[0].source == "A"
- assert result.edges[0].target == "B"
- assert mock_client.create_pit.await_count == 2
- @pytest.mark.asyncio
- async def test_get_knowledge_graph_all_paginates_edges_between_selected_nodes(
- self, global_config, embed_func, mock_client
- ):
- mock_client.count = AsyncMock(return_value={"count": 2})
- first_edge_page = [
- {
- "_id": f"edge-{i}",
- "_source": {
- "source_node_id": "A",
- "target_node_id": "B",
- "relationship": "knows",
- },
- "sort": [i],
- }
- for i in range(10000)
- ]
- mock_client.search = AsyncMock(
- side_effect=[
- {
- "hits": {
- "hits": [
- {"_id": "A"},
- {"_id": "B"},
- ],
- "total": {"value": 2},
- }
- },
- {"hits": {"hits": first_edge_page, "total": {"value": 10001}}},
- {
- "hits": {
- "hits": [
- {
- "_id": "edge-last",
- "_source": {
- "source_node_id": "B",
- "target_node_id": "A",
- "relationship": "knows",
- },
- "sort": [10000],
- }
- ],
- "total": {"value": 10001},
- }
- },
- ]
- )
- mock_client.mget = AsyncMock(
- return_value={
- "docs": [
- {"_id": "A", "found": True, "_source": {"entity_type": "person"}},
- {"_id": "B", "found": True, "_source": {"entity_type": "person"}},
- ]
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- result = await s.get_knowledge_graph("*", max_nodes=2)
- assert len(result.nodes) == 2
- assert len(result.edges) == 2
- assert {(edge.source, edge.target) for edge in result.edges} == {
- ("A", "B"),
- ("B", "A"),
- }
- assert mock_client.search.await_count == 3
- @pytest.mark.asyncio
- async def test_search_labels_empty_query(
- self, global_config, embed_func, mock_client
- ):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert await s.search_labels("") == []
- @pytest.mark.asyncio
- async def test_drop(self, global_config, embed_func, mock_client):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- result = await s.drop()
- assert result["status"] == "success"
- assert mock_client.indices.delete.await_count == 2
- @pytest.mark.asyncio
- async def test_drop_partial_error_marks_indices_not_ready_and_next_upsert_recreates_indices(
- self, global_config, embed_func, mock_client
- ):
- mock_client.exists = AsyncMock(return_value=False)
- mock_client.indices.delete = AsyncMock(
- side_effect=[None, OpenSearchException("edges drop failed")]
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- with patch.object(
- s, "_create_indices_if_not_exist", new_callable=AsyncMock
- ) as mock_create:
- result = await s.drop()
- assert result["status"] == "error"
- assert "edges drop failed" in result["message"]
- assert s._indices_ready is False
- await s.upsert_edge("A", "B", {"weight": "1.0"})
- mock_create.assert_awaited_once()
- @pytest.mark.asyncio
- async def test_drop_treats_missing_graph_indices_as_success(
- self, global_config, embed_func, mock_client
- ):
- mock_client.indices.delete = AsyncMock(
- side_effect=[_missing_index_error(), _missing_index_error()]
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- result = await s.drop()
- assert result["status"] == "success"
- assert s._indices_ready is False
- @pytest.mark.asyncio
- async def test_construct_graph_node(self, global_config, embed_func):
- s = self._make(global_config, embed_func)
- node = s._construct_graph_node(
- "Alice",
- {
- "entity_type": "person",
- "description": "A researcher",
- "_id": "Alice",
- "entity_id": "Alice",
- },
- )
- assert node.id == "Alice"
- assert "entity_type" in node.properties
- assert "_id" not in node.properties
- assert "entity_id" not in node.properties
- @pytest.mark.asyncio
- async def test_construct_graph_edge(self, global_config, embed_func):
- s = self._make(global_config, embed_func)
- edge = s._construct_graph_edge(
- "e1",
- {
- "source_node_id": "A",
- "target_node_id": "B",
- "relationship": "knows",
- "weight": 1.0,
- },
- )
- assert edge.source == "A"
- assert edge.target == "B"
- assert edge.type == "knows"
- assert "source_node_id" not in edge.properties
- @pytest.mark.asyncio
- async def test_bfs_subgraph_start_not_found(
- self, global_config, embed_func, mock_client
- ):
- mock_client.mget = AsyncMock(
- return_value={"docs": [{"_id": "NonExistent", "found": False}]}
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- result = await s.get_knowledge_graph("NonExistent", max_depth=2)
- assert len(result.nodes) == 0
- assert len(result.edges) == 0
- class TestGraphPPLDetection:
- """Tests for PPL graphlookup detection and server-side BFS."""
- def _make(self, global_config, embed_func, workspace="test"):
- return OpenSearchGraphStorage(
- namespace="chunk_entity_relation",
- global_config=global_config,
- embedding_func=embed_func,
- workspace=workspace,
- )
- @pytest.mark.asyncio
- async def test_ppl_detected_when_available(
- self, global_config, embed_func, mock_client
- ):
- """When PPL endpoint responds successfully, graphlookup should be detected."""
- mock_client.transport = AsyncMock()
- mock_client.transport.perform_request = AsyncMock(
- return_value={"datarows": [], "schema": []}
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert s._ppl_graphlookup_available is True
- @pytest.mark.asyncio
- async def test_ppl_not_detected_when_endpoint_fails(
- self, global_config, embed_func, mock_client
- ):
- """When PPL endpoint fails, should fall back to client-side BFS."""
- mock_client.transport = AsyncMock()
- mock_client.transport.perform_request = AsyncMock(
- side_effect=Exception("PPL not supported")
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert s._ppl_graphlookup_available is False
- @pytest.mark.asyncio
- async def test_env_override_true(self, global_config, embed_func, mock_client):
- with patch.dict("os.environ", {"OPENSEARCH_USE_PPL_GRAPHLOOKUP": "true"}):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert s._ppl_graphlookup_available is True
- # Should NOT have called transport.perform_request for detection
- mock_client.transport.perform_request.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_env_override_false(self, global_config, embed_func, mock_client):
- mock_client.transport = AsyncMock()
- mock_client.transport.perform_request = AsyncMock(
- return_value={"datarows": [], "schema": []}
- )
- with patch.dict("os.environ", {"OPENSEARCH_USE_PPL_GRAPHLOOKUP": "false"}):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert s._ppl_graphlookup_available is False
- @pytest.mark.asyncio
- async def test_ppl_bfs_calls_ppl_endpoint(
- self, global_config, embed_func, mock_client
- ):
- """When PPL is available, get_knowledge_graph should use PPL endpoint."""
- mock_client.transport = AsyncMock()
- # PPL response: connected_edges contains dicts with source_node_id/target_node_id
- ppl_response = {
- "schema": [
- {"name": "entity_id", "type": "string"},
- {"name": "connected_edges", "type": "struct"},
- ],
- "datarows": [
- [
- "A",
- [ # connected_edges array
- {
- "source_node_id": "A",
- "target_node_id": "B",
- "weight": 1.0,
- "_depth": 0,
- },
- {
- "source_node_id": "B",
- "target_node_id": "C",
- "weight": 0.5,
- "_depth": 1,
- },
- ],
- ]
- ],
- }
- mock_client.transport.perform_request = AsyncMock(return_value=ppl_response)
- # get_node for start node verification
- mock_client.get = AsyncMock(
- return_value={
- "_id": "A",
- "_source": {"entity_type": "person", "description": "Node A"},
- }
- )
- # mget for batch node fetch (only B and C, A is already added)
- mock_client.mget = AsyncMock(
- return_value={
- "docs": [
- {"_id": "B", "found": True, "_source": {"entity_type": "person"}},
- {"_id": "C", "found": True, "_source": {"entity_type": "person"}},
- ]
- }
- )
- # search for final edge fetch
- mock_client.search = AsyncMock(
- return_value={
- "hits": {
- "hits": [
- {
- "_id": "e1",
- "_source": {
- "source_node_id": "A",
- "target_node_id": "B",
- "relationship": "knows",
- },
- },
- {
- "_id": "e2",
- "_source": {
- "source_node_id": "B",
- "target_node_id": "C",
- "relationship": "knows",
- },
- },
- ],
- "total": {"value": 2},
- },
- "aggregations": {
- "status_counts": {"buckets": []},
- "src": {"buckets": []},
- "tgt": {"buckets": []},
- },
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert s._ppl_graphlookup_available is True
- result = await s.get_knowledge_graph("A", max_depth=2)
- assert len(result.nodes) == 3
- assert len(result.edges) == 2
- # Verify PPL was called (2 for detection + 1 for actual query)
- assert mock_client.transport.perform_request.await_count == 3
- # Verify the PPL query uses nodes index as source
- actual_query = mock_client.transport.perform_request.call_args_list[2]
- ppl_body = actual_query.kwargs.get("body") or actual_query[1].get(
- "body", {}
- )
- if isinstance(ppl_body, dict):
- assert s._nodes_index in ppl_body.get("query", "")
- @pytest.mark.asyncio
- async def test_ppl_bfs_falls_back_on_query_failure(
- self, global_config, embed_func, mock_client
- ):
- """If PPL query fails at runtime, should fall back to client-side BFS."""
- call_count = {"n": 0}
- async def ppl_side_effect(*args, **kwargs):
- call_count["n"] += 1
- if call_count["n"] <= 2:
- # Detection calls succeed
- return {"datarows": [], "schema": []}
- # Actual query fails
- raise Exception("PPL query timeout")
- mock_client.transport = AsyncMock()
- mock_client.transport.perform_request = AsyncMock(side_effect=ppl_side_effect)
- mock_client.mget = AsyncMock(
- return_value={"docs": [{"_id": "A", "found": False}]}
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert s._ppl_graphlookup_available is True
- # Should fall back to _bfs_subgraph, which returns empty (node not found)
- result = await s.get_knowledge_graph("A", max_depth=2)
- assert len(result.nodes) == 0
- @pytest.mark.asyncio
- async def test_escape_ppl(self, global_config, embed_func):
- s = self._make(global_config, embed_func)
- assert s._escape_ppl("it's") == "it\\'s"
- assert s._escape_ppl("normal") == "normal"
- assert s._escape_ppl("back\\slash") == "back\\\\slash"
- assert s._escape_ppl("both\\and'quote") == "both\\\\and\\'quote"
- @pytest.mark.asyncio
- async def test_ppl_bfs_depth_zero_returns_start_only(
- self, global_config, embed_func, mock_client
- ):
- """max_depth=0 should return only the start node without PPL query."""
- mock_client.transport = AsyncMock()
- mock_client.transport.perform_request = AsyncMock(
- return_value={"datarows": [], "schema": []}
- )
- mock_client.get = AsyncMock(
- return_value={"_id": "A", "_source": {"entity_type": "person"}}
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert s._ppl_graphlookup_available is True
- result = await s.get_knowledge_graph("A", max_depth=0)
- assert len(result.nodes) == 1
- assert result.nodes[0].id == "A"
- assert len(result.edges) == 0
- # PPL query should NOT have been called for the actual traversal (only 2 detection calls)
- assert mock_client.transport.perform_request.await_count == 2
- @pytest.mark.asyncio
- async def test_ppl_bfs_empty_connected_edges(
- self, global_config, embed_func, mock_client
- ):
- """PPL returns no connected edges — should return only start node."""
- mock_client.transport = AsyncMock()
- ppl_response = {
- "schema": [
- {"name": "entity_id", "type": "string"},
- {"name": "connected_edges", "type": "struct"},
- ],
- "datarows": [["A", []]],
- }
- mock_client.transport.perform_request = AsyncMock(return_value=ppl_response)
- mock_client.get = AsyncMock(
- return_value={"_id": "A", "_source": {"entity_type": "person"}}
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- result = await s.get_knowledge_graph("A", max_depth=2)
- assert len(result.nodes) == 1
- assert result.nodes[0].id == "A"
- @pytest.mark.asyncio
- async def test_ppl_bfs_truncates_nodes_by_depth_then_weight(
- self, global_config, embed_func, mock_client
- ):
- mock_client.transport = AsyncMock()
- ppl_response = {
- "schema": [
- {"name": "entity_id", "type": "string"},
- {"name": "connected_edges", "type": "struct"},
- ],
- "datarows": [
- [
- "A",
- [
- {
- "source_node_id": "A",
- "target_node_id": "C",
- "weight": 1.0,
- "_depth": 1,
- },
- {
- "source_node_id": "B",
- "target_node_id": "D",
- "weight": 10.0,
- "_depth": 1,
- },
- {
- "source_node_id": "A",
- "target_node_id": "B",
- "weight": 1.0,
- "_depth": 0,
- },
- ],
- ]
- ],
- }
- mock_client.transport.perform_request = AsyncMock(return_value=ppl_response)
- mock_client.mget = AsyncMock(
- side_effect=[
- {
- "docs": [
- {
- "_id": "A",
- "found": True,
- "_source": {"entity_type": "person"},
- }
- ]
- },
- {
- "docs": [
- {
- "_id": "B",
- "found": True,
- "_source": {"entity_type": "person"},
- },
- {
- "_id": "D",
- "found": True,
- "_source": {"entity_type": "person"},
- },
- ]
- },
- ]
- )
- mock_client.search = AsyncMock(
- return_value={
- "hits": {
- "hits": [
- {
- "_id": "e1",
- "_source": {
- "source_node_id": "A",
- "target_node_id": "B",
- "relationship": "knows",
- },
- "sort": [1],
- },
- {
- "_id": "e2",
- "_source": {
- "source_node_id": "B",
- "target_node_id": "D",
- "relationship": "knows",
- },
- "sort": [2],
- },
- ],
- "total": {"value": 2},
- }
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- result = await s.get_knowledge_graph("A", max_depth=2, max_nodes=3)
- assert [node.id for node in result.nodes] == ["A", "B", "D"]
- assert result.is_truncated is True
- assert {(edge.source, edge.target) for edge in result.edges} == {
- ("A", "B"),
- ("B", "D"),
- }
- @pytest.mark.asyncio
- async def test_upsert_node_adds_entity_id(
- self, global_config, embed_func, mock_client
- ):
- """upsert_node should always include entity_id field for PPL compatibility."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert_node("TestNode", {"description": "test"})
- body = mock_client.index.call_args.kwargs["body"]
- assert body["entity_id"] == "TestNode"
- assert body["description"] == "test"
- @pytest.mark.asyncio
- async def test_node_degree_uses_count_api(
- self, global_config, embed_func, mock_client
- ):
- """node_degree should use the count API, not search."""
- mock_client.count = AsyncMock(return_value={"count": 7})
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- degree = await s.node_degree("X")
- assert degree == 7
- # Verify count was called on the edges index
- mock_client.count.assert_awaited()
- call_kwargs = mock_client.count.call_args
- assert s._edges_index in str(call_kwargs)
- # ---------------------------------------------------------------------------
- # Vector Storage
- # ---------------------------------------------------------------------------
- class TestVectorStorage:
- """Tests for OpenSearchVectorDBStorage k-NN index, embeddings, cosine conversion, and entity deletion."""
- def _make(self, global_config, embed_func, workspace="test"):
- return OpenSearchVectorDBStorage(
- namespace="entities",
- global_config=global_config,
- embedding_func=embed_func,
- workspace=workspace,
- meta_fields={"content", "entity_name", "src_id", "tgt_id"},
- )
- @pytest.mark.asyncio
- async def test_index_name(self, global_config, embed_func):
- s = self._make(global_config, embed_func)
- assert s._index_name == "test_entities"
- @pytest.mark.asyncio
- async def test_cosine_threshold_required(self, embed_func):
- with pytest.raises(ValueError, match="cosine_better_than_threshold"):
- OpenSearchVectorDBStorage(
- namespace="v",
- global_config={
- "embedding_batch_num": 10,
- "vector_db_storage_cls_kwargs": {},
- },
- embedding_func=embed_func,
- )
- @pytest.mark.asyncio
- async def test_initialize_creates_knn_index(
- self, global_config, embed_func, mock_client
- ):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- mock_client.indices.create.assert_awaited_once()
- body = mock_client.indices.create.call_args.kwargs["body"]
- assert body["settings"]["index"]["knn"] is True
- assert body["mappings"]["properties"]["vector"]["dimension"] == 128
- assert (
- body["mappings"]["properties"]["vector"]["method"]["engine"] == "lucene"
- )
- @pytest.mark.asyncio
- async def test_upsert_generates_embeddings(
- self, global_config, embed_func, mock_client
- ):
- """Embeddings are deferred until flush; upsert only buffers payloads."""
- embed_func = CountingEmbeddingFunc()
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (2, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert(
- {
- "v1": {"content": "hello"},
- "v2": {"content": "world"},
- }
- )
- # Upsert buffers; no bulk write yet.
- mock_bulk.assert_not_awaited()
- assert embed_func.call_count == 0
- assert set(s._pending_vector_docs.keys()) == {"v1", "v2"}
- assert s._pending_vector_docs["v1"].vector is None
- # Flush embeds and triggers a single bulk call with both docs.
- await s.index_done_callback()
- assert embed_func.call_count == 1
- mock_bulk.assert_awaited_once()
- actions = mock_bulk.call_args[0][1]
- assert len(actions) == 2
- assert all(a["_op_type"] == "index" for a in actions)
- assert all("vector" in a["_source"] for a in actions)
- @pytest.mark.asyncio
- async def test_query_cosine_score_conversion(
- self, global_config, embed_func, mock_client
- ):
- """Test that scores are used directly and threshold filtering works."""
- mock_client.search = AsyncMock(
- return_value={
- "hits": {
- "hits": [
- {
- "_id": "v1",
- "_score": 0.85,
- "_source": {"content": "match", "entity_name": "E1"},
- },
- ],
- "total": {"value": 1},
- },
- "aggregations": {
- "status_counts": {"buckets": []},
- "src": {"buckets": []},
- "tgt": {"buckets": []},
- },
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- results = await s.query("test", top_k=5)
- assert len(results) == 1
- assert results[0]["distance"] == 0.85
- @pytest.mark.asyncio
- async def test_query_filters_below_threshold(
- self, global_config, embed_func, mock_client
- ):
- """Low scores should be filtered out."""
- # score 0.15 < threshold 0.2
- mock_client.search = AsyncMock(
- return_value={
- "hits": {
- "hits": [
- {
- "_id": "v1",
- "_score": 0.15,
- "_source": {"content": "weak match"},
- },
- ],
- "total": {"value": 1},
- },
- "aggregations": {
- "status_counts": {"buckets": []},
- "src": {"buckets": []},
- "tgt": {"buckets": []},
- },
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- results = await s.query("test", top_k=5)
- assert len(results) == 0
- @pytest.mark.asyncio
- async def test_query_with_provided_embedding(
- self, global_config, embed_func, mock_client
- ):
- mock_client.search = AsyncMock(
- return_value={
- "hits": {
- "hits": [
- {"_id": "v1", "_score": 1.0, "_source": {"content": "exact"}},
- ],
- "total": {"value": 1},
- },
- "aggregations": {
- "status_counts": {"buckets": []},
- "src": {"buckets": []},
- "tgt": {"buckets": []},
- },
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- vec = np.random.rand(128).astype(np.float32)
- results = await s.query("test", top_k=5, query_embedding=vec)
- assert len(results) == 1
- assert results[0]["distance"] == 1.0
- @pytest.mark.asyncio
- async def test_get_by_id(self, global_config, embed_func, mock_client):
- mock_client.mget = AsyncMock(
- return_value={
- "docs": [
- {
- "_id": "v1",
- "found": True,
- "_source": {"content": "hello", "vector": [0.1] * 128},
- }
- ]
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- doc = await s.get_by_id("v1")
- assert doc["id"] == "v1"
- assert doc["content"] == "hello"
- # vector field is stripped on the mget path to match NanoVectorDB
- assert "vector" not in doc
- mock_client.mget.assert_awaited_once_with(
- index=s._index_name,
- body={"ids": ["v1"]},
- _source_excludes=["vector"],
- )
- @pytest.mark.asyncio
- async def test_get_by_id_not_found(self, global_config, embed_func, mock_client):
- mock_client.mget = AsyncMock(
- return_value={"docs": [{"_id": "missing", "found": False}]}
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert await s.get_by_id("missing") is None
- mock_client.get.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_get_by_ids(self, global_config, embed_func, mock_client):
- mock_client.mget = AsyncMock(
- return_value={
- "docs": [
- {"_id": "v1", "found": True, "_source": {"content": "a"}},
- {"_id": "v2", "found": True, "_source": {"content": "b"}},
- ]
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- docs = await s.get_by_ids(["v1", "v2"])
- assert docs[0]["id"] == "v1"
- assert docs[1]["id"] == "v2"
- @pytest.mark.asyncio
- async def test_get_vectors_by_ids(self, global_config, embed_func, mock_client):
- vec = [0.1] * 128
- mock_client.mget = AsyncMock(
- return_value={
- "docs": [
- {"_id": "v1", "found": True, "_source": {"vector": vec}},
- {"_id": "v2", "found": False},
- ]
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- result = await s.get_vectors_by_ids(["v1", "v2"])
- assert "v1" in result
- assert "v2" not in result
- assert result["v1"] == vec
- @pytest.mark.asyncio
- async def test_delete(self, global_config, embed_func, mock_client):
- """delete() buffers ids; the actual bulk delete fires on flush."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (2, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.delete(["v1", "v2"])
- mock_bulk.assert_not_awaited()
- assert s._pending_vector_deletes == {"v1", "v2"}
- await s.index_done_callback()
- mock_bulk.assert_awaited_once()
- actions = mock_bulk.call_args[0][1]
- assert len(actions) == 2
- assert all(a["_op_type"] == "delete" for a in actions)
- @pytest.mark.asyncio
- async def test_delete_entity(self, global_config, embed_func, mock_client):
- """delete_entity buffers a tombstone for the computed mdhash id."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.delete_entity("Alice")
- # No direct client.delete call -- delete is buffered for batched flush.
- mock_client.delete.assert_not_awaited()
- assert len(s._pending_vector_deletes) == 1
- @pytest.mark.asyncio
- async def test_delete_entity_relation(self, global_config, embed_func, mock_client):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.delete_entity_relation("Alice")
- mock_client.delete_by_query.assert_awaited_once()
- @pytest.mark.asyncio
- async def test_drop_recreates_index(self, global_config, embed_func, mock_client):
- # After drop, _create_knn_index_if_not_exists is called again.
- # First call (init): exists=False -> create. Second call (after drop): exists=False -> create again.
- mock_client.indices.exists = AsyncMock(return_value=False)
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- result = await s.drop()
- assert result["status"] == "success"
- mock_client.indices.delete.assert_awaited_once()
- # create called twice: once during init, once during drop recreate
- assert mock_client.indices.create.await_count == 2
- @pytest.mark.asyncio
- async def test_drop_delete_error_marks_index_not_ready(
- self, global_config, embed_func, mock_client
- ):
- mock_client.indices.delete = AsyncMock(
- side_effect=OpenSearchException("delete failed")
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- result = await s.drop()
- assert result["status"] == "error"
- assert s._index_ready is False
- @pytest.mark.asyncio
- async def test_drop_recreate_error_marks_index_not_ready(
- self, global_config, embed_func, mock_client
- ):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- with patch.object(
- s,
- "_create_knn_index_if_not_exists",
- new=AsyncMock(side_effect=OpenSearchException("recreate failed")),
- ):
- result = await s.drop()
- assert result["status"] == "error"
- assert s._index_ready is False
- @pytest.mark.asyncio
- async def test_drop_recreates_index_when_missing(
- self, global_config, embed_func, mock_client
- ):
- mock_client.indices.exists = AsyncMock(return_value=False)
- mock_client.indices.delete = AsyncMock(
- side_effect=NotFoundError(404, "not found")
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- result = await s.drop()
- assert result["status"] == "success"
- assert mock_client.indices.create.await_count == 2
- @pytest.mark.asyncio
- async def test_reads_short_circuit_when_index_not_ready(
- self, global_config, embed_func, mock_client
- ):
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- s._index_ready = False
- assert await s.query("test", top_k=5) == []
- assert await s.get_by_id("v1") is None
- assert await s.get_vectors_by_ids(["v1"]) == {}
- mock_client.search.assert_not_awaited()
- mock_client.mget.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_read_missing_index_demotes_readiness(
- self, global_config, embed_func, mock_client
- ):
- mock_client.search = AsyncMock(side_effect=_missing_index_error())
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- assert await s.query("test", top_k=5) == []
- assert await s.query("test", top_k=5) == []
- assert s._index_ready is False
- assert mock_client.search.await_count == 1
- # ---------------------------------------------------------------------------
- # Vector storage write batching (issue #2785)
- # ---------------------------------------------------------------------------
- class TestVectorStorageBatching:
- """Tests for the buffered upsert/delete + flush behaviour added for #2785."""
- def _make(self, global_config, embed_func, workspace="test"):
- return OpenSearchVectorDBStorage(
- namespace="entities",
- global_config=global_config,
- embedding_func=embed_func,
- workspace=workspace,
- meta_fields={"content", "entity_name", "src_id", "tgt_id"},
- )
- @pytest.mark.asyncio
- async def test_repeated_upserts_flush_in_single_bulk_call(
- self, global_config, embed_func, mock_client
- ):
- """Many small upsert() calls collapse to one async_bulk on flush."""
- embed_func = CountingEmbeddingFunc()
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (5, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- for i in range(5):
- await s.upsert({f"v{i}": {"content": f"doc {i}"}})
- mock_bulk.assert_not_awaited()
- assert embed_func.call_count == 0
- await s.index_done_callback()
- assert embed_func.call_count == 1
- assert embed_func.batches == [[f"doc {i}" for i in range(5)]]
- mock_bulk.assert_awaited_once()
- actions = mock_bulk.call_args[0][1]
- assert len(actions) == 5
- assert {a["_id"] for a in actions} == {f"v{i}" for i in range(5)}
- @pytest.mark.asyncio
- async def test_deferred_embeddings_respect_batch_size(
- self, global_config, embed_func, mock_client
- ):
- """Flush batches deferred embeddings by embedding_batch_num."""
- embed_func = CountingEmbeddingFunc()
- config = {**global_config, "embedding_batch_num": 2}
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (5, [])
- s = self._make(config, embed_func)
- await s.initialize()
- for i in range(5):
- await s.upsert({f"v{i}": {"content": f"doc {i}"}})
- await s.index_done_callback()
- assert embed_func.batches == [
- ["doc 0", "doc 1"],
- ["doc 2", "doc 3"],
- ["doc 4"],
- ]
- mock_bulk.assert_awaited_once()
- @pytest.mark.asyncio
- async def test_upsert_overwrites_pending_doc_for_same_id(
- self, global_config, embed_func, mock_client
- ):
- """Upserting the same id twice keeps only the latest payload."""
- embed_func = CountingEmbeddingFunc()
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"v1": {"content": "first"}})
- await s.upsert({"v1": {"content": "second"}})
- await s.index_done_callback()
- assert embed_func.call_count == 1
- assert embed_func.texts == ["second"]
- actions = mock_bulk.call_args[0][1]
- assert len(actions) == 1
- assert actions[0]["_source"]["content"] == "second"
- @pytest.mark.asyncio
- async def test_delete_cancels_pending_upsert(
- self, global_config, embed_func, mock_client
- ):
- """A delete after a buffered upsert removes the upsert from the buffer."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"v1": {"content": "doomed"}})
- await s.delete(["v1"])
- assert "v1" not in s._pending_vector_docs
- assert "v1" in s._pending_vector_deletes
- await s.index_done_callback()
- actions = mock_bulk.call_args[0][1]
- assert len(actions) == 1
- assert actions[0]["_op_type"] == "delete"
- @pytest.mark.asyncio
- async def test_upsert_cancels_pending_delete(
- self, global_config, embed_func, mock_client
- ):
- """An upsert after a buffered delete removes the tombstone."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.delete(["v1"])
- await s.upsert({"v1": {"content": "resurrected"}})
- assert "v1" not in s._pending_vector_deletes
- assert "v1" in s._pending_vector_docs
- await s.index_done_callback()
- actions = mock_bulk.call_args[0][1]
- assert len(actions) == 1
- assert actions[0]["_op_type"] == "index"
- @pytest.mark.asyncio
- async def test_get_by_id_reads_pending_buffer(
- self, global_config, embed_func, mock_client
- ):
- """Buffered upserts are visible to get_by_id without hitting OpenSearch."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"v1": {"content": "buffered"}})
- doc = await s.get_by_id("v1")
- assert doc is not None
- assert doc["id"] == "v1"
- assert doc["content"] == "buffered"
- # Vector field is hidden from get_by_id results, mirroring the
- # _source excludes used by query().
- assert "vector" not in doc
- mock_client.mget.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_get_by_id_returns_none_for_pending_delete(
- self, global_config, embed_func, mock_client
- ):
- """A pending tombstone shadows any persisted doc."""
- mock_client.mget = AsyncMock() # would be wrong to invoke
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.delete(["v1"])
- assert await s.get_by_id("v1") is None
- mock_client.mget.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_get_by_ids_merges_buffer_and_index(
- self, global_config, embed_func, mock_client
- ):
- """get_by_ids returns buffered docs and falls back to mget for the rest."""
- mock_client.mget = AsyncMock(
- return_value={
- "docs": [
- {"_id": "v2", "found": True, "_source": {"content": "from_index"}},
- ]
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"v1": {"content": "buffered"}})
- docs = await s.get_by_ids(["v1", "v2"])
- assert docs[0]["content"] == "buffered"
- assert docs[1]["content"] == "from_index"
- # Only the unbuffered id is requested from OpenSearch,
- # and vector is excluded server-side.
- mock_client.mget.assert_awaited_once_with(
- index=s._index_name,
- body={"ids": ["v2"]},
- _source_excludes=["vector"],
- )
- @pytest.mark.asyncio
- async def test_get_vectors_by_ids_uses_buffer(
- self, global_config, embed_func, mock_client
- ):
- """get_vectors_by_ids returns buffered embeddings without an mget roundtrip."""
- embed_func = CountingEmbeddingFunc()
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"v1": {"content": "x"}})
- assert embed_func.call_count == 0
- vecs = await s.get_vectors_by_ids(["v1"])
- assert "v1" in vecs
- assert len(vecs["v1"]) == 128
- assert embed_func.call_count == 1
- assert s._pending_vector_docs["v1"].vector == vecs["v1"]
- mock_client.mget.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_lazy_get_vectors_cache_is_reused_by_flush(
- self, global_config, embed_func, mock_client
- ):
- """A lazy pending-vector read should not force a second embedding during flush."""
- embed_func = CountingEmbeddingFunc()
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"v1": {"content": "x"}})
- vecs = await s.get_vectors_by_ids(["v1"])
- await s.index_done_callback()
- assert embed_func.call_count == 1
- actions = mock_bulk.call_args[0][1]
- assert actions[0]["_source"]["vector"] == vecs["v1"]
- @pytest.mark.asyncio
- async def test_finalize_flushes_pending_ops(
- self, global_config, embed_func, mock_client
- ):
- """finalize() flushes buffered writes before releasing the client."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"v1": {"content": "to flush"}})
- await s.finalize()
- mock_bulk.assert_awaited_once()
- assert s.client is None
- @pytest.mark.asyncio
- async def test_vector_finalize_raises_when_retryable_buffer_remains(
- self, global_config, embed_func, mock_client
- ):
- """finalize() must surface a RuntimeError when retryable bulk
- failures left vector rows buffered, otherwise the upstream
- finalize_storages() call would log the storage as successfully
- finalized while writes are silently lost.
- The client is still released regardless to avoid connection leak.
- """
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch.object(
- ClientManager, "release_client", new_callable=AsyncMock
- ) as mock_release:
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk",
- new_callable=AsyncMock,
- ) as mock_bulk:
- mock_bulk.return_value = (
- 0,
- [{"index": {"_id": "v1", "status": 503, "error": "down"}}],
- )
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"v1": {"content": "stuck"}})
- with pytest.raises(RuntimeError, match="pending upserts"):
- await s.finalize()
- mock_release.assert_awaited_once()
- assert s.client is None
- @pytest.mark.asyncio
- async def test_vector_finalize_propagates_flush_exception(
- self, global_config, embed_func, mock_client
- ):
- """If async_bulk raises during the final flush, finalize() still
- releases the client and wraps the original error in a RuntimeError
- that names the unflushed buffer counts.
- """
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch.object(
- ClientManager, "release_client", new_callable=AsyncMock
- ) as mock_release:
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk",
- new_callable=AsyncMock,
- ) as mock_bulk:
- mock_bulk.side_effect = OpenSearchException("connection reset")
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"v1": {"content": "stuck"}})
- with pytest.raises(RuntimeError) as exc_info:
- await s.finalize()
- assert isinstance(exc_info.value.__cause__, OpenSearchException)
- mock_release.assert_awaited_once()
- assert s.client is None
- @pytest.mark.asyncio
- async def test_vector_finalize_propagates_cancellation(
- self, global_config, embed_func, mock_client
- ):
- """asyncio.CancelledError raised during the final flush must
- propagate UN-wrapped so the shutdown sequence honours the
- cancellation signal. The client is still released (finally
- block) before the cancellation continues.
- """
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch.object(
- ClientManager, "release_client", new_callable=AsyncMock
- ) as mock_release:
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk",
- new_callable=AsyncMock,
- ) as mock_bulk:
- mock_bulk.side_effect = asyncio.CancelledError()
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"v1": {"content": "stuck"}})
- with pytest.raises(asyncio.CancelledError):
- await s.finalize()
- mock_release.assert_awaited_once()
- assert s.client is None
- @pytest.mark.asyncio
- async def test_drop_discards_pending_buffers(
- self, global_config, embed_func, mock_client
- ):
- """drop() throws away pending writes; nothing is flushed to a deleted index."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"v1": {"content": "doomed"}})
- await s.delete(["v2"])
- await s.drop()
- assert s._pending_vector_docs == {}
- assert s._pending_vector_deletes == set()
- mock_bulk.assert_not_awaited()
- @pytest.mark.asyncio
- async def test_failed_flush_entries_retained_for_retry(
- self, global_config, embed_func, mock_client
- ):
- """Transient (5xx) per-doc failures stay buffered for the next flush."""
- embed_func = CountingEmbeddingFunc()
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- # First flush: v1 succeeds, v2 fails with 503 (retryable).
- mock_bulk.side_effect = [
- (
- 1,
- [{"index": {"_id": "v2", "status": 503, "error": "down"}}],
- ),
- (1, []),
- ]
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert(
- {
- "v1": {"content": "ok"},
- "v2": {"content": "boom"},
- }
- )
- await s.index_done_callback()
- # v1 cleared, v2 retained for retry.
- assert "v1" not in s._pending_vector_docs
- assert "v2" in s._pending_vector_docs
- assert s._pending_vector_docs["v2"].vector is not None
- assert embed_func.call_count == 1
- await s.index_done_callback()
- assert "v2" not in s._pending_vector_docs
- assert embed_func.call_count == 1
- assert mock_bulk.await_count == 2
- @pytest.mark.asyncio
- async def test_embedding_failure_leaves_pending_for_retry(
- self, global_config, embed_func, mock_client
- ):
- """Embedding failures behave like flush failures: buffers stay intact."""
- embed_func = CountingEmbeddingFunc(fail_times=1)
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"v1": {"content": "retry me"}})
- with pytest.raises(RuntimeError, match="embedding failed"):
- await s.index_done_callback()
- mock_bulk.assert_not_awaited()
- assert "v1" in s._pending_vector_docs
- assert s._pending_vector_docs["v1"].vector is None
- await s.index_done_callback()
- mock_bulk.assert_awaited_once()
- assert "v1" not in s._pending_vector_docs
- assert embed_func.call_count == 2
- @pytest.mark.asyncio
- async def test_finalize_wraps_embedding_failure(
- self, global_config, embed_func, mock_client
- ):
- """finalize() reports pending buffers when deferred embedding fails."""
- embed_func = CountingEmbeddingFunc(fail_times=1)
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch.object(
- ClientManager, "release_client", new_callable=AsyncMock
- ) as mock_release:
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk",
- new_callable=AsyncMock,
- ) as mock_bulk:
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"v1": {"content": "stuck"}})
- with pytest.raises(RuntimeError, match="pending upserts"):
- await s.finalize()
- mock_bulk.assert_not_awaited()
- mock_release.assert_awaited_once()
- assert s.client is None
- assert "v1" in s._pending_vector_docs
- assert s._pending_vector_docs["v1"].vector is None
- @pytest.mark.asyncio
- async def test_delete_entity_relation_prunes_pending_buffer(
- self, global_config, embed_func, mock_client
- ):
- """Pending docs whose src_id/tgt_id match the entity are dropped before delete_by_query."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert(
- {
- "rel-1": {
- "content": "Alice -> Bob",
- "src_id": "Alice",
- "tgt_id": "Bob",
- },
- "rel-2": {
- "content": "Carol -> Dave",
- "src_id": "Carol",
- "tgt_id": "Dave",
- },
- }
- )
- await s.delete_entity_relation("Alice")
- assert "rel-1" not in s._pending_vector_docs
- assert "rel-2" in s._pending_vector_docs
- mock_client.delete_by_query.assert_awaited_once()
- def test_extract_bulk_failed_ids_classifies_by_status(self):
- from lightrag.kg.opensearch_impl import _extract_bulk_failed_ids
- # No failures -> empty containers.
- retryable, non_retryable = _extract_bulk_failed_ids(None)
- assert retryable == set()
- assert non_retryable == []
- retryable, non_retryable = _extract_bulk_failed_ids([])
- assert retryable == set()
- assert non_retryable == []
- retryable, non_retryable = _extract_bulk_failed_ids(
- [
- # Retryable: 5xx server error.
- {"index": {"_id": "r-500", "status": 500}},
- # Retryable: rate-limited.
- {"index": {"_id": "r-429", "status": 429}},
- # Retryable: missing status (network / parse failure).
- {"create": {"_id": "r-none"}},
- # Non-retryable: bad request with dict-shape error.
- {
- "index": {
- "_id": "n-400",
- "status": 400,
- "error": {
- "type": "mapper_parsing_exception",
- "reason": "vector must be array",
- },
- }
- },
- # Non-retryable: not found on update (doc disappeared).
- {"update": {"_id": "n-404", "status": 404, "error": "not found"}},
- # Special case: delete of missing doc -> dropped from BOTH
- # sets, since the row is already gone.
- {"delete": {"_id": "drop-404", "status": 404}},
- # Malformed entries are skipped silently.
- "garbage",
- {"update": {}},
- ]
- )
- assert retryable == {"r-500", "r-429", "r-none"}
- non_retryable_ids = {op.doc_id for op in non_retryable}
- assert non_retryable_ids == {"n-400", "n-404"}
- by_id = {op.doc_id: op for op in non_retryable}
- # dict-shape error is summarised via "reason"
- assert by_id["n-400"].op == "index"
- assert by_id["n-400"].status == 400
- assert "vector must be array" in by_id["n-400"].error
- # string-shape error is passed through
- assert by_id["n-404"].op == "update"
- assert by_id["n-404"].status == 404
- assert by_id["n-404"].error == "not found"
- def test_extract_bulk_failed_ids_truncates_long_errors(self):
- from lightrag.kg.opensearch_impl import (
- _extract_bulk_failed_ids,
- _BULK_ERROR_SUMMARY_MAX_LEN,
- )
- long_reason = "x" * 1000
- _, non_retryable = _extract_bulk_failed_ids(
- [
- {
- "index": {
- "_id": "n-400",
- "status": 400,
- "error": {"reason": long_reason},
- }
- }
- ]
- )
- assert len(non_retryable) == 1
- assert len(non_retryable[0].error) <= _BULK_ERROR_SUMMARY_MAX_LEN
- assert non_retryable[0].error.endswith("...")
- @pytest.mark.asyncio
- async def test_failed_flush_drops_non_retryable_entries(
- self, global_config, embed_func, mock_client
- ):
- """4xx (non-429) failures are dropped, not perpetually retried."""
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- # v1 fails permanently (400 mapping error); v2 fails
- # transiently (503).
- mock_bulk.return_value = (
- 0,
- [
- {"index": {"_id": "v1", "status": 400, "error": "bad mapping"}},
- {"index": {"_id": "v2", "status": 503, "error": "down"}},
- ],
- )
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert(
- {"v1": {"content": "bad"}, "v2": {"content": "transient"}}
- )
- await s.index_done_callback()
- # v1 is dropped (non-retryable), v2 is retained (retryable).
- assert "v1" not in s._pending_vector_docs
- assert "v2" in s._pending_vector_docs
- @pytest.mark.asyncio
- async def test_concurrent_writes_during_flush_are_serialised(
- self, global_config, embed_func, mock_client
- ):
- """All buffer writes acquire the namespace lock, so an upsert issued
- while a flush is in flight is blocked until the flush completes and
- then lands in the live buffer for the next flush.
- """
- flush_started = asyncio.Event()
- flush_can_finish = asyncio.Event()
- async def slow_bulk(client, actions, raise_on_error=False, **kwargs):
- flush_started.set()
- await flush_can_finish.wait()
- return (len(actions), [])
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch("lightrag.kg.opensearch_impl.helpers.async_bulk", new=slow_bulk):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"v1": {"content": "first"}})
- flush_task = asyncio.create_task(s.index_done_callback())
- await flush_started.wait()
- # The flush is holding the lock and awaiting async_bulk.
- # Issue a concurrent upsert via create_task so we can
- # assert it is blocked (a direct await would deadlock the
- # single-threaded event loop on the lock acquisition).
- concurrent_task = asyncio.create_task(
- s.upsert({"v2": {"content": "concurrent"}})
- )
- # Yield so the concurrent task gets a chance to start its
- # embedding computation and arrive at the lock.
- for _ in range(5):
- await asyncio.sleep(0)
- assert (
- not concurrent_task.done()
- ), "concurrent upsert should be blocked by the flush lock"
- # v2 must not be visible in the buffer yet.
- assert "v2" not in s._pending_vector_docs
- # Release the bulk call; flush completes and the concurrent
- # upsert then finally writes v2 into the (now-empty) buffer.
- flush_can_finish.set()
- await flush_task
- await concurrent_task
- assert "v1" not in s._pending_vector_docs
- assert "v2" in s._pending_vector_docs
- @pytest.mark.asyncio
- async def test_concurrent_delete_during_flush_supersedes_retried_upsert(
- self, global_config, embed_func, mock_client
- ):
- """A delete that lands after a flush retains a transient failure
- wins over the retried upsert for the same id.
- Under the lock-everywhere model the delete runs strictly after the
- flush; the merge-back of the retryable v1 upsert is then cancelled
- by the delete in a single, sequential pass.
- """
- flush_started = asyncio.Event()
- flush_can_finish = asyncio.Event()
- async def slow_bulk(client, actions, raise_on_error=False, **kwargs):
- flush_started.set()
- await flush_can_finish.wait()
- # Report v1's upsert as a transient failure so the flush
- # leaves it in the buffer for retry.
- return (
- 0,
- [{"index": {"_id": "v1", "status": 503, "error": "down"}}],
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch("lightrag.kg.opensearch_impl.helpers.async_bulk", new=slow_bulk):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"v1": {"content": "first"}})
- flush_task = asyncio.create_task(s.index_done_callback())
- await flush_started.wait()
- # Issue the concurrent delete; it queues behind the lock.
- delete_task = asyncio.create_task(s.delete(["v1"]))
- for _ in range(5):
- await asyncio.sleep(0)
- assert (
- not delete_task.done()
- ), "concurrent delete should be blocked by the flush lock"
- flush_can_finish.set()
- await flush_task
- await delete_task
- # The retry left v1 in the docs buffer; the subsequent
- # delete then cancelled that upsert and replaced it with a
- # tombstone.
- assert "v1" not in s._pending_vector_docs
- assert "v1" in s._pending_vector_deletes
- @pytest.mark.asyncio
- async def test_get_by_id_strips_vector_from_mget_path(
- self, global_config, embed_func, mock_client
- ):
- """The mget fallback path returns the same shape as NanoVectorDB:
- no ``vector`` key, and the server-side _source_excludes is set so the
- embedding never crosses the wire in the first place.
- """
- mock_client.mget = AsyncMock(
- return_value={
- "docs": [
- {
- "_id": "v1",
- "found": True,
- # defensive: server-side excludes might be ignored
- # in misconfigured indices; we still pop client-side.
- "_source": {"content": "from_index", "vector": [0.1] * 128},
- }
- ]
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- # No upsert: buffer empty, falls through to mget.
- doc = await s.get_by_id("v1")
- assert doc is not None
- assert doc["id"] == "v1"
- assert doc["content"] == "from_index"
- assert "vector" not in doc
- mock_client.mget.assert_awaited_once_with(
- index=s._index_name,
- body={"ids": ["v1"]},
- _source_excludes=["vector"],
- )
- @pytest.mark.asyncio
- async def test_get_by_ids_strips_vector_from_mget_path(
- self, global_config, embed_func, mock_client
- ):
- """get_by_ids strips vector on the fallback path and forwards
- _source_excludes to mget."""
- mock_client.mget = AsyncMock(
- return_value={
- "docs": [
- {
- "_id": "v1",
- "found": True,
- "_source": {"content": "a", "vector": [0.1] * 128},
- },
- {
- "_id": "v2",
- "found": True,
- "_source": {"content": "b", "vector": [0.2] * 128},
- },
- ]
- }
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- s = self._make(global_config, embed_func)
- await s.initialize()
- docs = await s.get_by_ids(["v1", "v2"])
- assert all(d is not None for d in docs)
- assert all("vector" not in d for d in docs)
- assert docs[0]["content"] == "a"
- assert docs[1]["content"] == "b"
- mock_client.mget.assert_awaited_once_with(
- index=s._index_name,
- body={"ids": ["v1", "v2"]},
- _source_excludes=["vector"],
- )
- @pytest.mark.asyncio
- async def test_non_retryable_logs_sample_ids(
- self, global_config, embed_func, mock_client, caplog
- ):
- """Non-retryable bulk failures log a sample with id/status/error."""
- import logging as _logging
- failed = [
- {
- "index": {
- "_id": f"v{i}",
- "status": 400,
- "error": {
- "type": "mapper_parsing_exception",
- "reason": f"bad field {i}",
- },
- }
- }
- for i in range(6)
- ]
- # lightrag logger has propagate=False, so caplog's root handler
- # would miss these records. Re-enable propagation just for this
- # test so caplog can capture the warning we emit.
- lightrag_logger = _logging.getLogger("lightrag")
- original_propagate = lightrag_logger.propagate
- lightrag_logger.propagate = True
- try:
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk",
- new_callable=AsyncMock,
- ) as mock_bulk:
- mock_bulk.return_value = (0, failed)
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({f"v{i}": {"content": f"d{i}"} for i in range(6)})
- with caplog.at_level("WARNING", logger="lightrag"):
- await s.index_done_callback()
- finally:
- lightrag_logger.propagate = original_propagate
- warning_text = "\n".join(
- rec.message for rec in caplog.records if rec.levelname == "WARNING"
- )
- # Sample contains the first 5 ids with op/status/reason text.
- for i in range(5):
- assert f"v{i}" in warning_text
- assert "status=400" in warning_text
- assert "bad field" in warning_text
- # 6 permanent failures reported in aggregate.
- assert "6 vector ops" in warning_text
- @pytest.mark.asyncio
- async def test_index_done_callback_flushes_when_index_recreated(
- self, global_config, embed_func, mock_client
- ):
- """If the index was marked missing after writes were buffered, the
- callback must still flush — _flush_pending_vector_ops recreates the
- index via _ensure_index_ready before issuing the bulk call.
- """
- # Sequence the indices.exists results so the second _create
- # invocation actually creates the index again.
- exists_responses = [False, False]
- mock_client.indices.exists = AsyncMock(
- side_effect=lambda **kw: exists_responses.pop(0)
- if exists_responses
- else False
- )
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"v1": {"content": "ok"}})
- # Simulate the index disappearing (e.g. via a read 404)
- # AFTER the write was buffered.
- s._mark_index_missing()
- await s.index_done_callback()
- # The buffer was flushed, even though _index_ready was
- # False at callback entry.
- mock_bulk.assert_awaited_once()
- assert s._pending_vector_docs == {}
- # The index was recreated as part of flush.
- assert mock_client.indices.create.await_count >= 2
- @pytest.mark.asyncio
- async def test_delete_entity_relation_serialised_with_flush(
- self, global_config, embed_func, mock_client
- ):
- """delete_entity_relation runs entirely under the flush lock, so it
- cannot race with an in-flight bulk indexing operation."""
- flush_started = asyncio.Event()
- flush_can_finish = asyncio.Event()
- delete_started = asyncio.Event()
- async def slow_bulk(client, actions, raise_on_error=False, **kwargs):
- flush_started.set()
- await flush_can_finish.wait()
- return (len(actions), [])
- async def watch_delete_by_query(**kwargs):
- delete_started.set()
- return {"deleted": 0}
- mock_client.delete_by_query = AsyncMock(side_effect=watch_delete_by_query)
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch("lightrag.kg.opensearch_impl.helpers.async_bulk", new=slow_bulk):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert(
- {
- "rel-1": {
- "content": "X",
- "src_id": "Alice",
- "tgt_id": "Bob",
- }
- }
- )
- flush_task = asyncio.create_task(s.index_done_callback())
- await flush_started.wait()
- # delete_by_query must NOT fire while bulk is still in flight.
- rel_task = asyncio.create_task(s.delete_entity_relation("Alice"))
- for _ in range(5):
- await asyncio.sleep(0)
- assert (
- not delete_started.is_set()
- ), "delete_by_query should be blocked behind the flush lock"
- assert not rel_task.done()
- flush_can_finish.set()
- await flush_task
- await rel_task
- assert delete_started.is_set()
- @pytest.mark.asyncio
- async def test_drop_serialised_with_flush(
- self, global_config, embed_func, mock_client
- ):
- """drop must serialise with an in-flight flush; the index delete
- cannot land while bulk indexing is mid-request.
- """
- flush_started = asyncio.Event()
- flush_can_finish = asyncio.Event()
- drop_delete_started = asyncio.Event()
- async def slow_bulk(client, actions, raise_on_error=False, **kwargs):
- flush_started.set()
- await flush_can_finish.wait()
- return (len(actions), [])
- async def watch_indices_delete(**kwargs):
- drop_delete_started.set()
- mock_client.indices.delete = AsyncMock(side_effect=watch_indices_delete)
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch("lightrag.kg.opensearch_impl.helpers.async_bulk", new=slow_bulk):
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"v1": {"content": "x"}})
- flush_task = asyncio.create_task(s.index_done_callback())
- await flush_started.wait()
- drop_task = asyncio.create_task(s.drop())
- for _ in range(5):
- await asyncio.sleep(0)
- assert (
- not drop_delete_started.is_set()
- ), "indices.delete should be blocked behind the flush lock"
- assert not drop_task.done()
- flush_can_finish.set()
- await flush_task
- await drop_task
- assert drop_delete_started.is_set()
- @pytest.mark.asyncio
- async def test_drop_serialised_with_flush_embedding_phase(
- self, global_config, mock_client
- ):
- """drop must also wait while deferred embedding runs under the flush lock."""
- embedding_started = asyncio.Event()
- embedding_can_finish = asyncio.Event()
- drop_delete_started = asyncio.Event()
- class GatedEmbeddingFunc(MockEmbeddingFunc):
- async def __call__(self, texts, **kwargs):
- embedding_started.set()
- await embedding_can_finish.wait()
- return await super().__call__(texts, **kwargs)
- async def watch_indices_delete(**kwargs):
- drop_delete_started.set()
- mock_client.indices.delete = AsyncMock(side_effect=watch_indices_delete)
- embed_func = GatedEmbeddingFunc()
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk", new_callable=AsyncMock
- ) as mock_bulk:
- mock_bulk.return_value = (1, [])
- s = self._make(global_config, embed_func)
- await s.initialize()
- await s.upsert({"v1": {"content": "x"}})
- flush_task = asyncio.create_task(s.index_done_callback())
- await embedding_started.wait()
- drop_task = asyncio.create_task(s.drop())
- for _ in range(5):
- await asyncio.sleep(0)
- assert (
- not drop_delete_started.is_set()
- ), "indices.delete should be blocked during deferred embedding"
- assert not drop_task.done()
- embedding_can_finish.set()
- await flush_task
- await drop_task
- assert drop_delete_started.is_set()
- # ---------------------------------------------------------------------------
- # Cosine score edge cases
- # ---------------------------------------------------------------------------
- class TestScoreThreshold:
- """Verify that raw OpenSearch scores are compared directly against threshold."""
- def test_above_threshold(self):
- assert 0.85 >= 0.2
- def test_below_threshold(self):
- assert 0.15 < 0.2
- def test_exact_threshold(self):
- assert 0.2 >= 0.2
- # ---------------------------------------------------------------------------
- # Why raising EMBEDDING_BATCH_NUM does not lower the embedding call count
- # ---------------------------------------------------------------------------
- class TestEmbeddingBatchNumDiagnosis:
- """Pin down why bumping EMBEDDING_BATCH_NUM leaves the embedding call
- count (get_embedding_queue_status -> submitted_total) unchanged for
- entities/relations.
- ``merge_nodes_and_edges`` upserts entities/relations ONE id at a time:
- ``_merge_nodes_then_upsert`` calls ``entity_vdb.upsert({single})`` and
- ``_merge_edges_then_upsert`` calls ``relationships_vdb.upsert({single})``
- (lightrag/operate.py). ``EMBEDDING_BATCH_NUM`` only slices the items
- *within one embedding pass* (``contents[i:i+batch]``). So the call count
- is governed by how many items reach a single embedding pass, not by the
- batch size -- raising the batch size only helps once >= 2 items are
- embedded together.
- """
- def _make(self, batch_num, embed_func, workspace="diag"):
- config = {
- "embedding_batch_num": batch_num,
- "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.2},
- }
- return OpenSearchVectorDBStorage(
- namespace="entities",
- global_config=config,
- embedding_func=embed_func,
- workspace=workspace,
- meta_fields={"content", "entity_name"},
- )
- @staticmethod
- def _fake_bulk(_client, actions, *_args, **_kwargs):
- # async_bulk(raise_on_error=False) -> (success_count, failed_list).
- # Empty failed list = every buffered action persisted.
- return (len(actions), [])
- async def _run_per_item(self, batch_num, *, flush_each, n=100):
- """Upsert ``n`` entities one-at-a-time, mirroring the merge path.
- flush_each=True -> embed right after each single-item upsert, so every
- embedding pass sees exactly 1 item. This is the
- pre-defer / eager behaviour where ``upsert`` embeds
- inline.
- flush_each=False -> buffer every single-item upsert and flush once, i.e.
- the deferred-embedding design on this branch.
- """
- embed = CountingEmbeddingFunc()
- mock_client = _make_client()
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk",
- new_callable=AsyncMock,
- ) as mock_bulk:
- mock_bulk.side_effect = self._fake_bulk
- s = self._make(batch_num, embed)
- await s.initialize()
- for i in range(n):
- await s.upsert(
- {f"ent-{i}": {"content": f"entity {i}", "entity_name": f"E{i}"}}
- )
- if flush_each:
- await s.index_done_callback()
- if not flush_each:
- await s.index_done_callback()
- return embed
- @pytest.mark.asyncio
- async def test_per_item_embedding_makes_batch_num_a_noop(self):
- """Eager pattern: embedding happens once per single-item upsert.
- Reproduces the billing observation -- every embedding call carries
- exactly ONE item (~one entity's tokens) -- and bumping
- EMBEDDING_BATCH_NUM from 16 to 32 changes nothing.
- """
- embed16 = await self._run_per_item(16, flush_each=True)
- embed32 = await self._run_per_item(32, flush_each=True)
- assert embed16.call_count == 100
- assert embed32.call_count == 100
- # Each embedding pass saw exactly one item, regardless of batch size.
- assert all(len(b) == 1 for b in embed16.batches)
- assert all(len(b) == 1 for b in embed32.batches)
- # The crux: raising the batch size did not reduce the call count.
- assert embed16.call_count == embed32.call_count
- @pytest.mark.asyncio
- async def test_deferred_flush_makes_batch_num_effective(self):
- """Deferred pattern: buffer all single-item upserts, flush once.
- Now EMBEDDING_BATCH_NUM finally governs the count:
- ceil(100/16)=7 vs ceil(100/32)=4.
- """
- embed16 = await self._run_per_item(16, flush_each=False)
- embed32 = await self._run_per_item(32, flush_each=False)
- assert embed16.call_count == math.ceil(100 / 16) == 7
- assert embed32.call_count == math.ceil(100 / 32) == 4
- assert embed16.call_count != embed32.call_count
- # Every flushed batch respects the configured cap, and nothing is lost.
- assert all(len(b) <= 16 for b in embed16.batches)
- assert all(len(b) <= 32 for b in embed32.batches)
- assert len(embed16.texts) == 100
- assert len(embed32.texts) == 100
- @pytest.mark.asyncio
- async def test_single_multiitem_upsert_is_batched_like_chunks_vdb(self):
- """Contrast: chunks_vdb upserts a whole document's chunks in ONE call.
- When many items arrive in a single upsert/embedding pass,
- EMBEDDING_BATCH_NUM works as expected even with an immediate flush --
- proving the determining factor is items-per-embedding-pass, not the
- storage backend. This is why batch_num visibly affects chunks but not
- per-id entity/relation upserts.
- """
- embed16 = CountingEmbeddingFunc()
- embed32 = CountingEmbeddingFunc()
- for batch_num, embed in ((16, embed16), (32, embed32)):
- mock_client = _make_client()
- with patch.object(ClientManager, "get_client", return_value=mock_client):
- with patch(
- "lightrag.kg.opensearch_impl.helpers.async_bulk",
- new_callable=AsyncMock,
- ) as mock_bulk:
- mock_bulk.side_effect = self._fake_bulk
- s = self._make(batch_num, embed)
- await s.initialize()
- # chunks_vdb.upsert(chunks): one call carrying 100 items.
- await s.upsert(
- {
- f"chunk-{i}": {"content": f"chunk {i}", "entity_name": ""}
- for i in range(100)
- }
- )
- await s.index_done_callback()
- assert embed16.call_count == math.ceil(100 / 16) == 7
- assert embed32.call_count == math.ceil(100 / 32) == 4
- assert embed16.call_count != embed32.call_count
|