memgraph_impl.py 55 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346
  1. import os
  2. import asyncio
  3. import random
  4. from dataclasses import dataclass
  5. from typing import final
  6. import configparser
  7. from ..utils import logger
  8. from ..base import BaseGraphStorage
  9. from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
  10. from ..kg.shared_storage import get_data_init_lock
  11. import pipmaster as pm
  12. if not pm.is_installed("neo4j"):
  13. pm.install("neo4j")
  14. from neo4j import (
  15. AsyncGraphDatabase,
  16. AsyncManagedTransaction,
  17. )
  18. from neo4j.exceptions import TransientError, ResultFailedError
  19. from dotenv import load_dotenv
  20. # use the .env that is inside the current folder
  21. load_dotenv(dotenv_path=".env", override=False)
  22. MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
  23. config = configparser.ConfigParser()
  24. config.read("config.ini", "utf-8")
  25. @final
  26. @dataclass
  27. class MemgraphStorage(BaseGraphStorage):
  28. def __init__(self, namespace, global_config, embedding_func, workspace=None):
  29. # Priority: 1) MEMGRAPH_WORKSPACE env 2) user arg 3) default 'base'
  30. memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE")
  31. original_workspace = workspace # Save original value for logging
  32. if memgraph_workspace and memgraph_workspace.strip():
  33. workspace = memgraph_workspace
  34. if not workspace or not str(workspace).strip():
  35. workspace = "base"
  36. super().__init__(
  37. namespace=namespace,
  38. workspace=workspace,
  39. global_config=global_config,
  40. embedding_func=embedding_func,
  41. )
  42. # Log after super().__init__() to ensure self.workspace is initialized
  43. if memgraph_workspace and memgraph_workspace.strip():
  44. logger.info(
  45. f"Using MEMGRAPH_WORKSPACE environment variable: '{memgraph_workspace}' (overriding '{original_workspace}/{namespace}')"
  46. )
  47. self._driver = None
  48. def _get_workspace_label(self) -> str:
  49. """Return sanitized workspace label safe for use as a backtick-quoted identifier in Cypher queries.
  50. Escapes backticks by doubling them to prevent Cypher injection
  51. via the LIGHTRAG-WORKSPACE header, while preserving a 1-to-1 mapping
  52. for all other characters. The returned value is intended to be used
  53. inside backticks (for example, MATCH (n:`{label}`)) and is not
  54. validated as a standalone unquoted identifier.
  55. """
  56. workspace = self.workspace.strip()
  57. if not workspace:
  58. return "base"
  59. return workspace.replace("`", "``")
  60. async def initialize(self):
  61. async with get_data_init_lock():
  62. URI = os.environ.get(
  63. "MEMGRAPH_URI",
  64. config.get("memgraph", "uri", fallback="bolt://localhost:7687"),
  65. )
  66. USERNAME = os.environ.get(
  67. "MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="")
  68. )
  69. PASSWORD = os.environ.get(
  70. "MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="")
  71. )
  72. DATABASE = os.environ.get(
  73. "MEMGRAPH_DATABASE",
  74. config.get("memgraph", "database", fallback="memgraph"),
  75. )
  76. self._driver = AsyncGraphDatabase.driver(
  77. URI,
  78. auth=(USERNAME, PASSWORD),
  79. )
  80. self._DATABASE = DATABASE
  81. try:
  82. async with self._driver.session(database=DATABASE) as session:
  83. # Create index for base nodes on entity_id if it doesn't exist
  84. try:
  85. workspace_label = self._get_workspace_label()
  86. await session.run(
  87. f"""CREATE INDEX ON :{workspace_label}(entity_id)"""
  88. )
  89. logger.info(
  90. f"[{self.workspace}] Created index on :{workspace_label}(entity_id) in Memgraph."
  91. )
  92. except Exception as e:
  93. # Index may already exist, which is not an error
  94. logger.warning(
  95. f"[{self.workspace}] Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}"
  96. )
  97. await session.run("RETURN 1")
  98. logger.info(f"[{self.workspace}] Connected to Memgraph at {URI}")
  99. except Exception as e:
  100. logger.error(
  101. f"[{self.workspace}] Failed to connect to Memgraph at {URI}: {e}"
  102. )
  103. raise
  104. async def finalize(self):
  105. if self._driver is not None:
  106. await self._driver.close()
  107. self._driver = None
  108. async def __aexit__(self, exc_type, exc, tb):
  109. await self.finalize()
  110. async def index_done_callback(self):
  111. # Memgraph handles persistence automatically
  112. pass
  113. async def has_node(self, node_id: str) -> bool:
  114. """
  115. Check if a node exists in the graph.
  116. Args:
  117. node_id: The ID of the node to check.
  118. Returns:
  119. bool: True if the node exists, False otherwise.
  120. Raises:
  121. Exception: If there is an error checking the node existence.
  122. """
  123. if self._driver is None:
  124. raise RuntimeError(
  125. "Memgraph driver is not initialized. Call 'await initialize()' first."
  126. )
  127. async with self._driver.session(
  128. database=self._DATABASE, default_access_mode="READ"
  129. ) as session:
  130. result = None
  131. try:
  132. workspace_label = self._get_workspace_label()
  133. query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
  134. result = await session.run(query, entity_id=node_id)
  135. single_result = await result.single()
  136. await result.consume() # Ensure result is fully consumed
  137. return (
  138. single_result["node_exists"] if single_result is not None else False
  139. )
  140. except Exception as e:
  141. logger.error(
  142. f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
  143. )
  144. if result is not None:
  145. await (
  146. result.consume()
  147. ) # Ensure the result is consumed even on error
  148. raise
  149. async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
  150. """
  151. Check if an edge exists between two nodes in the graph.
  152. Args:
  153. source_node_id: The ID of the source node.
  154. target_node_id: The ID of the target node.
  155. Returns:
  156. bool: True if the edge exists, False otherwise.
  157. Raises:
  158. Exception: If there is an error checking the edge existence.
  159. """
  160. if self._driver is None:
  161. raise RuntimeError(
  162. "Memgraph driver is not initialized. Call 'await initialize()' first."
  163. )
  164. async with self._driver.session(
  165. database=self._DATABASE, default_access_mode="READ"
  166. ) as session:
  167. result = None
  168. try:
  169. workspace_label = self._get_workspace_label()
  170. query = (
  171. f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
  172. "RETURN COUNT(r) > 0 AS edgeExists"
  173. )
  174. result = await session.run(
  175. query,
  176. source_entity_id=source_node_id,
  177. target_entity_id=target_node_id,
  178. ) # type: ignore
  179. single_result = await result.single()
  180. await result.consume() # Ensure result is fully consumed
  181. return (
  182. single_result["edgeExists"] if single_result is not None else False
  183. )
  184. except Exception as e:
  185. logger.error(
  186. f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
  187. )
  188. if result is not None:
  189. await (
  190. result.consume()
  191. ) # Ensure the result is consumed even on error
  192. raise
  193. async def get_node(self, node_id: str) -> dict[str, str] | None:
  194. """Get node by its label identifier, return only node properties
  195. Args:
  196. node_id: The node label to look up
  197. Returns:
  198. dict: Node properties if found
  199. None: If node not found
  200. Raises:
  201. Exception: If there is an error executing the query
  202. """
  203. if self._driver is None:
  204. raise RuntimeError(
  205. "Memgraph driver is not initialized. Call 'await initialize()' first."
  206. )
  207. async with self._driver.session(
  208. database=self._DATABASE, default_access_mode="READ"
  209. ) as session:
  210. try:
  211. workspace_label = self._get_workspace_label()
  212. query = (
  213. f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n"
  214. )
  215. result = await session.run(query, entity_id=node_id)
  216. try:
  217. records = await result.fetch(
  218. 2
  219. ) # Get 2 records for duplication check
  220. if len(records) > 1:
  221. logger.warning(
  222. f"[{self.workspace}] Multiple nodes found with label '{node_id}'. Using first node."
  223. )
  224. if records:
  225. node = records[0]["n"]
  226. node_dict = dict(node)
  227. # Remove workspace label from labels list if it exists
  228. if "labels" in node_dict:
  229. node_dict["labels"] = [
  230. label
  231. for label in node_dict["labels"]
  232. if label != workspace_label
  233. ]
  234. return node_dict
  235. return None
  236. finally:
  237. await result.consume() # Ensure result is fully consumed
  238. except Exception as e:
  239. logger.error(
  240. f"[{self.workspace}] Error getting node for {node_id}: {str(e)}"
  241. )
  242. raise
  243. async def node_degree(self, node_id: str) -> int:
  244. """Get the degree (number of relationships) of a node with the given label.
  245. If multiple nodes have the same label, returns the degree of the first node.
  246. If no node is found, returns 0.
  247. Args:
  248. node_id: The label of the node
  249. Returns:
  250. int: The number of relationships the node has, or 0 if no node found
  251. Raises:
  252. Exception: If there is an error executing the query
  253. """
  254. if self._driver is None:
  255. raise RuntimeError(
  256. "Memgraph driver is not initialized. Call 'await initialize()' first."
  257. )
  258. async with self._driver.session(
  259. database=self._DATABASE, default_access_mode="READ"
  260. ) as session:
  261. try:
  262. workspace_label = self._get_workspace_label()
  263. query = f"""
  264. MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
  265. OPTIONAL MATCH (n)-[r]-()
  266. RETURN COUNT(r) AS degree
  267. """
  268. result = await session.run(query, entity_id=node_id)
  269. try:
  270. record = await result.single()
  271. if not record:
  272. logger.warning(
  273. f"[{self.workspace}] No node found with label '{node_id}'"
  274. )
  275. return 0
  276. degree = record["degree"]
  277. return degree
  278. finally:
  279. await result.consume() # Ensure result is fully consumed
  280. except Exception as e:
  281. logger.error(
  282. f"[{self.workspace}] Error getting node degree for {node_id}: {str(e)}"
  283. )
  284. raise
  285. async def get_all_labels(self) -> list[str]:
  286. """
  287. Get all existing node labels(entity names) in the database
  288. Returns:
  289. ["Person", "Company", ...] # Alphabetically sorted label list
  290. Raises:
  291. Exception: If there is an error executing the query
  292. """
  293. if self._driver is None:
  294. raise RuntimeError(
  295. "Memgraph driver is not initialized. Call 'await initialize()' first."
  296. )
  297. async with self._driver.session(
  298. database=self._DATABASE, default_access_mode="READ"
  299. ) as session:
  300. result = None
  301. try:
  302. workspace_label = self._get_workspace_label()
  303. query = f"""
  304. MATCH (n:`{workspace_label}`)
  305. WHERE n.entity_id IS NOT NULL
  306. RETURN DISTINCT n.entity_id AS label
  307. ORDER BY label
  308. """
  309. result = await session.run(query)
  310. labels = []
  311. async for record in result:
  312. labels.append(record["label"])
  313. await result.consume()
  314. return labels
  315. except Exception as e:
  316. logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}")
  317. if result is not None:
  318. await (
  319. result.consume()
  320. ) # Ensure the result is consumed even on error
  321. raise
  322. async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
  323. """Retrieves all edges (relationships) for a particular node identified by its label.
  324. Args:
  325. source_node_id: Label of the node to get edges for
  326. Returns:
  327. list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
  328. None: If no edges found
  329. Raises:
  330. Exception: If there is an error executing the query
  331. """
  332. if self._driver is None:
  333. raise RuntimeError(
  334. "Memgraph driver is not initialized. Call 'await initialize()' first."
  335. )
  336. try:
  337. async with self._driver.session(
  338. database=self._DATABASE, default_access_mode="READ"
  339. ) as session:
  340. results = None
  341. try:
  342. workspace_label = self._get_workspace_label()
  343. query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
  344. OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`)
  345. WHERE connected.entity_id IS NOT NULL
  346. RETURN n.entity_id AS node_entity_id,
  347. connected.entity_id AS connected_entity_id,
  348. startNode(r).entity_id AS start_entity_id"""
  349. results = await session.run(query, entity_id=source_node_id)
  350. edges = []
  351. async for record in results:
  352. node_entity_id = record["node_entity_id"]
  353. connected_entity_id = record["connected_entity_id"]
  354. start_entity_id = record["start_entity_id"]
  355. if not node_entity_id or not connected_entity_id:
  356. continue
  357. # Preserve the original edge direction via startNode(r)
  358. if start_entity_id == node_entity_id:
  359. edges.append((node_entity_id, connected_entity_id))
  360. else:
  361. edges.append((connected_entity_id, node_entity_id))
  362. await results.consume() # Ensure results are consumed
  363. return edges
  364. except Exception as e:
  365. logger.error(
  366. f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
  367. )
  368. if results is not None:
  369. await (
  370. results.consume()
  371. ) # Ensure results are consumed even on error
  372. raise
  373. except Exception as e:
  374. logger.error(
  375. f"[{self.workspace}] Error in get_node_edges for {source_node_id}: {str(e)}"
  376. )
  377. raise
  378. async def get_edge(
  379. self, source_node_id: str, target_node_id: str
  380. ) -> dict[str, str] | None:
  381. """Get edge properties between two nodes.
  382. Args:
  383. source_node_id: Label of the source node
  384. target_node_id: Label of the target node
  385. Returns:
  386. dict: Edge properties if found, default properties if not found or on error
  387. Raises:
  388. Exception: If there is an error executing the query
  389. """
  390. if self._driver is None:
  391. raise RuntimeError(
  392. "Memgraph driver is not initialized. Call 'await initialize()' first."
  393. )
  394. async with self._driver.session(
  395. database=self._DATABASE, default_access_mode="READ"
  396. ) as session:
  397. result = None
  398. try:
  399. workspace_label = self._get_workspace_label()
  400. query = f"""
  401. MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}})
  402. RETURN properties(r) as edge_properties
  403. """
  404. result = await session.run(
  405. query,
  406. source_entity_id=source_node_id,
  407. target_entity_id=target_node_id,
  408. )
  409. records = await result.fetch(2)
  410. await result.consume()
  411. if records:
  412. edge_result = dict(records[0]["edge_properties"])
  413. for key, default_value in {
  414. "weight": 1.0,
  415. "source_id": None,
  416. "description": None,
  417. "keywords": None,
  418. }.items():
  419. if key not in edge_result:
  420. edge_result[key] = default_value
  421. logger.warning(
  422. f"[{self.workspace}] Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}"
  423. )
  424. return edge_result
  425. return None
  426. except Exception as e:
  427. logger.error(
  428. f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
  429. )
  430. if result is not None:
  431. await (
  432. result.consume()
  433. ) # Ensure the result is consumed even on error
  434. raise
  435. async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
  436. """
  437. Upsert a node in the Memgraph database with manual transaction-level retry logic for transient errors.
  438. Args:
  439. node_id: The unique identifier for the node (used as label)
  440. node_data: Dictionary of node properties
  441. """
  442. if self._driver is None:
  443. raise RuntimeError(
  444. "Memgraph driver is not initialized. Call 'await initialize()' first."
  445. )
  446. properties = node_data
  447. if "entity_id" not in properties:
  448. raise ValueError(
  449. "Memgraph: node properties must contain an 'entity_id' field"
  450. )
  451. # Manual transaction-level retry following official Memgraph documentation
  452. max_retries = 100
  453. initial_wait_time = 0.2
  454. backoff_factor = 1.1
  455. jitter_factor = 0.1
  456. for attempt in range(max_retries):
  457. try:
  458. logger.debug(
  459. f"[{self.workspace}] Attempting node upsert, attempt {attempt + 1}/{max_retries}"
  460. )
  461. async with self._driver.session(database=self._DATABASE) as session:
  462. workspace_label = self._get_workspace_label()
  463. async def execute_upsert(tx: AsyncManagedTransaction):
  464. query = f"""
  465. MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
  466. SET n += $properties
  467. """
  468. result = await tx.run(
  469. query, entity_id=node_id, properties=properties
  470. )
  471. await result.consume() # Ensure result is fully consumed
  472. await session.execute_write(execute_upsert)
  473. break # Success - exit retry loop
  474. except (TransientError, ResultFailedError) as e:
  475. # Check if the root cause is a TransientError
  476. root_cause = e
  477. while hasattr(root_cause, "__cause__") and root_cause.__cause__:
  478. root_cause = root_cause.__cause__
  479. # Check if this is a transient error that should be retried
  480. is_transient = (
  481. isinstance(root_cause, TransientError)
  482. or isinstance(e, TransientError)
  483. or "TransientError" in str(e)
  484. or "Cannot resolve conflicting transactions" in str(e)
  485. )
  486. if is_transient:
  487. if attempt < max_retries - 1:
  488. # Calculate wait time with exponential backoff and jitter
  489. jitter = random.uniform(0, jitter_factor) * initial_wait_time
  490. wait_time = (
  491. initial_wait_time * (backoff_factor**attempt) + jitter
  492. )
  493. logger.warning(
  494. f"[{self.workspace}] Node upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
  495. )
  496. await asyncio.sleep(wait_time)
  497. else:
  498. logger.error(
  499. f"[{self.workspace}] Memgraph transient error during node upsert after {max_retries} retries: {str(e)}"
  500. )
  501. raise
  502. else:
  503. # Non-transient error, don't retry
  504. logger.error(
  505. f"[{self.workspace}] Non-transient error during node upsert: {str(e)}"
  506. )
  507. raise
  508. except Exception as e:
  509. logger.error(
  510. f"[{self.workspace}] Unexpected error during node upsert: {str(e)}"
  511. )
  512. raise
  513. async def upsert_edge(
  514. self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
  515. ) -> None:
  516. """
  517. Upsert an edge and its properties between two nodes identified by their labels with manual transaction-level retry logic for transient errors.
  518. Ensures both source and target nodes exist and are unique before creating the edge.
  519. Uses entity_id property to uniquely identify nodes.
  520. Args:
  521. source_node_id (str): Label of the source node (used as identifier)
  522. target_node_id (str): Label of the target node (used as identifier)
  523. edge_data (dict): Dictionary of properties to set on the edge
  524. Raises:
  525. Exception: If there is an error executing the query
  526. """
  527. if self._driver is None:
  528. raise RuntimeError(
  529. "Memgraph driver is not initialized. Call 'await initialize()' first."
  530. )
  531. edge_properties = edge_data
  532. # Manual transaction-level retry following official Memgraph documentation
  533. max_retries = 100
  534. initial_wait_time = 0.2
  535. backoff_factor = 1.1
  536. jitter_factor = 0.1
  537. for attempt in range(max_retries):
  538. try:
  539. logger.debug(
  540. f"[{self.workspace}] Attempting edge upsert, attempt {attempt + 1}/{max_retries}"
  541. )
  542. async with self._driver.session(database=self._DATABASE) as session:
  543. async def execute_upsert(tx: AsyncManagedTransaction):
  544. workspace_label = self._get_workspace_label()
  545. query = f"""
  546. MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
  547. WITH source
  548. MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
  549. MERGE (source)-[r:DIRECTED]-(target)
  550. SET r += $properties
  551. RETURN r, source, target
  552. """
  553. result = await tx.run(
  554. query,
  555. source_entity_id=source_node_id,
  556. target_entity_id=target_node_id,
  557. properties=edge_properties,
  558. )
  559. try:
  560. await result.fetch(2)
  561. finally:
  562. await result.consume() # Ensure result is consumed
  563. await session.execute_write(execute_upsert)
  564. break # Success - exit retry loop
  565. except (TransientError, ResultFailedError) as e:
  566. # Check if the root cause is a TransientError
  567. root_cause = e
  568. while hasattr(root_cause, "__cause__") and root_cause.__cause__:
  569. root_cause = root_cause.__cause__
  570. # Check if this is a transient error that should be retried
  571. is_transient = (
  572. isinstance(root_cause, TransientError)
  573. or isinstance(e, TransientError)
  574. or "TransientError" in str(e)
  575. or "Cannot resolve conflicting transactions" in str(e)
  576. )
  577. if is_transient:
  578. if attempt < max_retries - 1:
  579. # Calculate wait time with exponential backoff and jitter
  580. jitter = random.uniform(0, jitter_factor) * initial_wait_time
  581. wait_time = (
  582. initial_wait_time * (backoff_factor**attempt) + jitter
  583. )
  584. logger.warning(
  585. f"[{self.workspace}] Edge upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
  586. )
  587. await asyncio.sleep(wait_time)
  588. else:
  589. logger.error(
  590. f"[{self.workspace}] Memgraph transient error during edge upsert after {max_retries} retries: {str(e)}"
  591. )
  592. raise
  593. else:
  594. # Non-transient error, don't retry
  595. logger.error(
  596. f"[{self.workspace}] Non-transient error during edge upsert: {str(e)}"
  597. )
  598. raise
  599. except Exception as e:
  600. logger.error(
  601. f"[{self.workspace}] Unexpected error during edge upsert: {str(e)}"
  602. )
  603. raise
  604. async def upsert_nodes_batch(self, nodes: list[tuple[str, dict[str, str]]]) -> None:
  605. """Batch insert/update multiple nodes using a single UNWIND Cypher query.
  606. Uses the same transient-error retry logic as upsert_node().
  607. Args:
  608. nodes: List of (node_id, node_data) tuples.
  609. """
  610. if not nodes:
  611. return
  612. if self._driver is None:
  613. raise RuntimeError(
  614. "Memgraph driver is not initialized. Call 'await initialize()' first."
  615. )
  616. workspace_label = self._get_workspace_label()
  617. nodes_data = []
  618. for node_id, node_data in nodes:
  619. if "entity_id" not in node_data:
  620. raise ValueError(
  621. "Memgraph: node properties must contain an 'entity_id' field"
  622. )
  623. nodes_data.append({"entity_id": node_id, "props": node_data})
  624. max_retries = 100
  625. initial_wait_time = 0.2
  626. backoff_factor = 1.1
  627. jitter_factor = 0.1
  628. for attempt in range(max_retries):
  629. try:
  630. async with self._driver.session(database=self._DATABASE) as session:
  631. async def execute_batch(tx: AsyncManagedTransaction):
  632. query = f"""
  633. UNWIND $nodes AS row
  634. MERGE (n:`{workspace_label}` {{entity_id: row.entity_id}})
  635. SET n += row.props
  636. """
  637. result = await tx.run(query, nodes=nodes_data)
  638. await result.consume()
  639. await session.execute_write(execute_batch)
  640. break
  641. except (TransientError, ResultFailedError) as e:
  642. root_cause = e
  643. while hasattr(root_cause, "__cause__") and root_cause.__cause__:
  644. root_cause = root_cause.__cause__
  645. is_transient = (
  646. isinstance(root_cause, TransientError)
  647. or isinstance(e, TransientError)
  648. or "TransientError" in str(e)
  649. or "Cannot resolve conflicting transactions" in str(e)
  650. )
  651. if is_transient:
  652. if attempt < max_retries - 1:
  653. jitter = random.uniform(0, jitter_factor) * initial_wait_time
  654. wait_time = (
  655. initial_wait_time * (backoff_factor**attempt) + jitter
  656. )
  657. logger.warning(
  658. f"[{self.workspace}] Batch node upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f}s... Error: {str(e)}"
  659. )
  660. await asyncio.sleep(wait_time)
  661. else:
  662. logger.error(
  663. f"[{self.workspace}] Memgraph transient error during batch node upsert after {max_retries} retries: {str(e)}"
  664. )
  665. raise
  666. else:
  667. logger.error(
  668. f"[{self.workspace}] Non-transient error during batch node upsert: {str(e)}"
  669. )
  670. raise
  671. except Exception as e:
  672. logger.error(
  673. f"[{self.workspace}] Unexpected error during batch node upsert: {str(e)}"
  674. )
  675. raise
  676. async def has_nodes_batch(self, node_ids: list[str]) -> set[str]:
  677. """Check existence of multiple nodes in a single UNWIND query.
  678. Args:
  679. node_ids: List of node IDs to check.
  680. Returns:
  681. Set of node_ids that exist in the graph.
  682. """
  683. if not node_ids:
  684. return set()
  685. if self._driver is None:
  686. raise RuntimeError(
  687. "Memgraph driver is not initialized. Call 'await initialize()' first."
  688. )
  689. workspace_label = self._get_workspace_label()
  690. try:
  691. async with self._driver.session(
  692. database=self._DATABASE, default_access_mode="READ"
  693. ) as session:
  694. query = f"""
  695. UNWIND $ids AS id
  696. MATCH (n:`{workspace_label}` {{entity_id: id}})
  697. RETURN n.entity_id AS entity_id
  698. """
  699. result = await session.run(query, ids=node_ids)
  700. records = await result.data()
  701. await result.consume()
  702. return {r["entity_id"] for r in records}
  703. except Exception as e:
  704. logger.error(
  705. f"[{self.workspace}] Error during batch node existence check: {str(e)}"
  706. )
  707. raise
  708. async def upsert_edges_batch(
  709. self, edges: list[tuple[str, str, dict[str, str]]]
  710. ) -> None:
  711. """Batch insert/update multiple edges using a single UNWIND Cypher query.
  712. Uses the same transient-error retry logic as upsert_edge().
  713. Args:
  714. edges: List of (source_node_id, target_node_id, edge_data) tuples.
  715. """
  716. if not edges:
  717. return
  718. if self._driver is None:
  719. raise RuntimeError(
  720. "Memgraph driver is not initialized. Call 'await initialize()' first."
  721. )
  722. workspace_label = self._get_workspace_label()
  723. edges_data = [
  724. {"src": src, "tgt": tgt, "props": edge_data}
  725. for src, tgt, edge_data in edges
  726. ]
  727. max_retries = 100
  728. initial_wait_time = 0.2
  729. backoff_factor = 1.1
  730. jitter_factor = 0.1
  731. for attempt in range(max_retries):
  732. try:
  733. async with self._driver.session(database=self._DATABASE) as session:
  734. async def execute_batch(tx: AsyncManagedTransaction):
  735. query = f"""
  736. UNWIND $edges AS row
  737. MATCH (source:`{workspace_label}` {{entity_id: row.src}})
  738. WITH source, row
  739. MATCH (target:`{workspace_label}` {{entity_id: row.tgt}})
  740. MERGE (source)-[r:DIRECTED]-(target)
  741. SET r += row.props
  742. RETURN r
  743. """
  744. result = await tx.run(query, edges=edges_data)
  745. await result.consume()
  746. await session.execute_write(execute_batch)
  747. break
  748. except (TransientError, ResultFailedError) as e:
  749. root_cause = e
  750. while hasattr(root_cause, "__cause__") and root_cause.__cause__:
  751. root_cause = root_cause.__cause__
  752. is_transient = (
  753. isinstance(root_cause, TransientError)
  754. or isinstance(e, TransientError)
  755. or "TransientError" in str(e)
  756. or "Cannot resolve conflicting transactions" in str(e)
  757. )
  758. if is_transient:
  759. if attempt < max_retries - 1:
  760. jitter = random.uniform(0, jitter_factor) * initial_wait_time
  761. wait_time = (
  762. initial_wait_time * (backoff_factor**attempt) + jitter
  763. )
  764. logger.warning(
  765. f"[{self.workspace}] Batch edge upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f}s... Error: {str(e)}"
  766. )
  767. await asyncio.sleep(wait_time)
  768. else:
  769. logger.error(
  770. f"[{self.workspace}] Memgraph transient error during batch edge upsert after {max_retries} retries: {str(e)}"
  771. )
  772. raise
  773. else:
  774. logger.error(
  775. f"[{self.workspace}] Non-transient error during batch edge upsert: {str(e)}"
  776. )
  777. raise
  778. except Exception as e:
  779. logger.error(
  780. f"[{self.workspace}] Unexpected error during batch edge upsert: {str(e)}"
  781. )
  782. raise
  783. async def delete_node(self, node_id: str) -> None:
  784. """Delete a node with the specified label
  785. Args:
  786. node_id: The label of the node to delete
  787. Raises:
  788. Exception: If there is an error executing the query
  789. """
  790. if self._driver is None:
  791. raise RuntimeError(
  792. "Memgraph driver is not initialized. Call 'await initialize()' first."
  793. )
  794. async def _do_delete(tx: AsyncManagedTransaction):
  795. workspace_label = self._get_workspace_label()
  796. query = f"""
  797. MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
  798. DETACH DELETE n
  799. """
  800. result = await tx.run(query, entity_id=node_id)
  801. logger.debug(f"[{self.workspace}] Deleted node with label {node_id}")
  802. await result.consume()
  803. try:
  804. async with self._driver.session(database=self._DATABASE) as session:
  805. await session.execute_write(_do_delete)
  806. except Exception as e:
  807. logger.error(f"[{self.workspace}] Error during node deletion: {str(e)}")
  808. raise
  809. async def remove_nodes(self, nodes: list[str]):
  810. """Delete multiple nodes
  811. Args:
  812. nodes: List of node labels to be deleted
  813. """
  814. if self._driver is None:
  815. raise RuntimeError(
  816. "Memgraph driver is not initialized. Call 'await initialize()' first."
  817. )
  818. for node in nodes:
  819. await self.delete_node(node)
  820. async def remove_edges(self, edges: list[tuple[str, str]]):
  821. """Delete multiple edges
  822. Args:
  823. edges: List of edges to be deleted, each edge is a (source, target) tuple
  824. Raises:
  825. Exception: If there is an error executing the query
  826. """
  827. if self._driver is None:
  828. raise RuntimeError(
  829. "Memgraph driver is not initialized. Call 'await initialize()' first."
  830. )
  831. for source, target in edges:
  832. async def _do_delete_edge(tx: AsyncManagedTransaction):
  833. workspace_label = self._get_workspace_label()
  834. query = f"""
  835. MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}})
  836. DELETE r
  837. """
  838. result = await tx.run(
  839. query, source_entity_id=source, target_entity_id=target
  840. )
  841. logger.debug(
  842. f"[{self.workspace}] Deleted edge from '{source}' to '{target}'"
  843. )
  844. await result.consume() # Ensure result is fully consumed
  845. try:
  846. async with self._driver.session(database=self._DATABASE) as session:
  847. await session.execute_write(_do_delete_edge)
  848. except Exception as e:
  849. logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}")
  850. raise
  851. async def drop(self) -> dict[str, str]:
  852. """Drop all data from the current workspace and clean up resources
  853. This method will delete all nodes and relationships in the Memgraph database.
  854. Returns:
  855. dict[str, str]: Operation status and message
  856. - On success: {"status": "success", "message": "data dropped"}
  857. - On failure: {"status": "error", "message": "<error details>"}
  858. Raises:
  859. Exception: If there is an error executing the query
  860. """
  861. if self._driver is None:
  862. raise RuntimeError(
  863. "Memgraph driver is not initialized. Call 'await initialize()' first."
  864. )
  865. try:
  866. async with self._driver.session(database=self._DATABASE) as session:
  867. workspace_label = self._get_workspace_label()
  868. query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
  869. result = await session.run(query)
  870. await result.consume()
  871. logger.info(
  872. f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
  873. )
  874. return {"status": "success", "message": "workspace data dropped"}
  875. except Exception as e:
  876. logger.error(
  877. f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
  878. )
  879. return {"status": "error", "message": str(e)}
  880. async def edge_degree(self, src_id: str, tgt_id: str) -> int:
  881. """Get the total degree (sum of relationships) of two nodes.
  882. Args:
  883. src_id: Label of the source node
  884. tgt_id: Label of the target node
  885. Returns:
  886. int: Sum of the degrees of both nodes
  887. """
  888. if self._driver is None:
  889. raise RuntimeError(
  890. "Memgraph driver is not initialized. Call 'await initialize()' first."
  891. )
  892. src_degree = await self.node_degree(src_id)
  893. trg_degree = await self.node_degree(tgt_id)
  894. # Convert None to 0 for addition
  895. src_degree = 0 if src_degree is None else src_degree
  896. trg_degree = 0 if trg_degree is None else trg_degree
  897. degrees = int(src_degree) + int(trg_degree)
  898. return degrees
  899. async def get_knowledge_graph(
  900. self,
  901. node_label: str,
  902. max_depth: int = 3,
  903. max_nodes: int = None,
  904. ) -> KnowledgeGraph:
  905. """
  906. Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
  907. Args:
  908. node_label: Label of the starting node, * means all nodes
  909. max_depth: Maximum depth of the subgraph, Defaults to 3
  910. max_nodes: Maximum nodes to return by BFS, Defaults to 1000
  911. Returns:
  912. KnowledgeGraph object containing nodes and edges, with an is_truncated flag
  913. indicating whether the graph was truncated due to max_nodes limit
  914. """
  915. # Get max_nodes from global_config if not provided
  916. if max_nodes is None:
  917. max_nodes = self.global_config.get("max_graph_nodes", 1000)
  918. else:
  919. # Limit max_nodes to not exceed global_config max_graph_nodes
  920. max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
  921. workspace_label = self._get_workspace_label()
  922. result = KnowledgeGraph()
  923. seen_nodes = set()
  924. seen_edges = set()
  925. async with self._driver.session(
  926. database=self._DATABASE, default_access_mode="READ"
  927. ) as session:
  928. try:
  929. if node_label == "*":
  930. # First check total node count to determine if graph is truncated
  931. count_query = (
  932. f"MATCH (n:`{workspace_label}`) RETURN count(n) as total"
  933. )
  934. count_result = None
  935. try:
  936. count_result = await session.run(count_query)
  937. count_record = await count_result.single()
  938. if count_record and count_record["total"] > max_nodes:
  939. result.is_truncated = True
  940. logger.info(
  941. f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
  942. )
  943. finally:
  944. if count_result:
  945. await count_result.consume()
  946. # Run main query to get nodes with highest degree
  947. main_query = f"""
  948. MATCH (n:`{workspace_label}`)
  949. OPTIONAL MATCH (n)-[r]-()
  950. WITH n, COALESCE(count(r), 0) AS degree
  951. ORDER BY degree DESC
  952. LIMIT $max_nodes
  953. WITH collect({{node: n}}) AS filtered_nodes
  954. UNWIND filtered_nodes AS node_info
  955. WITH collect(node_info.node) AS kept_nodes, filtered_nodes
  956. OPTIONAL MATCH (a)-[r]-(b)
  957. WHERE a IN kept_nodes AND b IN kept_nodes
  958. RETURN filtered_nodes AS node_info,
  959. collect(DISTINCT r) AS relationships
  960. """
  961. result_set = None
  962. try:
  963. result_set = await session.run(
  964. main_query,
  965. {"max_nodes": max_nodes},
  966. )
  967. record = await result_set.single()
  968. finally:
  969. if result_set:
  970. await result_set.consume()
  971. else:
  972. # Run subgraph query for specific node_label
  973. subgraph_query = f"""
  974. MATCH (start:`{workspace_label}`)
  975. WHERE start.entity_id = $entity_id
  976. OPTIONAL MATCH path = (start)-[*BFS 0..{max_depth}]-(end:`{workspace_label}`)
  977. WHERE path IS NULL OR ALL(n IN nodes(path) WHERE '{workspace_label}' IN labels(n))
  978. WITH start, collect(DISTINCT end) AS discovered_nodes
  979. WITH start, [node IN discovered_nodes WHERE node IS NOT NULL AND node <> start] AS other_nodes
  980. WITH
  981. CASE
  982. WHEN 1 + size(other_nodes) <= $max_nodes THEN [start] + other_nodes
  983. ELSE [start] + other_nodes[0..$max_other_nodes]
  984. END AS limited_nodes,
  985. 1 + size(other_nodes) > $max_nodes AS is_truncated
  986. UNWIND limited_nodes AS n
  987. OPTIONAL MATCH (n)-[r]-(m)
  988. WHERE m IN limited_nodes
  989. WITH limited_nodes, collect(DISTINCT r) AS relationships, is_truncated
  990. RETURN
  991. [node IN limited_nodes | {{node: node}}] AS node_info,
  992. [rel IN relationships WHERE rel IS NOT NULL] AS relationships,
  993. is_truncated
  994. """
  995. result_set = None
  996. try:
  997. result_set = await session.run(
  998. subgraph_query,
  999. {
  1000. "entity_id": node_label,
  1001. "max_nodes": max_nodes,
  1002. "max_other_nodes": max(max_nodes - 1, 0),
  1003. },
  1004. )
  1005. record = await result_set.single()
  1006. # If no record found, return empty KnowledgeGraph
  1007. if not record:
  1008. logger.debug(
  1009. f"[{self.workspace}] No nodes found for entity_id: {node_label}"
  1010. )
  1011. return result
  1012. # Check if the result was truncated
  1013. if record.get("is_truncated"):
  1014. result.is_truncated = True
  1015. logger.info(
  1016. f"[{self.workspace}] Graph truncated: breadth-first search limited to {max_nodes} nodes"
  1017. )
  1018. finally:
  1019. if result_set:
  1020. await result_set.consume()
  1021. if record:
  1022. for node_info in record["node_info"]:
  1023. node = node_info["node"]
  1024. node_id = node.id
  1025. if node_id not in seen_nodes:
  1026. result.nodes.append(
  1027. KnowledgeGraphNode(
  1028. id=f"{node_id}",
  1029. labels=[node.get("entity_id")],
  1030. properties=dict(node),
  1031. )
  1032. )
  1033. seen_nodes.add(node_id)
  1034. for rel in record["relationships"]:
  1035. edge_id = rel.id
  1036. if edge_id not in seen_edges:
  1037. start = rel.start_node
  1038. end = rel.end_node
  1039. result.edges.append(
  1040. KnowledgeGraphEdge(
  1041. id=f"{edge_id}",
  1042. type=rel.type,
  1043. source=f"{start.id}",
  1044. target=f"{end.id}",
  1045. properties=dict(rel),
  1046. )
  1047. )
  1048. seen_edges.add(edge_id)
  1049. logger.info(
  1050. f"[{self.workspace}] Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
  1051. )
  1052. except Exception as e:
  1053. logger.warning(
  1054. f"[{self.workspace}] Memgraph error during subgraph query: {str(e)}"
  1055. )
  1056. return result
  1057. async def get_all_nodes(self) -> list[dict]:
  1058. """Get all nodes in the graph.
  1059. Returns:
  1060. A list of all nodes, where each node is a dictionary of its properties
  1061. """
  1062. if self._driver is None:
  1063. raise RuntimeError(
  1064. "Memgraph driver is not initialized. Call 'await initialize()' first."
  1065. )
  1066. workspace_label = self._get_workspace_label()
  1067. async with self._driver.session(
  1068. database=self._DATABASE, default_access_mode="READ"
  1069. ) as session:
  1070. query = f"""
  1071. MATCH (n:`{workspace_label}`)
  1072. RETURN n
  1073. """
  1074. result = await session.run(query)
  1075. nodes = []
  1076. async for record in result:
  1077. node = record["n"]
  1078. node_dict = dict(node)
  1079. # Add node id (entity_id) to the dictionary for easier access
  1080. node_dict["id"] = node_dict.get("entity_id")
  1081. nodes.append(node_dict)
  1082. await result.consume()
  1083. return nodes
  1084. async def get_all_edges(self) -> list[dict]:
  1085. """Get all edges in the graph.
  1086. Returns:
  1087. A list of all edges, where each edge is a dictionary of its properties
  1088. """
  1089. if self._driver is None:
  1090. raise RuntimeError(
  1091. "Memgraph driver is not initialized. Call 'await initialize()' first."
  1092. )
  1093. workspace_label = self._get_workspace_label()
  1094. async with self._driver.session(
  1095. database=self._DATABASE, default_access_mode="READ"
  1096. ) as session:
  1097. query = f"""
  1098. MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
  1099. RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
  1100. """
  1101. result = await session.run(query)
  1102. edges = []
  1103. async for record in result:
  1104. edge_properties = record["properties"]
  1105. edge_properties["source"] = record["source"]
  1106. edge_properties["target"] = record["target"]
  1107. edges.append(edge_properties)
  1108. await result.consume()
  1109. return edges
  1110. async def get_popular_labels(self, limit: int = 300) -> list[str]:
  1111. """Get popular labels by node degree (most connected entities)
  1112. Args:
  1113. limit: Maximum number of labels(entity names) to return
  1114. Returns:
  1115. List of labels(entity names) sorted by degree (highest first)
  1116. """
  1117. if self._driver is None:
  1118. raise RuntimeError(
  1119. "Memgraph driver is not initialized. Call 'await initialize()' first."
  1120. )
  1121. result = None
  1122. try:
  1123. workspace_label = self._get_workspace_label()
  1124. async with self._driver.session(
  1125. database=self._DATABASE, default_access_mode="READ"
  1126. ) as session:
  1127. query = f"""
  1128. MATCH (n:`{workspace_label}`)
  1129. WHERE n.entity_id IS NOT NULL
  1130. OPTIONAL MATCH (n)-[r]-()
  1131. WITH n.entity_id AS label, count(r) AS degree
  1132. ORDER BY degree DESC, label ASC
  1133. LIMIT {limit}
  1134. RETURN label
  1135. """
  1136. result = await session.run(query)
  1137. labels = []
  1138. async for record in result:
  1139. labels.append(record["label"])
  1140. await result.consume()
  1141. logger.debug(
  1142. f"[{self.workspace}] Retrieved {len(labels)} popular labels (limit: {limit})"
  1143. )
  1144. return labels
  1145. except Exception as e:
  1146. logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}")
  1147. if result is not None:
  1148. await result.consume()
  1149. return []
  1150. async def search_labels(self, query: str, limit: int = 50) -> list[str]:
  1151. """Search labels(entity names) with fuzzy matching
  1152. Args:
  1153. query: Search query string
  1154. limit: Maximum number of results to return
  1155. Returns:
  1156. List of matching labels(entity names) sorted by relevance
  1157. """
  1158. if self._driver is None:
  1159. raise RuntimeError(
  1160. "Memgraph driver is not initialized. Call 'await initialize()' first."
  1161. )
  1162. query_lower = query.lower().strip()
  1163. if not query_lower:
  1164. return []
  1165. result = None
  1166. try:
  1167. workspace_label = self._get_workspace_label()
  1168. async with self._driver.session(
  1169. database=self._DATABASE, default_access_mode="READ"
  1170. ) as session:
  1171. cypher_query = f"""
  1172. MATCH (n:`{workspace_label}`)
  1173. WHERE n.entity_id IS NOT NULL
  1174. WITH n.entity_id AS label, toLower(n.entity_id) AS label_lower
  1175. WHERE label_lower CONTAINS $query_lower
  1176. WITH label, label_lower,
  1177. CASE
  1178. WHEN label_lower = $query_lower THEN 1000
  1179. WHEN label_lower STARTS WITH $query_lower THEN 500
  1180. ELSE 100 - size(label)
  1181. END AS score
  1182. ORDER BY score DESC, label ASC
  1183. LIMIT {limit}
  1184. RETURN label
  1185. """
  1186. result = await session.run(cypher_query, query_lower=query_lower)
  1187. labels = []
  1188. async for record in result:
  1189. labels.append(record["label"])
  1190. await result.consume()
  1191. logger.debug(
  1192. f"[{self.workspace}] Search query '{query}' returned {len(labels)} results (limit: {limit})"
  1193. )
  1194. return labels
  1195. except Exception as e:
  1196. logger.error(f"[{self.workspace}] Error searching labels: {str(e)}")
  1197. if result is not None:
  1198. await result.consume()
  1199. return []