graph_visual_with_neo4j.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import os
  2. import json
  3. import xml.etree.ElementTree as ET
  4. from neo4j import GraphDatabase
  5. # Constants
  6. WORKING_DIR = "./dickens"
  7. BATCH_SIZE_NODES = 500
  8. BATCH_SIZE_EDGES = 100
  9. # Neo4j connection credentials
  10. NEO4J_URI = "bolt://localhost:7687"
  11. NEO4J_USERNAME = "neo4j"
  12. NEO4J_PASSWORD = "your_password"
  13. def xml_to_json(xml_file):
  14. try:
  15. tree = ET.parse(xml_file)
  16. root = tree.getroot()
  17. # Print the root element's tag and attributes to confirm the file has been correctly loaded
  18. print(f"Root element: {root.tag}")
  19. print(f"Root attributes: {root.attrib}")
  20. data = {"nodes": [], "edges": []}
  21. # Use namespace
  22. namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
  23. for node in root.findall(".//node", namespace):
  24. node_data = {
  25. "id": node.get("id").strip('"'),
  26. "entity_type": node.find("./data[@key='d1']", namespace).text.strip('"')
  27. if node.find("./data[@key='d1']", namespace) is not None
  28. else "",
  29. "description": node.find("./data[@key='d2']", namespace).text
  30. if node.find("./data[@key='d2']", namespace) is not None
  31. else "",
  32. "source_id": node.find("./data[@key='d3']", namespace).text
  33. if node.find("./data[@key='d3']", namespace) is not None
  34. else "",
  35. }
  36. data["nodes"].append(node_data)
  37. for edge in root.findall(".//edge", namespace):
  38. edge_data = {
  39. "source": edge.get("source").strip('"'),
  40. "target": edge.get("target").strip('"'),
  41. "weight": float(edge.find("./data[@key='d5']", namespace).text)
  42. if edge.find("./data[@key='d5']", namespace) is not None
  43. else 0.0,
  44. "description": edge.find("./data[@key='d6']", namespace).text
  45. if edge.find("./data[@key='d6']", namespace) is not None
  46. else "",
  47. "keywords": edge.find("./data[@key='d9']", namespace).text
  48. if edge.find("./data[@key='d9']", namespace) is not None
  49. else "",
  50. "source_id": edge.find("./data[@key='d8']", namespace).text
  51. if edge.find("./data[@key='d8']", namespace) is not None
  52. else "",
  53. }
  54. data["edges"].append(edge_data)
  55. # Print the number of nodes and edges found
  56. print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
  57. return data
  58. except ET.ParseError as e:
  59. print(f"Error parsing XML file: {e}")
  60. return None
  61. except Exception as e:
  62. print(f"An error occurred: {e}")
  63. return None
  64. def convert_xml_to_json(xml_path, output_path):
  65. """Converts XML file to JSON and saves the output."""
  66. if not os.path.exists(xml_path):
  67. print(f"Error: File not found - {xml_path}")
  68. return None
  69. json_data = xml_to_json(xml_path)
  70. if json_data:
  71. with open(output_path, "w", encoding="utf-8") as f:
  72. json.dump(json_data, f, ensure_ascii=False, indent=2)
  73. print(f"JSON file created: {output_path}")
  74. return json_data
  75. else:
  76. print("Failed to create JSON data")
  77. return None
  78. def process_in_batches(tx, query, data, batch_size):
  79. """Process data in batches and execute the given query."""
  80. for i in range(0, len(data), batch_size):
  81. batch = data[i : i + batch_size]
  82. tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
  83. def main():
  84. # Paths
  85. xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml")
  86. json_file = os.path.join(WORKING_DIR, "graph_data.json")
  87. # Convert XML to JSON
  88. json_data = convert_xml_to_json(xml_file, json_file)
  89. if json_data is None:
  90. return
  91. # Load nodes and edges
  92. nodes = json_data.get("nodes", [])
  93. edges = json_data.get("edges", [])
  94. # Neo4j queries
  95. create_nodes_query = """
  96. UNWIND $nodes AS node
  97. MERGE (e:Entity {id: node.id})
  98. SET e.entity_type = node.entity_type,
  99. e.description = node.description,
  100. e.source_id = node.source_id,
  101. e.displayName = node.id
  102. REMOVE e:Entity
  103. WITH e, node
  104. CALL apoc.create.addLabels(e, [node.id]) YIELD node AS labeledNode
  105. RETURN count(*)
  106. """
  107. create_edges_query = """
  108. UNWIND $edges AS edge
  109. MATCH (source {id: edge.source})
  110. MATCH (target {id: edge.target})
  111. WITH source, target, edge,
  112. CASE
  113. WHEN edge.keywords CONTAINS 'lead' THEN 'lead'
  114. WHEN edge.keywords CONTAINS 'participate' THEN 'participate'
  115. WHEN edge.keywords CONTAINS 'uses' THEN 'uses'
  116. WHEN edge.keywords CONTAINS 'located' THEN 'located'
  117. WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs'
  118. ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '')
  119. END AS relType
  120. CALL apoc.create.relationship(source, relType, {
  121. weight: edge.weight,
  122. description: edge.description,
  123. keywords: edge.keywords,
  124. source_id: edge.source_id
  125. }, target) YIELD rel
  126. RETURN count(*)
  127. """
  128. set_displayname_and_labels_query = """
  129. MATCH (n)
  130. SET n.displayName = n.id
  131. WITH n
  132. CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node
  133. RETURN count(*)
  134. """
  135. # Create a Neo4j driver
  136. driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
  137. try:
  138. # Execute queries in batches
  139. with driver.session() as session:
  140. # Insert nodes in batches
  141. session.execute_write(
  142. process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES
  143. )
  144. # Insert edges in batches
  145. session.execute_write(
  146. process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES
  147. )
  148. # Set displayName and labels
  149. session.run(set_displayname_and_labels_query)
  150. except Exception as e:
  151. print(f"Error occurred: {e}")
  152. finally:
  153. driver.close()
  154. if __name__ == "__main__":
  155. main()