neo4j_impl.py 85 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019
  1. import os
  2. import re
  3. from dataclasses import dataclass
  4. from typing import final
  5. import configparser
  6. from tenacity import (
  7. retry,
  8. stop_after_attempt,
  9. wait_exponential,
  10. retry_if_exception_type,
  11. )
  12. import logging
  13. from ..utils import logger
  14. from ..base import BaseGraphStorage
  15. from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
  16. from ..kg.shared_storage import get_data_init_lock
  17. import pipmaster as pm
  18. if not pm.is_installed("neo4j"):
  19. pm.install("neo4j")
  20. from neo4j import ( # type: ignore
  21. AsyncGraphDatabase,
  22. exceptions as neo4jExceptions,
  23. AsyncDriver,
  24. AsyncManagedTransaction,
  25. )
  26. from dotenv import load_dotenv
  27. # use the .env that is inside the current folder
  28. # allows to use different .env file for each lightrag instance
  29. # the OS environment variables take precedence over the .env file
  30. load_dotenv(dotenv_path=".env", override=False)
  31. config = configparser.ConfigParser()
  32. config.read("config.ini", "utf-8")
  33. # Set neo4j logger level to ERROR to suppress warning logs
  34. logging.getLogger("neo4j").setLevel(logging.ERROR)
  35. READ_RETRY_EXCEPTIONS = (
  36. neo4jExceptions.ServiceUnavailable,
  37. neo4jExceptions.TransientError,
  38. neo4jExceptions.SessionExpired,
  39. ConnectionResetError,
  40. OSError,
  41. AttributeError,
  42. )
  43. READ_RETRY = retry(
  44. stop=stop_after_attempt(3),
  45. wait=wait_exponential(multiplier=1, min=4, max=10),
  46. retry=retry_if_exception_type(READ_RETRY_EXCEPTIONS),
  47. reraise=True,
  48. )
  49. @final
  50. @dataclass
  51. class Neo4JStorage(BaseGraphStorage):
  52. def __init__(self, namespace, global_config, embedding_func, workspace=None):
  53. # Read env and override the arg if present
  54. neo4j_workspace = os.environ.get("NEO4J_WORKSPACE")
  55. original_workspace = workspace # Save original value for logging
  56. if neo4j_workspace and neo4j_workspace.strip():
  57. workspace = neo4j_workspace
  58. # Default to 'base' when both arg and env are empty
  59. if not workspace or not str(workspace).strip():
  60. workspace = "base"
  61. super().__init__(
  62. namespace=namespace,
  63. workspace=workspace,
  64. global_config=global_config,
  65. embedding_func=embedding_func,
  66. )
  67. # Log after super().__init__() to ensure self.workspace is initialized
  68. if neo4j_workspace and neo4j_workspace.strip():
  69. logger.info(
  70. f"Using NEO4J_WORKSPACE environment variable: '{neo4j_workspace}' (overriding '{original_workspace}/{namespace}')"
  71. )
  72. self._driver = None
  73. def _get_workspace_label(self) -> str:
  74. """Return sanitized workspace label safe for use as a backtick-quoted identifier in Cypher queries.
  75. Escapes backticks by doubling them to prevent Cypher injection
  76. via the LIGHTRAG-WORKSPACE header, while preserving a 1-to-1 mapping
  77. for all other characters. The returned value is intended to be used
  78. inside backticks (for example, MATCH (n:`{label}`)) and is not
  79. validated as a standalone unquoted identifier.
  80. """
  81. workspace = self.workspace.strip()
  82. if not workspace:
  83. return "base"
  84. return workspace.replace("`", "``")
  85. def _normalize_index_suffix(self, workspace_label: str) -> str:
  86. """Normalize workspace label for safe use in index names."""
  87. normalized = re.sub(r"[^A-Za-z0-9_]+", "_", workspace_label).strip("_")
  88. if not normalized:
  89. normalized = "base"
  90. if not re.match(r"[A-Za-z_]", normalized[0]):
  91. normalized = f"ws_{normalized}"
  92. return normalized
  93. def _get_fulltext_index_name(self, workspace_label: str) -> str:
  94. """Return a full-text index name derived from the normalized workspace label."""
  95. suffix = self._normalize_index_suffix(workspace_label)
  96. return f"entity_id_fulltext_idx_{suffix}"
  97. def _is_chinese_text(self, text: str) -> bool:
  98. """Check if text contains Chinese/CJK characters.
  99. Covers:
  100. - CJK Unified Ideographs (U+4E00-U+9FFF)
  101. - CJK Extension A (U+3400-U+4DBF)
  102. - CJK Compatibility Ideographs (U+F900-U+FAFF)
  103. - CJK Extension B-F (U+20000-U+2FA1F) - supplementary planes
  104. """
  105. cjk_pattern = re.compile(
  106. r"[\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff]|[\U00020000-\U0002fa1f]"
  107. )
  108. return bool(cjk_pattern.search(text))
  109. async def initialize(self):
  110. async with get_data_init_lock():
  111. URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
  112. USERNAME = os.environ.get(
  113. "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
  114. )
  115. PASSWORD = os.environ.get(
  116. "NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None)
  117. )
  118. MAX_CONNECTION_POOL_SIZE = int(
  119. os.environ.get(
  120. "NEO4J_MAX_CONNECTION_POOL_SIZE",
  121. config.get("neo4j", "connection_pool_size", fallback=100),
  122. )
  123. )
  124. CONNECTION_TIMEOUT = float(
  125. os.environ.get(
  126. "NEO4J_CONNECTION_TIMEOUT",
  127. config.get("neo4j", "connection_timeout", fallback=30.0),
  128. ),
  129. )
  130. CONNECTION_ACQUISITION_TIMEOUT = float(
  131. os.environ.get(
  132. "NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
  133. config.get(
  134. "neo4j", "connection_acquisition_timeout", fallback=30.0
  135. ),
  136. ),
  137. )
  138. MAX_TRANSACTION_RETRY_TIME = float(
  139. os.environ.get(
  140. "NEO4J_MAX_TRANSACTION_RETRY_TIME",
  141. config.get("neo4j", "max_transaction_retry_time", fallback=30.0),
  142. ),
  143. )
  144. MAX_CONNECTION_LIFETIME = float(
  145. os.environ.get(
  146. "NEO4J_MAX_CONNECTION_LIFETIME",
  147. config.get("neo4j", "max_connection_lifetime", fallback=300.0),
  148. ),
  149. )
  150. LIVENESS_CHECK_TIMEOUT = float(
  151. os.environ.get(
  152. "NEO4J_LIVENESS_CHECK_TIMEOUT",
  153. config.get("neo4j", "liveness_check_timeout", fallback=30.0),
  154. ),
  155. )
  156. KEEP_ALIVE = os.environ.get(
  157. "NEO4J_KEEP_ALIVE",
  158. config.get("neo4j", "keep_alive", fallback="true"),
  159. ).lower() in ("true", "1", "yes", "on")
  160. DATABASE = os.environ.get(
  161. "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
  162. )
  163. """The default value approach for the DATABASE is only intended to maintain compatibility with legacy practices."""
  164. self._driver: AsyncDriver = AsyncGraphDatabase.driver(
  165. URI,
  166. auth=(USERNAME, PASSWORD),
  167. max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
  168. connection_timeout=CONNECTION_TIMEOUT,
  169. connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
  170. max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
  171. max_connection_lifetime=MAX_CONNECTION_LIFETIME,
  172. liveness_check_timeout=LIVENESS_CHECK_TIMEOUT,
  173. keep_alive=KEEP_ALIVE,
  174. )
  175. # Try to connect to the database and create it if it doesn't exist
  176. for database in (DATABASE, None):
  177. self._DATABASE = database
  178. connected = False
  179. try:
  180. async with self._driver.session(database=database) as session:
  181. try:
  182. result = await session.run("MATCH (n) RETURN n LIMIT 0")
  183. await result.consume() # Ensure result is consumed
  184. logger.info(
  185. f"[{self.workspace}] Connected to {database} at {URI}"
  186. )
  187. connected = True
  188. except neo4jExceptions.ServiceUnavailable as e:
  189. logger.error(
  190. f"[{self.workspace}] "
  191. + f"Database {database} at {URI} is not available"
  192. )
  193. raise e
  194. except neo4jExceptions.AuthError as e:
  195. logger.error(
  196. f"[{self.workspace}] Authentication failed for {database} at {URI}"
  197. )
  198. raise e
  199. except neo4jExceptions.ClientError as e:
  200. if e.code == "Neo.ClientError.Database.DatabaseNotFound":
  201. logger.info(
  202. f"[{self.workspace}] "
  203. + f"Database {database} at {URI} not found. Try to create specified database."
  204. )
  205. try:
  206. async with self._driver.session() as session:
  207. result = await session.run(
  208. f"CREATE DATABASE `{database}` IF NOT EXISTS"
  209. )
  210. await result.consume() # Ensure result is consumed
  211. logger.info(
  212. f"[{self.workspace}] "
  213. + f"Database {database} at {URI} created"
  214. )
  215. connected = True
  216. except (
  217. neo4jExceptions.ClientError,
  218. neo4jExceptions.DatabaseError,
  219. ) as e:
  220. if (
  221. e.code
  222. == "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
  223. ) or (
  224. e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
  225. ):
  226. if database is not None:
  227. logger.warning(
  228. f"[{self.workspace}] This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
  229. )
  230. if database is None:
  231. logger.error(
  232. f"[{self.workspace}] Failed to create {database} at {URI}"
  233. )
  234. raise e
  235. if connected:
  236. workspace_label = self._get_workspace_label()
  237. # Create B-Tree index for entity_id for faster lookups
  238. try:
  239. async with self._driver.session(database=database) as session:
  240. await session.run(
  241. f"CREATE INDEX IF NOT EXISTS FOR (n:`{workspace_label}`) ON (n.entity_id)"
  242. )
  243. logger.info(
  244. f"[{self.workspace}] Ensured B-Tree index on entity_id for {workspace_label} in {database}"
  245. )
  246. except Exception as e:
  247. logger.warning(
  248. f"[{self.workspace}] Failed to create B-Tree index: {str(e)}"
  249. )
  250. # Create full-text index for entity_id for faster text searches
  251. await self._create_fulltext_index(
  252. self._driver, self._DATABASE, workspace_label
  253. )
  254. break
  255. async def _create_fulltext_index(
  256. self, driver: AsyncDriver, database: str, workspace_label: str
  257. ):
  258. """Create a full-text index on the entity_id property with Chinese tokenizer support."""
  259. index_name = self._get_fulltext_index_name(workspace_label)
  260. legacy_index_name = "entity_id_fulltext_idx"
  261. try:
  262. async with driver.session(database=database) as session:
  263. # Check if the full-text index exists and get its configuration
  264. check_index_query = "SHOW FULLTEXT INDEXES"
  265. result = await session.run(check_index_query)
  266. indexes = await result.data()
  267. await result.consume()
  268. existing_index = None
  269. legacy_index = None
  270. for idx in indexes:
  271. if idx["name"] == index_name:
  272. existing_index = idx
  273. elif idx["name"] == legacy_index_name:
  274. legacy_index = idx
  275. # Break early if we found both indexes
  276. if existing_index and legacy_index:
  277. break
  278. # Handle legacy index migration
  279. if legacy_index and not existing_index:
  280. logger.info(
  281. f"[{self.workspace}] Found legacy index '{legacy_index_name}'. Migrating to '{index_name}'."
  282. )
  283. try:
  284. # Drop the legacy index (use IF EXISTS for safety)
  285. drop_query = f"DROP INDEX {legacy_index_name} IF EXISTS"
  286. result = await session.run(drop_query)
  287. await result.consume()
  288. logger.info(
  289. f"[{self.workspace}] Dropped legacy index '{legacy_index_name}'"
  290. )
  291. except Exception as drop_error:
  292. logger.warning(
  293. f"[{self.workspace}] Failed to drop legacy index: {str(drop_error)}"
  294. )
  295. # Check if index exists and is online
  296. if existing_index:
  297. index_state = existing_index.get("state", "UNKNOWN")
  298. logger.info(
  299. f"[{self.workspace}] Found existing index '{index_name}' with state: {index_state}"
  300. )
  301. if index_state == "ONLINE":
  302. logger.info(
  303. f"[{self.workspace}] Full-text index '{index_name}' already exists and is online. Skipping recreation."
  304. )
  305. return
  306. else:
  307. logger.warning(
  308. f"[{self.workspace}] Existing index '{index_name}' is not online (state: {index_state}). Will recreate."
  309. )
  310. else:
  311. logger.info(
  312. f"[{self.workspace}] No existing index '{index_name}' found. Creating new index."
  313. )
  314. # Create or recreate the index if needed
  315. needs_recreation = (
  316. existing_index is not None
  317. and existing_index.get("state") != "ONLINE"
  318. )
  319. needs_creation = existing_index is None
  320. if needs_recreation or needs_creation:
  321. # Drop existing index if it needs recreation (use IF EXISTS for safety)
  322. if needs_recreation:
  323. try:
  324. drop_query = f"DROP INDEX {index_name} IF EXISTS"
  325. result = await session.run(drop_query)
  326. await result.consume()
  327. logger.info(
  328. f"[{self.workspace}] Dropped existing index '{index_name}'"
  329. )
  330. except Exception as drop_error:
  331. logger.warning(
  332. f"[{self.workspace}] Failed to drop existing index: {str(drop_error)}"
  333. )
  334. # Create new index with CJK analyzer
  335. logger.info(
  336. f"[{self.workspace}] Creating full-text index '{index_name}' with Chinese tokenizer support."
  337. )
  338. try:
  339. create_index_query = f"""
  340. CREATE FULLTEXT INDEX {index_name}
  341. FOR (n:`{workspace_label}`) ON EACH [n.entity_id]
  342. OPTIONS {{
  343. indexConfig: {{
  344. `fulltext.analyzer`: 'cjk',
  345. `fulltext.eventually_consistent`: true
  346. }}
  347. }}
  348. """
  349. result = await session.run(create_index_query)
  350. await result.consume()
  351. logger.info(
  352. f"[{self.workspace}] Successfully created full-text index '{index_name}' with CJK analyzer."
  353. )
  354. except Exception as cjk_error:
  355. # Fallback to standard analyzer if CJK is not supported
  356. logger.warning(
  357. f"[{self.workspace}] CJK analyzer not supported: {str(cjk_error)}. "
  358. "Falling back to standard analyzer."
  359. )
  360. create_index_query = f"""
  361. CREATE FULLTEXT INDEX {index_name}
  362. FOR (n:`{workspace_label}`) ON EACH [n.entity_id]
  363. """
  364. result = await session.run(create_index_query)
  365. await result.consume()
  366. logger.info(
  367. f"[{self.workspace}] Successfully created full-text index '{index_name}' with standard analyzer."
  368. )
  369. except Exception as e:
  370. # Handle cases where the command might not be supported
  371. if "Unknown command" in str(e) or "invalid syntax" in str(e).lower():
  372. logger.warning(
  373. f"[{self.workspace}] Could not create or verify full-text index '{index_name}'. "
  374. "This might be because you are using a Neo4j version that does not support it. "
  375. "Search functionality will fall back to slower, non-indexed queries."
  376. )
  377. else:
  378. logger.error(
  379. f"[{self.workspace}] Failed to create or verify full-text index '{index_name}': {str(e)}"
  380. )
  381. async def finalize(self):
  382. """Close the Neo4j driver and release all resources"""
  383. if self._driver:
  384. await self._driver.close()
  385. self._driver = None
  386. async def __aexit__(self, exc_type, exc, tb):
  387. """Ensure driver is closed when context manager exits"""
  388. await self.finalize()
  389. async def index_done_callback(self) -> None:
  390. # Neo4J handles persistence automatically
  391. pass
  392. @READ_RETRY
  393. async def has_node(self, node_id: str) -> bool:
  394. """
  395. Check if a node with the given label exists in the database
  396. Args:
  397. node_id: Label of the node to check
  398. Returns:
  399. bool: True if node exists, False otherwise
  400. Raises:
  401. ValueError: If node_id is invalid
  402. Exception: If there is an error executing the query
  403. """
  404. workspace_label = self._get_workspace_label()
  405. async with self._driver.session(
  406. database=self._DATABASE, default_access_mode="READ"
  407. ) as session:
  408. result = None
  409. try:
  410. query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
  411. result = await session.run(query, entity_id=node_id)
  412. single_result = await result.single()
  413. await result.consume() # Ensure result is fully consumed
  414. return single_result["node_exists"]
  415. except Exception as e:
  416. logger.error(
  417. f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
  418. )
  419. if result is not None:
  420. await result.consume() # Ensure results are consumed even on error
  421. raise
  422. @READ_RETRY
  423. async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
  424. """
  425. Check if an edge exists between two nodes
  426. Args:
  427. source_node_id: Label of the source node
  428. target_node_id: Label of the target node
  429. Returns:
  430. bool: True if edge exists, False otherwise
  431. Raises:
  432. ValueError: If either node_id is invalid
  433. Exception: If there is an error executing the query
  434. """
  435. workspace_label = self._get_workspace_label()
  436. async with self._driver.session(
  437. database=self._DATABASE, default_access_mode="READ"
  438. ) as session:
  439. result = None
  440. try:
  441. query = (
  442. f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
  443. "RETURN COUNT(r) > 0 AS edgeExists"
  444. )
  445. result = await session.run(
  446. query,
  447. source_entity_id=source_node_id,
  448. target_entity_id=target_node_id,
  449. )
  450. single_result = await result.single()
  451. await result.consume() # Ensure result is fully consumed
  452. return single_result["edgeExists"]
  453. except Exception as e:
  454. logger.error(
  455. f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
  456. )
  457. if result is not None:
  458. await result.consume() # Ensure results are consumed even on error
  459. raise
  460. @READ_RETRY
  461. async def get_node(self, node_id: str) -> dict[str, str] | None:
  462. """Get node by its label identifier, return only node properties
  463. Args:
  464. node_id: The node label to look up
  465. Returns:
  466. dict: Node properties if found
  467. None: If node not found
  468. Raises:
  469. ValueError: If node_id is invalid
  470. Exception: If there is an error executing the query
  471. """
  472. workspace_label = self._get_workspace_label()
  473. async with self._driver.session(
  474. database=self._DATABASE, default_access_mode="READ"
  475. ) as session:
  476. try:
  477. query = (
  478. f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n"
  479. )
  480. result = await session.run(query, entity_id=node_id)
  481. try:
  482. records = await result.fetch(
  483. 2
  484. ) # Get 2 records for duplication check
  485. if len(records) > 1:
  486. logger.warning(
  487. f"[{self.workspace}] Multiple nodes found with label '{node_id}'. Using first node."
  488. )
  489. if records:
  490. node = records[0]["n"]
  491. node_dict = dict(node)
  492. # Remove workspace label from labels list if it exists
  493. if "labels" in node_dict:
  494. node_dict["labels"] = [
  495. label
  496. for label in node_dict["labels"]
  497. if label != workspace_label
  498. ]
  499. # logger.debug(f"Neo4j query node {query} return: {node_dict}")
  500. return node_dict
  501. return None
  502. finally:
  503. await result.consume() # Ensure result is fully consumed
  504. except Exception as e:
  505. logger.error(
  506. f"[{self.workspace}] Error getting node for {node_id}: {str(e)}"
  507. )
  508. raise
  509. @READ_RETRY
  510. async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
  511. """
  512. Retrieve multiple nodes in one query using UNWIND.
  513. Args:
  514. node_ids: List of node entity IDs to fetch.
  515. Returns:
  516. A dictionary mapping each node_id to its node data (or None if not found).
  517. """
  518. workspace_label = self._get_workspace_label()
  519. async with self._driver.session(
  520. database=self._DATABASE, default_access_mode="READ"
  521. ) as session:
  522. query = f"""
  523. UNWIND $node_ids AS id
  524. MATCH (n:`{workspace_label}` {{entity_id: id}})
  525. RETURN n.entity_id AS entity_id, n
  526. """
  527. result = await session.run(query, node_ids=node_ids)
  528. nodes = {}
  529. async for record in result:
  530. entity_id = record["entity_id"]
  531. node = record["n"]
  532. node_dict = dict(node)
  533. # Remove the workspace label if present in a 'labels' property
  534. if "labels" in node_dict:
  535. node_dict["labels"] = [
  536. label
  537. for label in node_dict["labels"]
  538. if label != workspace_label
  539. ]
  540. nodes[entity_id] = node_dict
  541. await result.consume() # Make sure to consume the result fully
  542. return nodes
  543. @READ_RETRY
  544. async def node_degree(self, node_id: str) -> int:
  545. """Get the degree (number of relationships) of a node with the given label.
  546. If multiple nodes have the same label, returns the degree of the first node.
  547. If no node is found, returns 0.
  548. Args:
  549. node_id: The label of the node
  550. Returns:
  551. int: The number of relationships the node has, or 0 if no node found
  552. Raises:
  553. ValueError: If node_id is invalid
  554. Exception: If there is an error executing the query
  555. """
  556. workspace_label = self._get_workspace_label()
  557. async with self._driver.session(
  558. database=self._DATABASE, default_access_mode="READ"
  559. ) as session:
  560. try:
  561. query = f"""
  562. MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
  563. OPTIONAL MATCH (n)-[r]-()
  564. RETURN COUNT(r) AS degree
  565. """
  566. result = await session.run(query, entity_id=node_id)
  567. try:
  568. record = await result.single()
  569. if not record:
  570. logger.warning(
  571. f"[{self.workspace}] No node found with label '{node_id}'"
  572. )
  573. return 0
  574. degree = record["degree"]
  575. # logger.debug(
  576. # f"[{self.workspace}] Neo4j query node degree for {node_id} return: {degree}"
  577. # )
  578. return degree
  579. finally:
  580. await result.consume() # Ensure result is fully consumed
  581. except Exception as e:
  582. logger.error(
  583. f"[{self.workspace}] Error getting node degree for {node_id}: {str(e)}"
  584. )
  585. raise
  586. @READ_RETRY
  587. async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
  588. """
  589. Retrieve the degree for multiple nodes in a single query using UNWIND.
  590. Args:
  591. node_ids: List of node labels (entity_id values) to look up.
  592. Returns:
  593. A dictionary mapping each node_id to its degree (number of relationships).
  594. If a node is not found, its degree will be set to 0.
  595. """
  596. workspace_label = self._get_workspace_label()
  597. async with self._driver.session(
  598. database=self._DATABASE, default_access_mode="READ"
  599. ) as session:
  600. query = f"""
  601. UNWIND $node_ids AS id
  602. MATCH (n:`{workspace_label}` {{entity_id: id}})
  603. RETURN n.entity_id AS entity_id, count {{ (n)--() }} AS degree;
  604. """
  605. result = await session.run(query, node_ids=node_ids)
  606. degrees = {}
  607. async for record in result:
  608. entity_id = record["entity_id"]
  609. degrees[entity_id] = record["degree"]
  610. await result.consume() # Ensure result is fully consumed
  611. # For any node_id that did not return a record, set degree to 0.
  612. for nid in node_ids:
  613. if nid not in degrees:
  614. logger.warning(
  615. f"[{self.workspace}] No node found with label '{nid}'"
  616. )
  617. degrees[nid] = 0
  618. # logger.debug(f"[{self.workspace}] Neo4j batch node degree query returned: {degrees}")
  619. return degrees
  620. async def edge_degree(self, src_id: str, tgt_id: str) -> int:
  621. """Get the total degree (sum of relationships) of two nodes.
  622. Args:
  623. src_id: Label of the source node
  624. tgt_id: Label of the target node
  625. Returns:
  626. int: Sum of the degrees of both nodes
  627. """
  628. src_degree = await self.node_degree(src_id)
  629. trg_degree = await self.node_degree(tgt_id)
  630. # Convert None to 0 for addition
  631. src_degree = 0 if src_degree is None else src_degree
  632. trg_degree = 0 if trg_degree is None else trg_degree
  633. degrees = int(src_degree) + int(trg_degree)
  634. return degrees
  635. @READ_RETRY
  636. async def edge_degrees_batch(
  637. self, edge_pairs: list[tuple[str, str]]
  638. ) -> dict[tuple[str, str], int]:
  639. """
  640. Calculate the combined degree for each edge (sum of the source and target node degrees)
  641. in batch using the already implemented node_degrees_batch.
  642. Args:
  643. edge_pairs: List of (src, tgt) tuples.
  644. Returns:
  645. A dictionary mapping each (src, tgt) tuple to the sum of their degrees.
  646. """
  647. # Collect unique node IDs from all edge pairs.
  648. unique_node_ids = {src for src, _ in edge_pairs}
  649. unique_node_ids.update({tgt for _, tgt in edge_pairs})
  650. # Get degrees for all nodes in one go.
  651. degrees = await self.node_degrees_batch(list(unique_node_ids))
  652. # Sum up degrees for each edge pair.
  653. edge_degrees = {}
  654. for src, tgt in edge_pairs:
  655. edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
  656. return edge_degrees
  657. @READ_RETRY
  658. async def get_edge(
  659. self, source_node_id: str, target_node_id: str
  660. ) -> dict[str, str] | None:
  661. """Get edge properties between two nodes.
  662. Args:
  663. source_node_id: Label of the source node
  664. target_node_id: Label of the target node
  665. Returns:
  666. dict: Edge properties if found, default properties if not found or on error
  667. Raises:
  668. ValueError: If either node_id is invalid
  669. Exception: If there is an error executing the query
  670. """
  671. workspace_label = self._get_workspace_label()
  672. try:
  673. async with self._driver.session(
  674. database=self._DATABASE, default_access_mode="READ"
  675. ) as session:
  676. query = f"""
  677. MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}})
  678. RETURN properties(r) as edge_properties
  679. """
  680. result = await session.run(
  681. query,
  682. source_entity_id=source_node_id,
  683. target_entity_id=target_node_id,
  684. )
  685. try:
  686. records = await result.fetch(2)
  687. if len(records) > 1:
  688. logger.warning(
  689. f"[{self.workspace}] Multiple edges found between '{source_node_id}' and '{target_node_id}'. Using first edge."
  690. )
  691. if records:
  692. try:
  693. edge_result = dict(records[0]["edge_properties"])
  694. # logger.debug(f"Result: {edge_result}")
  695. # Ensure required keys exist with defaults
  696. required_keys = {
  697. "weight": 1.0,
  698. "source_id": None,
  699. "description": None,
  700. "keywords": None,
  701. }
  702. for key, default_value in required_keys.items():
  703. if key not in edge_result:
  704. edge_result[key] = default_value
  705. logger.warning(
  706. f"[{self.workspace}] Edge between {source_node_id} and {target_node_id} "
  707. f"missing {key}, using default: {default_value}"
  708. )
  709. # logger.debug(
  710. # f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}"
  711. # )
  712. return edge_result
  713. except (KeyError, TypeError, ValueError) as e:
  714. logger.error(
  715. f"[{self.workspace}] Error processing edge properties between {source_node_id} "
  716. f"and {target_node_id}: {str(e)}"
  717. )
  718. # Return default edge properties on error
  719. return {
  720. "weight": 1.0,
  721. "source_id": None,
  722. "description": None,
  723. "keywords": None,
  724. }
  725. # logger.debug(
  726. # f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}"
  727. # )
  728. # Return None when no edge found
  729. return None
  730. finally:
  731. await result.consume() # Ensure result is fully consumed
  732. except Exception as e:
  733. logger.error(
  734. f"[{self.workspace}] Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
  735. )
  736. raise
  737. @READ_RETRY
  738. async def get_edges_batch(
  739. self, pairs: list[dict[str, str]]
  740. ) -> dict[tuple[str, str], dict]:
  741. """
  742. Retrieve edge properties for multiple (src, tgt) pairs in one query.
  743. Args:
  744. pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
  745. Returns:
  746. A dictionary mapping (src, tgt) tuples to their edge properties.
  747. """
  748. workspace_label = self._get_workspace_label()
  749. async with self._driver.session(
  750. database=self._DATABASE, default_access_mode="READ"
  751. ) as session:
  752. query = f"""
  753. UNWIND $pairs AS pair
  754. MATCH (start:`{workspace_label}` {{entity_id: pair.src}})-[r:DIRECTED]-(end:`{workspace_label}` {{entity_id: pair.tgt}})
  755. RETURN pair.src AS src_id, pair.tgt AS tgt_id, collect(properties(r)) AS edges
  756. """
  757. result = await session.run(query, pairs=pairs)
  758. edges_dict = {}
  759. async for record in result:
  760. src = record["src_id"]
  761. tgt = record["tgt_id"]
  762. edges = record["edges"]
  763. if edges and len(edges) > 0:
  764. edge_props = edges[0] # choose the first if multiple exist
  765. # Ensure required keys exist with defaults
  766. for key, default in {
  767. "weight": 1.0,
  768. "source_id": None,
  769. "description": None,
  770. "keywords": None,
  771. }.items():
  772. if key not in edge_props:
  773. edge_props[key] = default
  774. edges_dict[(src, tgt)] = edge_props
  775. else:
  776. # No edge found – set default edge properties
  777. edges_dict[(src, tgt)] = {
  778. "weight": 1.0,
  779. "source_id": None,
  780. "description": None,
  781. "keywords": None,
  782. }
  783. await result.consume()
  784. return edges_dict
  785. @READ_RETRY
  786. async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
  787. """Retrieves all edges (relationships) for a particular node identified by its label.
  788. Args:
  789. source_node_id: Label of the node to get edges for
  790. Returns:
  791. list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
  792. None: If no edges found
  793. Raises:
  794. ValueError: If source_node_id is invalid
  795. Exception: If there is an error executing the query
  796. """
  797. try:
  798. async with self._driver.session(
  799. database=self._DATABASE, default_access_mode="READ"
  800. ) as session:
  801. results = None
  802. try:
  803. workspace_label = self._get_workspace_label()
  804. query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
  805. OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`)
  806. WHERE connected.entity_id IS NOT NULL
  807. RETURN n, r, connected"""
  808. results = await session.run(query, entity_id=source_node_id)
  809. edges = []
  810. async for record in results:
  811. source_node = record["n"]
  812. connected_node = record["connected"]
  813. # Skip if either node is None
  814. if not source_node or not connected_node:
  815. continue
  816. source_label = (
  817. source_node.get("entity_id")
  818. if source_node.get("entity_id")
  819. else None
  820. )
  821. target_label = (
  822. connected_node.get("entity_id")
  823. if connected_node.get("entity_id")
  824. else None
  825. )
  826. if source_label and target_label:
  827. edges.append((source_label, target_label))
  828. await results.consume() # Ensure results are consumed
  829. return edges
  830. except Exception as e:
  831. logger.error(
  832. f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
  833. )
  834. if results is not None:
  835. await (
  836. results.consume()
  837. ) # Ensure results are consumed even on error
  838. raise
  839. except Exception as e:
  840. logger.error(
  841. f"[{self.workspace}] Error in get_node_edges for {source_node_id}: {str(e)}"
  842. )
  843. raise
  844. @READ_RETRY
  845. async def get_nodes_edges_batch(
  846. self, node_ids: list[str]
  847. ) -> dict[str, list[tuple[str, str]]]:
  848. """
  849. Batch retrieve edges for multiple nodes in one query using UNWIND.
  850. For each node, returns both outgoing and incoming edges to properly represent
  851. the undirected graph nature.
  852. Args:
  853. node_ids: List of node IDs (entity_id) for which to retrieve edges.
  854. Returns:
  855. A dictionary mapping each node ID to its list of edge tuples (source, target).
  856. For each node, the list includes both:
  857. - Outgoing edges: (queried_node, connected_node)
  858. - Incoming edges: (connected_node, queried_node)
  859. """
  860. async with self._driver.session(
  861. database=self._DATABASE, default_access_mode="READ"
  862. ) as session:
  863. # Query to get both outgoing and incoming edges
  864. workspace_label = self._get_workspace_label()
  865. query = f"""
  866. UNWIND $node_ids AS id
  867. MATCH (n:`{workspace_label}` {{entity_id: id}})
  868. OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`)
  869. RETURN id AS queried_id, n.entity_id AS node_entity_id,
  870. connected.entity_id AS connected_entity_id,
  871. startNode(r).entity_id AS start_entity_id
  872. """
  873. result = await session.run(query, node_ids=node_ids)
  874. # Initialize the dictionary with empty lists for each node ID
  875. edges_dict = {node_id: [] for node_id in node_ids}
  876. # Process results to include both outgoing and incoming edges
  877. async for record in result:
  878. queried_id = record["queried_id"]
  879. node_entity_id = record["node_entity_id"]
  880. connected_entity_id = record["connected_entity_id"]
  881. start_entity_id = record["start_entity_id"]
  882. # Skip if either node is None
  883. if not node_entity_id or not connected_entity_id:
  884. continue
  885. # Determine the actual direction of the edge
  886. # If the start node is the queried node, it's an outgoing edge
  887. # Otherwise, it's an incoming edge
  888. if start_entity_id == node_entity_id:
  889. # Outgoing edge: (queried_node -> connected_node)
  890. edges_dict[queried_id].append((node_entity_id, connected_entity_id))
  891. else:
  892. # Incoming edge: (connected_node -> queried_node)
  893. edges_dict[queried_id].append((connected_entity_id, node_entity_id))
  894. await result.consume() # Ensure results are fully consumed
  895. return edges_dict
  896. @retry(
  897. stop=stop_after_attempt(3),
  898. wait=wait_exponential(multiplier=1, min=4, max=10),
  899. retry=retry_if_exception_type(
  900. (
  901. neo4jExceptions.ServiceUnavailable,
  902. neo4jExceptions.TransientError,
  903. neo4jExceptions.WriteServiceUnavailable,
  904. neo4jExceptions.ClientError,
  905. neo4jExceptions.SessionExpired,
  906. ConnectionResetError,
  907. OSError,
  908. )
  909. ),
  910. )
  911. async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
  912. """
  913. Upsert a node in the Neo4j database.
  914. Args:
  915. node_id: The unique identifier for the node (used as label)
  916. node_data: Dictionary of node properties
  917. """
  918. workspace_label = self._get_workspace_label()
  919. properties = node_data
  920. if "entity_id" not in properties:
  921. raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
  922. try:
  923. async with self._driver.session(database=self._DATABASE) as session:
  924. async def execute_upsert(tx: AsyncManagedTransaction):
  925. query = f"""
  926. MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
  927. SET n += $properties
  928. """
  929. result = await tx.run(
  930. query, entity_id=node_id, properties=properties
  931. )
  932. await result.consume() # Ensure result is fully consumed
  933. await session.execute_write(execute_upsert)
  934. except Exception as e:
  935. logger.error(f"[{self.workspace}] Error during upsert: {str(e)}")
  936. raise
  937. @retry(
  938. stop=stop_after_attempt(3),
  939. wait=wait_exponential(multiplier=1, min=4, max=10),
  940. retry=retry_if_exception_type(
  941. (
  942. neo4jExceptions.ServiceUnavailable,
  943. neo4jExceptions.TransientError,
  944. neo4jExceptions.WriteServiceUnavailable,
  945. neo4jExceptions.ClientError,
  946. neo4jExceptions.SessionExpired,
  947. ConnectionResetError,
  948. OSError,
  949. )
  950. ),
  951. )
  952. async def upsert_nodes_batch(self, nodes: list[tuple[str, dict[str, str]]]) -> None:
  953. """Batch insert/update multiple nodes using a single UNWIND Cypher query.
  954. Significantly faster than calling upsert_node() in a loop for large imports
  955. because it executes all merges in one round-trip to the database.
  956. Args:
  957. nodes: List of (node_id, node_data) tuples.
  958. """
  959. if not nodes:
  960. return
  961. workspace_label = self._get_workspace_label()
  962. nodes_data = []
  963. for node_id, node_data in nodes:
  964. if "entity_id" not in node_data:
  965. raise ValueError(
  966. "Neo4j: node properties must contain an 'entity_id' field"
  967. )
  968. nodes_data.append({"entity_id": node_id, "props": node_data})
  969. try:
  970. async with self._driver.session(database=self._DATABASE) as session:
  971. async def execute_batch(tx: AsyncManagedTransaction):
  972. query = f"""
  973. UNWIND $nodes AS row
  974. MERGE (n:`{workspace_label}` {{entity_id: row.entity_id}})
  975. SET n += row.props
  976. """
  977. result = await tx.run(query, nodes=nodes_data)
  978. await result.consume()
  979. await session.execute_write(execute_batch)
  980. except Exception as e:
  981. logger.error(f"[{self.workspace}] Error during batch node upsert: {str(e)}")
  982. raise
  983. @READ_RETRY
  984. async def has_nodes_batch(self, node_ids: list[str]) -> set[str]:
  985. """Check existence of multiple nodes in a single UNWIND query.
  986. Args:
  987. node_ids: List of node IDs to check.
  988. Returns:
  989. Set of node_ids that exist in the graph.
  990. """
  991. if not node_ids:
  992. return set()
  993. workspace_label = self._get_workspace_label()
  994. try:
  995. async with self._driver.session(
  996. database=self._DATABASE, default_access_mode="READ"
  997. ) as session:
  998. query = f"""
  999. UNWIND $ids AS id
  1000. MATCH (n:`{workspace_label}` {{entity_id: id}})
  1001. RETURN n.entity_id AS entity_id
  1002. """
  1003. result = await session.run(query, ids=node_ids)
  1004. records = await result.data()
  1005. await result.consume()
  1006. return {r["entity_id"] for r in records}
  1007. except Exception as e:
  1008. logger.error(
  1009. f"[{self.workspace}] Error during batch node existence check: {str(e)}"
  1010. )
  1011. raise
  1012. @retry(
  1013. stop=stop_after_attempt(3),
  1014. wait=wait_exponential(multiplier=1, min=4, max=10),
  1015. retry=retry_if_exception_type(
  1016. (
  1017. neo4jExceptions.ServiceUnavailable,
  1018. neo4jExceptions.TransientError,
  1019. neo4jExceptions.WriteServiceUnavailable,
  1020. neo4jExceptions.ClientError,
  1021. neo4jExceptions.SessionExpired,
  1022. ConnectionResetError,
  1023. OSError,
  1024. )
  1025. ),
  1026. )
  1027. async def upsert_edges_batch(
  1028. self, edges: list[tuple[str, str, dict[str, str]]]
  1029. ) -> None:
  1030. """Batch insert/update multiple edges using a single UNWIND Cypher query.
  1031. Args:
  1032. edges: List of (source_node_id, target_node_id, edge_data) tuples.
  1033. """
  1034. if not edges:
  1035. return
  1036. workspace_label = self._get_workspace_label()
  1037. edges_data = [
  1038. {"src": src, "tgt": tgt, "props": edge_data}
  1039. for src, tgt, edge_data in edges
  1040. ]
  1041. try:
  1042. async with self._driver.session(database=self._DATABASE) as session:
  1043. async def execute_batch(tx: AsyncManagedTransaction):
  1044. query = f"""
  1045. UNWIND $edges AS row
  1046. MATCH (source:`{workspace_label}` {{entity_id: row.src}})
  1047. WITH source, row
  1048. MATCH (target:`{workspace_label}` {{entity_id: row.tgt}})
  1049. MERGE (source)-[r:DIRECTED]-(target)
  1050. SET r += row.props
  1051. """
  1052. result = await tx.run(query, edges=edges_data)
  1053. await result.consume()
  1054. await session.execute_write(execute_batch)
  1055. except Exception as e:
  1056. logger.error(f"[{self.workspace}] Error during batch edge upsert: {str(e)}")
  1057. raise
  1058. @retry(
  1059. stop=stop_after_attempt(3),
  1060. wait=wait_exponential(multiplier=1, min=4, max=10),
  1061. retry=retry_if_exception_type(
  1062. (
  1063. neo4jExceptions.ServiceUnavailable,
  1064. neo4jExceptions.TransientError,
  1065. neo4jExceptions.WriteServiceUnavailable,
  1066. neo4jExceptions.ClientError,
  1067. neo4jExceptions.SessionExpired,
  1068. ConnectionResetError,
  1069. OSError,
  1070. )
  1071. ),
  1072. )
  1073. async def upsert_edge(
  1074. self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
  1075. ) -> None:
  1076. """
  1077. Upsert an edge and its properties between two nodes identified by their labels.
  1078. Ensures both source and target nodes exist and are unique before creating the edge.
  1079. Uses entity_id property to uniquely identify nodes.
  1080. Args:
  1081. source_node_id (str): Label of the source node (used as identifier)
  1082. target_node_id (str): Label of the target node (used as identifier)
  1083. edge_data (dict): Dictionary of properties to set on the edge
  1084. Raises:
  1085. ValueError: If either source or target node does not exist or is not unique
  1086. """
  1087. try:
  1088. edge_properties = edge_data
  1089. async with self._driver.session(database=self._DATABASE) as session:
  1090. async def execute_upsert(tx: AsyncManagedTransaction):
  1091. workspace_label = self._get_workspace_label()
  1092. query = f"""
  1093. MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
  1094. WITH source
  1095. MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
  1096. MERGE (source)-[r:DIRECTED]-(target)
  1097. SET r += $properties
  1098. RETURN r, source, target
  1099. """
  1100. result = await tx.run(
  1101. query,
  1102. source_entity_id=source_node_id,
  1103. target_entity_id=target_node_id,
  1104. properties=edge_properties,
  1105. )
  1106. try:
  1107. await result.fetch(2)
  1108. finally:
  1109. await result.consume() # Ensure result is consumed
  1110. await session.execute_write(execute_upsert)
  1111. except Exception as e:
  1112. logger.error(f"[{self.workspace}] Error during edge upsert: {str(e)}")
  1113. raise
  1114. async def get_knowledge_graph(
  1115. self,
  1116. node_label: str,
  1117. max_depth: int = 3,
  1118. max_nodes: int = None,
  1119. ) -> KnowledgeGraph:
  1120. """
  1121. Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
  1122. Args:
  1123. node_label: Label of the starting node, * means all nodes
  1124. max_depth: Maximum depth of the subgraph, Defaults to 3
  1125. max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
  1126. Returns:
  1127. KnowledgeGraph object containing nodes and edges, with an is_truncated flag
  1128. indicating whether the graph was truncated due to max_nodes limit
  1129. """
  1130. # Get max_nodes from global_config if not provided
  1131. if max_nodes is None:
  1132. max_nodes = self.global_config.get("max_graph_nodes", 1000)
  1133. else:
  1134. # Limit max_nodes to not exceed global_config max_graph_nodes
  1135. max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
  1136. workspace_label = self._get_workspace_label()
  1137. result = KnowledgeGraph()
  1138. seen_nodes = set()
  1139. seen_edges = set()
  1140. async with self._driver.session(
  1141. database=self._DATABASE, default_access_mode="READ"
  1142. ) as session:
  1143. try:
  1144. if node_label == "*":
  1145. # First check total node count to determine if graph is truncated
  1146. count_query = (
  1147. f"MATCH (n:`{workspace_label}`) RETURN count(n) as total"
  1148. )
  1149. count_result = None
  1150. try:
  1151. count_result = await session.run(count_query)
  1152. count_record = await count_result.single()
  1153. if count_record and count_record["total"] > max_nodes:
  1154. result.is_truncated = True
  1155. logger.info(
  1156. f"[{self.workspace}] Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
  1157. )
  1158. finally:
  1159. if count_result:
  1160. await count_result.consume()
  1161. # Run main query to get nodes with highest degree
  1162. main_query = f"""
  1163. MATCH (n:`{workspace_label}`)
  1164. OPTIONAL MATCH (n)-[r]-()
  1165. WITH n, COALESCE(count(r), 0) AS degree
  1166. ORDER BY degree DESC
  1167. LIMIT $max_nodes
  1168. WITH collect({{node: n}}) AS filtered_nodes
  1169. UNWIND filtered_nodes AS node_info
  1170. WITH collect(node_info.node) AS kept_nodes, filtered_nodes
  1171. OPTIONAL MATCH (a)-[r]-(b)
  1172. WHERE a IN kept_nodes AND b IN kept_nodes
  1173. RETURN filtered_nodes AS node_info,
  1174. collect(DISTINCT r) AS relationships
  1175. """
  1176. result_set = None
  1177. try:
  1178. result_set = await session.run(
  1179. main_query,
  1180. {"max_nodes": max_nodes},
  1181. )
  1182. record = await result_set.single()
  1183. finally:
  1184. if result_set:
  1185. await result_set.consume()
  1186. else:
  1187. # return await self._robust_fallback(node_label, max_depth, max_nodes)
  1188. # First try without limit to check if we need to truncate
  1189. full_query = f"""
  1190. MATCH (start:`{workspace_label}`)
  1191. WHERE start.entity_id = $entity_id
  1192. WITH start
  1193. CALL apoc.path.subgraphAll(start, {{
  1194. relationshipFilter: '',
  1195. labelFilter: '{workspace_label}',
  1196. minLevel: 0,
  1197. maxLevel: $max_depth,
  1198. bfs: true
  1199. }})
  1200. YIELD nodes, relationships
  1201. WITH nodes, relationships, size(nodes) AS total_nodes
  1202. UNWIND nodes AS node
  1203. WITH collect({{node: node}}) AS node_info, relationships, total_nodes
  1204. RETURN node_info, relationships, total_nodes
  1205. """
  1206. # Try to get full result
  1207. full_result = None
  1208. try:
  1209. full_result = await session.run(
  1210. full_query,
  1211. {
  1212. "entity_id": node_label,
  1213. "max_depth": max_depth,
  1214. },
  1215. )
  1216. full_record = await full_result.single()
  1217. # If no record found, return empty KnowledgeGraph
  1218. if not full_record:
  1219. logger.debug(
  1220. f"[{self.workspace}] No nodes found for entity_id: {node_label}"
  1221. )
  1222. return result
  1223. # If record found, check node count
  1224. total_nodes = full_record["total_nodes"]
  1225. if total_nodes <= max_nodes:
  1226. # If node count is within limit, use full result directly
  1227. logger.debug(
  1228. f"[{self.workspace}] Using full result with {total_nodes} nodes (no truncation needed)"
  1229. )
  1230. record = full_record
  1231. else:
  1232. # If node count exceeds limit, set truncated flag and run limited query
  1233. result.is_truncated = True
  1234. logger.info(
  1235. f"[{self.workspace}] Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
  1236. )
  1237. # Run limited query
  1238. limited_query = f"""
  1239. MATCH (start:`{workspace_label}`)
  1240. WHERE start.entity_id = $entity_id
  1241. WITH start
  1242. CALL apoc.path.subgraphAll(start, {{
  1243. relationshipFilter: '',
  1244. labelFilter: '{workspace_label}',
  1245. minLevel: 0,
  1246. maxLevel: $max_depth,
  1247. limit: $max_nodes,
  1248. bfs: true
  1249. }})
  1250. YIELD nodes, relationships
  1251. UNWIND nodes AS node
  1252. WITH collect({{node: node}}) AS node_info, relationships
  1253. RETURN node_info, relationships
  1254. """
  1255. result_set = None
  1256. try:
  1257. result_set = await session.run(
  1258. limited_query,
  1259. {
  1260. "entity_id": node_label,
  1261. "max_depth": max_depth,
  1262. "max_nodes": max_nodes,
  1263. },
  1264. )
  1265. record = await result_set.single()
  1266. finally:
  1267. if result_set:
  1268. await result_set.consume()
  1269. finally:
  1270. if full_result:
  1271. await full_result.consume()
  1272. if record:
  1273. # Handle nodes (compatible with multi-label cases)
  1274. for node_info in record["node_info"]:
  1275. node = node_info["node"]
  1276. node_id = node.id
  1277. if node_id not in seen_nodes:
  1278. result.nodes.append(
  1279. KnowledgeGraphNode(
  1280. id=f"{node_id}",
  1281. labels=[node.get("entity_id")],
  1282. properties=dict(node),
  1283. )
  1284. )
  1285. seen_nodes.add(node_id)
  1286. # Handle relationships (including direction information)
  1287. for rel in record["relationships"]:
  1288. edge_id = rel.id
  1289. if edge_id not in seen_edges:
  1290. start = rel.start_node
  1291. end = rel.end_node
  1292. result.edges.append(
  1293. KnowledgeGraphEdge(
  1294. id=f"{edge_id}",
  1295. type=rel.type,
  1296. source=f"{start.id}",
  1297. target=f"{end.id}",
  1298. properties=dict(rel),
  1299. )
  1300. )
  1301. seen_edges.add(edge_id)
  1302. logger.info(
  1303. f"[{self.workspace}] Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
  1304. )
  1305. except neo4jExceptions.ClientError as e:
  1306. logger.warning(f"[{self.workspace}] APOC plugin error: {str(e)}")
  1307. if node_label != "*":
  1308. logger.warning(
  1309. f"[{self.workspace}] Neo4j: falling back to basic Cypher recursive search..."
  1310. )
  1311. return await self._robust_fallback(node_label, max_depth, max_nodes)
  1312. else:
  1313. logger.warning(
  1314. f"[{self.workspace}] Neo4j: APOC plugin error with wildcard query, returning empty result"
  1315. )
  1316. return result
  1317. async def _robust_fallback(
  1318. self, node_label: str, max_depth: int, max_nodes: int
  1319. ) -> KnowledgeGraph:
  1320. """
  1321. Fallback implementation when APOC plugin is not available or incompatible.
  1322. This method implements the same functionality as get_knowledge_graph but uses
  1323. only basic Cypher queries and true breadth-first traversal instead of APOC procedures.
  1324. """
  1325. from collections import deque
  1326. result = KnowledgeGraph()
  1327. visited_nodes = set()
  1328. visited_edges = set()
  1329. visited_edge_pairs = set()
  1330. # Get the starting node's data
  1331. workspace_label = self._get_workspace_label()
  1332. async with self._driver.session(
  1333. database=self._DATABASE, default_access_mode="READ"
  1334. ) as session:
  1335. query = f"""
  1336. MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
  1337. RETURN id(n) as node_id, n
  1338. """
  1339. node_result = await session.run(query, entity_id=node_label)
  1340. try:
  1341. node_record = await node_result.single()
  1342. if not node_record:
  1343. return result
  1344. # Create initial KnowledgeGraphNode
  1345. start_node = KnowledgeGraphNode(
  1346. id=f"{node_record['n'].get('entity_id')}",
  1347. labels=[node_record["n"].get("entity_id")],
  1348. properties=dict(node_record["n"]._properties),
  1349. )
  1350. finally:
  1351. await node_result.consume() # Ensure results are consumed
  1352. # Initialize queue for BFS with (node, edge, depth) tuples
  1353. # edge is None for the starting node
  1354. queue = deque([(start_node, None, 0)])
  1355. # True BFS implementation using a queue
  1356. while queue and len(visited_nodes) < max_nodes:
  1357. # Dequeue the next node to process
  1358. current_node, current_edge, current_depth = queue.popleft()
  1359. # Skip if already visited or exceeds max depth
  1360. if current_node.id in visited_nodes:
  1361. continue
  1362. if current_depth > max_depth:
  1363. logger.debug(
  1364. f"[{self.workspace}] Skipping node at depth {current_depth} (max_depth: {max_depth})"
  1365. )
  1366. continue
  1367. # Add current node to result
  1368. result.nodes.append(current_node)
  1369. visited_nodes.add(current_node.id)
  1370. # Add edge to result if it exists and not already added
  1371. if current_edge and current_edge.id not in visited_edges:
  1372. result.edges.append(current_edge)
  1373. visited_edges.add(current_edge.id)
  1374. # Stop if we've reached the node limit
  1375. if len(visited_nodes) >= max_nodes:
  1376. result.is_truncated = True
  1377. logger.info(
  1378. f"[{self.workspace}] Graph truncated: breadth-first search limited to: {max_nodes} nodes"
  1379. )
  1380. break
  1381. # Get all edges and target nodes for the current node (even at max_depth)
  1382. async with self._driver.session(
  1383. database=self._DATABASE, default_access_mode="READ"
  1384. ) as session:
  1385. workspace_label = self._get_workspace_label()
  1386. query = f"""
  1387. MATCH (a:`{workspace_label}` {{entity_id: $entity_id}})-[r]-(b)
  1388. WITH r, b, id(r) as edge_id, id(b) as target_id
  1389. RETURN r, b, edge_id, target_id
  1390. """
  1391. results = await session.run(query, entity_id=current_node.id)
  1392. # Get all records and release database connection
  1393. records = await results.fetch(1000) # Max neighbor nodes we can handle
  1394. await results.consume() # Ensure results are consumed
  1395. # Process all neighbors - capture all edges but only queue unvisited nodes
  1396. for record in records:
  1397. rel = record["r"]
  1398. edge_id = str(record["edge_id"])
  1399. if edge_id not in visited_edges:
  1400. b_node = record["b"]
  1401. target_id = b_node.get("entity_id")
  1402. if target_id: # Only process if target node has entity_id
  1403. # Create KnowledgeGraphNode for target
  1404. target_node = KnowledgeGraphNode(
  1405. id=f"{target_id}",
  1406. labels=[target_id],
  1407. properties=dict(b_node._properties),
  1408. )
  1409. # Create KnowledgeGraphEdge
  1410. target_edge = KnowledgeGraphEdge(
  1411. id=f"{edge_id}",
  1412. type=rel.type,
  1413. source=f"{current_node.id}",
  1414. target=f"{target_id}",
  1415. properties=dict(rel),
  1416. )
  1417. # Sort source_id and target_id to ensure (A,B) and (B,A) are treated as the same edge
  1418. sorted_pair = tuple(sorted([current_node.id, target_id]))
  1419. # Check if the same edge already exists (considering undirectedness)
  1420. if sorted_pair not in visited_edge_pairs:
  1421. # Only add the edge if the target node is already in the result or will be added
  1422. if target_id in visited_nodes or (
  1423. target_id not in visited_nodes
  1424. and current_depth < max_depth
  1425. ):
  1426. result.edges.append(target_edge)
  1427. visited_edges.add(edge_id)
  1428. visited_edge_pairs.add(sorted_pair)
  1429. # Only add unvisited nodes to the queue for further expansion
  1430. if target_id not in visited_nodes:
  1431. # Only add to queue if we're not at max depth yet
  1432. if current_depth < max_depth:
  1433. # Add node to queue with incremented depth
  1434. # Edge is already added to result, so we pass None as edge
  1435. queue.append((target_node, None, current_depth + 1))
  1436. else:
  1437. # At max depth, we've already added the edge but we don't add the node
  1438. # This prevents adding nodes beyond max_depth to the result
  1439. logger.debug(
  1440. f"[{self.workspace}] Node {target_id} beyond max depth {max_depth}, edge added but node not included"
  1441. )
  1442. else:
  1443. # If target node already exists in result, we don't need to add it again
  1444. logger.debug(
  1445. f"[{self.workspace}] Node {target_id} already visited, edge added but node not queued"
  1446. )
  1447. else:
  1448. logger.warning(
  1449. f"[{self.workspace}] Skipping edge {edge_id} due to missing entity_id on target node"
  1450. )
  1451. logger.info(
  1452. f"[{self.workspace}] BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
  1453. )
  1454. return result
  1455. async def get_all_labels(self) -> list[str]:
  1456. """
  1457. Get all existing entity_ids(entity names) in the database
  1458. Returns:
  1459. ["Person", "Company", ...] # Alphabetically sorted label list
  1460. """
  1461. workspace_label = self._get_workspace_label()
  1462. async with self._driver.session(
  1463. database=self._DATABASE, default_access_mode="READ"
  1464. ) as session:
  1465. # Method 1: Direct metadata query (Available for Neo4j 4.3+)
  1466. # query = "CALL db.labels() YIELD label RETURN label"
  1467. # Method 2: Query compatible with older versions
  1468. query = f"""
  1469. MATCH (n:`{workspace_label}`)
  1470. WHERE n.entity_id IS NOT NULL
  1471. RETURN DISTINCT n.entity_id AS label
  1472. ORDER BY label
  1473. """
  1474. result = await session.run(query)
  1475. labels = []
  1476. try:
  1477. async for record in result:
  1478. labels.append(record["label"])
  1479. finally:
  1480. await (
  1481. result.consume()
  1482. ) # Ensure results are consumed even if processing fails
  1483. return labels
  1484. @retry(
  1485. stop=stop_after_attempt(3),
  1486. wait=wait_exponential(multiplier=1, min=4, max=10),
  1487. retry=retry_if_exception_type(
  1488. (
  1489. neo4jExceptions.ServiceUnavailable,
  1490. neo4jExceptions.TransientError,
  1491. neo4jExceptions.WriteServiceUnavailable,
  1492. neo4jExceptions.ClientError,
  1493. neo4jExceptions.SessionExpired,
  1494. ConnectionResetError,
  1495. OSError,
  1496. )
  1497. ),
  1498. )
  1499. async def delete_node(self, node_id: str) -> None:
  1500. """Delete a node with the specified label
  1501. Args:
  1502. node_id: The label of the node to delete
  1503. """
  1504. async def _do_delete(tx: AsyncManagedTransaction):
  1505. workspace_label = self._get_workspace_label()
  1506. query = f"""
  1507. MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
  1508. DETACH DELETE n
  1509. """
  1510. result = await tx.run(query, entity_id=node_id)
  1511. logger.debug(f"[{self.workspace}] Deleted node with label '{node_id}'")
  1512. await result.consume() # Ensure result is fully consumed
  1513. try:
  1514. async with self._driver.session(database=self._DATABASE) as session:
  1515. await session.execute_write(_do_delete)
  1516. except Exception as e:
  1517. logger.error(f"[{self.workspace}] Error during node deletion: {str(e)}")
  1518. raise
  1519. @retry(
  1520. stop=stop_after_attempt(3),
  1521. wait=wait_exponential(multiplier=1, min=4, max=10),
  1522. retry=retry_if_exception_type(
  1523. (
  1524. neo4jExceptions.ServiceUnavailable,
  1525. neo4jExceptions.TransientError,
  1526. neo4jExceptions.WriteServiceUnavailable,
  1527. neo4jExceptions.ClientError,
  1528. neo4jExceptions.SessionExpired,
  1529. ConnectionResetError,
  1530. OSError,
  1531. )
  1532. ),
  1533. )
  1534. async def remove_nodes(self, nodes: list[str]):
  1535. """Delete multiple nodes
  1536. Args:
  1537. nodes: List of node labels to be deleted
  1538. """
  1539. for node in nodes:
  1540. await self.delete_node(node)
  1541. @retry(
  1542. stop=stop_after_attempt(3),
  1543. wait=wait_exponential(multiplier=1, min=4, max=10),
  1544. retry=retry_if_exception_type(
  1545. (
  1546. neo4jExceptions.ServiceUnavailable,
  1547. neo4jExceptions.TransientError,
  1548. neo4jExceptions.WriteServiceUnavailable,
  1549. neo4jExceptions.ClientError,
  1550. neo4jExceptions.SessionExpired,
  1551. ConnectionResetError,
  1552. OSError,
  1553. )
  1554. ),
  1555. )
  1556. async def remove_edges(self, edges: list[tuple[str, str]]):
  1557. """Delete multiple edges
  1558. Args:
  1559. edges: List of edges to be deleted, each edge is a (source, target) tuple
  1560. """
  1561. for source, target in edges:
  1562. async def _do_delete_edge(tx: AsyncManagedTransaction):
  1563. workspace_label = self._get_workspace_label()
  1564. query = f"""
  1565. MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}})
  1566. DELETE r
  1567. """
  1568. result = await tx.run(
  1569. query, source_entity_id=source, target_entity_id=target
  1570. )
  1571. logger.debug(
  1572. f"[{self.workspace}] Deleted edge from '{source}' to '{target}'"
  1573. )
  1574. await result.consume() # Ensure result is fully consumed
  1575. try:
  1576. async with self._driver.session(database=self._DATABASE) as session:
  1577. await session.execute_write(_do_delete_edge)
  1578. except Exception as e:
  1579. logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}")
  1580. raise
  1581. async def get_all_nodes(self) -> list[dict]:
  1582. """Get all nodes in the graph.
  1583. Returns:
  1584. A list of all nodes, where each node is a dictionary of its properties
  1585. """
  1586. workspace_label = self._get_workspace_label()
  1587. async with self._driver.session(
  1588. database=self._DATABASE, default_access_mode="READ"
  1589. ) as session:
  1590. query = f"""
  1591. MATCH (n:`{workspace_label}`)
  1592. RETURN n
  1593. """
  1594. result = await session.run(query)
  1595. nodes = []
  1596. async for record in result:
  1597. node = record["n"]
  1598. node_dict = dict(node)
  1599. # Add node id (entity_id) to the dictionary for easier access
  1600. node_dict["id"] = node_dict.get("entity_id")
  1601. nodes.append(node_dict)
  1602. await result.consume()
  1603. return nodes
  1604. async def get_all_edges(self) -> list[dict]:
  1605. """Get all edges in the graph.
  1606. Returns:
  1607. A list of all edges, where each edge is a dictionary of its properties
  1608. """
  1609. workspace_label = self._get_workspace_label()
  1610. async with self._driver.session(
  1611. database=self._DATABASE, default_access_mode="READ"
  1612. ) as session:
  1613. query = f"""
  1614. MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
  1615. RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
  1616. """
  1617. result = await session.run(query)
  1618. edges = []
  1619. async for record in result:
  1620. edge_properties = record["properties"]
  1621. edge_properties["source"] = record["source"]
  1622. edge_properties["target"] = record["target"]
  1623. edges.append(edge_properties)
  1624. await result.consume()
  1625. return edges
  1626. async def get_popular_labels(self, limit: int = 300) -> list[str]:
  1627. """Get popular labels(entity names) by node degree (most connected entities)
  1628. Args:
  1629. limit: Maximum number of labels to return
  1630. Returns:
  1631. List of labels(entity names) sorted by degree (highest first)
  1632. """
  1633. workspace_label = self._get_workspace_label()
  1634. async with self._driver.session(
  1635. database=self._DATABASE, default_access_mode="READ"
  1636. ) as session:
  1637. result = None
  1638. try:
  1639. query = f"""
  1640. MATCH (n:`{workspace_label}`)
  1641. WHERE n.entity_id IS NOT NULL
  1642. OPTIONAL MATCH (n)-[r]-()
  1643. WITH n.entity_id AS label, count(r) AS degree
  1644. ORDER BY degree DESC, label ASC
  1645. LIMIT $limit
  1646. RETURN label
  1647. """
  1648. result = await session.run(query, limit=limit)
  1649. labels = []
  1650. async for record in result:
  1651. labels.append(record["label"])
  1652. await result.consume()
  1653. logger.debug(
  1654. f"[{self.workspace}] Retrieved {len(labels)} popular labels (limit: {limit})"
  1655. )
  1656. return labels
  1657. except Exception as e:
  1658. logger.error(
  1659. f"[{self.workspace}] Error getting popular labels: {str(e)}"
  1660. )
  1661. if result is not None:
  1662. await result.consume()
  1663. raise
  1664. async def search_labels(self, query: str, limit: int = 50) -> list[str]:
  1665. """
  1666. Search labels(entity names) with fuzzy matching, using a full-text index for performance if available.
  1667. Enhanced with Chinese text support using CJK analyzer.
  1668. Falls back to a slower CONTAINS search if the index is not available or fails.
  1669. """
  1670. workspace_label = self._get_workspace_label()
  1671. query_strip = query.strip()
  1672. if not query_strip:
  1673. return []
  1674. query_lower = query_strip.lower()
  1675. is_chinese = self._is_chinese_text(query_strip)
  1676. index_name = self._get_fulltext_index_name(workspace_label)
  1677. # Attempt to use the full-text index first
  1678. try:
  1679. async with self._driver.session(
  1680. database=self._DATABASE, default_access_mode="READ"
  1681. ) as session:
  1682. if is_chinese:
  1683. # For Chinese text, use different search strategies
  1684. cypher_query = f"""
  1685. CALL db.index.fulltext.queryNodes($index_name, $search_query) YIELD node, score
  1686. WITH node, score
  1687. WHERE node:`{workspace_label}`
  1688. WITH node.entity_id AS label, score
  1689. WITH label, score,
  1690. CASE
  1691. WHEN label = $query_strip THEN score + 1000
  1692. WHEN label CONTAINS $query_strip THEN score + 500
  1693. ELSE score
  1694. END AS final_score
  1695. RETURN label
  1696. ORDER BY final_score DESC, label ASC
  1697. LIMIT $limit
  1698. """
  1699. # For Chinese, don't add wildcard as it may not work properly with CJK analyzer
  1700. search_query = query_strip
  1701. else:
  1702. # For non-Chinese text, use the original logic with wildcard
  1703. cypher_query = f"""
  1704. CALL db.index.fulltext.queryNodes($index_name, $search_query) YIELD node, score
  1705. WITH node, score
  1706. WHERE node:`{workspace_label}`
  1707. WITH node.entity_id AS label, toLower(node.entity_id) AS label_lower, score
  1708. WITH label, label_lower, score,
  1709. CASE
  1710. WHEN label_lower = $query_lower THEN score + 1000
  1711. WHEN label_lower STARTS WITH $query_lower THEN score + 500
  1712. WHEN label_lower CONTAINS ' ' + $query_lower OR label_lower CONTAINS '_' + $query_lower THEN score + 50
  1713. ELSE score
  1714. END AS final_score
  1715. RETURN label
  1716. ORDER BY final_score DESC, label ASC
  1717. LIMIT $limit
  1718. """
  1719. search_query = f"{query_strip}*"
  1720. result = await session.run(
  1721. cypher_query,
  1722. index_name=index_name,
  1723. search_query=search_query,
  1724. query_lower=query_lower,
  1725. query_strip=query_strip,
  1726. limit=limit,
  1727. )
  1728. labels = [record["label"] async for record in result]
  1729. await result.consume()
  1730. logger.debug(
  1731. f"[{self.workspace}] Full-text search ({'Chinese' if is_chinese else 'Latin'}) for '{query}' returned {len(labels)} results (limit: {limit})"
  1732. )
  1733. return labels
  1734. except Exception as e:
  1735. # If the full-text search fails, fall back to CONTAINS search
  1736. logger.warning(
  1737. f"[{self.workspace}] Full-text search failed with error: {str(e)}. "
  1738. "Falling back to slower, non-indexed search."
  1739. )
  1740. # Enhanced fallback implementation
  1741. async with self._driver.session(
  1742. database=self._DATABASE, default_access_mode="READ"
  1743. ) as session:
  1744. if is_chinese:
  1745. # For Chinese text, use direct CONTAINS without case conversion
  1746. cypher_query = f"""
  1747. MATCH (n:`{workspace_label}`)
  1748. WHERE n.entity_id IS NOT NULL
  1749. WITH n.entity_id AS label
  1750. WHERE label CONTAINS $query_strip
  1751. WITH label,
  1752. CASE
  1753. WHEN label = $query_strip THEN 1000
  1754. WHEN label STARTS WITH $query_strip THEN 500
  1755. ELSE 100 - size(label)
  1756. END AS score
  1757. ORDER BY score DESC, label ASC
  1758. LIMIT $limit
  1759. RETURN label
  1760. """
  1761. result = await session.run(
  1762. cypher_query, query_strip=query_strip, limit=limit
  1763. )
  1764. else:
  1765. # For non-Chinese text, use the original fallback logic
  1766. cypher_query = f"""
  1767. MATCH (n:`{workspace_label}`)
  1768. WHERE n.entity_id IS NOT NULL
  1769. WITH n.entity_id AS label, toLower(n.entity_id) AS label_lower
  1770. WHERE label_lower CONTAINS $query_lower
  1771. WITH label, label_lower,
  1772. CASE
  1773. WHEN label_lower = $query_lower THEN 1000
  1774. WHEN label_lower STARTS WITH $query_lower THEN 500
  1775. ELSE 100 - size(label)
  1776. END AS score
  1777. ORDER BY score DESC, label ASC
  1778. LIMIT $limit
  1779. RETURN label
  1780. """
  1781. result = await session.run(
  1782. cypher_query, query_lower=query_lower, limit=limit
  1783. )
  1784. labels = [record["label"] async for record in result]
  1785. await result.consume()
  1786. logger.debug(
  1787. f"[{self.workspace}] Fallback search ({'Chinese' if is_chinese else 'Latin'}) for '{query}' returned {len(labels)} results (limit: {limit})"
  1788. )
  1789. return labels
  1790. async def drop(self) -> dict[str, str]:
  1791. """Drop all data from current workspace storage and clean up resources
  1792. This method will delete all nodes and relationships in the current workspace only.
  1793. Returns:
  1794. dict[str, str]: Operation status and message
  1795. - On success: {"status": "success", "message": "workspace data dropped"}
  1796. - On failure: {"status": "error", "message": "<error details>"}
  1797. """
  1798. workspace_label = self._get_workspace_label()
  1799. try:
  1800. async with self._driver.session(database=self._DATABASE) as session:
  1801. # Delete all nodes and relationships in current workspace only
  1802. query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
  1803. result = await session.run(query)
  1804. await result.consume() # Ensure result is fully consumed
  1805. # logger.debug(
  1806. # f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}"
  1807. # )
  1808. return {
  1809. "status": "success",
  1810. "message": f"workspace '{workspace_label}' data dropped",
  1811. }
  1812. except Exception as e:
  1813. logger.error(
  1814. f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}"
  1815. )
  1816. return {"status": "error", "message": str(e)}