| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346 |
- import os
- import asyncio
- import random
- from dataclasses import dataclass
- from typing import final
- import configparser
- from ..utils import logger
- from ..base import BaseGraphStorage
- from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
- from ..kg.shared_storage import get_data_init_lock
- import pipmaster as pm
- if not pm.is_installed("neo4j"):
- pm.install("neo4j")
- from neo4j import (
- AsyncGraphDatabase,
- AsyncManagedTransaction,
- )
- from neo4j.exceptions import TransientError, ResultFailedError
- from dotenv import load_dotenv
- # use the .env that is inside the current folder
- load_dotenv(dotenv_path=".env", override=False)
- MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
- config = configparser.ConfigParser()
- config.read("config.ini", "utf-8")
- @final
- @dataclass
- class MemgraphStorage(BaseGraphStorage):
- def __init__(self, namespace, global_config, embedding_func, workspace=None):
- # Priority: 1) MEMGRAPH_WORKSPACE env 2) user arg 3) default 'base'
- memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE")
- original_workspace = workspace # Save original value for logging
- if memgraph_workspace and memgraph_workspace.strip():
- workspace = memgraph_workspace
- if not workspace or not str(workspace).strip():
- workspace = "base"
- super().__init__(
- namespace=namespace,
- workspace=workspace,
- global_config=global_config,
- embedding_func=embedding_func,
- )
- # Log after super().__init__() to ensure self.workspace is initialized
- if memgraph_workspace and memgraph_workspace.strip():
- logger.info(
- f"Using MEMGRAPH_WORKSPACE environment variable: '{memgraph_workspace}' (overriding '{original_workspace}/{namespace}')"
- )
- self._driver = None
- def _get_workspace_label(self) -> str:
- """Return sanitized workspace label safe for use as a backtick-quoted identifier in Cypher queries.
- Escapes backticks by doubling them to prevent Cypher injection
- via the LIGHTRAG-WORKSPACE header, while preserving a 1-to-1 mapping
- for all other characters. The returned value is intended to be used
- inside backticks (for example, MATCH (n:`{label}`)) and is not
- validated as a standalone unquoted identifier.
- """
- workspace = self.workspace.strip()
- if not workspace:
- return "base"
- return workspace.replace("`", "``")
- async def initialize(self):
- async with get_data_init_lock():
- URI = os.environ.get(
- "MEMGRAPH_URI",
- config.get("memgraph", "uri", fallback="bolt://localhost:7687"),
- )
- USERNAME = os.environ.get(
- "MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="")
- )
- PASSWORD = os.environ.get(
- "MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="")
- )
- DATABASE = os.environ.get(
- "MEMGRAPH_DATABASE",
- config.get("memgraph", "database", fallback="memgraph"),
- )
- self._driver = AsyncGraphDatabase.driver(
- URI,
- auth=(USERNAME, PASSWORD),
- )
- self._DATABASE = DATABASE
- try:
- async with self._driver.session(database=DATABASE) as session:
- # Create index for base nodes on entity_id if it doesn't exist
- try:
- workspace_label = self._get_workspace_label()
- await session.run(
- f"""CREATE INDEX ON :{workspace_label}(entity_id)"""
- )
- logger.info(
- f"[{self.workspace}] Created index on :{workspace_label}(entity_id) in Memgraph."
- )
- except Exception as e:
- # Index may already exist, which is not an error
- logger.warning(
- f"[{self.workspace}] Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}"
- )
- await session.run("RETURN 1")
- logger.info(f"[{self.workspace}] Connected to Memgraph at {URI}")
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Failed to connect to Memgraph at {URI}: {e}"
- )
- raise
- async def finalize(self):
- if self._driver is not None:
- await self._driver.close()
- self._driver = None
- async def __aexit__(self, exc_type, exc, tb):
- await self.finalize()
- async def index_done_callback(self):
- # Memgraph handles persistence automatically
- pass
- async def has_node(self, node_id: str) -> bool:
- """
- Check if a node exists in the graph.
- Args:
- node_id: The ID of the node to check.
- Returns:
- bool: True if the node exists, False otherwise.
- Raises:
- Exception: If there is an error checking the node existence.
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- async with self._driver.session(
- database=self._DATABASE, default_access_mode="READ"
- ) as session:
- result = None
- try:
- workspace_label = self._get_workspace_label()
- query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
- result = await session.run(query, entity_id=node_id)
- single_result = await result.single()
- await result.consume() # Ensure result is fully consumed
- return (
- single_result["node_exists"] if single_result is not None else False
- )
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
- )
- if result is not None:
- await (
- result.consume()
- ) # Ensure the result is consumed even on error
- raise
- async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
- """
- Check if an edge exists between two nodes in the graph.
- Args:
- source_node_id: The ID of the source node.
- target_node_id: The ID of the target node.
- Returns:
- bool: True if the edge exists, False otherwise.
- Raises:
- Exception: If there is an error checking the edge existence.
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- async with self._driver.session(
- database=self._DATABASE, default_access_mode="READ"
- ) as session:
- result = None
- try:
- workspace_label = self._get_workspace_label()
- query = (
- f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
- "RETURN COUNT(r) > 0 AS edgeExists"
- )
- result = await session.run(
- query,
- source_entity_id=source_node_id,
- target_entity_id=target_node_id,
- ) # type: ignore
- single_result = await result.single()
- await result.consume() # Ensure result is fully consumed
- return (
- single_result["edgeExists"] if single_result is not None else False
- )
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
- )
- if result is not None:
- await (
- result.consume()
- ) # Ensure the result is consumed even on error
- raise
- async def get_node(self, node_id: str) -> dict[str, str] | None:
- """Get node by its label identifier, return only node properties
- Args:
- node_id: The node label to look up
- Returns:
- dict: Node properties if found
- None: If node not found
- Raises:
- Exception: If there is an error executing the query
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- async with self._driver.session(
- database=self._DATABASE, default_access_mode="READ"
- ) as session:
- try:
- workspace_label = self._get_workspace_label()
- query = (
- f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n"
- )
- result = await session.run(query, entity_id=node_id)
- try:
- records = await result.fetch(
- 2
- ) # Get 2 records for duplication check
- if len(records) > 1:
- logger.warning(
- f"[{self.workspace}] Multiple nodes found with label '{node_id}'. Using first node."
- )
- if records:
- node = records[0]["n"]
- node_dict = dict(node)
- # Remove workspace label from labels list if it exists
- if "labels" in node_dict:
- node_dict["labels"] = [
- label
- for label in node_dict["labels"]
- if label != workspace_label
- ]
- return node_dict
- return None
- finally:
- await result.consume() # Ensure result is fully consumed
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error getting node for {node_id}: {str(e)}"
- )
- raise
- async def node_degree(self, node_id: str) -> int:
- """Get the degree (number of relationships) of a node with the given label.
- If multiple nodes have the same label, returns the degree of the first node.
- If no node is found, returns 0.
- Args:
- node_id: The label of the node
- Returns:
- int: The number of relationships the node has, or 0 if no node found
- Raises:
- Exception: If there is an error executing the query
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- async with self._driver.session(
- database=self._DATABASE, default_access_mode="READ"
- ) as session:
- try:
- workspace_label = self._get_workspace_label()
- query = f"""
- MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
- OPTIONAL MATCH (n)-[r]-()
- RETURN COUNT(r) AS degree
- """
- result = await session.run(query, entity_id=node_id)
- try:
- record = await result.single()
- if not record:
- logger.warning(
- f"[{self.workspace}] No node found with label '{node_id}'"
- )
- return 0
- degree = record["degree"]
- return degree
- finally:
- await result.consume() # Ensure result is fully consumed
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error getting node degree for {node_id}: {str(e)}"
- )
- raise
- async def get_all_labels(self) -> list[str]:
- """
- Get all existing node labels(entity names) in the database
- Returns:
- ["Person", "Company", ...] # Alphabetically sorted label list
- Raises:
- Exception: If there is an error executing the query
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- async with self._driver.session(
- database=self._DATABASE, default_access_mode="READ"
- ) as session:
- result = None
- try:
- workspace_label = self._get_workspace_label()
- query = f"""
- MATCH (n:`{workspace_label}`)
- WHERE n.entity_id IS NOT NULL
- RETURN DISTINCT n.entity_id AS label
- ORDER BY label
- """
- result = await session.run(query)
- labels = []
- async for record in result:
- labels.append(record["label"])
- await result.consume()
- return labels
- except Exception as e:
- logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}")
- if result is not None:
- await (
- result.consume()
- ) # Ensure the result is consumed even on error
- raise
- async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
- """Retrieves all edges (relationships) for a particular node identified by its label.
- Args:
- source_node_id: Label of the node to get edges for
- Returns:
- list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
- None: If no edges found
- Raises:
- Exception: If there is an error executing the query
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- try:
- async with self._driver.session(
- database=self._DATABASE, default_access_mode="READ"
- ) as session:
- results = None
- try:
- workspace_label = self._get_workspace_label()
- query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
- OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`)
- WHERE connected.entity_id IS NOT NULL
- RETURN n.entity_id AS node_entity_id,
- connected.entity_id AS connected_entity_id,
- startNode(r).entity_id AS start_entity_id"""
- results = await session.run(query, entity_id=source_node_id)
- edges = []
- async for record in results:
- node_entity_id = record["node_entity_id"]
- connected_entity_id = record["connected_entity_id"]
- start_entity_id = record["start_entity_id"]
- if not node_entity_id or not connected_entity_id:
- continue
- # Preserve the original edge direction via startNode(r)
- if start_entity_id == node_entity_id:
- edges.append((node_entity_id, connected_entity_id))
- else:
- edges.append((connected_entity_id, node_entity_id))
- await results.consume() # Ensure results are consumed
- return edges
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
- )
- if results is not None:
- await (
- results.consume()
- ) # Ensure results are consumed even on error
- raise
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error in get_node_edges for {source_node_id}: {str(e)}"
- )
- raise
- async def get_edge(
- self, source_node_id: str, target_node_id: str
- ) -> dict[str, str] | None:
- """Get edge properties between two nodes.
- Args:
- source_node_id: Label of the source node
- target_node_id: Label of the target node
- Returns:
- dict: Edge properties if found, default properties if not found or on error
- Raises:
- Exception: If there is an error executing the query
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- async with self._driver.session(
- database=self._DATABASE, default_access_mode="READ"
- ) as session:
- result = None
- try:
- workspace_label = self._get_workspace_label()
- query = f"""
- MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}})
- RETURN properties(r) as edge_properties
- """
- result = await session.run(
- query,
- source_entity_id=source_node_id,
- target_entity_id=target_node_id,
- )
- records = await result.fetch(2)
- await result.consume()
- if records:
- edge_result = dict(records[0]["edge_properties"])
- for key, default_value in {
- "weight": 1.0,
- "source_id": None,
- "description": None,
- "keywords": None,
- }.items():
- if key not in edge_result:
- edge_result[key] = default_value
- logger.warning(
- f"[{self.workspace}] Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}"
- )
- return edge_result
- return None
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
- )
- if result is not None:
- await (
- result.consume()
- ) # Ensure the result is consumed even on error
- raise
- async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
- """
- Upsert a node in the Memgraph database with manual transaction-level retry logic for transient errors.
- Args:
- node_id: The unique identifier for the node (used as label)
- node_data: Dictionary of node properties
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- properties = node_data
- if "entity_id" not in properties:
- raise ValueError(
- "Memgraph: node properties must contain an 'entity_id' field"
- )
- # Manual transaction-level retry following official Memgraph documentation
- max_retries = 100
- initial_wait_time = 0.2
- backoff_factor = 1.1
- jitter_factor = 0.1
- for attempt in range(max_retries):
- try:
- logger.debug(
- f"[{self.workspace}] Attempting node upsert, attempt {attempt + 1}/{max_retries}"
- )
- async with self._driver.session(database=self._DATABASE) as session:
- workspace_label = self._get_workspace_label()
- async def execute_upsert(tx: AsyncManagedTransaction):
- query = f"""
- MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
- SET n += $properties
- """
- result = await tx.run(
- query, entity_id=node_id, properties=properties
- )
- await result.consume() # Ensure result is fully consumed
- await session.execute_write(execute_upsert)
- break # Success - exit retry loop
- except (TransientError, ResultFailedError) as e:
- # Check if the root cause is a TransientError
- root_cause = e
- while hasattr(root_cause, "__cause__") and root_cause.__cause__:
- root_cause = root_cause.__cause__
- # Check if this is a transient error that should be retried
- is_transient = (
- isinstance(root_cause, TransientError)
- or isinstance(e, TransientError)
- or "TransientError" in str(e)
- or "Cannot resolve conflicting transactions" in str(e)
- )
- if is_transient:
- if attempt < max_retries - 1:
- # Calculate wait time with exponential backoff and jitter
- jitter = random.uniform(0, jitter_factor) * initial_wait_time
- wait_time = (
- initial_wait_time * (backoff_factor**attempt) + jitter
- )
- logger.warning(
- f"[{self.workspace}] Node upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
- )
- await asyncio.sleep(wait_time)
- else:
- logger.error(
- f"[{self.workspace}] Memgraph transient error during node upsert after {max_retries} retries: {str(e)}"
- )
- raise
- else:
- # Non-transient error, don't retry
- logger.error(
- f"[{self.workspace}] Non-transient error during node upsert: {str(e)}"
- )
- raise
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Unexpected error during node upsert: {str(e)}"
- )
- raise
- async def upsert_edge(
- self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
- ) -> None:
- """
- Upsert an edge and its properties between two nodes identified by their labels with manual transaction-level retry logic for transient errors.
- Ensures both source and target nodes exist and are unique before creating the edge.
- Uses entity_id property to uniquely identify nodes.
- Args:
- source_node_id (str): Label of the source node (used as identifier)
- target_node_id (str): Label of the target node (used as identifier)
- edge_data (dict): Dictionary of properties to set on the edge
- Raises:
- Exception: If there is an error executing the query
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- edge_properties = edge_data
- # Manual transaction-level retry following official Memgraph documentation
- max_retries = 100
- initial_wait_time = 0.2
- backoff_factor = 1.1
- jitter_factor = 0.1
- for attempt in range(max_retries):
- try:
- logger.debug(
- f"[{self.workspace}] Attempting edge upsert, attempt {attempt + 1}/{max_retries}"
- )
- async with self._driver.session(database=self._DATABASE) as session:
- async def execute_upsert(tx: AsyncManagedTransaction):
- workspace_label = self._get_workspace_label()
- query = f"""
- MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
- WITH source
- MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
- MERGE (source)-[r:DIRECTED]-(target)
- SET r += $properties
- RETURN r, source, target
- """
- result = await tx.run(
- query,
- source_entity_id=source_node_id,
- target_entity_id=target_node_id,
- properties=edge_properties,
- )
- try:
- await result.fetch(2)
- finally:
- await result.consume() # Ensure result is consumed
- await session.execute_write(execute_upsert)
- break # Success - exit retry loop
- except (TransientError, ResultFailedError) as e:
- # Check if the root cause is a TransientError
- root_cause = e
- while hasattr(root_cause, "__cause__") and root_cause.__cause__:
- root_cause = root_cause.__cause__
- # Check if this is a transient error that should be retried
- is_transient = (
- isinstance(root_cause, TransientError)
- or isinstance(e, TransientError)
- or "TransientError" in str(e)
- or "Cannot resolve conflicting transactions" in str(e)
- )
- if is_transient:
- if attempt < max_retries - 1:
- # Calculate wait time with exponential backoff and jitter
- jitter = random.uniform(0, jitter_factor) * initial_wait_time
- wait_time = (
- initial_wait_time * (backoff_factor**attempt) + jitter
- )
- logger.warning(
- f"[{self.workspace}] Edge upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
- )
- await asyncio.sleep(wait_time)
- else:
- logger.error(
- f"[{self.workspace}] Memgraph transient error during edge upsert after {max_retries} retries: {str(e)}"
- )
- raise
- else:
- # Non-transient error, don't retry
- logger.error(
- f"[{self.workspace}] Non-transient error during edge upsert: {str(e)}"
- )
- raise
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Unexpected error during edge upsert: {str(e)}"
- )
- raise
- async def upsert_nodes_batch(self, nodes: list[tuple[str, dict[str, str]]]) -> None:
- """Batch insert/update multiple nodes using a single UNWIND Cypher query.
- Uses the same transient-error retry logic as upsert_node().
- Args:
- nodes: List of (node_id, node_data) tuples.
- """
- if not nodes:
- return
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- workspace_label = self._get_workspace_label()
- nodes_data = []
- for node_id, node_data in nodes:
- if "entity_id" not in node_data:
- raise ValueError(
- "Memgraph: node properties must contain an 'entity_id' field"
- )
- nodes_data.append({"entity_id": node_id, "props": node_data})
- max_retries = 100
- initial_wait_time = 0.2
- backoff_factor = 1.1
- jitter_factor = 0.1
- for attempt in range(max_retries):
- try:
- async with self._driver.session(database=self._DATABASE) as session:
- async def execute_batch(tx: AsyncManagedTransaction):
- query = f"""
- UNWIND $nodes AS row
- MERGE (n:`{workspace_label}` {{entity_id: row.entity_id}})
- SET n += row.props
- """
- result = await tx.run(query, nodes=nodes_data)
- await result.consume()
- await session.execute_write(execute_batch)
- break
- except (TransientError, ResultFailedError) as e:
- root_cause = e
- while hasattr(root_cause, "__cause__") and root_cause.__cause__:
- root_cause = root_cause.__cause__
- is_transient = (
- isinstance(root_cause, TransientError)
- or isinstance(e, TransientError)
- or "TransientError" in str(e)
- or "Cannot resolve conflicting transactions" in str(e)
- )
- if is_transient:
- if attempt < max_retries - 1:
- jitter = random.uniform(0, jitter_factor) * initial_wait_time
- wait_time = (
- initial_wait_time * (backoff_factor**attempt) + jitter
- )
- logger.warning(
- f"[{self.workspace}] Batch node upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f}s... Error: {str(e)}"
- )
- await asyncio.sleep(wait_time)
- else:
- logger.error(
- f"[{self.workspace}] Memgraph transient error during batch node upsert after {max_retries} retries: {str(e)}"
- )
- raise
- else:
- logger.error(
- f"[{self.workspace}] Non-transient error during batch node upsert: {str(e)}"
- )
- raise
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Unexpected error during batch node upsert: {str(e)}"
- )
- raise
- async def has_nodes_batch(self, node_ids: list[str]) -> set[str]:
- """Check existence of multiple nodes in a single UNWIND query.
- Args:
- node_ids: List of node IDs to check.
- Returns:
- Set of node_ids that exist in the graph.
- """
- if not node_ids:
- return set()
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- workspace_label = self._get_workspace_label()
- try:
- async with self._driver.session(
- database=self._DATABASE, default_access_mode="READ"
- ) as session:
- query = f"""
- UNWIND $ids AS id
- MATCH (n:`{workspace_label}` {{entity_id: id}})
- RETURN n.entity_id AS entity_id
- """
- result = await session.run(query, ids=node_ids)
- records = await result.data()
- await result.consume()
- return {r["entity_id"] for r in records}
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error during batch node existence check: {str(e)}"
- )
- raise
- async def upsert_edges_batch(
- self, edges: list[tuple[str, str, dict[str, str]]]
- ) -> None:
- """Batch insert/update multiple edges using a single UNWIND Cypher query.
- Uses the same transient-error retry logic as upsert_edge().
- Args:
- edges: List of (source_node_id, target_node_id, edge_data) tuples.
- """
- if not edges:
- return
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- workspace_label = self._get_workspace_label()
- edges_data = [
- {"src": src, "tgt": tgt, "props": edge_data}
- for src, tgt, edge_data in edges
- ]
- max_retries = 100
- initial_wait_time = 0.2
- backoff_factor = 1.1
- jitter_factor = 0.1
- for attempt in range(max_retries):
- try:
- async with self._driver.session(database=self._DATABASE) as session:
- async def execute_batch(tx: AsyncManagedTransaction):
- query = f"""
- UNWIND $edges AS row
- MATCH (source:`{workspace_label}` {{entity_id: row.src}})
- WITH source, row
- MATCH (target:`{workspace_label}` {{entity_id: row.tgt}})
- MERGE (source)-[r:DIRECTED]-(target)
- SET r += row.props
- RETURN r
- """
- result = await tx.run(query, edges=edges_data)
- await result.consume()
- await session.execute_write(execute_batch)
- break
- except (TransientError, ResultFailedError) as e:
- root_cause = e
- while hasattr(root_cause, "__cause__") and root_cause.__cause__:
- root_cause = root_cause.__cause__
- is_transient = (
- isinstance(root_cause, TransientError)
- or isinstance(e, TransientError)
- or "TransientError" in str(e)
- or "Cannot resolve conflicting transactions" in str(e)
- )
- if is_transient:
- if attempt < max_retries - 1:
- jitter = random.uniform(0, jitter_factor) * initial_wait_time
- wait_time = (
- initial_wait_time * (backoff_factor**attempt) + jitter
- )
- logger.warning(
- f"[{self.workspace}] Batch edge upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f}s... Error: {str(e)}"
- )
- await asyncio.sleep(wait_time)
- else:
- logger.error(
- f"[{self.workspace}] Memgraph transient error during batch edge upsert after {max_retries} retries: {str(e)}"
- )
- raise
- else:
- logger.error(
- f"[{self.workspace}] Non-transient error during batch edge upsert: {str(e)}"
- )
- raise
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Unexpected error during batch edge upsert: {str(e)}"
- )
- raise
- async def delete_node(self, node_id: str) -> None:
- """Delete a node with the specified label
- Args:
- node_id: The label of the node to delete
- Raises:
- Exception: If there is an error executing the query
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- async def _do_delete(tx: AsyncManagedTransaction):
- workspace_label = self._get_workspace_label()
- query = f"""
- MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
- DETACH DELETE n
- """
- result = await tx.run(query, entity_id=node_id)
- logger.debug(f"[{self.workspace}] Deleted node with label {node_id}")
- await result.consume()
- try:
- async with self._driver.session(database=self._DATABASE) as session:
- await session.execute_write(_do_delete)
- except Exception as e:
- logger.error(f"[{self.workspace}] Error during node deletion: {str(e)}")
- raise
- async def remove_nodes(self, nodes: list[str]):
- """Delete multiple nodes
- Args:
- nodes: List of node labels to be deleted
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- for node in nodes:
- await self.delete_node(node)
- async def remove_edges(self, edges: list[tuple[str, str]]):
- """Delete multiple edges
- Args:
- edges: List of edges to be deleted, each edge is a (source, target) tuple
- Raises:
- Exception: If there is an error executing the query
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- for source, target in edges:
- async def _do_delete_edge(tx: AsyncManagedTransaction):
- workspace_label = self._get_workspace_label()
- query = f"""
- MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}})
- DELETE r
- """
- result = await tx.run(
- query, source_entity_id=source, target_entity_id=target
- )
- logger.debug(
- f"[{self.workspace}] Deleted edge from '{source}' to '{target}'"
- )
- await result.consume() # Ensure result is fully consumed
- try:
- async with self._driver.session(database=self._DATABASE) as session:
- await session.execute_write(_do_delete_edge)
- except Exception as e:
- logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}")
- raise
- async def drop(self) -> dict[str, str]:
- """Drop all data from the current workspace and clean up resources
- This method will delete all nodes and relationships in the Memgraph database.
- Returns:
- dict[str, str]: Operation status and message
- - On success: {"status": "success", "message": "data dropped"}
- - On failure: {"status": "error", "message": "<error details>"}
- Raises:
- Exception: If there is an error executing the query
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- try:
- async with self._driver.session(database=self._DATABASE) as session:
- workspace_label = self._get_workspace_label()
- query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
- result = await session.run(query)
- await result.consume()
- logger.info(
- f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
- )
- return {"status": "success", "message": "workspace data dropped"}
- except Exception as e:
- logger.error(
- f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
- )
- return {"status": "error", "message": str(e)}
- async def edge_degree(self, src_id: str, tgt_id: str) -> int:
- """Get the total degree (sum of relationships) of two nodes.
- Args:
- src_id: Label of the source node
- tgt_id: Label of the target node
- Returns:
- int: Sum of the degrees of both nodes
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- src_degree = await self.node_degree(src_id)
- trg_degree = await self.node_degree(tgt_id)
- # Convert None to 0 for addition
- src_degree = 0 if src_degree is None else src_degree
- trg_degree = 0 if trg_degree is None else trg_degree
- degrees = int(src_degree) + int(trg_degree)
- return degrees
- async def get_knowledge_graph(
- self,
- node_label: str,
- max_depth: int = 3,
- max_nodes: int = None,
- ) -> KnowledgeGraph:
- """
- Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
- Args:
- node_label: Label of the starting node, * means all nodes
- max_depth: Maximum depth of the subgraph, Defaults to 3
- max_nodes: Maximum nodes to return by BFS, Defaults to 1000
- Returns:
- KnowledgeGraph object containing nodes and edges, with an is_truncated flag
- indicating whether the graph was truncated due to max_nodes limit
- """
- # Get max_nodes from global_config if not provided
- if max_nodes is None:
- max_nodes = self.global_config.get("max_graph_nodes", 1000)
- else:
- # Limit max_nodes to not exceed global_config max_graph_nodes
- max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000))
- workspace_label = self._get_workspace_label()
- result = KnowledgeGraph()
- seen_nodes = set()
- seen_edges = set()
- async with self._driver.session(
- database=self._DATABASE, default_access_mode="READ"
- ) as session:
- try:
- if node_label == "*":
- # First check total node count to determine if graph is truncated
- count_query = (
- f"MATCH (n:`{workspace_label}`) RETURN count(n) as total"
- )
- count_result = None
- try:
- count_result = await session.run(count_query)
- count_record = await count_result.single()
- if count_record and count_record["total"] > max_nodes:
- result.is_truncated = True
- logger.info(
- f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
- )
- finally:
- if count_result:
- await count_result.consume()
- # Run main query to get nodes with highest degree
- main_query = f"""
- MATCH (n:`{workspace_label}`)
- OPTIONAL MATCH (n)-[r]-()
- WITH n, COALESCE(count(r), 0) AS degree
- ORDER BY degree DESC
- LIMIT $max_nodes
- WITH collect({{node: n}}) AS filtered_nodes
- UNWIND filtered_nodes AS node_info
- WITH collect(node_info.node) AS kept_nodes, filtered_nodes
- OPTIONAL MATCH (a)-[r]-(b)
- WHERE a IN kept_nodes AND b IN kept_nodes
- RETURN filtered_nodes AS node_info,
- collect(DISTINCT r) AS relationships
- """
- result_set = None
- try:
- result_set = await session.run(
- main_query,
- {"max_nodes": max_nodes},
- )
- record = await result_set.single()
- finally:
- if result_set:
- await result_set.consume()
- else:
- # Run subgraph query for specific node_label
- subgraph_query = f"""
- MATCH (start:`{workspace_label}`)
- WHERE start.entity_id = $entity_id
- OPTIONAL MATCH path = (start)-[*BFS 0..{max_depth}]-(end:`{workspace_label}`)
- WHERE path IS NULL OR ALL(n IN nodes(path) WHERE '{workspace_label}' IN labels(n))
- WITH start, collect(DISTINCT end) AS discovered_nodes
- WITH start, [node IN discovered_nodes WHERE node IS NOT NULL AND node <> start] AS other_nodes
- WITH
- CASE
- WHEN 1 + size(other_nodes) <= $max_nodes THEN [start] + other_nodes
- ELSE [start] + other_nodes[0..$max_other_nodes]
- END AS limited_nodes,
- 1 + size(other_nodes) > $max_nodes AS is_truncated
- UNWIND limited_nodes AS n
- OPTIONAL MATCH (n)-[r]-(m)
- WHERE m IN limited_nodes
- WITH limited_nodes, collect(DISTINCT r) AS relationships, is_truncated
- RETURN
- [node IN limited_nodes | {{node: node}}] AS node_info,
- [rel IN relationships WHERE rel IS NOT NULL] AS relationships,
- is_truncated
- """
- result_set = None
- try:
- result_set = await session.run(
- subgraph_query,
- {
- "entity_id": node_label,
- "max_nodes": max_nodes,
- "max_other_nodes": max(max_nodes - 1, 0),
- },
- )
- record = await result_set.single()
- # If no record found, return empty KnowledgeGraph
- if not record:
- logger.debug(
- f"[{self.workspace}] No nodes found for entity_id: {node_label}"
- )
- return result
- # Check if the result was truncated
- if record.get("is_truncated"):
- result.is_truncated = True
- logger.info(
- f"[{self.workspace}] Graph truncated: breadth-first search limited to {max_nodes} nodes"
- )
- finally:
- if result_set:
- await result_set.consume()
- if record:
- for node_info in record["node_info"]:
- node = node_info["node"]
- node_id = node.id
- if node_id not in seen_nodes:
- result.nodes.append(
- KnowledgeGraphNode(
- id=f"{node_id}",
- labels=[node.get("entity_id")],
- properties=dict(node),
- )
- )
- seen_nodes.add(node_id)
- for rel in record["relationships"]:
- edge_id = rel.id
- if edge_id not in seen_edges:
- start = rel.start_node
- end = rel.end_node
- result.edges.append(
- KnowledgeGraphEdge(
- id=f"{edge_id}",
- type=rel.type,
- source=f"{start.id}",
- target=f"{end.id}",
- properties=dict(rel),
- )
- )
- seen_edges.add(edge_id)
- logger.info(
- f"[{self.workspace}] Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
- )
- except Exception as e:
- logger.warning(
- f"[{self.workspace}] Memgraph error during subgraph query: {str(e)}"
- )
- return result
- async def get_all_nodes(self) -> list[dict]:
- """Get all nodes in the graph.
- Returns:
- A list of all nodes, where each node is a dictionary of its properties
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- workspace_label = self._get_workspace_label()
- async with self._driver.session(
- database=self._DATABASE, default_access_mode="READ"
- ) as session:
- query = f"""
- MATCH (n:`{workspace_label}`)
- RETURN n
- """
- result = await session.run(query)
- nodes = []
- async for record in result:
- node = record["n"]
- node_dict = dict(node)
- # Add node id (entity_id) to the dictionary for easier access
- node_dict["id"] = node_dict.get("entity_id")
- nodes.append(node_dict)
- await result.consume()
- return nodes
- async def get_all_edges(self) -> list[dict]:
- """Get all edges in the graph.
- Returns:
- A list of all edges, where each edge is a dictionary of its properties
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- workspace_label = self._get_workspace_label()
- async with self._driver.session(
- database=self._DATABASE, default_access_mode="READ"
- ) as session:
- query = f"""
- MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
- RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
- """
- result = await session.run(query)
- edges = []
- async for record in result:
- edge_properties = record["properties"]
- edge_properties["source"] = record["source"]
- edge_properties["target"] = record["target"]
- edges.append(edge_properties)
- await result.consume()
- return edges
- async def get_popular_labels(self, limit: int = 300) -> list[str]:
- """Get popular labels by node degree (most connected entities)
- Args:
- limit: Maximum number of labels(entity names) to return
- Returns:
- List of labels(entity names) sorted by degree (highest first)
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- result = None
- try:
- workspace_label = self._get_workspace_label()
- async with self._driver.session(
- database=self._DATABASE, default_access_mode="READ"
- ) as session:
- query = f"""
- MATCH (n:`{workspace_label}`)
- WHERE n.entity_id IS NOT NULL
- OPTIONAL MATCH (n)-[r]-()
- WITH n.entity_id AS label, count(r) AS degree
- ORDER BY degree DESC, label ASC
- LIMIT {limit}
- RETURN label
- """
- result = await session.run(query)
- labels = []
- async for record in result:
- labels.append(record["label"])
- await result.consume()
- logger.debug(
- f"[{self.workspace}] Retrieved {len(labels)} popular labels (limit: {limit})"
- )
- return labels
- except Exception as e:
- logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}")
- if result is not None:
- await result.consume()
- return []
- async def search_labels(self, query: str, limit: int = 50) -> list[str]:
- """Search labels(entity names) with fuzzy matching
- Args:
- query: Search query string
- limit: Maximum number of results to return
- Returns:
- List of matching labels(entity names) sorted by relevance
- """
- if self._driver is None:
- raise RuntimeError(
- "Memgraph driver is not initialized. Call 'await initialize()' first."
- )
- query_lower = query.lower().strip()
- if not query_lower:
- return []
- result = None
- try:
- workspace_label = self._get_workspace_label()
- async with self._driver.session(
- database=self._DATABASE, default_access_mode="READ"
- ) as session:
- cypher_query = f"""
- MATCH (n:`{workspace_label}`)
- WHERE n.entity_id IS NOT NULL
- WITH n.entity_id AS label, toLower(n.entity_id) AS label_lower
- WHERE label_lower CONTAINS $query_lower
- WITH label, label_lower,
- CASE
- WHEN label_lower = $query_lower THEN 1000
- WHEN label_lower STARTS WITH $query_lower THEN 500
- ELSE 100 - size(label)
- END AS score
- ORDER BY score DESC, label ASC
- LIMIT {limit}
- RETURN label
- """
- result = await session.run(cypher_query, query_lower=query_lower)
- labels = []
- async for record in result:
- labels.append(record["label"])
- await result.consume()
- logger.debug(
- f"[{self.workspace}] Search query '{query}' returned {len(labels)} results (limit: {limit})"
- )
- return labels
- except Exception as e:
- logger.error(f"[{self.workspace}] Error searching labels: {str(e)}")
- if result is not None:
- await result.consume()
- return []
|