| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186 |
- import os
- import json
- import xml.etree.ElementTree as ET
- from neo4j import GraphDatabase
- # Constants
- WORKING_DIR = "./dickens"
- BATCH_SIZE_NODES = 500
- BATCH_SIZE_EDGES = 100
- # Neo4j connection credentials
- NEO4J_URI = "bolt://localhost:7687"
- NEO4J_USERNAME = "neo4j"
- NEO4J_PASSWORD = "your_password"
- def xml_to_json(xml_file):
- try:
- tree = ET.parse(xml_file)
- root = tree.getroot()
- # Print the root element's tag and attributes to confirm the file has been correctly loaded
- print(f"Root element: {root.tag}")
- print(f"Root attributes: {root.attrib}")
- data = {"nodes": [], "edges": []}
- # Use namespace
- namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
- for node in root.findall(".//node", namespace):
- node_data = {
- "id": node.get("id").strip('"'),
- "entity_type": node.find("./data[@key='d1']", namespace).text.strip('"')
- if node.find("./data[@key='d1']", namespace) is not None
- else "",
- "description": node.find("./data[@key='d2']", namespace).text
- if node.find("./data[@key='d2']", namespace) is not None
- else "",
- "source_id": node.find("./data[@key='d3']", namespace).text
- if node.find("./data[@key='d3']", namespace) is not None
- else "",
- }
- data["nodes"].append(node_data)
- for edge in root.findall(".//edge", namespace):
- edge_data = {
- "source": edge.get("source").strip('"'),
- "target": edge.get("target").strip('"'),
- "weight": float(edge.find("./data[@key='d5']", namespace).text)
- if edge.find("./data[@key='d5']", namespace) is not None
- else 0.0,
- "description": edge.find("./data[@key='d6']", namespace).text
- if edge.find("./data[@key='d6']", namespace) is not None
- else "",
- "keywords": edge.find("./data[@key='d9']", namespace).text
- if edge.find("./data[@key='d9']", namespace) is not None
- else "",
- "source_id": edge.find("./data[@key='d8']", namespace).text
- if edge.find("./data[@key='d8']", namespace) is not None
- else "",
- }
- data["edges"].append(edge_data)
- # Print the number of nodes and edges found
- print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
- return data
- except ET.ParseError as e:
- print(f"Error parsing XML file: {e}")
- return None
- except Exception as e:
- print(f"An error occurred: {e}")
- return None
- def convert_xml_to_json(xml_path, output_path):
- """Converts XML file to JSON and saves the output."""
- if not os.path.exists(xml_path):
- print(f"Error: File not found - {xml_path}")
- return None
- json_data = xml_to_json(xml_path)
- if json_data:
- with open(output_path, "w", encoding="utf-8") as f:
- json.dump(json_data, f, ensure_ascii=False, indent=2)
- print(f"JSON file created: {output_path}")
- return json_data
- else:
- print("Failed to create JSON data")
- return None
- def process_in_batches(tx, query, data, batch_size):
- """Process data in batches and execute the given query."""
- for i in range(0, len(data), batch_size):
- batch = data[i : i + batch_size]
- tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
- def main():
- # Paths
- xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml")
- json_file = os.path.join(WORKING_DIR, "graph_data.json")
- # Convert XML to JSON
- json_data = convert_xml_to_json(xml_file, json_file)
- if json_data is None:
- return
- # Load nodes and edges
- nodes = json_data.get("nodes", [])
- edges = json_data.get("edges", [])
- # Neo4j queries
- create_nodes_query = """
- UNWIND $nodes AS node
- MERGE (e:Entity {id: node.id})
- SET e.entity_type = node.entity_type,
- e.description = node.description,
- e.source_id = node.source_id,
- e.displayName = node.id
- REMOVE e:Entity
- WITH e, node
- CALL apoc.create.addLabels(e, [node.id]) YIELD node AS labeledNode
- RETURN count(*)
- """
- create_edges_query = """
- UNWIND $edges AS edge
- MATCH (source {id: edge.source})
- MATCH (target {id: edge.target})
- WITH source, target, edge,
- CASE
- WHEN edge.keywords CONTAINS 'lead' THEN 'lead'
- WHEN edge.keywords CONTAINS 'participate' THEN 'participate'
- WHEN edge.keywords CONTAINS 'uses' THEN 'uses'
- WHEN edge.keywords CONTAINS 'located' THEN 'located'
- WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs'
- ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '')
- END AS relType
- CALL apoc.create.relationship(source, relType, {
- weight: edge.weight,
- description: edge.description,
- keywords: edge.keywords,
- source_id: edge.source_id
- }, target) YIELD rel
- RETURN count(*)
- """
- set_displayname_and_labels_query = """
- MATCH (n)
- SET n.displayName = n.id
- WITH n
- CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node
- RETURN count(*)
- """
- # Create a Neo4j driver
- driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
- try:
- # Execute queries in batches
- with driver.session() as session:
- # Insert nodes in batches
- session.execute_write(
- process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES
- )
- # Insert edges in batches
- session.execute_write(
- process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES
- )
- # Set displayName and labels
- session.run(set_displayname_and_labels_query)
- except Exception as e:
- print(f"Error occurred: {e}")
- finally:
- driver.close()
- if __name__ == "__main__":
- main()
|