utils.py 142 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915
  1. from __future__ import annotations
  2. import weakref
  3. import sys
  4. import asyncio
  5. import html
  6. import csv
  7. import inspect
  8. import json
  9. import logging
  10. import logging.handlers
  11. import os
  12. import re
  13. import time
  14. import uuid
  15. import warnings
  16. from dataclasses import dataclass
  17. from datetime import datetime
  18. from functools import wraps
  19. from hashlib import md5
  20. from pathlib import Path
  21. from typing import (
  22. Any,
  23. Protocol,
  24. Callable,
  25. TYPE_CHECKING,
  26. List,
  27. Optional,
  28. Iterable,
  29. Sequence,
  30. Collection,
  31. )
  32. import numpy as np
  33. from dotenv import load_dotenv
  34. from lightrag.constants import (
  35. DEFAULT_LOG_MAX_BYTES,
  36. DEFAULT_LOG_BACKUP_COUNT,
  37. DEFAULT_LOG_FILENAME,
  38. GRAPH_FIELD_SEP,
  39. DEFAULT_MAX_TOTAL_TOKENS,
  40. DEFAULT_SOURCE_IDS_LIMIT_METHOD,
  41. VALID_SOURCE_IDS_LIMIT_METHODS,
  42. SOURCE_IDS_LIMIT_METHOD_FIFO,
  43. PARSED_DIR_NAME,
  44. )
  45. # Precompile regex pattern for JSON sanitization (module-level, compiled once)
  46. _SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
  47. _CONTROL_CHAR_PATTERN_ALL = re.compile(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]")
  48. class SafeStreamHandler(logging.StreamHandler):
  49. """StreamHandler that gracefully handles closed streams during shutdown.
  50. This handler prevents "ValueError: I/O operation on closed file" errors
  51. that can occur when pytest or other test frameworks close stdout/stderr
  52. before Python's logging cleanup runs.
  53. """
  54. def flush(self):
  55. """Flush the stream, ignoring errors if the stream is closed."""
  56. try:
  57. super().flush()
  58. except (ValueError, OSError):
  59. # Stream is closed or otherwise unavailable, silently ignore
  60. pass
  61. def close(self):
  62. """Close the handler, ignoring errors if the stream is already closed."""
  63. try:
  64. super().close()
  65. except (ValueError, OSError):
  66. # Stream is closed or otherwise unavailable, silently ignore
  67. pass
  68. # Initialize logger with basic configuration
  69. logger = logging.getLogger("lightrag")
  70. logger.propagate = False # prevent log message send to root logger
  71. logger.setLevel(logging.INFO)
  72. # Add console handler if no handlers exist
  73. if not logger.handlers:
  74. console_handler = SafeStreamHandler()
  75. console_handler.setLevel(logging.INFO)
  76. formatter = logging.Formatter("%(levelname)s: %(message)s")
  77. console_handler.setFormatter(formatter)
  78. logger.addHandler(console_handler)
  79. # Set httpx logging level to WARNING
  80. logging.getLogger("httpx").setLevel(logging.WARNING)
  81. def _patch_ascii_colors_console_handler() -> None:
  82. """Prevent ascii_colors from printing flush errors during interpreter exit."""
  83. try:
  84. from ascii_colors import ConsoleHandler
  85. except ImportError:
  86. return
  87. if getattr(ConsoleHandler, "_lightrag_patched", False):
  88. return
  89. original_handle_error = ConsoleHandler.handle_error
  90. def _safe_handle_error(self, message: str) -> None: # type: ignore[override]
  91. exc_type, _, _ = sys.exc_info()
  92. if exc_type in (ValueError, OSError) and "close" in message.lower():
  93. return
  94. original_handle_error(self, message)
  95. ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment]
  96. ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined]
  97. _patch_ascii_colors_console_handler()
  98. # Global import for pypinyin with startup-time logging
  99. try:
  100. import pypinyin
  101. _PYPINYIN_AVAILABLE = True
  102. # logger.info("pypinyin loaded successfully for Chinese pinyin sorting")
  103. except ImportError:
  104. pypinyin = None
  105. _PYPINYIN_AVAILABLE = False
  106. logger.warning(
  107. "pypinyin is not installed. Chinese pinyin sorting will use simple string sorting."
  108. )
  109. async def safe_vdb_operation_with_exception(
  110. operation: Callable,
  111. operation_name: str,
  112. entity_name: str = "",
  113. max_retries: int = 3,
  114. retry_delay: float = 0.2,
  115. logger_func: Optional[Callable] = None,
  116. timeout_seconds: float | None = None,
  117. log_start: bool = False,
  118. success_log_threshold_seconds: float = 10.0,
  119. ) -> None:
  120. """
  121. Safely execute vector database operations with retry mechanism and exception handling.
  122. This function ensures that VDB operations are executed with proper error handling
  123. and retry logic. If all retries fail, it raises an exception to maintain data consistency.
  124. Args:
  125. operation: The async operation to execute
  126. operation_name: Operation name for logging purposes
  127. entity_name: Entity name for logging purposes
  128. max_retries: Maximum number of retry attempts
  129. retry_delay: Delay between retries in seconds
  130. logger_func: Logger function to use for error messages
  131. timeout_seconds: Optional timeout for a single operation attempt
  132. log_start: Whether to emit start/success logs for each attempt
  133. success_log_threshold_seconds: Log successful attempts when duration exceeds this threshold
  134. Raises:
  135. Exception: When operation fails after all retry attempts
  136. """
  137. log_func = logger_func or logger.warning
  138. for attempt in range(max_retries):
  139. start_ts = time.perf_counter()
  140. attempt_label = f"{attempt + 1}/{max_retries}"
  141. try:
  142. if log_start:
  143. logger.info(
  144. "VDB %s start for %s (attempt %s, timeout=%s)",
  145. operation_name,
  146. entity_name or "<unknown>",
  147. attempt_label,
  148. f"{timeout_seconds:.1f}s"
  149. if timeout_seconds is not None
  150. else "none",
  151. )
  152. if timeout_seconds is not None and timeout_seconds > 0:
  153. await asyncio.wait_for(operation(), timeout=timeout_seconds)
  154. else:
  155. await operation()
  156. elapsed = time.perf_counter() - start_ts
  157. if log_start or elapsed >= success_log_threshold_seconds:
  158. logger.info(
  159. "VDB %s success for %s in %.2fs (attempt %s)",
  160. operation_name,
  161. entity_name or "<unknown>",
  162. elapsed,
  163. attempt_label,
  164. )
  165. return # Success, return immediately
  166. except asyncio.TimeoutError as e:
  167. elapsed = time.perf_counter() - start_ts
  168. timeout_msg = (
  169. f"VDB {operation_name} timeout for {entity_name or '<unknown>'} "
  170. f"after {elapsed:.2f}s (attempt {attempt_label}, timeout={timeout_seconds}s)"
  171. )
  172. if attempt >= max_retries - 1:
  173. log_func(timeout_msg)
  174. raise TimeoutError(timeout_msg) from e
  175. log_func(f"{timeout_msg}, retrying...")
  176. if retry_delay > 0:
  177. await asyncio.sleep(retry_delay)
  178. except Exception as e:
  179. elapsed = time.perf_counter() - start_ts
  180. if attempt >= max_retries - 1:
  181. error_msg = (
  182. f"VDB {operation_name} failed for {entity_name or '<unknown>'} "
  183. f"after {max_retries} attempts in {elapsed:.2f}s: {e}"
  184. )
  185. log_func(error_msg)
  186. raise Exception(error_msg) from e
  187. else:
  188. log_func(
  189. f"VDB {operation_name} attempt {attempt + 1} failed for "
  190. f"{entity_name or '<unknown>'} after {elapsed:.2f}s: {e}, retrying..."
  191. )
  192. if retry_delay > 0:
  193. await asyncio.sleep(retry_delay)
  194. def parse_optional_float(raw: str | None) -> float | None:
  195. """Decode env strings (or any text) into ``float | None``.
  196. Empty string and the literal ``"None"`` (case-insensitive) collapse
  197. to ``None`` so users can leave a knob un-set in ``.env`` and have
  198. the consuming code fall back to its own default. Any other
  199. non-numeric value raises :class:`ValueError` so misconfigured envs
  200. fail loudly at parse time rather than silently downstream.
  201. """
  202. if raw is None:
  203. return None
  204. stripped = raw.strip()
  205. if not stripped or stripped.lower() == "none":
  206. return None
  207. return float(stripped)
  208. def get_env_value(
  209. env_key: str, default: any, value_type: type = str, special_none: bool = False
  210. ) -> any:
  211. """
  212. Get value from environment variable with type conversion
  213. Args:
  214. env_key (str): Environment variable key
  215. default (any): Default value if env variable is not set
  216. value_type (type): Type to convert the value to
  217. special_none (bool): If True, return None when value is "None"
  218. Returns:
  219. any: Converted value from environment or default
  220. """
  221. value = os.getenv(env_key)
  222. if value is None:
  223. return default
  224. # Handle special case for "None" string
  225. if special_none and value == "None":
  226. return None
  227. if value_type is bool:
  228. return value.lower() in ("true", "1", "yes", "t", "on")
  229. # Handle list type with JSON parsing
  230. if value_type is list:
  231. try:
  232. import json
  233. parsed_value = json.loads(value)
  234. # Ensure the parsed value is actually a list
  235. if isinstance(parsed_value, list):
  236. return parsed_value
  237. else:
  238. logger.warning(
  239. f"Environment variable {env_key} is not a valid JSON list, using default"
  240. )
  241. return default
  242. except (json.JSONDecodeError, ValueError) as e:
  243. logger.warning(
  244. f"Failed to parse {env_key} as JSON list: {e}, using default"
  245. )
  246. return default
  247. try:
  248. return value_type(value)
  249. except (ValueError, TypeError):
  250. return default
  251. # Use TYPE_CHECKING to avoid circular imports
  252. if TYPE_CHECKING:
  253. from lightrag.base import BaseKVStorage, BaseVectorStorage, QueryParam
  254. # use the .env that is inside the current folder
  255. # allows to use different .env file for each lightrag instance
  256. # the OS environment variables take precedence over the .env file
  257. load_dotenv(dotenv_path=".env", override=False)
  258. VERBOSE_DEBUG = os.getenv("VERBOSE", "false").lower() == "true"
  259. PERFORMANCE_TIMING_LOGS = (
  260. os.getenv("LIGHTRAG_PERFORMANCE_TIMING_LOGS", "false").lower() == "true"
  261. )
  262. def verbose_debug(msg: str, *args, **kwargs):
  263. """Function for outputting detailed debug information.
  264. When VERBOSE_DEBUG=True, outputs the complete message.
  265. When VERBOSE_DEBUG=False, outputs only the first 50 characters.
  266. Args:
  267. msg: The message format string
  268. *args: Arguments to be formatted into the message
  269. **kwargs: Keyword arguments passed to logger.debug()
  270. """
  271. if VERBOSE_DEBUG:
  272. logger.debug(msg, *args, **kwargs)
  273. else:
  274. # Format the message with args first
  275. if args:
  276. formatted_msg = msg % args
  277. else:
  278. formatted_msg = msg
  279. # Then truncate the formatted message
  280. truncated_msg = (
  281. formatted_msg[:150] + "..." if len(formatted_msg) > 150 else formatted_msg
  282. )
  283. # Remove consecutive newlines
  284. truncated_msg = re.sub(r"\n+", "\n", truncated_msg)
  285. logger.debug(truncated_msg, **kwargs)
  286. def set_verbose_debug(enabled: bool):
  287. """Enable or disable verbose debug output"""
  288. global VERBOSE_DEBUG
  289. VERBOSE_DEBUG = enabled
  290. def performance_timing_log(msg: str, *args, **kwargs):
  291. """Emit targeted performance timing logs only when explicitly enabled."""
  292. if PERFORMANCE_TIMING_LOGS:
  293. logger.info(msg, *args, **kwargs)
  294. statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
  295. class LightragPathFilter(logging.Filter):
  296. """Filter for lightrag logger to filter out frequent path access logs"""
  297. def __init__(self):
  298. super().__init__()
  299. # Define paths to be filtered
  300. self.filtered_paths = [
  301. "/documents",
  302. "/documents/paginated",
  303. "/health",
  304. "/webui/",
  305. "/documents/pipeline_status",
  306. ]
  307. # self.filtered_paths = ["/health", "/webui/"]
  308. def filter(self, record):
  309. try:
  310. # Check if record has the required attributes for an access log
  311. if not hasattr(record, "args") or not isinstance(record.args, tuple):
  312. return True
  313. if len(record.args) < 5:
  314. return True
  315. # Extract method, path and status from the record args
  316. method = record.args[1]
  317. path = record.args[2]
  318. status = record.args[4]
  319. # Filter out successful GET/POST requests to filtered paths
  320. if (
  321. (method == "GET" or method == "POST")
  322. and (status == 200 or status == 304)
  323. and path in self.filtered_paths
  324. ):
  325. return False
  326. return True
  327. except Exception:
  328. # In case of any error, let the message through
  329. return True
  330. def setup_logger(
  331. logger_name: str,
  332. level: str = "INFO",
  333. add_filter: bool = False,
  334. log_file_path: str | None = None,
  335. enable_file_logging: bool = True,
  336. ):
  337. """Set up a logger with console and optionally file handlers
  338. Args:
  339. logger_name: Name of the logger to set up
  340. level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
  341. add_filter: Whether to add LightragPathFilter to the logger
  342. log_file_path: Path to the log file. If None and file logging is enabled, defaults to lightrag.log in LOG_DIR or cwd
  343. enable_file_logging: Whether to enable logging to a file (defaults to True)
  344. """
  345. # Configure formatters
  346. detailed_formatter = logging.Formatter(
  347. "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
  348. )
  349. simple_formatter = logging.Formatter("%(levelname)s: %(message)s")
  350. logger_instance = logging.getLogger(logger_name)
  351. logger_instance.setLevel(level)
  352. logger_instance.handlers = [] # Clear existing handlers
  353. logger_instance.propagate = False
  354. # Add console handler with safe stream handling
  355. console_handler = SafeStreamHandler()
  356. console_handler.setFormatter(simple_formatter)
  357. console_handler.setLevel(level)
  358. logger_instance.addHandler(console_handler)
  359. # Add file handler by default unless explicitly disabled
  360. if enable_file_logging:
  361. # Get log file path
  362. if log_file_path is None:
  363. log_dir = os.getenv("LOG_DIR", os.getcwd())
  364. log_file_path = os.path.abspath(os.path.join(log_dir, DEFAULT_LOG_FILENAME))
  365. # Ensure log directory exists
  366. os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
  367. # Get log file max size and backup count from environment variables
  368. log_max_bytes = get_env_value("LOG_MAX_BYTES", DEFAULT_LOG_MAX_BYTES, int)
  369. log_backup_count = get_env_value(
  370. "LOG_BACKUP_COUNT", DEFAULT_LOG_BACKUP_COUNT, int
  371. )
  372. try:
  373. # Add file handler
  374. file_handler = logging.handlers.RotatingFileHandler(
  375. filename=log_file_path,
  376. maxBytes=log_max_bytes,
  377. backupCount=log_backup_count,
  378. encoding="utf-8",
  379. )
  380. file_handler.setFormatter(detailed_formatter)
  381. file_handler.setLevel(level)
  382. logger_instance.addHandler(file_handler)
  383. except PermissionError as e:
  384. logger.warning(f"Could not create log file at {log_file_path}: {str(e)}")
  385. logger.warning("Continuing with console logging only")
  386. # Add path filter if requested
  387. if add_filter:
  388. path_filter = LightragPathFilter()
  389. logger_instance.addFilter(path_filter)
  390. class UnlimitedSemaphore:
  391. """A context manager that allows unlimited access."""
  392. async def __aenter__(self):
  393. pass
  394. async def __aexit__(self, exc_type, exc, tb):
  395. pass
  396. @dataclass
  397. class TaskState:
  398. """Task state tracking for priority queue management"""
  399. future: asyncio.Future
  400. start_time: float
  401. execution_start_time: float = None
  402. worker_started: bool = False
  403. cancellation_requested: bool = False
  404. cleanup_done: bool = False
  405. @dataclass
  406. class EmbeddingFunc:
  407. """Embedding function wrapper with dimension validation
  408. This class wraps an embedding function to ensure that the output embeddings have the correct dimension.
  409. If wrapped multiple times, the inner wrappers will be automatically unwrapped to prevent
  410. configuration conflicts where inner wrapper settings would override outer wrapper settings.
  411. Using functools.partial for parameter binding:
  412. A common pattern is to use functools.partial to pre-bind model and host parameters
  413. to an embedding function. When the base embedding function is already decorated with
  414. @wrap_embedding_func_with_attrs (e.g., ollama_embed), use `.func` to access the
  415. original unwrapped function to avoid double wrapping:
  416. Example:
  417. from functools import partial
  418. # ❌ Wrong - causes double wrapping (inner EmbeddingFunc still executes)
  419. func=partial(ollama_embed, embed_model="bge-m3:latest", host="http://localhost:11434")
  420. # ✅ Correct - access the unwrapped function via .func
  421. func=partial(ollama_embed.func, embed_model="bge-m3:latest", host="http://localhost:11434")
  422. Context-aware embedding:
  423. The wrapper supports passing a 'context' parameter to distinguish between query and document
  424. embeddings. This allows wrapped functions to apply different processing (e.g., prefixes,
  425. different models) based on the context:
  426. Example:
  427. embeddings = await embed_func(texts, context="document") # For indexing
  428. embeddings = await embed_func([query], context="query") # For search
  429. Args:
  430. embedding_dim: Expected dimension of the embeddings(For dimension checking and workspace data isolation in vector DB)
  431. func: The actual embedding function to wrap
  432. max_token_size: Enable embedding token limit checking for description summarization(Set embedding_token_limit in LightRAG)
  433. send_dimensions: Whether to inject embedding_dim argument to underlying function
  434. model_name: Model name for implementing workspace data isolation in vector DB
  435. supports_asymmetric: Whether the underlying function supports context parameter so it can be injected
  436. """
  437. embedding_dim: int
  438. func: callable
  439. max_token_size: int | None = None
  440. send_dimensions: bool = False
  441. model_name: str | None = (
  442. None # Model name for implementing workspace data isolation in vector DB
  443. )
  444. supports_asymmetric: bool = (
  445. False # Whether underlying function accepts context parameter
  446. )
  447. def __post_init__(self):
  448. """Unwrap nested EmbeddingFunc to prevent double wrapping issues.
  449. When an EmbeddingFunc wraps another EmbeddingFunc, the inner wrapper's
  450. __call__ preprocessing would override the outer wrapper's settings.
  451. This method detects and unwraps nested EmbeddingFunc instances to ensure
  452. that only the outermost wrapper's configuration is applied.
  453. """
  454. # Check if func is already an EmbeddingFunc instance and unwrap it
  455. max_unwrap_depth = 3 # Safety limit to prevent infinite loops
  456. unwrap_count = 0
  457. while isinstance(self.func, EmbeddingFunc):
  458. unwrap_count += 1
  459. if unwrap_count > max_unwrap_depth:
  460. raise ValueError(
  461. f"EmbeddingFunc unwrap depth exceeded {max_unwrap_depth}. "
  462. "Possible circular reference detected."
  463. )
  464. # Unwrap to get the original function
  465. self.func = self.func.func
  466. if unwrap_count > 0:
  467. logger.warning(
  468. f"Detected nested EmbeddingFunc wrapping (depth: {unwrap_count}), "
  469. "auto-unwrapped to prevent configuration conflicts. "
  470. "Consider using .func to access the unwrapped function directly."
  471. )
  472. async def __call__(self, *args, **kwargs) -> np.ndarray:
  473. # Only inject embedding_dim when send_dimensions is True
  474. if self.send_dimensions:
  475. # Check if user provided embedding_dim parameter
  476. if "embedding_dim" in kwargs:
  477. user_provided_dim = kwargs["embedding_dim"]
  478. # If user's value differs from class attribute, output warning
  479. if (
  480. user_provided_dim is not None
  481. and user_provided_dim != self.embedding_dim
  482. ):
  483. logger.warning(
  484. f"Ignoring user-provided embedding_dim={user_provided_dim}, "
  485. f"using declared embedding_dim={self.embedding_dim} from decorator"
  486. )
  487. # Inject embedding_dim from decorator
  488. kwargs["embedding_dim"] = self.embedding_dim
  489. # Remove context parameter if underlying function does not support asymmetric embedding
  490. if "context" in kwargs and not self.supports_asymmetric:
  491. # Log when a user-provided context is ignored due to lack of support
  492. logger.debug(
  493. "Context parameter was provided but supports_asymmetric=False. The context value has been ignored."
  494. )
  495. kwargs.pop("context")
  496. # Check if underlying function supports max_token_size and inject if not provided
  497. if self.max_token_size is not None and "max_token_size" not in kwargs:
  498. sig = inspect.signature(self.func)
  499. if "max_token_size" in sig.parameters:
  500. kwargs["max_token_size"] = self.max_token_size
  501. # Call the actual embedding function
  502. result = await self.func(*args, **kwargs)
  503. # Validate embedding dimensions using total element count
  504. total_elements = result.size # Total number of elements in the numpy array
  505. expected_dim = self.embedding_dim
  506. # Check if total elements can be evenly divided by embedding_dim
  507. if total_elements % expected_dim != 0:
  508. raise ValueError(
  509. f"Embedding dimension mismatch detected: "
  510. f"total elements ({total_elements}) cannot be evenly divided by "
  511. f"expected dimension ({expected_dim}). "
  512. )
  513. # Optional: Verify vector count matches input text count
  514. actual_vectors = total_elements // expected_dim
  515. if args and isinstance(args[0], (list, tuple)):
  516. expected_vectors = len(args[0])
  517. if actual_vectors != expected_vectors:
  518. raise ValueError(
  519. f"Vector count mismatch: "
  520. f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)."
  521. )
  522. return result
  523. def compute_args_hash(*args: Any) -> str:
  524. """Compute a hash for the given arguments with safe Unicode handling.
  525. Args:
  526. *args: Arguments to hash
  527. Returns:
  528. str: Hash string
  529. """
  530. # Convert all arguments to strings and join them
  531. args_str = "".join([str(arg) for arg in args])
  532. # Use 'replace' error handling to safely encode problematic Unicode characters
  533. # This replaces invalid characters with Unicode replacement character (U+FFFD)
  534. try:
  535. return md5(args_str.encode("utf-8")).hexdigest()
  536. except UnicodeEncodeError:
  537. # Handle surrogate characters and other encoding issues
  538. safe_bytes = args_str.encode("utf-8", errors="replace")
  539. return md5(safe_bytes).hexdigest()
  540. def _serialize_cache_variant(value: Any) -> str:
  541. """Serialize cache-affecting options to a stable string for hash inputs."""
  542. if value is None:
  543. return ""
  544. if hasattr(value, "model_dump") and callable(value.model_dump):
  545. try:
  546. value = value.model_dump(mode="json")
  547. except TypeError:
  548. value = value.model_dump()
  549. if hasattr(value, "model_json_schema") and callable(value.model_json_schema):
  550. value = value.model_json_schema()
  551. try:
  552. return json.dumps(
  553. value,
  554. ensure_ascii=False,
  555. sort_keys=True,
  556. separators=(",", ":"),
  557. default=repr,
  558. )
  559. except (TypeError, ValueError):
  560. return repr(value)
  561. def get_llm_cache_identity(
  562. global_config: dict[str, Any] | None,
  563. role: str,
  564. ) -> dict[str, Any]:
  565. """Get the non-secret LLM identity used to partition LLM cache keys.
  566. Includes ``role``, ``binding``, ``model``, and ``host``. Deliberately excludes
  567. ``api_key`` and ``provider_options`` so cache keys remain non-secret and safe
  568. to persist.
  569. """
  570. config = global_config or {}
  571. identities = config.get("llm_cache_identities")
  572. if isinstance(identities, dict):
  573. identity = identities.get(role)
  574. if isinstance(identity, dict):
  575. return dict(identity)
  576. return {
  577. "role": role,
  578. "binding": None,
  579. "model": config.get("llm_model_name"),
  580. "host": None,
  581. }
  582. def serialize_llm_cache_identity(identity: Any) -> str:
  583. """Serialize an LLM cache identity for inclusion in hash inputs."""
  584. return _serialize_cache_variant(identity)
  585. def _validate_cached_response_format(response_format: Any | None) -> None:
  586. """Reject structured-output modes that the cache wrapper does not support."""
  587. if response_format is None:
  588. return
  589. if (
  590. isinstance(response_format, dict)
  591. and response_format.get("type") == "json_object"
  592. ):
  593. return
  594. raise ValueError(
  595. "use_llm_func_with_cache only supports response_format={'type': 'json_object'}; "
  596. "json_schema and typed response_format values must not be passed through the cache wrapper."
  597. )
  598. def compute_mdhash_id(content: str, prefix: str = "") -> str:
  599. """
  600. Compute a unique ID for a given content string.
  601. The ID is a combination of the given prefix and the MD5 hash of the content string.
  602. """
  603. return prefix + compute_args_hash(content)
  604. def get_unique_filename_in_parsed(target_dir: Path, original_name: str) -> str:
  605. """Generate a unique filename in target_dir, adding numeric suffixes on conflict.
  606. Tries the original name first, then `{stem}_001{ext}` ... `{stem}_999{ext}`,
  607. falling back to a timestamp-suffixed name if all numeric slots are taken.
  608. """
  609. original_path = Path(original_name)
  610. base_name = original_path.stem
  611. extension = original_path.suffix
  612. if not (target_dir / original_name).exists():
  613. return original_name
  614. for i in range(1, 1000):
  615. new_name = f"{base_name}_{i:03d}{extension}"
  616. if not (target_dir / new_name).exists():
  617. return new_name
  618. return f"{base_name}_{int(time.time())}{extension}"
  619. async def move_file_to_parsed_dir(
  620. file_path: Path,
  621. *,
  622. skip_if_already_parsed: bool = False,
  623. ) -> Path | None:
  624. """Move a processed source file into its sibling __parsed__ directory.
  625. Returns the new path on success, the input path if `skip_if_already_parsed`
  626. is set and the file already lives in a `__parsed__` directory, or None if
  627. the source no longer exists.
  628. """
  629. if not file_path.exists() or not file_path.is_file():
  630. return None
  631. if skip_if_already_parsed and file_path.parent.name == PARSED_DIR_NAME:
  632. return file_path
  633. parsed_dir = file_path.parent / PARSED_DIR_NAME
  634. await asyncio.to_thread(parsed_dir.mkdir, parents=True, exist_ok=True)
  635. unique_filename = get_unique_filename_in_parsed(parsed_dir, file_path.name)
  636. target_path = parsed_dir / unique_filename
  637. await asyncio.to_thread(file_path.rename, target_path)
  638. logger.debug(
  639. f"Moved file to parsed directory: {file_path.name} -> {unique_filename}"
  640. )
  641. return target_path
  642. def make_relation_vdb_ids(src_entity: str, tgt_entity: str) -> list[str]:
  643. """Return candidate relation VDB IDs for an undirected edge.
  644. The normalized ID is returned first for all new writes. The reverse-order ID is
  645. kept as a compatibility fallback for historical custom-KG imports that hashed
  646. the relation using the original endpoint order.
  647. """
  648. normalized_src, normalized_tgt = sorted((src_entity, tgt_entity))
  649. relation_ids = [compute_mdhash_id(normalized_src + normalized_tgt, prefix="rel-")]
  650. reverse_relation_id = compute_mdhash_id(
  651. normalized_tgt + normalized_src, prefix="rel-"
  652. )
  653. if reverse_relation_id not in relation_ids:
  654. relation_ids.append(reverse_relation_id)
  655. return relation_ids
  656. def generate_cache_key(mode: str, cache_type: str, hash_value: str) -> str:
  657. """Generate a flattened cache key in the format {mode}:{cache_type}:{hash}
  658. Args:
  659. mode: Cache mode (e.g., 'default', 'local', 'global')
  660. cache_type: Type of cache (e.g., 'extract', 'query', 'keywords')
  661. hash_value: Hash value from compute_args_hash
  662. Returns:
  663. str: Flattened cache key
  664. """
  665. return f"{mode}:{cache_type}:{hash_value}"
  666. def parse_cache_key(cache_key: str) -> tuple[str, str, str] | None:
  667. """Parse a flattened cache key back into its components
  668. Args:
  669. cache_key: Flattened cache key in format {mode}:{cache_type}:{hash}
  670. Returns:
  671. tuple[str, str, str] | None: (mode, cache_type, hash) or None if invalid format
  672. """
  673. parts = cache_key.split(":", 2)
  674. if len(parts) == 3:
  675. return parts[0], parts[1], parts[2]
  676. return None
  677. # Custom exception classes
  678. class QueueFullError(Exception):
  679. """Raised when the queue is full and the wait times out"""
  680. pass
  681. class WorkerTimeoutError(Exception):
  682. """Worker-level timeout exception with specific timeout information"""
  683. def __init__(self, timeout_value: float, timeout_type: str = "execution"):
  684. self.timeout_value = timeout_value
  685. self.timeout_type = timeout_type
  686. super().__init__(f"Worker {timeout_type} timeout after {timeout_value}s")
  687. class HealthCheckTimeoutError(Exception):
  688. """Health Check-level timeout exception"""
  689. def __init__(self, timeout_value: float, execution_duration: float):
  690. self.timeout_value = timeout_value
  691. self.execution_duration = execution_duration
  692. super().__init__(
  693. f"Task forcefully terminated due to execution timeout (>{timeout_value}s, actual: {execution_duration:.1f}s)"
  694. )
  695. def priority_limit_async_func_call(
  696. max_size: int,
  697. llm_timeout: float = None,
  698. max_execution_timeout: float = None,
  699. max_task_duration: float = None,
  700. max_queue_size: int = 1000,
  701. cleanup_timeout: float = 2.0,
  702. queue_name: str = "limit_async",
  703. ):
  704. """
  705. Enhanced priority-limited asynchronous function call decorator with robust timeout handling
  706. This decorator provides a comprehensive solution for managing concurrent LLM requests with:
  707. - Multi-layer timeout protection (LLM -> Worker -> Health Check -> User)
  708. - Task state tracking to prevent race conditions
  709. - Enhanced health check system with stuck task detection
  710. - Proper resource cleanup and error recovery
  711. Args:
  712. max_size: Maximum number of concurrent calls
  713. max_queue_size: Maximum queue capacity to prevent memory overflow
  714. llm_timeout: LLM provider timeout (from global config), used to calculate other timeouts
  715. max_execution_timeout: Maximum time for worker to execute function (defaults to llm_timeout + 30s)
  716. max_task_duration: Maximum time before health check intervenes (defaults to llm_timeout + 60s)
  717. cleanup_timeout: Maximum time to wait for cleanup operations (defaults to 2.0s)
  718. queue_name: Optional queue name for logging identification (defaults to "limit_async")
  719. Returns:
  720. Decorator function
  721. """
  722. def final_decro(func):
  723. # Ensure func is callable
  724. if not callable(func):
  725. raise TypeError(f"Expected a callable object, got {type(func)}")
  726. # Calculate timeout hierarchy if llm_timeout is provided (Dynamic Timeout Calculation)
  727. if llm_timeout is not None:
  728. nonlocal max_execution_timeout, max_task_duration
  729. if max_execution_timeout is None:
  730. max_execution_timeout = (
  731. llm_timeout * 2
  732. ) # Reserved timeout buffer for low-level retry
  733. if max_task_duration is None:
  734. max_task_duration = (
  735. llm_timeout * 2 + 15
  736. ) # Reserved timeout buffer for health check phase
  737. queue = asyncio.PriorityQueue(maxsize=max_queue_size)
  738. tasks = set()
  739. initialization_lock = asyncio.Lock()
  740. counter = 0
  741. shutdown_event = asyncio.Event()
  742. initialized = False
  743. accepting_new_tasks = True
  744. worker_health_check_task = None
  745. # Enhanced task state management
  746. task_states = {} # task_id -> TaskState
  747. task_states_lock = asyncio.Lock()
  748. active_futures = weakref.WeakSet()
  749. reinit_count = 0
  750. submitted_total = 0
  751. completed_total = 0
  752. failed_total = 0
  753. cancelled_total = 0
  754. rejected_total = 0
  755. async def worker():
  756. """Enhanced worker that processes tasks with proper timeout and state management"""
  757. try:
  758. while not shutdown_event.is_set():
  759. try:
  760. # Get task from queue with timeout for shutdown checking
  761. try:
  762. (
  763. priority,
  764. count,
  765. task_id,
  766. args,
  767. kwargs,
  768. ) = await asyncio.wait_for(queue.get(), timeout=1.0)
  769. except asyncio.TimeoutError:
  770. continue
  771. # Get task state and mark worker as started
  772. async with task_states_lock:
  773. if task_id not in task_states:
  774. queue.task_done()
  775. continue
  776. task_state = task_states[task_id]
  777. task_state.worker_started = True
  778. # Record execution start time when worker actually begins processing
  779. task_state.execution_start_time = (
  780. asyncio.get_event_loop().time()
  781. )
  782. # Check if task was cancelled before worker started
  783. if (
  784. task_state.cancellation_requested
  785. or task_state.future.cancelled()
  786. ):
  787. async with task_states_lock:
  788. task_states.pop(task_id, None)
  789. queue.task_done()
  790. continue
  791. try:
  792. # Execute function with timeout protection
  793. if max_execution_timeout is not None:
  794. result = await asyncio.wait_for(
  795. func(*args, **kwargs), timeout=max_execution_timeout
  796. )
  797. else:
  798. result = await func(*args, **kwargs)
  799. # Set result if future is still valid
  800. if not task_state.future.done():
  801. task_state.future.set_result(result)
  802. except asyncio.TimeoutError:
  803. # Worker-level timeout (max_execution_timeout exceeded)
  804. logger.warning(
  805. f"{queue_name}: Worker timeout for task {task_id} after {max_execution_timeout}s"
  806. )
  807. if not task_state.future.done():
  808. task_state.future.set_exception(
  809. WorkerTimeoutError(
  810. max_execution_timeout, "execution"
  811. )
  812. )
  813. except asyncio.CancelledError:
  814. # Task was cancelled during execution
  815. if not task_state.future.done():
  816. task_state.future.cancel()
  817. logger.debug(
  818. f"{queue_name}: Task {task_id} cancelled during execution"
  819. )
  820. except Exception as e:
  821. # Function execution error
  822. logger.error(
  823. f"{queue_name}: Error in decorated function for task {task_id}: {str(e)}"
  824. )
  825. if not task_state.future.done():
  826. task_state.future.set_exception(e)
  827. finally:
  828. # Clean up task state
  829. async with task_states_lock:
  830. task_states.pop(task_id, None)
  831. queue.task_done()
  832. except Exception as e:
  833. # Critical error in worker loop
  834. logger.error(
  835. f"{queue_name}: Critical error in worker: {str(e)}"
  836. )
  837. await asyncio.sleep(0.1)
  838. finally:
  839. logger.debug(f"{queue_name}: Worker exiting")
  840. async def enhanced_health_check():
  841. """Enhanced health check with stuck task detection and recovery"""
  842. nonlocal initialized
  843. try:
  844. while not shutdown_event.is_set():
  845. await asyncio.sleep(5) # Check every 5 seconds
  846. current_time = asyncio.get_event_loop().time()
  847. # Detect and handle stuck tasks based on execution start time
  848. if max_task_duration is not None:
  849. stuck_tasks = []
  850. async with task_states_lock:
  851. for task_id, task_state in list(task_states.items()):
  852. # Only check tasks that have started execution
  853. if (
  854. task_state.worker_started
  855. and task_state.execution_start_time is not None
  856. and current_time - task_state.execution_start_time
  857. > max_task_duration
  858. ):
  859. stuck_tasks.append(
  860. (
  861. task_id,
  862. current_time
  863. - task_state.execution_start_time,
  864. )
  865. )
  866. # Force cleanup of stuck tasks
  867. for task_id, execution_duration in stuck_tasks:
  868. logger.warning(
  869. f"{queue_name}: Detected stuck task {task_id} (execution time: {execution_duration:.1f}s), forcing cleanup"
  870. )
  871. async with task_states_lock:
  872. if task_id in task_states:
  873. task_state = task_states[task_id]
  874. if not task_state.future.done():
  875. task_state.future.set_exception(
  876. HealthCheckTimeoutError(
  877. max_task_duration, execution_duration
  878. )
  879. )
  880. task_states.pop(task_id, None)
  881. # Worker recovery logic
  882. current_tasks = set(tasks)
  883. done_tasks = {t for t in current_tasks if t.done()}
  884. tasks.difference_update(done_tasks)
  885. active_tasks_count = len(tasks)
  886. workers_needed = max_size - active_tasks_count
  887. if workers_needed > 0:
  888. logger.info(
  889. f"{queue_name}: Creating {workers_needed} new workers"
  890. )
  891. new_tasks = set()
  892. for _ in range(workers_needed):
  893. task = asyncio.create_task(worker())
  894. new_tasks.add(task)
  895. task.add_done_callback(tasks.discard)
  896. tasks.update(new_tasks)
  897. except Exception as e:
  898. logger.error(f"{queue_name}: Error in enhanced health check: {str(e)}")
  899. finally:
  900. logger.debug(f"{queue_name}: Enhanced health check task exiting")
  901. initialized = False
  902. async def ensure_workers():
  903. """Ensure worker system is initialized with enhanced error handling"""
  904. nonlocal initialized, worker_health_check_task, tasks, reinit_count
  905. if initialized:
  906. return
  907. async with initialization_lock:
  908. if initialized:
  909. return
  910. if reinit_count > 0:
  911. reinit_count += 1
  912. logger.warning(
  913. f"{queue_name}: Reinitializing system (count: {reinit_count})"
  914. )
  915. else:
  916. reinit_count = 1
  917. # Clean up completed tasks
  918. current_tasks = set(tasks)
  919. done_tasks = {t for t in current_tasks if t.done()}
  920. tasks.difference_update(done_tasks)
  921. active_tasks_count = len(tasks)
  922. if active_tasks_count > 0 and reinit_count > 1:
  923. logger.warning(
  924. f"{queue_name}: {active_tasks_count} tasks still running during reinitialization"
  925. )
  926. # Create worker tasks
  927. workers_needed = max_size - active_tasks_count
  928. for _ in range(workers_needed):
  929. task = asyncio.create_task(worker())
  930. tasks.add(task)
  931. task.add_done_callback(tasks.discard)
  932. # Start enhanced health check
  933. worker_health_check_task = asyncio.create_task(enhanced_health_check())
  934. initialized = True
  935. # Log dynamic timeout configuration
  936. timeout_info = []
  937. if llm_timeout is not None:
  938. timeout_info.append(f"Func: {llm_timeout}s")
  939. if max_execution_timeout is not None:
  940. timeout_info.append(f"Worker: {max_execution_timeout}s")
  941. if max_task_duration is not None:
  942. timeout_info.append(f"Health Check: {max_task_duration}s")
  943. timeout_str = (
  944. f"(Timeouts: {', '.join(timeout_info)})" if timeout_info else ""
  945. )
  946. logger.info(
  947. f"{queue_name}: {workers_needed} new workers initialized {timeout_str}"
  948. )
  949. async def get_queue_stats():
  950. """Return a best-effort snapshot of queue and worker state."""
  951. async with task_states_lock:
  952. running = sum(
  953. 1
  954. for task_state in task_states.values()
  955. if task_state.worker_started and not task_state.future.done()
  956. )
  957. in_flight = len(task_states)
  958. active_workers = len([task for task in tasks if not task.done()])
  959. return {
  960. "queue_name": queue_name,
  961. "max_async": max_size,
  962. "max_queue_size": max_queue_size,
  963. "queued": queue.qsize(),
  964. "running": running,
  965. "in_flight": in_flight,
  966. "worker_count": active_workers,
  967. "initialized": initialized,
  968. "submitted_total": submitted_total,
  969. "completed_total": completed_total,
  970. "failed_total": failed_total,
  971. "cancelled_total": cancelled_total,
  972. "rejected_total": rejected_total,
  973. }
  974. async def shutdown(graceful: bool = True, timeout: float | None = None):
  975. """Shut down workers and cleanup resources.
  976. Graceful mode stops new submissions and drains queued/running
  977. work; if the drain exceeds ``timeout`` (defaulting to
  978. ``max_task_duration`` or 30s), it falls through to forced
  979. cancellation so shutdown never blocks indefinitely.
  980. """
  981. nonlocal accepting_new_tasks, initialized, worker_health_check_task
  982. logger.info(f"{queue_name}: Shutting down priority queue workers")
  983. accepting_new_tasks = False
  984. drain_timed_out = False
  985. if graceful:
  986. effective_timeout = timeout
  987. if effective_timeout is None:
  988. effective_timeout = (
  989. max_task_duration if max_task_duration is not None else 30.0
  990. )
  991. try:
  992. await asyncio.wait_for(queue.join(), timeout=effective_timeout)
  993. except asyncio.TimeoutError:
  994. drain_timed_out = True
  995. logger.warning(
  996. f"{queue_name}: Graceful drain timed out after "
  997. f"{effective_timeout}s; cancelling pending work"
  998. )
  999. if not graceful or drain_timed_out:
  1000. # Cancel all active futures
  1001. for future in list(active_futures):
  1002. if not future.done():
  1003. future.cancel()
  1004. # Cancel all pending tasks
  1005. async with task_states_lock:
  1006. for task_id, task_state in list(task_states.items()):
  1007. if not task_state.future.done():
  1008. task_state.future.cancel()
  1009. task_states.clear()
  1010. while True:
  1011. try:
  1012. queue.get_nowait()
  1013. queue.task_done()
  1014. except asyncio.QueueEmpty:
  1015. break
  1016. shutdown_event.set()
  1017. # Cancel worker tasks
  1018. for task in list(tasks):
  1019. if not task.done():
  1020. task.cancel()
  1021. # Wait for all tasks to complete
  1022. if tasks:
  1023. await asyncio.gather(*tasks, return_exceptions=True)
  1024. # Cancel health check task
  1025. if worker_health_check_task and not worker_health_check_task.done():
  1026. worker_health_check_task.cancel()
  1027. try:
  1028. await worker_health_check_task
  1029. except asyncio.CancelledError:
  1030. pass
  1031. worker_health_check_task = None
  1032. initialized = False
  1033. logger.info(f"{queue_name}: Priority queue workers shutdown complete")
  1034. @wraps(func)
  1035. async def wait_func(
  1036. *args, _priority=10, _timeout=None, _queue_timeout=None, **kwargs
  1037. ):
  1038. """
  1039. Execute function with enhanced priority-based concurrency control and timeout handling
  1040. Args:
  1041. *args: Positional arguments passed to the function
  1042. _priority: Call priority (lower values have higher priority)
  1043. _timeout: Maximum time to wait for completion (in seconds, none means determinded by max_execution_timeout of the queue)
  1044. _queue_timeout: Maximum time to wait for entering the queue (in seconds)
  1045. **kwargs: Keyword arguments passed to the function
  1046. Returns:
  1047. The result of the function call
  1048. Raises:
  1049. TimeoutError: If the function call times out at any level
  1050. QueueFullError: If the queue is full and waiting times out
  1051. Any exception raised by the decorated function
  1052. """
  1053. nonlocal submitted_total, completed_total, cancelled_total, failed_total
  1054. nonlocal rejected_total
  1055. if not accepting_new_tasks:
  1056. rejected_total += 1
  1057. raise RuntimeError(f"{queue_name}: Queue is shutting down")
  1058. await ensure_workers()
  1059. # Generate unique task ID
  1060. task_id = f"{id(asyncio.current_task())}_{asyncio.get_event_loop().time()}"
  1061. future = asyncio.Future()
  1062. # Create task state
  1063. task_state = TaskState(
  1064. future=future, start_time=asyncio.get_event_loop().time()
  1065. )
  1066. try:
  1067. # Register task state
  1068. async with task_states_lock:
  1069. task_states[task_id] = task_state
  1070. active_futures.add(future)
  1071. # Get counter for FIFO ordering
  1072. nonlocal counter
  1073. async with initialization_lock:
  1074. current_count = counter
  1075. counter += 1
  1076. # Queue the task with timeout handling
  1077. try:
  1078. if not accepting_new_tasks:
  1079. rejected_total += 1
  1080. raise RuntimeError(f"{queue_name}: Queue is shutting down")
  1081. if _queue_timeout is not None:
  1082. await asyncio.wait_for(
  1083. queue.put(
  1084. (_priority, current_count, task_id, args, kwargs)
  1085. ),
  1086. timeout=_queue_timeout,
  1087. )
  1088. else:
  1089. await queue.put(
  1090. (_priority, current_count, task_id, args, kwargs)
  1091. )
  1092. submitted_total += 1
  1093. except asyncio.TimeoutError:
  1094. raise QueueFullError(
  1095. f"{queue_name}: Queue full, timeout after {_queue_timeout} seconds"
  1096. )
  1097. except Exception as e:
  1098. # Clean up on queue error
  1099. if not future.done():
  1100. future.set_exception(e)
  1101. raise
  1102. # Wait for result with timeout handling
  1103. try:
  1104. if _timeout is not None:
  1105. result = await asyncio.wait_for(future, _timeout)
  1106. else:
  1107. result = await future
  1108. completed_total += 1
  1109. return result
  1110. except asyncio.TimeoutError:
  1111. # This is user-level timeout (asyncio.wait_for caused)
  1112. # Mark cancellation request
  1113. async with task_states_lock:
  1114. if task_id in task_states:
  1115. task_states[task_id].cancellation_requested = True
  1116. # Cancel future
  1117. if not future.done():
  1118. future.cancel()
  1119. # Wait for worker cleanup with timeout
  1120. cleanup_start = asyncio.get_event_loop().time()
  1121. while (
  1122. task_id in task_states
  1123. and asyncio.get_event_loop().time() - cleanup_start
  1124. < cleanup_timeout
  1125. ):
  1126. await asyncio.sleep(0.1)
  1127. cancelled_total += 1
  1128. raise TimeoutError(
  1129. f"{queue_name}: User timeout after {_timeout} seconds"
  1130. )
  1131. except WorkerTimeoutError as e:
  1132. # This is Worker-level timeout, directly propagate exception information
  1133. failed_total += 1
  1134. raise TimeoutError(f"{queue_name}: {str(e)}")
  1135. except HealthCheckTimeoutError as e:
  1136. # This is Health Check-level timeout, directly propagate exception information
  1137. failed_total += 1
  1138. raise TimeoutError(f"{queue_name}: {str(e)}")
  1139. except asyncio.CancelledError:
  1140. cancelled_total += 1
  1141. raise
  1142. except Exception:
  1143. failed_total += 1
  1144. raise
  1145. finally:
  1146. # Ensure cleanup
  1147. active_futures.discard(future)
  1148. async with task_states_lock:
  1149. task_states.pop(task_id, None)
  1150. # Add shutdown method to decorated function
  1151. wait_func.shutdown = shutdown
  1152. wait_func.get_queue_stats = get_queue_stats
  1153. return wait_func
  1154. return final_decro
  1155. def wrap_embedding_func_with_attrs(**kwargs):
  1156. """Decorator to add embedding dimension and token limit attributes to embedding functions.
  1157. This decorator wraps an async embedding function and returns an EmbeddingFunc instance
  1158. that automatically handles dimension parameter injection and attribute management.
  1159. WARNING: DO NOT apply this decorator to wrapper functions that call other
  1160. decorated embedding functions. This will cause double decoration and parameter
  1161. injection conflicts.
  1162. Correct usage patterns:
  1163. 1. Direct decoration:
  1164. ```python
  1165. @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192, model_name="my_embedding_model")
  1166. async def my_embed(texts, embedding_dim=None):
  1167. # Direct implementation
  1168. return embeddings
  1169. ```
  1170. 2. Double decoration:
  1171. ```python
  1172. @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192, model_name="my_embedding_model")
  1173. @retry(...)
  1174. async def my_embed(texts, ...):
  1175. # Base implementation
  1176. pass
  1177. @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=4096, model_name="another_embedding_model")
  1178. # Note: No @retry here!
  1179. async def my_new_embed(texts, ...):
  1180. # CRITICAL: Call .func to access unwrapped function
  1181. return await my_embed.func(texts, ...) # ✅ Correct
  1182. # return await my_embed(texts, ...) # ❌ Wrong - double decoration!
  1183. ```
  1184. 3. Context-aware decoration:
  1185. ```python
  1186. @wrap_embedding_func_with_attrs(
  1187. embedding_dim=1536,
  1188. model_name="my_embedding_model",
  1189. supports_asymmetric=True
  1190. )
  1191. async def my_embed(texts, context="document"):
  1192. # Apply different prefixes based on context
  1193. if context == "query":
  1194. texts = ["search_query: " + t for t in texts]
  1195. elif context == "document":
  1196. texts = ["search_document: " + t for t in texts]
  1197. return embeddings
  1198. ```
  1199. The decorated function becomes an EmbeddingFunc instance with:
  1200. - embedding_dim: The embedding dimension
  1201. - max_token_size: Maximum token limit (optional)
  1202. - model_name: Model name (optional)
  1203. - supports_asymmetric: Whether context parameter is supported (optional)
  1204. - func: The original unwrapped function (access via .func)
  1205. - __call__: Wrapper that injects embedding_dim parameter and context
  1206. Args:
  1207. embedding_dim: The dimension of embedding vectors
  1208. max_token_size: Maximum number of tokens (optional)
  1209. send_dimensions: Whether to pass embedding_dim as a keyword argument (for models with configurable embedding dimensions).
  1210. supports_asymmetric: Whether the function supports context parameter (optional).
  1211. If omitted, this is auto-detected from the wrapped function's signature
  1212. (set to True iff the function accepts a ``context`` parameter).
  1213. Returns:
  1214. A decorator that wraps the function as an EmbeddingFunc instance
  1215. """
  1216. def final_decro(func) -> EmbeddingFunc:
  1217. embedding_kwargs = dict(kwargs)
  1218. # Auto-detect supports_asymmetric from the wrapped function's signature
  1219. # if the caller did not declare it explicitly. Without this, any user or
  1220. # third-party embed function that accepts a `context` parameter but
  1221. # forgets to set ``supports_asymmetric=True`` would have its `context`
  1222. # silently dropped by ``EmbeddingFunc.__call__``, defeating the
  1223. # task-aware embedding feature.
  1224. if "supports_asymmetric" not in embedding_kwargs:
  1225. try:
  1226. sig = inspect.signature(func)
  1227. embedding_kwargs["supports_asymmetric"] = "context" in sig.parameters
  1228. except (TypeError, ValueError):
  1229. # inspect.signature can fail for builtins; fall back to False.
  1230. embedding_kwargs["supports_asymmetric"] = False
  1231. new_func = EmbeddingFunc(**embedding_kwargs, func=func)
  1232. return new_func
  1233. return final_decro
  1234. def load_json(file_name):
  1235. if not os.path.exists(file_name):
  1236. return None
  1237. with open(file_name, encoding="utf-8-sig") as f:
  1238. return json.load(f)
  1239. def _sanitize_string_for_json(text: str) -> str:
  1240. """Remove characters that cannot be encoded in UTF-8 for JSON serialization.
  1241. Uses regex for optimal performance with zero-copy optimization for clean strings.
  1242. Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings.
  1243. Args:
  1244. text: String to sanitize
  1245. Returns:
  1246. Original string if clean (zero-copy), sanitized string if dirty
  1247. """
  1248. if not text:
  1249. return text
  1250. # Fast path: Check if sanitization is needed using C-level regex search
  1251. if not _SURROGATE_PATTERN.search(text):
  1252. return text # Zero-copy for clean strings - most common case
  1253. # Slow path: Remove problematic characters using C-level regex substitution
  1254. return _SURROGATE_PATTERN.sub("", text)
  1255. class SanitizingJSONEncoder(json.JSONEncoder):
  1256. """
  1257. Custom JSON encoder that sanitizes data during serialization.
  1258. This encoder cleans strings during the encoding process without creating
  1259. a full copy of the data structure, making it memory-efficient for large datasets.
  1260. """
  1261. def encode(self, o):
  1262. """Override encode method to handle simple string cases"""
  1263. if isinstance(o, str):
  1264. return json.encoder.encode_basestring(_sanitize_string_for_json(o))
  1265. return super().encode(o)
  1266. def iterencode(self, o, _one_shot=False):
  1267. """
  1268. Override iterencode to sanitize strings during serialization.
  1269. This is the core method that handles complex nested structures.
  1270. """
  1271. # Preprocess: sanitize all strings in the object
  1272. sanitized = self._sanitize_for_encoding(o)
  1273. # Call parent's iterencode with sanitized data
  1274. for chunk in super().iterencode(sanitized, _one_shot):
  1275. yield chunk
  1276. def _sanitize_for_encoding(self, obj):
  1277. """
  1278. Recursively sanitize strings in an object.
  1279. Creates new objects only when necessary to avoid deep copies.
  1280. Args:
  1281. obj: Object to sanitize
  1282. Returns:
  1283. Sanitized object with cleaned strings
  1284. """
  1285. if isinstance(obj, str):
  1286. return _sanitize_string_for_json(obj)
  1287. elif isinstance(obj, dict):
  1288. # Create new dict with sanitized keys and values
  1289. new_dict = {}
  1290. for k, v in obj.items():
  1291. clean_k = _sanitize_string_for_json(k) if isinstance(k, str) else k
  1292. clean_v = self._sanitize_for_encoding(v)
  1293. new_dict[clean_k] = clean_v
  1294. return new_dict
  1295. elif isinstance(obj, (list, tuple)):
  1296. # Sanitize list/tuple elements
  1297. cleaned = [self._sanitize_for_encoding(item) for item in obj]
  1298. return type(obj)(cleaned) if isinstance(obj, tuple) else cleaned
  1299. else:
  1300. # Numbers, booleans, None, etc. remain unchanged
  1301. return obj
  1302. def write_json(json_obj, file_name):
  1303. """
  1304. Write JSON data to file with optimized sanitization strategy.
  1305. This function uses a two-stage approach:
  1306. 1. Fast path: Try direct serialization (works for clean data ~99% of time)
  1307. 2. Slow path: Use custom encoder that sanitizes during serialization
  1308. The custom encoder approach avoids creating a deep copy of the data,
  1309. making it memory-efficient. When sanitization occurs, the caller should
  1310. reload the cleaned data from the file to update shared memory.
  1311. Writes are atomic: both the fast path and the sanitizing fallback land
  1312. in the same per-writer tmp sibling, and only the final ``os.replace``
  1313. publishes the file. A crash mid-write leaves the prior snapshot intact.
  1314. Args:
  1315. json_obj: Object to serialize (may be a shallow copy from shared memory)
  1316. file_name: Output file path
  1317. Returns:
  1318. bool: True if sanitization was applied (caller should reload data),
  1319. False if direct write succeeded (no reload needed)
  1320. """
  1321. from lightrag.file_atomic import atomic_write
  1322. sanitized = False
  1323. def _do_write(tmp_path: str) -> None:
  1324. nonlocal sanitized
  1325. try:
  1326. # Strategy 1: Fast path - try direct serialization.
  1327. with open(tmp_path, "w", encoding="utf-8") as f:
  1328. json.dump(json_obj, f, indent=2, ensure_ascii=False)
  1329. except (UnicodeEncodeError, UnicodeDecodeError) as e:
  1330. logger.debug(f"Direct JSON write failed, using sanitizing encoder: {e}")
  1331. # Strategy 2: Use sanitizing encoder (zero-copy). Reusing the
  1332. # same tmp path keeps the operation single-rename even on the
  1333. # slow path.
  1334. with open(tmp_path, "w", encoding="utf-8") as f:
  1335. json.dump(
  1336. json_obj,
  1337. f,
  1338. indent=2,
  1339. ensure_ascii=False,
  1340. cls=SanitizingJSONEncoder,
  1341. )
  1342. sanitized = True
  1343. atomic_write(file_name, _do_write)
  1344. if sanitized:
  1345. logger.info(f"JSON sanitization applied during write: {file_name}")
  1346. return sanitized
  1347. class TokenizerInterface(Protocol):
  1348. """
  1349. Defines the interface for a tokenizer, requiring encode and decode methods.
  1350. """
  1351. def encode(self, content: str) -> List[int]:
  1352. """Encodes a string into a list of tokens."""
  1353. ...
  1354. def decode(self, tokens: List[int]) -> str:
  1355. """Decodes a list of tokens into a string."""
  1356. ...
  1357. class Tokenizer:
  1358. """
  1359. A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
  1360. """
  1361. def __init__(self, model_name: str, tokenizer: TokenizerInterface):
  1362. """
  1363. Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
  1364. Args:
  1365. model_name: The associated model name for the tokenizer.
  1366. tokenizer: An instance of a class implementing the TokenizerInterface.
  1367. """
  1368. self.model_name: str = model_name
  1369. self.tokenizer: TokenizerInterface = tokenizer
  1370. def encode(self, content: str) -> List[int]:
  1371. """
  1372. Encodes a string into a list of tokens using the underlying tokenizer.
  1373. Args:
  1374. content: The string to encode.
  1375. Returns:
  1376. A list of integer tokens.
  1377. """
  1378. try:
  1379. return self.tokenizer.encode(content)
  1380. except ValueError as e:
  1381. # tiktoken (and some other tokenizers) raise ValueError when the
  1382. # content contains literal special-token strings such as
  1383. # "<|endoftext|>", because by default disallowed_special is the
  1384. # full set of special tokens. This crashes document indexing on
  1385. # any user content that happens to contain those strings — common
  1386. # in documentation, notes, or model output captured in source
  1387. # corpora. Retry with disallowed_special=() so the tokens are
  1388. # encoded as ordinary text. Tokenizers that don't accept the
  1389. # kwarg fall through and re-raise the original error.
  1390. if "special token" not in str(e):
  1391. raise
  1392. try:
  1393. return self.tokenizer.encode(content, disallowed_special=())
  1394. except TypeError:
  1395. raise e
  1396. def decode(self, tokens: List[int]) -> str:
  1397. """
  1398. Decodes a list of tokens into a string using the underlying tokenizer.
  1399. Args:
  1400. tokens: A list of integer tokens to decode.
  1401. Returns:
  1402. The decoded string.
  1403. """
  1404. return self.tokenizer.decode(tokens)
  1405. class TiktokenTokenizer(Tokenizer):
  1406. """
  1407. A Tokenizer implementation using the tiktoken library.
  1408. """
  1409. def __init__(self, model_name: str = "gpt-4o-mini"):
  1410. """
  1411. Initializes the TiktokenTokenizer with a specified model name.
  1412. Args:
  1413. model_name: The model name for the tiktoken tokenizer to use. Defaults to "gpt-4o-mini".
  1414. Raises:
  1415. ImportError: If tiktoken is not installed.
  1416. ValueError: If the model_name is invalid.
  1417. """
  1418. try:
  1419. import tiktoken
  1420. except ImportError:
  1421. raise ImportError(
  1422. "tiktoken is not installed. Please install it with `pip install tiktoken` or define custom `tokenizer_func`."
  1423. )
  1424. try:
  1425. tokenizer = tiktoken.encoding_for_model(model_name)
  1426. super().__init__(model_name=model_name, tokenizer=tokenizer)
  1427. except KeyError:
  1428. raise ValueError(f"Invalid model_name: {model_name}.")
  1429. def pack_user_ass_to_openai_messages(*args: str):
  1430. roles = ["user", "assistant"]
  1431. return [
  1432. {"role": roles[i % 2], "content": content} for i, content in enumerate(args)
  1433. ]
  1434. def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
  1435. """Split a string by multiple markers"""
  1436. if not markers:
  1437. return [content]
  1438. content = content if content is not None else ""
  1439. results = re.split("|".join(re.escape(marker) for marker in markers), content)
  1440. return [r.strip() for r in results if r.strip()]
  1441. def is_float_regex(value: str) -> bool:
  1442. return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
  1443. def truncate_list_by_token_size(
  1444. list_data: list[Any],
  1445. key: Callable[[Any], str],
  1446. max_token_size: int,
  1447. tokenizer: Tokenizer,
  1448. ) -> list[int]:
  1449. """Truncate a list of data by token size"""
  1450. if max_token_size <= 0:
  1451. return []
  1452. tokens = 0
  1453. for i, data in enumerate(list_data):
  1454. tokens += len(tokenizer.encode(key(data)))
  1455. if tokens > max_token_size:
  1456. return list_data[:i]
  1457. return list_data
  1458. def normalize_string_list(raw_values: Any, context: str = "") -> list[str]:
  1459. """Return a list of non-empty strings from raw_values.
  1460. Non-string elements are dropped and logged as warnings. If raw_values is
  1461. not a list, an empty list is returned.
  1462. """
  1463. if not isinstance(raw_values, list):
  1464. return []
  1465. result = []
  1466. for i, value in enumerate(raw_values):
  1467. if isinstance(value, str) and value:
  1468. result.append(value)
  1469. else:
  1470. logger.warning(
  1471. "Non-string element dropped from list%s at index %d: %r",
  1472. f" ({context})" if context else "",
  1473. i,
  1474. value,
  1475. )
  1476. return result
  1477. def split_text_units_for_hard_fallback(text: str) -> list[str]:
  1478. """Split text into sentence/paragraph-like units for fallback chunking."""
  1479. if not text:
  1480. return []
  1481. units: list[str] = []
  1482. for para in text.split("\n\n"):
  1483. p = para.strip()
  1484. if not p:
  1485. continue
  1486. for sentence in re.split(r"(?<=[。!?;.!?])", p):
  1487. s = sentence.strip()
  1488. if s:
  1489. units.append(s)
  1490. return units if units else [text]
  1491. def split_text_by_token_limit(
  1492. text: str, tokenizer: Tokenizer, max_tokens: int
  1493. ) -> list[str]:
  1494. """Split text by token limit with sentence-first, token-window fallback."""
  1495. if not text:
  1496. return []
  1497. try:
  1498. total_tokens = len(tokenizer.encode(text))
  1499. except Exception:
  1500. total_tokens = 0
  1501. if total_tokens > 0 and total_tokens <= max_tokens:
  1502. return [text]
  1503. units = split_text_units_for_hard_fallback(text)
  1504. out: list[str] = []
  1505. cur_parts: list[str] = []
  1506. cur_tokens = 0
  1507. for unit in units:
  1508. try:
  1509. unit_tokens = len(tokenizer.encode(unit))
  1510. except Exception:
  1511. unit_tokens = 0
  1512. # Sentence itself is oversize: token-window split directly.
  1513. if unit_tokens > max_tokens:
  1514. if cur_parts:
  1515. out.append("\n\n".join(cur_parts))
  1516. cur_parts = []
  1517. cur_tokens = 0
  1518. token_ids = tokenizer.encode(unit)
  1519. for start in range(0, len(token_ids), max_tokens):
  1520. piece = tokenizer.decode(token_ids[start : start + max_tokens]).strip()
  1521. if piece:
  1522. out.append(piece)
  1523. continue
  1524. if cur_parts and cur_tokens + unit_tokens > max_tokens:
  1525. out.append("\n\n".join(cur_parts))
  1526. cur_parts = [unit]
  1527. cur_tokens = unit_tokens
  1528. else:
  1529. cur_parts.append(unit)
  1530. cur_tokens += unit_tokens
  1531. if cur_parts:
  1532. out.append("\n\n".join(cur_parts))
  1533. return [x for x in out if x.strip()]
  1534. def enforce_chunk_token_limit_before_embedding(
  1535. chunking_result: list[dict[str, Any]] | tuple[dict[str, Any], ...],
  1536. tokenizer: Tokenizer,
  1537. max_tokens: int,
  1538. ) -> list[dict[str, Any]]:
  1539. """Hard fallback split before embedding while preserving heading hierarchy."""
  1540. if max_tokens <= 0:
  1541. return list(chunking_result)
  1542. normalized: list[dict[str, Any]] = []
  1543. for dp in chunking_result:
  1544. if not isinstance(dp, dict):
  1545. continue
  1546. content = dp.get("content", "")
  1547. if not isinstance(content, str) or not content.strip():
  1548. continue
  1549. try:
  1550. token_count = len(tokenizer.encode(content))
  1551. except Exception:
  1552. token_count = (
  1553. dp.get("tokens", 0) if isinstance(dp.get("tokens"), int) else 0
  1554. )
  1555. if token_count <= max_tokens:
  1556. ndp = dict(dp)
  1557. ndp["tokens"] = token_count if token_count > 0 else ndp.get("tokens", 0)
  1558. normalized.append(ndp)
  1559. continue
  1560. pieces = split_text_by_token_limit(content, tokenizer, max_tokens)
  1561. if not pieces:
  1562. ndp = dict(dp)
  1563. ndp["tokens"] = token_count
  1564. normalized.append(ndp)
  1565. continue
  1566. base_chunk_id = dp.get("chunk_id")
  1567. total_parts = len(pieces)
  1568. for i, piece in enumerate(pieces, 1):
  1569. new_dp = dict(dp)
  1570. new_dp["content"] = piece
  1571. try:
  1572. new_dp["tokens"] = len(tokenizer.encode(piece))
  1573. except Exception:
  1574. new_dp["tokens"] = max(1, int(len(piece) * 0.5))
  1575. # Shallow-copy preserves the nested heading dict and sidecar
  1576. # block from the source chunk; only the payload (content/tokens
  1577. # /chunk_id) is rewritten per split slice.
  1578. if isinstance(base_chunk_id, str) and base_chunk_id.strip():
  1579. new_dp["chunk_id"] = f"{base_chunk_id}-s{i:02d}"
  1580. new_dp["split_type"] = "hard_fallback"
  1581. new_dp["split_part"] = i
  1582. new_dp["split_total"] = total_parts
  1583. normalized.append(new_dp)
  1584. # Rebuild order index to keep continuity after splitting.
  1585. for idx, item in enumerate(normalized):
  1586. item["chunk_order_index"] = idx
  1587. return normalized
  1588. def cosine_similarity(v1, v2):
  1589. """Calculate cosine similarity between two vectors"""
  1590. dot_product = np.dot(v1, v2)
  1591. norm1 = np.linalg.norm(v1)
  1592. norm2 = np.linalg.norm(v2)
  1593. return dot_product / (norm1 * norm2)
  1594. async def handle_cache(
  1595. hashing_kv,
  1596. args_hash,
  1597. prompt,
  1598. mode="default",
  1599. cache_type="unknown",
  1600. ) -> tuple[str, int] | None:
  1601. """Generic cache handling function with flattened cache keys
  1602. Returns:
  1603. tuple[str, int] | None: (content, create_time) if cache hit, None if cache miss
  1604. """
  1605. if hashing_kv is None:
  1606. return None
  1607. if mode != "default": # handle cache for all type of query
  1608. if not hashing_kv.global_config.get("enable_llm_cache"):
  1609. return None
  1610. else: # handle cache for entity extraction
  1611. if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
  1612. return None
  1613. # Use flattened cache key format: {mode}:{cache_type}:{hash}
  1614. flattened_key = generate_cache_key(mode, cache_type, args_hash)
  1615. cache_entry = await hashing_kv.get_by_id(flattened_key)
  1616. if cache_entry:
  1617. logger.debug(f"Flattened cache hit(key:{flattened_key})")
  1618. content = cache_entry["return"]
  1619. timestamp = cache_entry.get("create_time", 0)
  1620. return content, timestamp
  1621. logger.debug(f"Cache missed(mode:{mode} type:{cache_type})")
  1622. return None
  1623. @dataclass
  1624. class CacheData:
  1625. args_hash: str
  1626. content: str
  1627. prompt: str
  1628. mode: str = "default"
  1629. cache_type: str = "query"
  1630. chunk_id: str | None = None
  1631. queryparam: dict | None = None
  1632. async def save_to_cache(hashing_kv, cache_data: CacheData):
  1633. """Save data to cache using flattened key structure.
  1634. Args:
  1635. hashing_kv: The key-value storage for caching
  1636. cache_data: The cache data to save
  1637. """
  1638. # Skip if storage is None or content is a streaming response
  1639. if hashing_kv is None or not cache_data.content:
  1640. return
  1641. # If content is a streaming response, don't cache it
  1642. if hasattr(cache_data.content, "__aiter__"):
  1643. logger.debug("Streaming response detected, skipping cache")
  1644. return
  1645. # Use flattened cache key format: {mode}:{cache_type}:{hash}
  1646. flattened_key = generate_cache_key(
  1647. cache_data.mode, cache_data.cache_type, cache_data.args_hash
  1648. )
  1649. # Check if we already have identical content cached
  1650. existing_cache = await hashing_kv.get_by_id(flattened_key)
  1651. if existing_cache:
  1652. existing_content = existing_cache.get("return")
  1653. if existing_content == cache_data.content:
  1654. logger.warning(
  1655. f"Cache duplication detected for {flattened_key}, skipping update"
  1656. )
  1657. return
  1658. # Create cache entry with flattened structure
  1659. cache_entry = {
  1660. "return": cache_data.content,
  1661. "cache_type": cache_data.cache_type,
  1662. "chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
  1663. "original_prompt": cache_data.prompt,
  1664. "queryparam": cache_data.queryparam
  1665. if cache_data.queryparam is not None
  1666. else None,
  1667. }
  1668. logger.info(f" == LLM cache == saving: {flattened_key}")
  1669. # Save using flattened key
  1670. await hashing_kv.upsert({flattened_key: cache_entry})
  1671. def safe_unicode_decode(content):
  1672. # Regular expression to find all Unicode escape sequences of the form \uXXXX
  1673. unicode_escape_pattern = re.compile(r"\\u([0-9a-fA-F]{4})")
  1674. # Function to replace the Unicode escape with the actual character
  1675. def replace_unicode_escape(match):
  1676. # Convert the matched hexadecimal value into the actual Unicode character
  1677. return chr(int(match.group(1), 16))
  1678. # Perform the substitution
  1679. decoded_content = unicode_escape_pattern.sub(
  1680. replace_unicode_escape, content.decode("utf-8")
  1681. )
  1682. return decoded_content
  1683. def exists_func(obj, func_name: str) -> bool:
  1684. """Check if a function exists in an object or not.
  1685. :param obj:
  1686. :param func_name:
  1687. :return: True / False
  1688. """
  1689. if callable(getattr(obj, func_name, None)):
  1690. return True
  1691. else:
  1692. return False
  1693. async def _cooperative_yield(iteration: int, every: int = 64) -> None:
  1694. """Periodically yield control to the event loop during CPU-heavy async loops.
  1695. Call inside long synchronous-style loops to prevent event loop starvation
  1696. in single-worker deployments. Yields every `every` iterations.
  1697. """
  1698. if iteration > 0 and iteration % every == 0:
  1699. await asyncio.sleep(0)
  1700. def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
  1701. """
  1702. Ensure that there is always an event loop available.
  1703. This function tries to get the current event loop. If the current event loop is closed or does not exist,
  1704. it creates a new event loop and sets it as the current event loop.
  1705. Returns:
  1706. asyncio.AbstractEventLoop: The current or newly created event loop.
  1707. """
  1708. try:
  1709. # Try to get the current event loop
  1710. current_loop = asyncio.get_event_loop()
  1711. if current_loop.is_closed():
  1712. raise RuntimeError("Event loop is closed.")
  1713. return current_loop
  1714. except RuntimeError:
  1715. # If no event loop exists or it is closed, create a new one
  1716. logger.info("Creating a new event loop in main thread.")
  1717. new_loop = asyncio.new_event_loop()
  1718. asyncio.set_event_loop(new_loop)
  1719. return new_loop
  1720. async def aexport_data(
  1721. chunk_entity_relation_graph,
  1722. entities_vdb,
  1723. relationships_vdb,
  1724. output_path: str,
  1725. file_format: str = "csv",
  1726. include_vector_data: bool = False,
  1727. ) -> None:
  1728. """
  1729. Asynchronously exports all entities, relations, and relationships to various formats.
  1730. Args:
  1731. chunk_entity_relation_graph: Graph storage instance for entities and relations
  1732. entities_vdb: Vector database storage for entities
  1733. relationships_vdb: Vector database storage for relationships
  1734. output_path: The path to the output file (including extension).
  1735. file_format: Output format - "csv", "excel", "md", "txt".
  1736. - csv: Comma-separated values file
  1737. - excel: Microsoft Excel file with multiple sheets
  1738. - md: Markdown tables
  1739. - txt: Plain text formatted output
  1740. include_vector_data: Whether to include data from the vector database.
  1741. """
  1742. # Collect data
  1743. entities_data = []
  1744. relations_data = []
  1745. relationships_data = []
  1746. # --- Entities ---
  1747. all_entities = await chunk_entity_relation_graph.get_all_labels()
  1748. for entity_name in all_entities:
  1749. # Get entity information from graph
  1750. node_data = await chunk_entity_relation_graph.get_node(entity_name)
  1751. source_id = node_data.get("source_id") if node_data else None
  1752. entity_info = {
  1753. "graph_data": node_data,
  1754. "source_id": source_id,
  1755. }
  1756. # Optional: Get vector database information
  1757. if include_vector_data:
  1758. entity_id = compute_mdhash_id(entity_name, prefix="ent-")
  1759. vector_data = await entities_vdb.get_by_id(entity_id)
  1760. entity_info["vector_data"] = vector_data
  1761. entity_row = {
  1762. "entity_name": entity_name,
  1763. "source_id": source_id,
  1764. "graph_data": str(
  1765. entity_info["graph_data"]
  1766. ), # Convert to string to ensure compatibility
  1767. }
  1768. if include_vector_data and "vector_data" in entity_info:
  1769. entity_row["vector_data"] = str(entity_info["vector_data"])
  1770. entities_data.append(entity_row)
  1771. # --- Relations ---
  1772. for src_entity in all_entities:
  1773. for tgt_entity in all_entities:
  1774. if src_entity == tgt_entity:
  1775. continue
  1776. edge_exists = await chunk_entity_relation_graph.has_edge(
  1777. src_entity, tgt_entity
  1778. )
  1779. if edge_exists:
  1780. # Get edge information from graph
  1781. edge_data = await chunk_entity_relation_graph.get_edge(
  1782. src_entity, tgt_entity
  1783. )
  1784. source_id = edge_data.get("source_id") if edge_data else None
  1785. relation_info = {
  1786. "graph_data": edge_data,
  1787. "source_id": source_id,
  1788. }
  1789. # Optional: Get vector database information
  1790. if include_vector_data:
  1791. vector_data = None
  1792. for rel_id in make_relation_vdb_ids(src_entity, tgt_entity):
  1793. vector_data = await relationships_vdb.get_by_id(rel_id)
  1794. if vector_data is not None:
  1795. break
  1796. relation_info["vector_data"] = vector_data
  1797. relation_row = {
  1798. "src_entity": src_entity,
  1799. "tgt_entity": tgt_entity,
  1800. "source_id": relation_info["source_id"],
  1801. "graph_data": str(relation_info["graph_data"]), # Convert to string
  1802. }
  1803. if include_vector_data and "vector_data" in relation_info:
  1804. relation_row["vector_data"] = str(relation_info["vector_data"])
  1805. relations_data.append(relation_row)
  1806. # --- Relationships (from VectorDB) ---
  1807. all_relationships = await relationships_vdb.client_storage
  1808. for rel in all_relationships["data"]:
  1809. relationships_data.append(
  1810. {
  1811. "relationship_id": rel["__id__"],
  1812. "data": str(rel), # Convert to string for compatibility
  1813. }
  1814. )
  1815. # Export based on format
  1816. if file_format == "csv":
  1817. # CSV export
  1818. with open(output_path, "w", newline="", encoding="utf-8") as csvfile:
  1819. # Entities
  1820. if entities_data:
  1821. csvfile.write("# ENTITIES\n")
  1822. writer = csv.DictWriter(csvfile, fieldnames=entities_data[0].keys())
  1823. writer.writeheader()
  1824. writer.writerows(entities_data)
  1825. csvfile.write("\n\n")
  1826. # Relations
  1827. if relations_data:
  1828. csvfile.write("# RELATIONS\n")
  1829. writer = csv.DictWriter(csvfile, fieldnames=relations_data[0].keys())
  1830. writer.writeheader()
  1831. writer.writerows(relations_data)
  1832. csvfile.write("\n\n")
  1833. # Relationships
  1834. if relationships_data:
  1835. csvfile.write("# RELATIONSHIPS\n")
  1836. writer = csv.DictWriter(
  1837. csvfile, fieldnames=relationships_data[0].keys()
  1838. )
  1839. writer.writeheader()
  1840. writer.writerows(relationships_data)
  1841. elif file_format == "excel":
  1842. # Excel export
  1843. import pandas as pd
  1844. entities_df = pd.DataFrame(entities_data) if entities_data else pd.DataFrame()
  1845. relations_df = (
  1846. pd.DataFrame(relations_data) if relations_data else pd.DataFrame()
  1847. )
  1848. relationships_df = (
  1849. pd.DataFrame(relationships_data) if relationships_data else pd.DataFrame()
  1850. )
  1851. with pd.ExcelWriter(output_path, engine="xlsxwriter") as writer:
  1852. if not entities_df.empty:
  1853. entities_df.to_excel(writer, sheet_name="Entities", index=False)
  1854. if not relations_df.empty:
  1855. relations_df.to_excel(writer, sheet_name="Relations", index=False)
  1856. if not relationships_df.empty:
  1857. relationships_df.to_excel(
  1858. writer, sheet_name="Relationships", index=False
  1859. )
  1860. elif file_format == "md":
  1861. # Markdown export
  1862. with open(output_path, "w", encoding="utf-8") as mdfile:
  1863. mdfile.write("# LightRAG Data Export\n\n")
  1864. # Entities
  1865. mdfile.write("## Entities\n\n")
  1866. if entities_data:
  1867. # Write header
  1868. mdfile.write("| " + " | ".join(entities_data[0].keys()) + " |\n")
  1869. mdfile.write(
  1870. "| " + " | ".join(["---"] * len(entities_data[0].keys())) + " |\n"
  1871. )
  1872. # Write rows
  1873. for entity in entities_data:
  1874. mdfile.write(
  1875. "| " + " | ".join(str(v) for v in entity.values()) + " |\n"
  1876. )
  1877. mdfile.write("\n\n")
  1878. else:
  1879. mdfile.write("*No entity data available*\n\n")
  1880. # Relations
  1881. mdfile.write("## Relations\n\n")
  1882. if relations_data:
  1883. # Write header
  1884. mdfile.write("| " + " | ".join(relations_data[0].keys()) + " |\n")
  1885. mdfile.write(
  1886. "| " + " | ".join(["---"] * len(relations_data[0].keys())) + " |\n"
  1887. )
  1888. # Write rows
  1889. for relation in relations_data:
  1890. mdfile.write(
  1891. "| " + " | ".join(str(v) for v in relation.values()) + " |\n"
  1892. )
  1893. mdfile.write("\n\n")
  1894. else:
  1895. mdfile.write("*No relation data available*\n\n")
  1896. # Relationships
  1897. mdfile.write("## Relationships\n\n")
  1898. if relationships_data:
  1899. # Write header
  1900. mdfile.write("| " + " | ".join(relationships_data[0].keys()) + " |\n")
  1901. mdfile.write(
  1902. "| "
  1903. + " | ".join(["---"] * len(relationships_data[0].keys()))
  1904. + " |\n"
  1905. )
  1906. # Write rows
  1907. for relationship in relationships_data:
  1908. mdfile.write(
  1909. "| "
  1910. + " | ".join(str(v) for v in relationship.values())
  1911. + " |\n"
  1912. )
  1913. else:
  1914. mdfile.write("*No relationship data available*\n\n")
  1915. elif file_format == "txt":
  1916. # Plain text export
  1917. with open(output_path, "w", encoding="utf-8") as txtfile:
  1918. txtfile.write("LIGHTRAG DATA EXPORT\n")
  1919. txtfile.write("=" * 80 + "\n\n")
  1920. # Entities
  1921. txtfile.write("ENTITIES\n")
  1922. txtfile.write("-" * 80 + "\n")
  1923. if entities_data:
  1924. # Create fixed width columns
  1925. col_widths = {
  1926. k: max(len(k), max(len(str(e[k])) for e in entities_data))
  1927. for k in entities_data[0]
  1928. }
  1929. header = " ".join(k.ljust(col_widths[k]) for k in entities_data[0])
  1930. txtfile.write(header + "\n")
  1931. txtfile.write("-" * len(header) + "\n")
  1932. # Write rows
  1933. for entity in entities_data:
  1934. row = " ".join(
  1935. str(v).ljust(col_widths[k]) for k, v in entity.items()
  1936. )
  1937. txtfile.write(row + "\n")
  1938. txtfile.write("\n\n")
  1939. else:
  1940. txtfile.write("No entity data available\n\n")
  1941. # Relations
  1942. txtfile.write("RELATIONS\n")
  1943. txtfile.write("-" * 80 + "\n")
  1944. if relations_data:
  1945. # Create fixed width columns
  1946. col_widths = {
  1947. k: max(len(k), max(len(str(r[k])) for r in relations_data))
  1948. for k in relations_data[0]
  1949. }
  1950. header = " ".join(k.ljust(col_widths[k]) for k in relations_data[0])
  1951. txtfile.write(header + "\n")
  1952. txtfile.write("-" * len(header) + "\n")
  1953. # Write rows
  1954. for relation in relations_data:
  1955. row = " ".join(
  1956. str(v).ljust(col_widths[k]) for k, v in relation.items()
  1957. )
  1958. txtfile.write(row + "\n")
  1959. txtfile.write("\n\n")
  1960. else:
  1961. txtfile.write("No relation data available\n\n")
  1962. # Relationships
  1963. txtfile.write("RELATIONSHIPS\n")
  1964. txtfile.write("-" * 80 + "\n")
  1965. if relationships_data:
  1966. # Create fixed width columns
  1967. col_widths = {
  1968. k: max(len(k), max(len(str(r[k])) for r in relationships_data))
  1969. for k in relationships_data[0]
  1970. }
  1971. header = " ".join(
  1972. k.ljust(col_widths[k]) for k in relationships_data[0]
  1973. )
  1974. txtfile.write(header + "\n")
  1975. txtfile.write("-" * len(header) + "\n")
  1976. # Write rows
  1977. for relationship in relationships_data:
  1978. row = " ".join(
  1979. str(v).ljust(col_widths[k]) for k, v in relationship.items()
  1980. )
  1981. txtfile.write(row + "\n")
  1982. else:
  1983. txtfile.write("No relationship data available\n\n")
  1984. else:
  1985. raise ValueError(
  1986. f"Unsupported file format: {file_format}. Choose from: csv, excel, md, txt"
  1987. )
  1988. if file_format is not None:
  1989. print(f"Data exported to: {output_path} with format: {file_format}")
  1990. else:
  1991. print("Data displayed as table format")
  1992. def export_data(
  1993. chunk_entity_relation_graph,
  1994. entities_vdb,
  1995. relationships_vdb,
  1996. output_path: str,
  1997. file_format: str = "csv",
  1998. include_vector_data: bool = False,
  1999. ) -> None:
  2000. """
  2001. Synchronously exports all entities, relations, and relationships to various formats.
  2002. Args:
  2003. chunk_entity_relation_graph: Graph storage instance for entities and relations
  2004. entities_vdb: Vector database storage for entities
  2005. relationships_vdb: Vector database storage for relationships
  2006. output_path: The path to the output file (including extension).
  2007. file_format: Output format - "csv", "excel", "md", "txt".
  2008. - csv: Comma-separated values file
  2009. - excel: Microsoft Excel file with multiple sheets
  2010. - md: Markdown tables
  2011. - txt: Plain text formatted output
  2012. include_vector_data: Whether to include data from the vector database.
  2013. """
  2014. try:
  2015. loop = asyncio.get_event_loop()
  2016. except RuntimeError:
  2017. loop = asyncio.new_event_loop()
  2018. asyncio.set_event_loop(loop)
  2019. loop.run_until_complete(
  2020. aexport_data(
  2021. chunk_entity_relation_graph,
  2022. entities_vdb,
  2023. relationships_vdb,
  2024. output_path,
  2025. file_format,
  2026. include_vector_data,
  2027. )
  2028. )
  2029. def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
  2030. """Lazily import a class from an external module based on the package of the caller."""
  2031. # Get the caller's module and package
  2032. import inspect
  2033. caller_frame = inspect.currentframe().f_back
  2034. module = inspect.getmodule(caller_frame)
  2035. package = module.__package__ if module else None
  2036. def import_class(*args: Any, **kwargs: Any):
  2037. import importlib
  2038. module = importlib.import_module(module_name, package=package)
  2039. cls = getattr(module, class_name)
  2040. return cls(*args, **kwargs)
  2041. return import_class
  2042. async def update_chunk_cache_list(
  2043. chunk_id: str,
  2044. text_chunks_storage: "BaseKVStorage",
  2045. cache_keys: list[str],
  2046. cache_scenario: str = "batch_update",
  2047. ) -> None:
  2048. """Update chunk's llm_cache_list with the given cache keys
  2049. Args:
  2050. chunk_id: Chunk identifier
  2051. text_chunks_storage: Text chunks storage instance
  2052. cache_keys: List of cache keys to add to the list
  2053. cache_scenario: Description of the cache scenario for logging
  2054. """
  2055. if not cache_keys:
  2056. return
  2057. try:
  2058. chunk_data = await text_chunks_storage.get_by_id(chunk_id)
  2059. if chunk_data:
  2060. # Ensure llm_cache_list exists
  2061. if "llm_cache_list" not in chunk_data:
  2062. chunk_data["llm_cache_list"] = []
  2063. # Add cache keys to the list if not already present
  2064. existing_keys = set(chunk_data["llm_cache_list"])
  2065. new_keys = [key for key in cache_keys if key not in existing_keys]
  2066. if new_keys:
  2067. chunk_data["llm_cache_list"].extend(new_keys)
  2068. # Update the chunk in storage
  2069. await text_chunks_storage.upsert({chunk_id: chunk_data})
  2070. logger.debug(
  2071. f"Updated chunk {chunk_id} with {len(new_keys)} cache keys ({cache_scenario})"
  2072. )
  2073. except Exception as e:
  2074. logger.warning(
  2075. f"Failed to update chunk {chunk_id} with cache references on {cache_scenario}: {e}"
  2076. )
  2077. def remove_think_tags(text: str) -> str:
  2078. """Remove <think>...</think> tags and their content from the text.
  2079. Handles two cases:
  2080. 1. Complete <think>...</think> blocks anywhere in the text.
  2081. 2. Orphaned </think> at the very start (e.g., from streaming that begins
  2082. mid-think-block), removing everything before and including it.
  2083. """
  2084. # First, remove orphaned </think> prefix (content before first </think>
  2085. # when there is no preceding <think> tag)
  2086. text = re.sub(r"^((?!<think>).)*?</think>", "", text, flags=re.DOTALL)
  2087. # Then remove all complete <think>...</think> blocks
  2088. text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
  2089. return text.strip()
  2090. async def use_llm_func_with_cache(
  2091. user_prompt: str,
  2092. use_llm_func: callable,
  2093. llm_response_cache: "BaseKVStorage | None" = None,
  2094. system_prompt: str | None = None,
  2095. max_tokens: int = None,
  2096. history_messages: list[dict[str, str]] = None,
  2097. cache_type: str = "extract",
  2098. chunk_id: str | None = None,
  2099. cache_keys_collector: list = None,
  2100. response_format: Any | None = None,
  2101. entity_extraction: bool = False,
  2102. llm_cache_identity: Any | None = None,
  2103. ) -> tuple[str, int]:
  2104. """Call LLM function with cache support and text sanitization
  2105. If cache is available and enabled (determined by handle_cache based on mode),
  2106. retrieve result from cache; otherwise call LLM function and save result to cache.
  2107. This function applies text sanitization to prevent UTF-8 encoding errors for all LLM providers.
  2108. Args:
  2109. input_text: Input text to send to LLM
  2110. use_llm_func: LLM function with higher priority
  2111. llm_response_cache: Cache storage instance
  2112. max_tokens: Maximum tokens for generation
  2113. history_messages: History messages list
  2114. cache_type: Type of cache
  2115. chunk_id: Chunk identifier to store in cache
  2116. text_chunks_storage: Text chunks storage to update llm_cache_list
  2117. cache_keys_collector: Optional list to collect cache keys for batch processing
  2118. response_format: Structured output control forwarded to the LLM provider.
  2119. Providers translate this to their native structured-output surface
  2120. (OpenAI response_format, Ollama format, Gemini response_mime_type/schema).
  2121. ``{"type": "json_object"}`` requests JSON output; typed/schema payloads
  2122. trigger schema-constrained output where supported; ``None`` leaves
  2123. output unconstrained. Providers that do not support structured output
  2124. safely strip this argument.
  2125. entity_extraction: Deprecated. When True and ``response_format`` is not
  2126. provided, maps to ``{"type": "json_object"}``. Prefer passing
  2127. ``response_format`` directly.
  2128. llm_cache_identity: Non-secret model/provider identity used to partition
  2129. cache entries across role model, binding, or host changes.
  2130. Returns:
  2131. tuple[str, int]: (LLM response text, timestamp)
  2132. - For cache hits: (content, cache_create_time)
  2133. - For cache misses: (content, current_timestamp)
  2134. """
  2135. if entity_extraction and response_format is None:
  2136. warnings.warn(
  2137. "use_llm_func_with_cache(entity_extraction=True) is deprecated; "
  2138. "pass response_format={'type': 'json_object'} instead.",
  2139. DeprecationWarning,
  2140. stacklevel=2,
  2141. )
  2142. response_format = {"type": "json_object"}
  2143. _validate_cached_response_format(response_format)
  2144. # Sanitize input text to prevent UTF-8 encoding errors for all LLM providers
  2145. safe_user_prompt = sanitize_text_for_encoding(user_prompt)
  2146. safe_system_prompt = (
  2147. sanitize_text_for_encoding(system_prompt) if system_prompt else None
  2148. )
  2149. # Sanitize history messages if provided
  2150. safe_history_messages = None
  2151. if history_messages:
  2152. safe_history_messages = []
  2153. for i, msg in enumerate(history_messages):
  2154. safe_msg = msg.copy()
  2155. if "content" in safe_msg:
  2156. safe_msg["content"] = sanitize_text_for_encoding(safe_msg["content"])
  2157. safe_history_messages.append(safe_msg)
  2158. history = json.dumps(safe_history_messages, ensure_ascii=False)
  2159. else:
  2160. history = None
  2161. if llm_response_cache:
  2162. prompt_parts = []
  2163. if safe_user_prompt:
  2164. prompt_parts.append(safe_user_prompt)
  2165. if safe_system_prompt:
  2166. prompt_parts.append(safe_system_prompt)
  2167. if history:
  2168. prompt_parts.append(history)
  2169. _prompt = "\n".join(prompt_parts)
  2170. response_format_key = _serialize_cache_variant(response_format)
  2171. llm_identity_key = serialize_llm_cache_identity(llm_cache_identity)
  2172. arg_hash = compute_args_hash(
  2173. _prompt,
  2174. "\n<response_format>\n",
  2175. response_format_key,
  2176. "\n<llm_identity>\n",
  2177. llm_identity_key,
  2178. )
  2179. # Generate cache key for this LLM call
  2180. cache_key = generate_cache_key("default", cache_type, arg_hash)
  2181. cached_result = await handle_cache(
  2182. llm_response_cache,
  2183. arg_hash,
  2184. _prompt,
  2185. "default",
  2186. cache_type=cache_type,
  2187. )
  2188. if cached_result:
  2189. content, timestamp = cached_result
  2190. logger.debug(f"Found cache for {arg_hash}")
  2191. statistic_data["llm_cache"] += 1
  2192. # Add cache key to collector if provided
  2193. if cache_keys_collector is not None:
  2194. cache_keys_collector.append(cache_key)
  2195. return content, timestamp
  2196. statistic_data["llm_call"] += 1
  2197. # Call LLM with sanitized input
  2198. kwargs = {}
  2199. if safe_history_messages:
  2200. kwargs["history_messages"] = safe_history_messages
  2201. if max_tokens is not None:
  2202. kwargs["max_tokens"] = max_tokens
  2203. if response_format is not None:
  2204. kwargs["response_format"] = response_format
  2205. res: str = await use_llm_func(
  2206. safe_user_prompt, system_prompt=safe_system_prompt, **kwargs
  2207. )
  2208. res = remove_think_tags(res)
  2209. # Generate timestamp for cache miss (LLM call completion time)
  2210. current_timestamp = int(time.time())
  2211. if llm_response_cache.global_config.get("enable_llm_cache_for_entity_extract"):
  2212. await save_to_cache(
  2213. llm_response_cache,
  2214. CacheData(
  2215. args_hash=arg_hash,
  2216. content=res,
  2217. prompt=_prompt,
  2218. cache_type=cache_type,
  2219. chunk_id=chunk_id,
  2220. ),
  2221. )
  2222. # Add cache key to collector if provided
  2223. if cache_keys_collector is not None:
  2224. cache_keys_collector.append(cache_key)
  2225. return res, current_timestamp
  2226. # When cache is disabled, directly call LLM with sanitized input
  2227. kwargs = {}
  2228. if safe_history_messages:
  2229. kwargs["history_messages"] = safe_history_messages
  2230. if max_tokens is not None:
  2231. kwargs["max_tokens"] = max_tokens
  2232. if response_format is not None:
  2233. kwargs["response_format"] = response_format
  2234. try:
  2235. res = await use_llm_func(
  2236. safe_user_prompt, system_prompt=safe_system_prompt, **kwargs
  2237. )
  2238. except Exception as e:
  2239. # Add [LLM func] prefix to error message
  2240. error_msg = f"[LLM func] {str(e)}"
  2241. # Re-raise with the same exception type but modified message
  2242. raise type(e)(error_msg) from e
  2243. # Generate timestamp for non-cached LLM call
  2244. current_timestamp = int(time.time())
  2245. return remove_think_tags(res), current_timestamp
  2246. def get_content_summary(content: str, max_length: int = 250) -> str:
  2247. """Get summary of document content
  2248. Args:
  2249. content: Original document content
  2250. max_length: Maximum length of summary
  2251. Returns:
  2252. Truncated content with ellipsis if needed
  2253. """
  2254. content = content.strip()
  2255. if len(content) <= max_length:
  2256. return content
  2257. return content[:max_length] + "..."
  2258. def sanitize_and_normalize_extracted_text(
  2259. input_text: str, remove_inner_quotes=False
  2260. ) -> str:
  2261. """Santitize and normalize extracted text
  2262. Args:
  2263. input_text: text string to be processed
  2264. is_name: whether the input text is a entity or relation name
  2265. Returns:
  2266. Santitized and normalized text string
  2267. """
  2268. safe_input_text = sanitize_text_for_encoding(input_text)
  2269. if safe_input_text:
  2270. normalized_text = normalize_extracted_info(
  2271. safe_input_text, remove_inner_quotes=remove_inner_quotes
  2272. )
  2273. return normalized_text
  2274. return ""
  2275. def normalize_extracted_info(name: str, remove_inner_quotes=False) -> str:
  2276. """Normalize entity/relation names and description with the following rules:
  2277. - Clean HTML tags (paragraph and line break tags)
  2278. - Convert Chinese symbols to English symbols
  2279. - Remove spaces between Chinese characters
  2280. - Remove spaces between Chinese characters and English letters/numbers
  2281. - Preserve spaces within English text and numbers
  2282. - Replace Chinese parentheses with English parentheses
  2283. - Replace Chinese dash with English dash
  2284. - Remove English quotation marks from the beginning and end of the text
  2285. - Remove English quotation marks in and around chinese
  2286. - Remove Chinese quotation marks
  2287. - Filter out short numeric-only text (length < 3 and only digits/dots)
  2288. - remove_inner_quotes = True
  2289. remove Chinese quotes
  2290. remove English quotes in and around chinese
  2291. Convert non-breaking spaces to regular spaces
  2292. Convert narrow non-breaking spaces after non-digits to regular spaces
  2293. Args:
  2294. name: Entity name to normalize
  2295. is_entity: Whether this is an entity name (affects quote handling)
  2296. Returns:
  2297. Normalized entity name
  2298. """
  2299. # Clean HTML tags - remove paragraph and line break tags
  2300. name = re.sub(r"</p\s*>|<p\s*>|<p/>", "", name, flags=re.IGNORECASE)
  2301. name = re.sub(r"</br\s*>|<br\s*>|<br/>", "", name, flags=re.IGNORECASE)
  2302. # Chinese full-width letters to half-width (A-Z, a-z)
  2303. name = name.translate(
  2304. str.maketrans(
  2305. "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
  2306. "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
  2307. )
  2308. )
  2309. # Chinese full-width numbers to half-width
  2310. name = name.translate(str.maketrans("0123456789", "0123456789"))
  2311. # Chinese full-width symbols to half-width
  2312. name = name.replace("-", "-") # Chinese minus
  2313. name = name.replace("+", "+") # Chinese plus
  2314. name = name.replace("/", "/") # Chinese slash
  2315. name = name.replace("*", "*") # Chinese asterisk
  2316. # Replace Chinese parentheses with English parentheses
  2317. name = name.replace("(", "(").replace(")", ")")
  2318. # Replace Chinese dash with English dash (additional patterns)
  2319. name = name.replace("—", "-").replace("-", "-")
  2320. # Chinese full-width space to regular space (after other replacements)
  2321. name = name.replace(" ", " ")
  2322. # Use regex to remove spaces between Chinese characters
  2323. # Regex explanation:
  2324. # (?<=[\u4e00-\u9fa5]): Positive lookbehind for Chinese character
  2325. # \s+: One or more whitespace characters
  2326. # (?=[\u4e00-\u9fa5]): Positive lookahead for Chinese character
  2327. name = re.sub(r"(?<=[\u4e00-\u9fa5])\s+(?=[\u4e00-\u9fa5])", "", name)
  2328. # Remove spaces between Chinese and English/numbers/symbols
  2329. name = re.sub(
  2330. r"(?<=[\u4e00-\u9fa5])\s+(?=[a-zA-Z0-9\(\)\[\]@#$%!&\*\-=+_])", "", name
  2331. )
  2332. name = re.sub(
  2333. r"(?<=[a-zA-Z0-9\(\)\[\]@#$%!&\*\-=+_])\s+(?=[\u4e00-\u9fa5])", "", name
  2334. )
  2335. # Remove outer quotes
  2336. if len(name) >= 2:
  2337. # Handle double quotes
  2338. if name.startswith('"') and name.endswith('"'):
  2339. inner_content = name[1:-1]
  2340. if '"' not in inner_content: # No double quotes inside
  2341. name = inner_content
  2342. # Handle single quotes
  2343. if name.startswith("'") and name.endswith("'"):
  2344. inner_content = name[1:-1]
  2345. if "'" not in inner_content: # No single quotes inside
  2346. name = inner_content
  2347. # Handle Chinese-style double quotes
  2348. if name.startswith("“") and name.endswith("”"):
  2349. inner_content = name[1:-1]
  2350. if "“" not in inner_content and "”" not in inner_content:
  2351. name = inner_content
  2352. if name.startswith("‘") and name.endswith("’"):
  2353. inner_content = name[1:-1]
  2354. if "‘" not in inner_content and "’" not in inner_content:
  2355. name = inner_content
  2356. # Handle Chinese-style book title mark
  2357. if name.startswith("《") and name.endswith("》"):
  2358. inner_content = name[1:-1]
  2359. if "《" not in inner_content and "》" not in inner_content:
  2360. name = inner_content
  2361. if remove_inner_quotes:
  2362. # Remove Chinese quotes
  2363. name = name.replace("“", "").replace("”", "").replace("‘", "").replace("’", "")
  2364. # Remove English queotes in and around chinese
  2365. name = re.sub(r"['\"]+(?=[\u4e00-\u9fa5])", "", name)
  2366. name = re.sub(r"(?<=[\u4e00-\u9fa5])['\"]+", "", name)
  2367. # Convert non-breaking space to regular space
  2368. name = name.replace("\u00a0", " ")
  2369. # Convert narrow non-breaking space to regular space when after non-digits
  2370. name = re.sub(r"(?<=[^\d])\u202F", " ", name)
  2371. # Remove spaces from the beginning and end of the text
  2372. name = name.strip()
  2373. # Filter out pure numeric content with length < 3
  2374. if len(name) < 3 and re.match(r"^[0-9]+$", name):
  2375. return ""
  2376. def should_filter_by_dots(text):
  2377. """
  2378. Check if the string consists only of dots and digits, with at least one dot
  2379. Filter cases include: 1.2.3, 12.3, .123, 123., 12.3., .1.23 etc.
  2380. """
  2381. return all(c.isdigit() or c == "." for c in text) and "." in text
  2382. if len(name) < 6 and should_filter_by_dots(name):
  2383. # Filter out mixed numeric and dot content with length < 6, requiring at least one dot
  2384. return ""
  2385. return name
  2386. def sanitize_text_for_encoding(text: str, replacement_char: str = "") -> str:
  2387. """Sanitize text to ensure safe UTF-8 encoding by removing or replacing problematic characters.
  2388. This function handles:
  2389. - Surrogate characters (the main cause of encoding errors)
  2390. - Other invalid Unicode sequences
  2391. - Control characters that might cause issues
  2392. - Unescape HTML escapes
  2393. - Remove control characters
  2394. - Whitespace trimming
  2395. Args:
  2396. text: Input text to sanitize
  2397. replacement_char: Character to use for replacing invalid sequences
  2398. Returns:
  2399. Sanitized text that can be safely encoded as UTF-8
  2400. """
  2401. if not text:
  2402. return text
  2403. # First, strip whitespace
  2404. text = text.strip()
  2405. # Early return if text is empty after basic cleaning
  2406. if not text:
  2407. return text
  2408. # 1. html.unescape first to catch entities that might become surrogates or control chars
  2409. text = html.unescape(text)
  2410. # 2. Use pre-compiled regex to clean surrogates and non-characters in one pass
  2411. # This replaces the slow manual loop and initial .encode() check
  2412. text = _SURROGATE_PATTERN.sub(replacement_char, text)
  2413. # 3. Remove control characters but preserve common whitespace (\t, \n, \r)
  2414. text = _CONTROL_CHAR_PATTERN_ALL.sub(replacement_char, text)
  2415. return text.strip()
  2416. def check_storage_env_vars(storage_name: str) -> None:
  2417. """Check if all required environment variables for storage implementation exist
  2418. Args:
  2419. storage_name: Storage implementation name
  2420. Raises:
  2421. ValueError: If required environment variables are missing
  2422. """
  2423. from lightrag.kg import STORAGE_ENV_REQUIREMENTS
  2424. required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
  2425. missing_vars = [var for var in required_vars if var not in os.environ]
  2426. if missing_vars:
  2427. raise ValueError(
  2428. f"Storage implementation '{storage_name}' requires the following "
  2429. f"environment variables: {', '.join(missing_vars)}"
  2430. )
  2431. def pick_by_weighted_polling(
  2432. entities_or_relations: list[dict],
  2433. max_related_chunks: int,
  2434. min_related_chunks: int = 1,
  2435. ) -> list[str]:
  2436. """
  2437. Linear gradient weighted polling algorithm for text chunk selection.
  2438. This algorithm ensures that entities/relations with higher importance get more text chunks,
  2439. forming a linear decreasing allocation pattern.
  2440. Args:
  2441. entities_or_relations: List of entities or relations sorted by importance (high to low)
  2442. max_related_chunks: Expected number of text chunks for the highest importance entity/relation
  2443. min_related_chunks: Expected number of text chunks for the lowest importance entity/relation
  2444. Returns:
  2445. List of selected text chunk IDs
  2446. """
  2447. if not entities_or_relations:
  2448. return []
  2449. n = len(entities_or_relations)
  2450. if n == 1:
  2451. # Only one entity/relation, return its first max_related_chunks text chunks
  2452. entity_chunks = entities_or_relations[0].get("sorted_chunks", [])
  2453. return entity_chunks[:max_related_chunks]
  2454. # Calculate expected text chunk count for each position (linear decrease)
  2455. expected_counts = []
  2456. for i in range(n):
  2457. # Linear interpolation: from max_related_chunks to min_related_chunks
  2458. ratio = i / (n - 1) if n > 1 else 0
  2459. expected = max_related_chunks - ratio * (
  2460. max_related_chunks - min_related_chunks
  2461. )
  2462. expected_counts.append(int(round(expected)))
  2463. # First round allocation: allocate by expected values
  2464. selected_chunks = []
  2465. used_counts = [] # Track number of chunks used by each entity
  2466. total_remaining = 0 # Accumulate remaining quotas
  2467. for i, entity_rel in enumerate(entities_or_relations):
  2468. entity_chunks = entity_rel.get("sorted_chunks", [])
  2469. expected = expected_counts[i]
  2470. # Actual allocatable count
  2471. actual = min(expected, len(entity_chunks))
  2472. selected_chunks.extend(entity_chunks[:actual])
  2473. used_counts.append(actual)
  2474. # Accumulate remaining quota
  2475. remaining = expected - actual
  2476. if remaining > 0:
  2477. total_remaining += remaining
  2478. # Second round allocation: multi-round scanning to allocate remaining quotas
  2479. for _ in range(total_remaining):
  2480. allocated = False
  2481. # Scan entities one by one, allocate one chunk when finding unused chunks
  2482. for i, entity_rel in enumerate(entities_or_relations):
  2483. entity_chunks = entity_rel.get("sorted_chunks", [])
  2484. # Check if there are still unused chunks
  2485. if used_counts[i] < len(entity_chunks):
  2486. # Allocate one chunk
  2487. selected_chunks.append(entity_chunks[used_counts[i]])
  2488. used_counts[i] += 1
  2489. allocated = True
  2490. break
  2491. # If no chunks were allocated in this round, all entities are exhausted
  2492. if not allocated:
  2493. break
  2494. return selected_chunks
  2495. async def pick_by_vector_similarity(
  2496. query: str,
  2497. text_chunks_storage: "BaseKVStorage",
  2498. chunks_vdb: "BaseVectorStorage",
  2499. num_of_chunks: int,
  2500. entity_info: list[dict[str, Any]],
  2501. embedding_func: callable,
  2502. query_embedding=None,
  2503. ) -> list[str]:
  2504. """
  2505. Vector similarity-based text chunk selection algorithm.
  2506. This algorithm selects text chunks based on cosine similarity between
  2507. the query embedding and text chunk embeddings.
  2508. Args:
  2509. query: User's original query string
  2510. text_chunks_storage: Text chunks storage instance
  2511. chunks_vdb: Vector database storage for chunks
  2512. num_of_chunks: Number of chunks to select
  2513. entity_info: List of entity information containing chunk IDs
  2514. embedding_func: Embedding function to compute query embedding
  2515. Returns:
  2516. List of selected text chunk IDs sorted by similarity (highest first)
  2517. """
  2518. logger.debug(
  2519. f"Vector similarity chunk selection: num_of_chunks={num_of_chunks}, entity_info_count={len(entity_info) if entity_info else 0}"
  2520. )
  2521. if not entity_info or num_of_chunks <= 0:
  2522. return []
  2523. # Collect all unique chunk IDs from entity info
  2524. all_chunk_ids = set()
  2525. for i, entity in enumerate(entity_info):
  2526. chunk_ids = entity.get("sorted_chunks", [])
  2527. all_chunk_ids.update(chunk_ids)
  2528. if not all_chunk_ids:
  2529. logger.warning(
  2530. "Vector similarity chunk selection: no chunk IDs found in entity_info"
  2531. )
  2532. return []
  2533. logger.debug(
  2534. f"Vector similarity chunk selection: {len(all_chunk_ids)} unique chunk IDs collected"
  2535. )
  2536. all_chunk_ids = list(all_chunk_ids)
  2537. try:
  2538. # Use pre-computed query embedding if provided, otherwise compute it
  2539. if query_embedding is None:
  2540. query_embedding = await embedding_func([query], context="query")
  2541. query_embedding = query_embedding[
  2542. 0
  2543. ] # Extract first embedding from batch result
  2544. logger.debug(
  2545. "Computed query embedding for vector similarity chunk selection"
  2546. )
  2547. else:
  2548. logger.debug(
  2549. "Using pre-computed query embedding for vector similarity chunk selection"
  2550. )
  2551. # Get chunk embeddings from vector database
  2552. chunk_vectors = await chunks_vdb.get_vectors_by_ids(all_chunk_ids)
  2553. logger.debug(
  2554. f"Vector similarity chunk selection: {len(chunk_vectors)} chunk vectors Retrieved"
  2555. )
  2556. if not chunk_vectors or len(chunk_vectors) != len(all_chunk_ids):
  2557. if not chunk_vectors:
  2558. logger.warning(
  2559. "Vector similarity chunk selection: no vectors retrieved from chunks_vdb"
  2560. )
  2561. else:
  2562. logger.warning(
  2563. f"Vector similarity chunk selection: found {len(chunk_vectors)} but expecting {len(all_chunk_ids)}"
  2564. )
  2565. return []
  2566. # Calculate cosine similarities
  2567. similarities = []
  2568. valid_vectors = 0
  2569. for chunk_id in all_chunk_ids:
  2570. if chunk_id in chunk_vectors:
  2571. chunk_embedding = chunk_vectors[chunk_id]
  2572. try:
  2573. # Calculate cosine similarity
  2574. similarity = cosine_similarity(query_embedding, chunk_embedding)
  2575. similarities.append((chunk_id, similarity))
  2576. valid_vectors += 1
  2577. except Exception as e:
  2578. logger.warning(
  2579. f"Vector similarity chunk selection: failed to calculate similarity for chunk {chunk_id}: {e}"
  2580. )
  2581. else:
  2582. logger.warning(
  2583. f"Vector similarity chunk selection: no vector found for chunk {chunk_id}"
  2584. )
  2585. # Sort by similarity (highest first) and select top num_of_chunks
  2586. similarities.sort(key=lambda x: x[1], reverse=True)
  2587. selected_chunks = [chunk_id for chunk_id, _ in similarities[:num_of_chunks]]
  2588. logger.debug(
  2589. f"Vector similarity chunk selection: {len(selected_chunks)} chunks from {len(all_chunk_ids)} candidates"
  2590. )
  2591. return selected_chunks
  2592. except Exception as e:
  2593. logger.error(f"[VECTOR_SIMILARITY] Error in vector similarity sorting: {e}")
  2594. import traceback
  2595. logger.error(f"[VECTOR_SIMILARITY] Traceback: {traceback.format_exc()}")
  2596. # Fallback to simple truncation
  2597. logger.debug("[VECTOR_SIMILARITY] Falling back to simple truncation")
  2598. return all_chunk_ids[:num_of_chunks]
  2599. class TokenTracker:
  2600. """Track token usage for LLM calls."""
  2601. def __init__(self):
  2602. self.reset()
  2603. def __enter__(self):
  2604. self.reset()
  2605. return self
  2606. def __exit__(self, exc_type, exc_val, exc_tb):
  2607. print(self)
  2608. def reset(self):
  2609. self.prompt_tokens = 0
  2610. self.completion_tokens = 0
  2611. self.total_tokens = 0
  2612. self.call_count = 0
  2613. def add_usage(self, token_counts):
  2614. """Add token usage from one LLM call.
  2615. Args:
  2616. token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens
  2617. """
  2618. self.prompt_tokens += token_counts.get("prompt_tokens", 0)
  2619. self.completion_tokens += token_counts.get("completion_tokens", 0)
  2620. # If total_tokens is provided, use it directly; otherwise calculate the sum
  2621. if "total_tokens" in token_counts:
  2622. self.total_tokens += token_counts["total_tokens"]
  2623. else:
  2624. self.total_tokens += token_counts.get(
  2625. "prompt_tokens", 0
  2626. ) + token_counts.get("completion_tokens", 0)
  2627. self.call_count += 1
  2628. def get_usage(self):
  2629. """Get current usage statistics."""
  2630. return {
  2631. "prompt_tokens": self.prompt_tokens,
  2632. "completion_tokens": self.completion_tokens,
  2633. "total_tokens": self.total_tokens,
  2634. "call_count": self.call_count,
  2635. }
  2636. def __str__(self):
  2637. usage = self.get_usage()
  2638. return (
  2639. f"LLM call count: {usage['call_count']}, "
  2640. f"Prompt tokens: {usage['prompt_tokens']}, "
  2641. f"Completion tokens: {usage['completion_tokens']}, "
  2642. f"Total tokens: {usage['total_tokens']}"
  2643. )
  2644. async def apply_rerank_if_enabled(
  2645. query: str,
  2646. retrieved_docs: list[dict],
  2647. global_config: dict,
  2648. enable_rerank: bool = True,
  2649. top_n: int = None,
  2650. ) -> list[dict]:
  2651. """
  2652. Apply reranking to retrieved documents if rerank is enabled.
  2653. Args:
  2654. query: The search query
  2655. retrieved_docs: List of retrieved documents
  2656. global_config: Global configuration containing rerank settings
  2657. enable_rerank: Whether to enable reranking from query parameter
  2658. top_n: Number of top documents to return after reranking
  2659. Returns:
  2660. Reranked documents if rerank is enabled, otherwise original documents
  2661. """
  2662. if not enable_rerank or not retrieved_docs:
  2663. return retrieved_docs
  2664. rerank_func = global_config.get("rerank_model_func")
  2665. if not rerank_func:
  2666. logger.warning(
  2667. "Rerank is enabled but no rerank model is configured. Please set up a rerank model or set enable_rerank=False in query parameters."
  2668. )
  2669. return retrieved_docs
  2670. try:
  2671. # Extract document content for reranking
  2672. document_texts = []
  2673. for doc in retrieved_docs:
  2674. # Try multiple possible content fields
  2675. content = (
  2676. doc.get("content")
  2677. or doc.get("text")
  2678. or doc.get("chunk_content")
  2679. or doc.get("document")
  2680. or str(doc)
  2681. )
  2682. document_texts.append(content)
  2683. # Call the new rerank function that returns index-based results
  2684. rerank_results = await rerank_func(
  2685. query=query,
  2686. documents=document_texts,
  2687. top_n=top_n,
  2688. )
  2689. # Process rerank results based on return format
  2690. if rerank_results and len(rerank_results) > 0:
  2691. # Check if results are in the new index-based format
  2692. if isinstance(rerank_results[0], dict) and "index" in rerank_results[0]:
  2693. # New format: [{"index": 0, "relevance_score": 0.85}, ...]
  2694. reranked_docs = []
  2695. for result in rerank_results:
  2696. index = result["index"]
  2697. relevance_score = result["relevance_score"]
  2698. # Get original document and add rerank score
  2699. if 0 <= index < len(retrieved_docs):
  2700. doc = retrieved_docs[index].copy()
  2701. doc["rerank_score"] = relevance_score
  2702. reranked_docs.append(doc)
  2703. logger.info(
  2704. f"Successfully reranked: {len(reranked_docs)} chunks from {len(retrieved_docs)} original chunks"
  2705. )
  2706. return reranked_docs
  2707. else:
  2708. # Legacy format: assume it's already reranked documents
  2709. logger.info(f"Using legacy rerank format: {len(rerank_results)} chunks")
  2710. return rerank_results[:top_n] if top_n else rerank_results
  2711. else:
  2712. logger.warning("Rerank returned empty results, using original chunks")
  2713. return retrieved_docs
  2714. except Exception as e:
  2715. logger.error(f"Error during reranking: {e}, using original chunks")
  2716. return retrieved_docs
  2717. async def process_chunks_unified(
  2718. query: str,
  2719. unique_chunks: list[dict],
  2720. query_param: "QueryParam",
  2721. global_config: dict,
  2722. source_type: str = "mixed",
  2723. chunk_token_limit: int = None, # Add parameter for dynamic token limit
  2724. ) -> list[dict]:
  2725. """
  2726. Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation.
  2727. Args:
  2728. query: Search query for reranking
  2729. chunks: List of text chunks to process
  2730. query_param: Query parameters containing configuration
  2731. global_config: Global configuration dictionary
  2732. source_type: Source type for logging ("vector", "entity", "relationship", "mixed")
  2733. chunk_token_limit: Dynamic token limit for chunks (if None, uses default)
  2734. Returns:
  2735. Processed and filtered list of text chunks
  2736. """
  2737. if not unique_chunks:
  2738. return []
  2739. origin_count = len(unique_chunks)
  2740. # 1. Apply reranking if enabled and query is provided
  2741. if query_param.enable_rerank and query and unique_chunks:
  2742. rerank_top_k = query_param.chunk_top_k or len(unique_chunks)
  2743. unique_chunks = await apply_rerank_if_enabled(
  2744. query=query,
  2745. retrieved_docs=unique_chunks,
  2746. global_config=global_config,
  2747. enable_rerank=query_param.enable_rerank,
  2748. top_n=rerank_top_k,
  2749. )
  2750. # 2. Filter by minimum rerank score if reranking is enabled
  2751. if query_param.enable_rerank and unique_chunks:
  2752. min_rerank_score = global_config.get("min_rerank_score", 0.5)
  2753. if min_rerank_score > 0.0:
  2754. original_count = len(unique_chunks)
  2755. # Filter chunks with score below threshold
  2756. filtered_chunks = []
  2757. for chunk in unique_chunks:
  2758. rerank_score = chunk.get(
  2759. "rerank_score", 1.0
  2760. ) # Default to 1.0 if no score
  2761. if rerank_score >= min_rerank_score:
  2762. filtered_chunks.append(chunk)
  2763. unique_chunks = filtered_chunks
  2764. filtered_count = original_count - len(unique_chunks)
  2765. if filtered_count > 0:
  2766. logger.info(
  2767. f"Rerank filtering: {len(unique_chunks)} chunks remained (min rerank score: {min_rerank_score})"
  2768. )
  2769. if not unique_chunks:
  2770. return []
  2771. # 3. Apply chunk_top_k limiting if specified
  2772. if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0:
  2773. if len(unique_chunks) > query_param.chunk_top_k:
  2774. unique_chunks = unique_chunks[: query_param.chunk_top_k]
  2775. logger.debug(
  2776. f"Kept chunk_top-k: {len(unique_chunks)} chunks (deduplicated original: {origin_count})"
  2777. )
  2778. # 4. Token-based final truncation
  2779. tokenizer = global_config.get("tokenizer")
  2780. if tokenizer and unique_chunks:
  2781. # Set default chunk_token_limit if not provided
  2782. if chunk_token_limit is None:
  2783. # Get default from query_param or global_config
  2784. chunk_token_limit = getattr(
  2785. query_param,
  2786. "max_total_tokens",
  2787. global_config.get("MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS),
  2788. )
  2789. original_count = len(unique_chunks)
  2790. unique_chunks = truncate_list_by_token_size(
  2791. unique_chunks,
  2792. key=lambda x: "\n".join(
  2793. json.dumps(item, ensure_ascii=False) for item in [x]
  2794. ),
  2795. max_token_size=chunk_token_limit,
  2796. tokenizer=tokenizer,
  2797. )
  2798. logger.debug(
  2799. f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
  2800. f"(chunk available tokens: {chunk_token_limit}, source: {source_type})"
  2801. )
  2802. # 5. add id field to each chunk
  2803. final_chunks = []
  2804. for i, chunk in enumerate(unique_chunks):
  2805. chunk_with_id = chunk.copy()
  2806. chunk_with_id["id"] = f"DC{i + 1}"
  2807. final_chunks.append(chunk_with_id)
  2808. return final_chunks
  2809. def normalize_source_ids_limit_method(method: str | None) -> str:
  2810. """Normalize the source ID limiting strategy and fall back to default when invalid."""
  2811. if not method:
  2812. return DEFAULT_SOURCE_IDS_LIMIT_METHOD
  2813. normalized = method.upper()
  2814. if normalized not in VALID_SOURCE_IDS_LIMIT_METHODS:
  2815. logger.warning(
  2816. "Unknown SOURCE_IDS_LIMIT_METHOD '%s', falling back to %s",
  2817. method,
  2818. DEFAULT_SOURCE_IDS_LIMIT_METHOD,
  2819. )
  2820. return DEFAULT_SOURCE_IDS_LIMIT_METHOD
  2821. return normalized
  2822. def merge_source_ids(
  2823. existing_ids: Iterable[str] | None, new_ids: Iterable[str] | None
  2824. ) -> list[str]:
  2825. """Merge two iterables of source IDs while preserving order and removing duplicates."""
  2826. merged: list[str] = []
  2827. seen: set[str] = set()
  2828. for sequence in (existing_ids, new_ids):
  2829. if not sequence:
  2830. continue
  2831. for source_id in sequence:
  2832. if not source_id:
  2833. continue
  2834. if source_id not in seen:
  2835. seen.add(source_id)
  2836. merged.append(source_id)
  2837. return merged
  2838. def apply_source_ids_limit(
  2839. source_ids: Sequence[str],
  2840. limit: int,
  2841. method: str,
  2842. *,
  2843. identifier: str | None = None,
  2844. ) -> list[str]:
  2845. """Apply a limit strategy to a sequence of source IDs."""
  2846. if limit <= 0:
  2847. return []
  2848. source_ids_list = list(source_ids)
  2849. if len(source_ids_list) <= limit:
  2850. return source_ids_list
  2851. normalized_method = normalize_source_ids_limit_method(method)
  2852. if normalized_method == SOURCE_IDS_LIMIT_METHOD_FIFO:
  2853. truncated = source_ids_list[-limit:]
  2854. else: # IGNORE_NEW
  2855. truncated = source_ids_list[:limit]
  2856. if identifier and len(truncated) < len(source_ids_list):
  2857. logger.debug(
  2858. "Source_id truncated: %s | %s keeping %s of %s entries",
  2859. identifier,
  2860. normalized_method,
  2861. len(truncated),
  2862. len(source_ids_list),
  2863. )
  2864. return truncated
  2865. def compute_incremental_chunk_ids(
  2866. existing_full_chunk_ids: list[str],
  2867. old_chunk_ids: list[str],
  2868. new_chunk_ids: list[str],
  2869. ) -> list[str]:
  2870. """
  2871. Compute incrementally updated chunk IDs based on changes.
  2872. This function applies delta changes (additions and removals) to an existing
  2873. list of chunk IDs while maintaining order and ensuring deduplication.
  2874. Delta additions from new_chunk_ids are placed at the end.
  2875. Args:
  2876. existing_full_chunk_ids: Complete list of existing chunk IDs from storage
  2877. old_chunk_ids: Previous chunk IDs from source_id (chunks being replaced)
  2878. new_chunk_ids: New chunk IDs from updated source_id (chunks being added)
  2879. Returns:
  2880. Updated list of chunk IDs with deduplication
  2881. Example:
  2882. >>> existing = ['chunk-1', 'chunk-2', 'chunk-3']
  2883. >>> old = ['chunk-1', 'chunk-2']
  2884. >>> new = ['chunk-2', 'chunk-4']
  2885. >>> compute_incremental_chunk_ids(existing, old, new)
  2886. ['chunk-3', 'chunk-2', 'chunk-4']
  2887. """
  2888. # Calculate changes
  2889. chunks_to_remove = set(old_chunk_ids) - set(new_chunk_ids)
  2890. chunks_to_add = set(new_chunk_ids) - set(old_chunk_ids)
  2891. # Apply changes to full chunk_ids
  2892. # Step 1: Remove chunks that are no longer needed
  2893. updated_chunk_ids = [
  2894. cid for cid in existing_full_chunk_ids if cid not in chunks_to_remove
  2895. ]
  2896. # Step 2: Add new chunks (preserving order from new_chunk_ids)
  2897. # Note: 'cid not in updated_chunk_ids' check ensures deduplication
  2898. for cid in new_chunk_ids:
  2899. if cid in chunks_to_add and cid not in updated_chunk_ids:
  2900. updated_chunk_ids.append(cid)
  2901. return updated_chunk_ids
  2902. def subtract_source_ids(
  2903. source_ids: Iterable[str],
  2904. ids_to_remove: Collection[str],
  2905. ) -> list[str]:
  2906. """Remove a collection of IDs from an ordered iterable while preserving order."""
  2907. removal_set = set(ids_to_remove)
  2908. if not removal_set:
  2909. return [source_id for source_id in source_ids if source_id]
  2910. return [
  2911. source_id
  2912. for source_id in source_ids
  2913. if source_id and source_id not in removal_set
  2914. ]
  2915. def make_relation_chunk_key(src: str, tgt: str) -> str:
  2916. """Create a deterministic storage key for relation chunk tracking."""
  2917. return GRAPH_FIELD_SEP.join(sorted((src, tgt)))
  2918. def parse_relation_chunk_key(key: str) -> tuple[str, str]:
  2919. """Parse a relation chunk storage key back into its entity pair."""
  2920. parts = key.split(GRAPH_FIELD_SEP)
  2921. if len(parts) != 2:
  2922. raise ValueError(f"Invalid relation chunk key: {key}")
  2923. return parts[0], parts[1]
  2924. def generate_track_id(prefix: str = "upload") -> str:
  2925. """Generate a unique tracking ID with timestamp and UUID
  2926. Args:
  2927. prefix: Prefix for the track ID (e.g., 'upload', 'insert')
  2928. Returns:
  2929. str: Unique tracking ID in format: {prefix}_{timestamp}_{uuid}
  2930. """
  2931. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  2932. unique_id = str(uuid.uuid4())[:8] # Use first 8 characters of UUID
  2933. return f"{prefix}_{timestamp}_{unique_id}"
  2934. def get_pinyin_sort_key(text: str) -> str:
  2935. """Generate sort key for Chinese pinyin sorting
  2936. This function uses pypinyin for true Chinese pinyin sorting.
  2937. If pypinyin is not available, it falls back to simple lowercase string sorting.
  2938. Args:
  2939. text: Text to generate sort key for
  2940. Returns:
  2941. str: Sort key that can be used for comparison and sorting
  2942. """
  2943. if not text:
  2944. return ""
  2945. if _PYPINYIN_AVAILABLE:
  2946. try:
  2947. # Convert Chinese characters to pinyin, keep non-Chinese as-is
  2948. pinyin_list = pypinyin.lazy_pinyin(text, style=pypinyin.Style.NORMAL)
  2949. return "".join(pinyin_list).lower()
  2950. except Exception:
  2951. # Silently fall back to simple string sorting on any error
  2952. return text.lower()
  2953. else:
  2954. # pypinyin not available, use simple string sorting
  2955. return text.lower()
  2956. def fix_tuple_delimiter_corruption(
  2957. record: str, delimiter_core: str, tuple_delimiter: str
  2958. ) -> str:
  2959. """
  2960. Fix various forms of tuple_delimiter corruption from LLM output.
  2961. This function handles missing or replaced characters around the core delimiter.
  2962. It fixes common corruption patterns where the LLM output doesn't match the expected
  2963. tuple_delimiter format.
  2964. Args:
  2965. record: The text record to fix
  2966. delimiter_core: The core delimiter (e.g., "S" from "<|#|>")
  2967. tuple_delimiter: The complete tuple delimiter (e.g., "<|#|>")
  2968. Returns:
  2969. The corrected record with proper tuple_delimiter format
  2970. """
  2971. if not record or not delimiter_core or not tuple_delimiter:
  2972. return record
  2973. # Escape the delimiter core for regex use
  2974. escaped_delimiter_core = re.escape(delimiter_core)
  2975. # Fix: <|##|> -> <|#|>, <|#||#|> -> <|#|>, <|#|||#|> -> <|#|>
  2976. record = re.sub(
  2977. rf"<\|{escaped_delimiter_core}\|*?{escaped_delimiter_core}\|>",
  2978. tuple_delimiter,
  2979. record,
  2980. )
  2981. # Fix: <|\#|> -> <|#|>
  2982. record = re.sub(
  2983. rf"<\|\\{escaped_delimiter_core}\|>",
  2984. tuple_delimiter,
  2985. record,
  2986. )
  2987. # Fix: <|> -> <|#|>, <||> -> <|#|>
  2988. record = re.sub(
  2989. r"<\|+>",
  2990. tuple_delimiter,
  2991. record,
  2992. )
  2993. # Fix: <X|#|> -> <|#|>, <|#|Y> -> <|#|>, <X|#|Y> -> <|#|>, <||#||> -> <|#|> (one extra characters outside pipes)
  2994. record = re.sub(
  2995. rf"<.?\|{escaped_delimiter_core}\|.?>",
  2996. tuple_delimiter,
  2997. record,
  2998. )
  2999. # Fix: <#>, <#|>, <|#> -> <|#|> (missing one or both pipes)
  3000. record = re.sub(
  3001. rf"<\|?{escaped_delimiter_core}\|?>",
  3002. tuple_delimiter,
  3003. record,
  3004. )
  3005. # Fix: <X#|> -> <|#|>, <|#X> -> <|#|> (one pipe is replaced by other character)
  3006. record = re.sub(
  3007. rf"<[^|]{escaped_delimiter_core}\|>|<\|{escaped_delimiter_core}[^|]>",
  3008. tuple_delimiter,
  3009. record,
  3010. )
  3011. # Fix: <|#| -> <|#|>, <|#|| -> <|#|> (missing closing >)
  3012. record = re.sub(
  3013. rf"<\|{escaped_delimiter_core}\|+(?!>)",
  3014. tuple_delimiter,
  3015. record,
  3016. )
  3017. # Fix <|#: -> <|#|> (missing closing >)
  3018. record = re.sub(
  3019. rf"<\|{escaped_delimiter_core}:(?!>)",
  3020. tuple_delimiter,
  3021. record,
  3022. )
  3023. # Fix: <||#> -> <|#|> (double pipe at start, missing pipe at end)
  3024. record = re.sub(
  3025. rf"<\|+{escaped_delimiter_core}>",
  3026. tuple_delimiter,
  3027. record,
  3028. )
  3029. # Fix: <|| -> <|#|>
  3030. record = re.sub(
  3031. r"<\|\|(?!>)",
  3032. tuple_delimiter,
  3033. record,
  3034. )
  3035. # Fix: |#|> -> <|#|> (missing opening <)
  3036. record = re.sub(
  3037. rf"(?<!<)\|{escaped_delimiter_core}\|>",
  3038. tuple_delimiter,
  3039. record,
  3040. )
  3041. # Fix: <|#|>| -> <|#|> ( this is a fix for: <|#|| -> <|#|> )
  3042. record = re.sub(
  3043. rf"<\|{escaped_delimiter_core}\|>\|",
  3044. tuple_delimiter,
  3045. record,
  3046. )
  3047. # Fix: ||#|| -> <|#|> (double pipes on both sides without angle brackets)
  3048. record = re.sub(
  3049. rf"\|\|{escaped_delimiter_core}\|\|",
  3050. tuple_delimiter,
  3051. record,
  3052. )
  3053. return record
  3054. def create_prefixed_exception(original_exception: Exception, prefix: str) -> Exception:
  3055. """
  3056. Safely create a prefixed exception that adapts to all error types.
  3057. Args:
  3058. original_exception: The original exception.
  3059. prefix: The prefix to add.
  3060. Returns:
  3061. A new exception with the prefix, maintaining the original exception type if possible.
  3062. """
  3063. try:
  3064. # Method 1: Try to reconstruct using original arguments.
  3065. if hasattr(original_exception, "args") and original_exception.args:
  3066. args = list(original_exception.args)
  3067. # Find the first string argument and prefix it. This is safer for
  3068. # exceptions like OSError where the first arg is an integer (errno).
  3069. found_str = False
  3070. for i, arg in enumerate(args):
  3071. if isinstance(arg, str):
  3072. args[i] = f"{prefix}: {arg}"
  3073. found_str = True
  3074. break
  3075. # If no string argument is found, prefix the first argument's string representation.
  3076. if not found_str:
  3077. args[0] = f"{prefix}: {args[0]}"
  3078. return type(original_exception)(*args)
  3079. else:
  3080. # Method 2: If no args, try single parameter construction.
  3081. return type(original_exception)(f"{prefix}: {str(original_exception)}")
  3082. except Exception:
  3083. # Method 3: If reconstruction fails for any reason, wrap it in a
  3084. # RuntimeError preserving the original type name and message. This is a
  3085. # defensive catch-all: most known failures already surface as TypeError
  3086. # (e.g. json.JSONDecodeError needs (msg, doc, pos) and
  3087. # openai.APIStatusError/BadRequestError need keyword-only
  3088. # (response, body), so rebuilding from args alone raises TypeError), but
  3089. # an exotic constructor could raise something else (KeyError, a
  3090. # validation error, ...). Catching `Exception` guarantees this helper
  3091. # never raises while prefixing — `KeyboardInterrupt`/`SystemExit` are
  3092. # BaseException and still propagate. The original exception and its full
  3093. # traceback are preserved by the caller's `raise ... from original`.
  3094. return RuntimeError(
  3095. f"{prefix}: {type(original_exception).__name__}: {str(original_exception)}"
  3096. )
  3097. def convert_to_user_format(
  3098. entities_context: list[dict],
  3099. relations_context: list[dict],
  3100. chunks: list[dict],
  3101. references: list[dict],
  3102. query_mode: str,
  3103. entity_id_to_original: dict = None,
  3104. relation_id_to_original: dict = None,
  3105. ) -> dict[str, Any]:
  3106. """Convert internal data format to user-friendly format using original database data"""
  3107. # Convert entities format using original data when available
  3108. formatted_entities = []
  3109. for entity in entities_context:
  3110. entity_name = entity.get("entity", "")
  3111. # Try to get original data first
  3112. original_entity = None
  3113. if entity_id_to_original and entity_name in entity_id_to_original:
  3114. original_entity = entity_id_to_original[entity_name]
  3115. if original_entity:
  3116. # Use original database data
  3117. formatted_entities.append(
  3118. {
  3119. "entity_name": original_entity.get("entity_name", entity_name),
  3120. "entity_type": original_entity.get("entity_type", "UNKNOWN"),
  3121. "description": original_entity.get("description", ""),
  3122. "source_id": original_entity.get("source_id", ""),
  3123. "file_path": original_entity.get("file_path", "unknown_source"),
  3124. "created_at": original_entity.get("created_at", ""),
  3125. }
  3126. )
  3127. else:
  3128. # Fallback to LLM context data (for backward compatibility)
  3129. formatted_entities.append(
  3130. {
  3131. "entity_name": entity_name,
  3132. "entity_type": entity.get("type", "UNKNOWN"),
  3133. "description": entity.get("description", ""),
  3134. "source_id": entity.get("source_id", ""),
  3135. "file_path": entity.get("file_path", "unknown_source"),
  3136. "created_at": entity.get("created_at", ""),
  3137. }
  3138. )
  3139. # Convert relationships format using original data when available
  3140. formatted_relationships = []
  3141. for relation in relations_context:
  3142. entity1 = relation.get("entity1", "")
  3143. entity2 = relation.get("entity2", "")
  3144. relation_key = (entity1, entity2)
  3145. # Try to get original data first
  3146. original_relation = None
  3147. if relation_id_to_original and relation_key in relation_id_to_original:
  3148. original_relation = relation_id_to_original[relation_key]
  3149. if original_relation:
  3150. # Use original database data
  3151. formatted_relationships.append(
  3152. {
  3153. "src_id": original_relation.get("src_id", entity1),
  3154. "tgt_id": original_relation.get("tgt_id", entity2),
  3155. "description": original_relation.get("description", ""),
  3156. "keywords": original_relation.get("keywords", ""),
  3157. "weight": original_relation.get("weight", 1.0),
  3158. "source_id": original_relation.get("source_id", ""),
  3159. "file_path": original_relation.get("file_path", "unknown_source"),
  3160. "created_at": original_relation.get("created_at", ""),
  3161. }
  3162. )
  3163. else:
  3164. # Fallback to LLM context data (for backward compatibility)
  3165. formatted_relationships.append(
  3166. {
  3167. "src_id": entity1,
  3168. "tgt_id": entity2,
  3169. "description": relation.get("description", ""),
  3170. "keywords": relation.get("keywords", ""),
  3171. "weight": relation.get("weight", 1.0),
  3172. "source_id": relation.get("source_id", ""),
  3173. "file_path": relation.get("file_path", "unknown_source"),
  3174. "created_at": relation.get("created_at", ""),
  3175. }
  3176. )
  3177. # Convert chunks format (chunks already contain complete data)
  3178. formatted_chunks = []
  3179. for i, chunk in enumerate(chunks):
  3180. chunk_data = {
  3181. "reference_id": chunk.get("reference_id", ""),
  3182. "content": chunk.get("content", ""),
  3183. "file_path": chunk.get("file_path", "unknown_source"),
  3184. "chunk_id": chunk.get("chunk_id", ""),
  3185. }
  3186. formatted_chunks.append(chunk_data)
  3187. logger.debug(
  3188. f"[convert_to_user_format] Formatted {len(formatted_chunks)}/{len(chunks)} chunks"
  3189. )
  3190. # Build basic metadata (metadata details will be added by calling functions)
  3191. metadata = {
  3192. "query_mode": query_mode,
  3193. "keywords": {
  3194. "high_level": [],
  3195. "low_level": [],
  3196. }, # Placeholder, will be set by calling functions
  3197. }
  3198. return {
  3199. "status": "success",
  3200. "message": "Query processed successfully",
  3201. "data": {
  3202. "entities": formatted_entities,
  3203. "relationships": formatted_relationships,
  3204. "chunks": formatted_chunks,
  3205. "references": references,
  3206. },
  3207. "metadata": metadata,
  3208. }
  3209. def generate_reference_list_from_chunks(
  3210. chunks: list[dict],
  3211. ) -> tuple[list[dict], list[dict]]:
  3212. """
  3213. Generate reference list from chunks, prioritizing by occurrence frequency.
  3214. This function extracts file_paths from chunks, counts their occurrences,
  3215. sorts by frequency and first appearance order, creates reference_id mappings,
  3216. and builds a reference_list structure.
  3217. Args:
  3218. chunks: List of chunk dictionaries with file_path information
  3219. Returns:
  3220. tuple: (reference_list, updated_chunks_with_reference_ids)
  3221. - reference_list: List of dicts with reference_id and file_path
  3222. - updated_chunks_with_reference_ids: Original chunks with reference_id field added
  3223. """
  3224. if not chunks:
  3225. return [], []
  3226. # 1. Extract all valid file_paths and count their occurrences
  3227. file_path_counts = {}
  3228. for chunk in chunks:
  3229. file_path = chunk.get("file_path", "")
  3230. if file_path and file_path != "unknown_source":
  3231. file_path_counts[file_path] = file_path_counts.get(file_path, 0) + 1
  3232. # 2. Sort file paths by frequency (descending), then by first appearance order
  3233. # Create a list of (file_path, count, first_index) tuples
  3234. file_path_with_indices = []
  3235. seen_paths = set()
  3236. for i, chunk in enumerate(chunks):
  3237. file_path = chunk.get("file_path", "")
  3238. if file_path and file_path != "unknown_source" and file_path not in seen_paths:
  3239. file_path_with_indices.append((file_path, file_path_counts[file_path], i))
  3240. seen_paths.add(file_path)
  3241. # Sort by count (descending), then by first appearance index (ascending)
  3242. sorted_file_paths = sorted(file_path_with_indices, key=lambda x: (-x[1], x[2]))
  3243. unique_file_paths = [item[0] for item in sorted_file_paths]
  3244. # 3. Create mapping from file_path to reference_id (prioritized by frequency)
  3245. file_path_to_ref_id = {}
  3246. for i, file_path in enumerate(unique_file_paths):
  3247. file_path_to_ref_id[file_path] = str(i + 1)
  3248. # 4. Add reference_id field to each chunk
  3249. updated_chunks = []
  3250. for chunk in chunks:
  3251. chunk_copy = chunk.copy()
  3252. file_path = chunk_copy.get("file_path", "")
  3253. if file_path and file_path != "unknown_source":
  3254. chunk_copy["reference_id"] = file_path_to_ref_id[file_path]
  3255. else:
  3256. chunk_copy["reference_id"] = ""
  3257. updated_chunks.append(chunk_copy)
  3258. # 5. Build reference_list
  3259. reference_list = []
  3260. for i, file_path in enumerate(unique_file_paths):
  3261. reference_list.append({"reference_id": str(i + 1), "file_path": file_path})
  3262. return reference_list, updated_chunks