graph_visualizer.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221
  1. from typing import Optional, Tuple, Dict, List
  2. import numpy as np
  3. import networkx as nx
  4. import pipmaster as pm
  5. # Added automatic libraries install using pipmaster
  6. if not pm.is_installed("moderngl"):
  7. pm.install("moderngl")
  8. if not pm.is_installed("imgui_bundle"):
  9. pm.install("imgui_bundle")
  10. if not pm.is_installed("pyglm"):
  11. pm.install("pyglm")
  12. if not pm.is_installed("python-louvain"):
  13. pm.install("python-louvain")
  14. import moderngl
  15. from imgui_bundle import imgui, immapp, hello_imgui
  16. import community
  17. import glm
  18. import tkinter as tk
  19. from tkinter import filedialog
  20. import traceback
  21. import colorsys
  22. import os
  23. CUSTOM_FONT = "font.ttf"
  24. DEFAULT_FONT_ENG = "Geist-Regular.ttf"
  25. DEFAULT_FONT_CHI = "SmileySans-Oblique.ttf"
  26. class Node3D:
  27. """Class representing a 3D node in the graph"""
  28. def __init__(
  29. self, position: glm.vec3, color: glm.vec3, label: str, size: float, idx: int
  30. ):
  31. self.position = position
  32. self.color = color
  33. self.label = label
  34. self.size = size
  35. self.idx = idx
  36. class GraphViewer:
  37. """Main class for 3D graph visualization"""
  38. def __init__(self):
  39. self.glctx = None # ModernGL context
  40. self.graph: Optional[nx.Graph] = None
  41. self.nodes: List[Node3D] = []
  42. self.id_node_map: Dict[str, Node3D] = {}
  43. self.communities = None
  44. self.community_colors = None
  45. # Window dimensions
  46. self.window_width = 1280
  47. self.window_height = 720
  48. # Camera parameters
  49. self.position = glm.vec3(0.0, -10.0, 0.0) # Initial camera position
  50. self.front = glm.vec3(0.0, 1.0, 0.0) # Direction camera is facing
  51. self.up = glm.vec3(0.0, 0.0, 1.0) # Up vector
  52. self.yaw = 90.0 # Horizontal rotation (around Z axis)
  53. self.pitch = 0.0 # Vertical rotation
  54. self.move_speed = 0.05
  55. self.mouse_sensitivity = 0.15
  56. # Graph visualization settings
  57. self.layout_type = "Spring"
  58. self.node_scale = 0.2
  59. self.edge_width = 0.5
  60. self.show_labels = True
  61. self.label_size = 2
  62. self.label_color = (1.0, 1.0, 1.0, 1.0)
  63. self.label_culling_distance = 10.0
  64. self.available_layouts = ("Spring", "Circular", "Shell", "Random")
  65. self.background_color = (0.05, 0.05, 0.05, 1.0)
  66. # Mouse interaction
  67. self.last_mouse_pos = None
  68. self.mouse_pressed = False
  69. self.mouse_button = -1
  70. self.first_mouse = True
  71. # File dialog state
  72. self.show_load_error = False
  73. self.error_message = ""
  74. # Selection state
  75. self.selected_node: Optional[Node3D] = None
  76. self.highlighted_node: Optional[Node3D] = None
  77. # Node id map
  78. self.node_id_fbo = None
  79. self.node_id_texture = None
  80. self.node_id_depth = None
  81. self.node_id_texture_np: np.ndarray = None
  82. # Static data
  83. self.sphere_data = create_sphere()
  84. # Initialization flag
  85. self.initialized = False
  86. def setup(self):
  87. self.setup_render_context()
  88. self.setup_shaders()
  89. self.setup_buffers()
  90. self.initialized = True
  91. def handle_keyboard_input(self):
  92. """Handle WASD keyboard input for camera movement"""
  93. io = imgui.get_io()
  94. if io.want_capture_keyboard:
  95. return
  96. # Calculate camera vectors
  97. right = glm.normalize(glm.cross(self.front, self.up))
  98. # Get movement direction from WASD keys
  99. if imgui.is_key_down(imgui.Key.w): # Forward
  100. self.position += self.front * self.move_speed * 0.1
  101. if imgui.is_key_down(imgui.Key.s): # Backward
  102. self.position -= self.front * self.move_speed * 0.1
  103. if imgui.is_key_down(imgui.Key.a): # Left
  104. self.position -= right * self.move_speed * 0.1
  105. if imgui.is_key_down(imgui.Key.d): # Right
  106. self.position += right * self.move_speed * 0.1
  107. if imgui.is_key_down(imgui.Key.q): # Up
  108. self.position += self.up * self.move_speed * 0.1
  109. if imgui.is_key_down(imgui.Key.e): # Down
  110. self.position -= self.up * self.move_speed * 0.1
  111. def handle_mouse_interaction(self):
  112. """Handle mouse interaction for camera control and node selection"""
  113. if (
  114. imgui.is_any_item_active()
  115. or imgui.is_any_item_hovered()
  116. or imgui.is_any_item_focused()
  117. ):
  118. return
  119. io = imgui.get_io()
  120. mouse_pos = (io.mouse_pos.x, io.mouse_pos.y)
  121. if (
  122. mouse_pos[0] < 0
  123. or mouse_pos[1] < 0
  124. or mouse_pos[0] >= self.window_width
  125. or mouse_pos[1] >= self.window_height
  126. ):
  127. return
  128. # Handle first mouse input
  129. if self.first_mouse:
  130. self.last_mouse_pos = mouse_pos
  131. self.first_mouse = False
  132. return
  133. # Handle mouse movement for camera rotation
  134. if self.mouse_pressed and self.mouse_button == 1: # Right mouse button
  135. dx = self.last_mouse_pos[0] - mouse_pos[0]
  136. dy = self.last_mouse_pos[1] - mouse_pos[1] # Reversed for intuitive control
  137. dx *= self.mouse_sensitivity
  138. dy *= self.mouse_sensitivity
  139. self.yaw += dx
  140. self.pitch += dy
  141. # Limit pitch to avoid flipping
  142. self.pitch = np.clip(self.pitch, -89.0, 89.0)
  143. # Update front vector
  144. self.front = glm.normalize(
  145. glm.vec3(
  146. np.cos(np.radians(self.yaw)) * np.cos(np.radians(self.pitch)),
  147. np.sin(np.radians(self.yaw)) * np.cos(np.radians(self.pitch)),
  148. np.sin(np.radians(self.pitch)),
  149. )
  150. )
  151. if not imgui.is_window_hovered():
  152. return
  153. if io.mouse_wheel != 0:
  154. self.move_speed += io.mouse_wheel * 0.05
  155. self.move_speed = np.max([self.move_speed, 0.01])
  156. # Handle mouse press/release
  157. for button in range(3):
  158. if imgui.is_mouse_clicked(button):
  159. self.mouse_pressed = True
  160. self.mouse_button = button
  161. if button == 0 and self.highlighted_node: # Left click for selection
  162. self.selected_node = self.highlighted_node
  163. if imgui.is_mouse_released(button) and self.mouse_button == button:
  164. self.mouse_pressed = False
  165. self.mouse_button = -1
  166. # Handle node hovering
  167. if not self.mouse_pressed:
  168. hovered = self.find_node_at((int(mouse_pos[0]), int(mouse_pos[1])))
  169. self.highlighted_node = hovered
  170. # Update last mouse position
  171. self.last_mouse_pos = mouse_pos
  172. def update_layout(self):
  173. """Update the graph layout"""
  174. pos = nx.spring_layout(
  175. self.graph,
  176. dim=3,
  177. pos={
  178. node_id: list(node.position)
  179. for node_id, node in self.id_node_map.items()
  180. },
  181. k=2.0,
  182. iterations=100,
  183. weight=None,
  184. )
  185. # Update node positions
  186. for node_id, position in pos.items():
  187. self.id_node_map[node_id].position = glm.vec3(position)
  188. self.update_buffers()
  189. def render_node_details(self):
  190. """Render node details window"""
  191. if self.selected_node and imgui.begin("Node Details"):
  192. imgui.text(f"ID: {self.selected_node.label}")
  193. if self.graph:
  194. node_data = self.graph.nodes[self.selected_node.label]
  195. imgui.text(f"Type: {node_data.get('type', 'default')}")
  196. degree = self.graph.degree[self.selected_node.label]
  197. imgui.text(f"Degree: {degree}")
  198. for key, value in node_data.items():
  199. if key != "type":
  200. imgui.text(f"{key}: {value}")
  201. if value and imgui.is_item_hovered():
  202. imgui.set_tooltip(str(value))
  203. imgui.separator()
  204. connections = self.graph[self.selected_node.label]
  205. if connections:
  206. imgui.text("Connections:")
  207. keys = next(iter(connections.values())).keys()
  208. if imgui.begin_table(
  209. "Connections",
  210. len(keys) + 1,
  211. imgui.TableFlags_.borders
  212. | imgui.TableFlags_.row_bg
  213. | imgui.TableFlags_.resizable
  214. | imgui.TableFlags_.hideable,
  215. ):
  216. imgui.table_setup_column("Node")
  217. for key in keys:
  218. imgui.table_setup_column(key)
  219. imgui.table_headers_row()
  220. for neighbor, edge_data in connections.items():
  221. imgui.table_next_row()
  222. imgui.table_set_column_index(0)
  223. if imgui.selectable(str(neighbor), True)[0]:
  224. # Select neighbor node
  225. self.selected_node = self.id_node_map[neighbor]
  226. self.position = self.selected_node.position - self.front
  227. for idx, key in enumerate(keys):
  228. imgui.table_set_column_index(idx + 1)
  229. value = str(edge_data.get(key, ""))
  230. imgui.text(value)
  231. if value and imgui.is_item_hovered():
  232. imgui.set_tooltip(value)
  233. imgui.end_table()
  234. imgui.end()
  235. def setup_render_context(self):
  236. """Initialize ModernGL context"""
  237. self.glctx = moderngl.create_context()
  238. self.glctx.enable(moderngl.DEPTH_TEST | moderngl.CULL_FACE)
  239. self.glctx.clear_color = self.background_color
  240. def setup_shaders(self):
  241. """Setup vertex and fragment shaders for node and edge rendering"""
  242. # Node shader program
  243. self.node_prog = self.glctx.program(
  244. vertex_shader="""
  245. #version 330
  246. uniform mat4 mvp;
  247. uniform vec3 camera;
  248. uniform int selected_node;
  249. uniform int highlighted_node;
  250. uniform float scale;
  251. in vec3 in_position;
  252. in vec3 in_instance_position;
  253. in vec3 in_instance_color;
  254. in float in_instance_size;
  255. out vec3 frag_color;
  256. out vec3 frag_normal;
  257. out vec3 frag_view_dir;
  258. void main() {
  259. vec3 pos = in_position * in_instance_size * scale + in_instance_position;
  260. gl_Position = mvp * vec4(pos, 1.0);
  261. frag_normal = normalize(in_position);
  262. frag_view_dir = normalize(camera - pos);
  263. if (selected_node == gl_InstanceID) {
  264. frag_color = vec3(1.0, 0.5, 0.0);
  265. }
  266. else if (highlighted_node == gl_InstanceID) {
  267. frag_color = vec3(1.0, 0.8, 0.2);
  268. }
  269. else {
  270. frag_color = in_instance_color;
  271. }
  272. }
  273. """,
  274. fragment_shader="""
  275. #version 330
  276. in vec3 frag_color;
  277. in vec3 frag_normal;
  278. in vec3 frag_view_dir;
  279. out vec4 outColor;
  280. void main() {
  281. // Edge detection based on normal-view angle
  282. float edge = 1.0 - abs(dot(frag_normal, frag_view_dir));
  283. // Create sharp outline
  284. float outline = smoothstep(0.8, 0.9, edge);
  285. // Mix the sphere color with outline
  286. vec3 final_color = mix(frag_color, vec3(0.0), outline);
  287. outColor = vec4(final_color, 1.0);
  288. }
  289. """,
  290. )
  291. # Edge shader program with wide lines using geometry shader
  292. self.edge_prog = self.glctx.program(
  293. vertex_shader="""
  294. #version 330
  295. uniform mat4 mvp;
  296. in vec3 in_position;
  297. in vec3 in_color;
  298. out vec3 v_color;
  299. out vec4 v_position;
  300. void main() {
  301. v_position = mvp * vec4(in_position, 1.0);
  302. gl_Position = v_position;
  303. v_color = in_color;
  304. }
  305. """,
  306. geometry_shader="""
  307. #version 330
  308. layout(lines) in;
  309. layout(triangle_strip, max_vertices = 4) out;
  310. uniform float edge_width;
  311. uniform vec2 viewport_size;
  312. in vec3 v_color[];
  313. in vec4 v_position[];
  314. out vec3 g_color;
  315. out float edge_coord;
  316. void main() {
  317. // Get the two vertices of the line
  318. vec4 p1 = v_position[0];
  319. vec4 p2 = v_position[1];
  320. // Perspective division
  321. vec4 p1_ndc = p1 / p1.w;
  322. vec4 p2_ndc = p2 / p2.w;
  323. // Calculate line direction in screen space
  324. vec2 dir = normalize((p2_ndc.xy - p1_ndc.xy) * viewport_size);
  325. vec2 normal = vec2(-dir.y, dir.x);
  326. // Calculate half width based on screen space
  327. float half_width = edge_width * 0.5;
  328. vec2 offset = normal * (half_width / viewport_size);
  329. // Emit vertices with proper depth
  330. gl_Position = vec4(p1_ndc.xy + offset, p1_ndc.z, 1.0);
  331. gl_Position *= p1.w; // Restore perspective
  332. g_color = v_color[0];
  333. edge_coord = 1.0;
  334. EmitVertex();
  335. gl_Position = vec4(p1_ndc.xy - offset, p1_ndc.z, 1.0);
  336. gl_Position *= p1.w;
  337. g_color = v_color[0];
  338. edge_coord = -1.0;
  339. EmitVertex();
  340. gl_Position = vec4(p2_ndc.xy + offset, p2_ndc.z, 1.0);
  341. gl_Position *= p2.w;
  342. g_color = v_color[1];
  343. edge_coord = 1.0;
  344. EmitVertex();
  345. gl_Position = vec4(p2_ndc.xy - offset, p2_ndc.z, 1.0);
  346. gl_Position *= p2.w;
  347. g_color = v_color[1];
  348. edge_coord = -1.0;
  349. EmitVertex();
  350. EndPrimitive();
  351. }
  352. """,
  353. fragment_shader="""
  354. #version 330
  355. in vec3 g_color;
  356. in float edge_coord;
  357. out vec4 fragColor;
  358. void main() {
  359. // Edge outline parameters
  360. float outline_width = 0.2; // Width of the outline relative to edge
  361. float edge_softness = 0.1; // Softness of the edge
  362. float edge_dist = abs(edge_coord);
  363. // Calculate outline
  364. float outline_factor = smoothstep(1.0 - outline_width - edge_softness,
  365. 1.0 - outline_width,
  366. edge_dist);
  367. // Mix edge color with outline (black)
  368. vec3 final_color = mix(g_color, vec3(0.0), outline_factor);
  369. // Calculate alpha for anti-aliasing
  370. float alpha = 1.0 - smoothstep(1.0 - edge_softness, 1.0, edge_dist);
  371. fragColor = vec4(final_color, alpha);
  372. }
  373. """,
  374. )
  375. # Id framebuffer shader program
  376. self.node_id_prog = self.glctx.program(
  377. vertex_shader="""
  378. #version 330
  379. uniform mat4 mvp;
  380. uniform float scale;
  381. in vec3 in_position;
  382. in vec3 in_instance_position;
  383. in float in_instance_size;
  384. out vec3 frag_color;
  385. vec3 int_to_rgb(int value) {
  386. float R = float((value >> 16) & 0xFF);
  387. float G = float((value >> 8) & 0xFF);
  388. float B = float(value & 0xFF);
  389. // normalize to [0, 1]
  390. return vec3(R / 255.0, G / 255.0, B / 255.0);
  391. }
  392. void main() {
  393. vec3 pos = in_position * in_instance_size * scale + in_instance_position;
  394. gl_Position = mvp * vec4(pos, 1.0);
  395. frag_color = int_to_rgb(gl_InstanceID);
  396. }
  397. """,
  398. fragment_shader="""
  399. #version 330
  400. in vec3 frag_color;
  401. out vec4 outColor;
  402. void main() {
  403. outColor = vec4(frag_color, 1.0);
  404. }
  405. """,
  406. )
  407. def setup_buffers(self):
  408. """Setup vertex buffers for nodes and edges"""
  409. # We'll create these when loading the graph
  410. self.node_vbo = None
  411. self.node_color_vbo = None
  412. self.node_size_vbo = None
  413. self.edge_vbo = None
  414. self.edge_color_vbo = None
  415. self.node_vao = None
  416. self.edge_vao = None
  417. self.node_id_vao = None
  418. self.sphere_pos_vbo = None
  419. self.sphere_index_buffer = None
  420. def load_file(self, filepath: str):
  421. """Load a GraphML file with error handling"""
  422. try:
  423. # Clear existing data
  424. self.id_node_map.clear()
  425. self.nodes.clear()
  426. self.selected_node = None
  427. self.highlighted_node = None
  428. self.setup_buffers()
  429. # Load new graph
  430. self.graph = nx.read_graphml(filepath)
  431. self.calculate_layout()
  432. self.update_buffers()
  433. self.show_load_error = False
  434. self.error_message = ""
  435. except Exception as _:
  436. self.show_load_error = True
  437. self.error_message = traceback.format_exc()
  438. print(self.error_message)
  439. def calculate_layout(self):
  440. """Calculate 3D layout for the graph"""
  441. if not self.graph:
  442. return
  443. # Detect communities for coloring
  444. self.communities = community.best_partition(self.graph)
  445. num_communities = len(set(self.communities.values()))
  446. self.community_colors = generate_colors(num_communities)
  447. # Calculate layout based on selected type
  448. if self.layout_type == "Spring":
  449. pos = nx.spring_layout(
  450. self.graph, dim=3, k=2.0, iterations=100, weight=None
  451. )
  452. elif self.layout_type == "Circular":
  453. pos_2d = nx.circular_layout(self.graph)
  454. pos = {node: np.array((x, 0.0, y)) for node, (x, y) in pos_2d.items()}
  455. elif self.layout_type == "Shell":
  456. # Group nodes by community for shell layout
  457. comm_lists = [[] for _ in range(num_communities)]
  458. for node, comm in self.communities.items():
  459. comm_lists[comm].append(node)
  460. pos_2d = nx.shell_layout(self.graph, comm_lists)
  461. pos = {node: np.array((x, 0.0, y)) for node, (x, y) in pos_2d.items()}
  462. else: # Random
  463. pos = {node: np.random.rand(3) * 2 - 1 for node in self.graph.nodes()}
  464. # Scale positions
  465. positions = np.array(list(pos.values()))
  466. if len(positions) > 0:
  467. scale = 10.0 / max(1.0, np.max(np.abs(positions)))
  468. pos = {node: coords * scale for node, coords in pos.items()}
  469. # Calculate degree-based sizes
  470. degrees = dict(self.graph.degree())
  471. max_degree = max(degrees.values()) if degrees else 1
  472. min_degree = min(degrees.values()) if degrees else 1
  473. idx = 0
  474. # Create nodes with community colors
  475. for node_id in self.graph.nodes():
  476. position = glm.vec3(pos[node_id])
  477. color = self.get_node_color(node_id)
  478. # Normalize sizes between 0.5 and 2.0
  479. size = 1.0
  480. if max_degree != min_degree:
  481. # Normalize and scale size
  482. normalized = (degrees[node_id] - min_degree) / (max_degree - min_degree)
  483. size = 0.5 + normalized * 1.5
  484. if node_id in self.id_node_map:
  485. node = self.id_node_map[node_id]
  486. node.position = position
  487. node.base_color = color
  488. node.color = color
  489. node.size = size
  490. else:
  491. node = Node3D(position, color, str(node_id), size, idx)
  492. self.id_node_map[node_id] = node
  493. self.nodes.append(node)
  494. idx += 1
  495. self.update_buffers()
  496. def get_node_color(self, node_id: str) -> glm.vec3:
  497. """Get RGBA color based on community"""
  498. if self.communities and node_id in self.communities:
  499. comm_id = self.communities[node_id]
  500. color = self.community_colors[comm_id]
  501. return color
  502. return glm.vec3(0.5, 0.5, 0.5)
  503. def update_buffers(self):
  504. """Update vertex buffers with current node and edge data using batch rendering"""
  505. if not self.graph:
  506. return
  507. # Update node buffers
  508. node_positions = []
  509. node_colors = []
  510. node_sizes = []
  511. for node in self.nodes:
  512. node_positions.append(node.position)
  513. node_colors.append(node.color) # Only use RGB components
  514. node_sizes.append(node.size)
  515. if node_positions:
  516. node_positions = np.array(node_positions, dtype=np.float32)
  517. node_colors = np.array(node_colors, dtype=np.float32)
  518. node_sizes = np.array(node_sizes, dtype=np.float32)
  519. self.node_vbo = self.glctx.buffer(node_positions.tobytes())
  520. self.node_color_vbo = self.glctx.buffer(node_colors.tobytes())
  521. self.node_size_vbo = self.glctx.buffer(node_sizes.tobytes())
  522. self.sphere_pos_vbo = self.glctx.buffer(self.sphere_data[0].tobytes())
  523. self.sphere_index_buffer = self.glctx.buffer(self.sphere_data[1].tobytes())
  524. self.node_vao = self.glctx.vertex_array(
  525. self.node_prog,
  526. [
  527. (self.sphere_pos_vbo, "3f", "in_position"),
  528. (self.node_vbo, "3f /i", "in_instance_position"),
  529. (self.node_color_vbo, "3f /i", "in_instance_color"),
  530. (self.node_size_vbo, "f /i", "in_instance_size"),
  531. ],
  532. index_buffer=self.sphere_index_buffer,
  533. index_element_size=4,
  534. )
  535. self.node_vao.instances = len(self.nodes)
  536. self.node_id_vao = self.glctx.vertex_array(
  537. self.node_id_prog,
  538. [
  539. (self.sphere_pos_vbo, "3f", "in_position"),
  540. (self.node_vbo, "3f /i", "in_instance_position"),
  541. (self.node_size_vbo, "f /i", "in_instance_size"),
  542. ],
  543. index_buffer=self.sphere_index_buffer,
  544. index_element_size=4,
  545. )
  546. self.node_id_vao.instances = len(self.nodes)
  547. # Update edge buffers
  548. edge_positions = []
  549. edge_colors = []
  550. for edge in self.graph.edges():
  551. start_node = self.id_node_map[edge[0]]
  552. end_node = self.id_node_map[edge[1]]
  553. edge_positions.append(start_node.position)
  554. edge_colors.append(start_node.color)
  555. edge_positions.append(end_node.position)
  556. edge_colors.append(end_node.color)
  557. if edge_positions:
  558. edge_positions = np.array(edge_positions, dtype=np.float32)
  559. edge_colors = np.array(edge_colors, dtype=np.float32)
  560. self.edge_vbo = self.glctx.buffer(edge_positions.tobytes())
  561. self.edge_color_vbo = self.glctx.buffer(edge_colors.tobytes())
  562. self.edge_vao = self.glctx.vertex_array(
  563. self.edge_prog,
  564. [
  565. (self.edge_vbo, "3f", "in_position"),
  566. (self.edge_color_vbo, "3f", "in_color"),
  567. ],
  568. )
  569. def update_view_proj_matrix(self):
  570. """Update view matrix based on camera parameters"""
  571. self.view_matrix = glm.lookAt(
  572. self.position, self.position + self.front, self.up
  573. )
  574. aspect_ratio = self.window_width / self.window_height
  575. self.proj_matrix = glm.perspective(
  576. glm.radians(60.0), # FOV
  577. aspect_ratio, # Aspect ratio
  578. 0.001, # Near plane
  579. 1000.0, # Far plane
  580. )
  581. def find_node_at(self, screen_pos: Tuple[int, int]) -> Optional[Node3D]:
  582. """Find the node at a specific screen position"""
  583. if (
  584. self.node_id_texture_np is None
  585. or self.node_id_texture_np.shape[1] != self.window_width
  586. or self.node_id_texture_np.shape[0] != self.window_height
  587. or screen_pos[0] < 0
  588. or screen_pos[1] < 0
  589. or screen_pos[0] >= self.window_width
  590. or screen_pos[1] >= self.window_height
  591. ):
  592. return None
  593. x = screen_pos[0]
  594. y = self.window_height - screen_pos[1] - 1
  595. pixel = self.node_id_texture_np[y, x]
  596. if pixel[3] == 0:
  597. return None
  598. R = int(round(pixel[0] * 255))
  599. G = int(round(pixel[1] * 255))
  600. B = int(round(pixel[2] * 255))
  601. index = (R << 16) | (G << 8) | B
  602. if index > len(self.nodes):
  603. return None
  604. return self.nodes[index]
  605. def is_node_visible_at(self, screen_pos: Tuple[int, int], node_idx: int) -> bool:
  606. """Check if a node exists at a specific screen position"""
  607. node = self.find_node_at(screen_pos)
  608. return node is not None and node.idx == node_idx
  609. def render_settings(self):
  610. """Render settings window"""
  611. if imgui.begin("Graph Settings"):
  612. # Layout type combo
  613. changed, value = imgui.combo(
  614. "Layout",
  615. self.available_layouts.index(self.layout_type),
  616. self.available_layouts,
  617. )
  618. if changed:
  619. self.layout_type = self.available_layouts[value]
  620. self.calculate_layout() # Recalculate layout when changed
  621. # Node size slider
  622. changed, value = imgui.slider_float("Node Scale", self.node_scale, 0.01, 10)
  623. if changed:
  624. self.node_scale = value
  625. # Edge width slider
  626. changed, value = imgui.slider_float("Edge Width", self.edge_width, 0, 20)
  627. if changed:
  628. self.edge_width = value
  629. # Show labels checkbox
  630. changed, value = imgui.checkbox("Show Labels", self.show_labels)
  631. if changed:
  632. self.show_labels = value
  633. if self.show_labels:
  634. # Label size slider
  635. changed, value = imgui.slider_float(
  636. "Label Size", self.label_size, 0.5, 10.0
  637. )
  638. if changed:
  639. self.label_size = value
  640. # Label color picker
  641. changed, value = imgui.color_edit4(
  642. "Label Color",
  643. self.label_color,
  644. imgui.ColorEditFlags_.picker_hue_wheel,
  645. )
  646. if changed:
  647. self.label_color = (value[0], value[1], value[2], value[3])
  648. # Label culling distance slider
  649. changed, value = imgui.slider_float(
  650. "Label Culling Distance", self.label_culling_distance, 0.1, 100.0
  651. )
  652. if changed:
  653. self.label_culling_distance = value
  654. # Background color picker
  655. changed, value = imgui.color_edit4(
  656. "Background Color",
  657. self.background_color,
  658. imgui.ColorEditFlags_.picker_hue_wheel,
  659. )
  660. if changed:
  661. self.background_color = (value[0], value[1], value[2], value[3])
  662. imgui.end()
  663. def save_node_id_texture_to_png(self, filename):
  664. # Convert to a PIL Image and save as PNG
  665. from PIL import Image
  666. scaled_array = self.node_id_texture_np * 255
  667. img = Image.fromarray(
  668. scaled_array.astype(np.uint8),
  669. "RGBA",
  670. )
  671. img = img.transpose(method=Image.FLIP_TOP_BOTTOM)
  672. img.save(filename)
  673. def render_id_map(self, mvp: glm.mat4):
  674. """Render an offscreen id map where each node is drawn with a unique id color."""
  675. # Lazy initialization of id framebuffer
  676. if self.node_id_texture is not None:
  677. if (
  678. self.node_id_texture.width != self.window_width
  679. or self.node_id_texture.height != self.window_height
  680. ):
  681. self.node_id_fbo = None
  682. self.node_id_texture = None
  683. self.node_id_texture_np = None
  684. self.node_id_depth = None
  685. if self.node_id_texture is None:
  686. self.node_id_texture = self.glctx.texture(
  687. (self.window_width, self.window_height), components=4, dtype="f4"
  688. )
  689. self.node_id_depth = self.glctx.depth_renderbuffer(
  690. size=(self.window_width, self.window_height)
  691. )
  692. self.node_id_fbo = self.glctx.framebuffer(
  693. color_attachments=[self.node_id_texture],
  694. depth_attachment=self.node_id_depth,
  695. )
  696. self.node_id_texture_np = np.zeros(
  697. (self.window_height, self.window_width, 4), dtype=np.float32
  698. )
  699. # Bind the offscreen framebuffer
  700. self.node_id_fbo.use()
  701. self.glctx.clear(0, 0, 0, 0)
  702. # Render nodes
  703. if self.node_id_vao:
  704. self.node_id_prog["mvp"].write(mvp.to_bytes())
  705. self.node_id_prog["scale"].write(np.float32(self.node_scale).tobytes())
  706. self.node_id_vao.render(moderngl.TRIANGLES)
  707. # Revert to default framebuffer
  708. self.glctx.screen.use()
  709. self.node_id_texture.read_into(self.node_id_texture_np.data)
  710. def render(self):
  711. """Render the graph"""
  712. # Clear screen
  713. self.glctx.clear(*self.background_color, depth=1)
  714. if not self.graph:
  715. return
  716. # Enable blending for transparency
  717. self.glctx.enable(moderngl.BLEND)
  718. self.glctx.blend_func = moderngl.SRC_ALPHA, moderngl.ONE_MINUS_SRC_ALPHA
  719. # Update view and projection matrices
  720. self.update_view_proj_matrix()
  721. mvp = self.proj_matrix * self.view_matrix
  722. # Render edges first (under nodes)
  723. if self.edge_vao:
  724. self.edge_prog["mvp"].write(mvp.to_bytes())
  725. self.edge_prog["edge_width"].value = (
  726. float(self.edge_width) * 2.0
  727. ) # Double the width for better visibility
  728. self.edge_prog["viewport_size"].value = (
  729. float(self.window_width),
  730. float(self.window_height),
  731. )
  732. self.edge_vao.render(moderngl.LINES)
  733. # Render nodes
  734. if self.node_vao:
  735. self.node_prog["mvp"].write(mvp.to_bytes())
  736. self.node_prog["camera"].write(self.position.to_bytes())
  737. self.node_prog["selected_node"].write(
  738. np.int32(self.selected_node.idx).tobytes()
  739. if self.selected_node
  740. else np.int32(-1).tobytes()
  741. )
  742. self.node_prog["highlighted_node"].write(
  743. np.int32(self.highlighted_node.idx).tobytes()
  744. if self.highlighted_node
  745. else np.int32(-1).tobytes()
  746. )
  747. self.node_prog["scale"].write(np.float32(self.node_scale).tobytes())
  748. self.node_vao.render(moderngl.TRIANGLES)
  749. self.glctx.disable(moderngl.BLEND)
  750. # Render id map
  751. self.render_id_map(mvp)
  752. def render_labels(self):
  753. # Render labels if enabled
  754. if self.show_labels and self.nodes:
  755. # Save current font scale
  756. original_scale = imgui.get_font_size()
  757. self.update_view_proj_matrix()
  758. mvp = self.proj_matrix * self.view_matrix
  759. for node in self.nodes:
  760. # Project node position to screen space
  761. pos = mvp * glm.vec4(
  762. node.position[0], node.position[1], node.position[2], 1.0
  763. )
  764. # Check if node is behind camera
  765. if pos.w > 0 and pos.w < self.label_culling_distance:
  766. screen_x = (pos.x / pos.w + 1) * self.window_width / 2
  767. screen_y = (-pos.y / pos.w + 1) * self.window_height / 2
  768. if self.is_node_visible_at(
  769. (int(screen_x), int(screen_y)), node.idx
  770. ):
  771. # Set font scale
  772. imgui.set_window_font_scale(float(self.label_size) * node.size)
  773. # Calculate label size
  774. label_size = imgui.calc_text_size(node.label)
  775. # Adjust position to center the label
  776. screen_x -= label_size.x / 2
  777. screen_y -= label_size.y / 2
  778. # Set text color with calculated alpha
  779. imgui.push_style_color(imgui.Col_.text, self.label_color)
  780. # Draw label using ImGui
  781. imgui.set_cursor_pos((screen_x, screen_y))
  782. imgui.text(node.label)
  783. # Restore text color
  784. imgui.pop_style_color()
  785. # Restore original font scale
  786. imgui.set_window_font_scale(original_scale)
  787. def reset_view(self):
  788. """Reset camera view to default"""
  789. self.position = glm.vec3(0.0, -10.0, 0.0)
  790. self.front = glm.vec3(0.0, 1.0, 0.0)
  791. self.yaw = 90.0
  792. self.pitch = 0.0
  793. def generate_colors(n: int) -> List[glm.vec3]:
  794. """Generate n distinct colors using HSV color space"""
  795. colors = []
  796. for i in range(n):
  797. # Use golden ratio to generate well-distributed hues
  798. hue = (i * 0.618033988749895) % 1.0
  799. # Fixed saturation and value for vibrant colors
  800. saturation = 0.8
  801. value = 0.95
  802. # Convert HSV to RGB
  803. rgb = colorsys.hsv_to_rgb(hue, saturation, value)
  804. # Add alpha channel
  805. colors.append(glm.vec3(rgb))
  806. return colors
  807. def show_file_dialog() -> Optional[str]:
  808. """Show a file dialog for selecting GraphML files"""
  809. file_path = filedialog.askopenfilename(
  810. title="Select GraphML File",
  811. filetypes=[("GraphML files", "*.graphml"), ("All files", "*.*")],
  812. )
  813. return file_path if file_path else None
  814. def create_sphere(sectors: int = 32, rings: int = 16) -> Tuple:
  815. """
  816. Creates a sphere.
  817. """
  818. R = 1.0 / (rings - 1)
  819. S = 1.0 / (sectors - 1)
  820. # Use those names as normals and uvs are part of the API
  821. vertices_l = [0.0] * (rings * sectors * 3)
  822. # normals_l = [0.0] * (rings * sectors * 3)
  823. uvs_l = [0.0] * (rings * sectors * 2)
  824. v, n, t = 0, 0, 0
  825. for r in range(rings):
  826. for s in range(sectors):
  827. y = np.sin(-np.pi / 2 + np.pi * r * R)
  828. x = np.cos(2 * np.pi * s * S) * np.sin(np.pi * r * R)
  829. z = np.sin(2 * np.pi * s * S) * np.sin(np.pi * r * R)
  830. uvs_l[t] = s * S
  831. uvs_l[t + 1] = r * R
  832. vertices_l[v] = x
  833. vertices_l[v + 1] = y
  834. vertices_l[v + 2] = z
  835. t += 2
  836. v += 3
  837. n += 3
  838. indices = [0] * rings * sectors * 6
  839. i = 0
  840. for r in range(rings - 1):
  841. for s in range(sectors - 1):
  842. indices[i] = r * sectors + s
  843. indices[i + 1] = (r + 1) * sectors + (s + 1)
  844. indices[i + 2] = r * sectors + (s + 1)
  845. indices[i + 3] = r * sectors + s
  846. indices[i + 4] = (r + 1) * sectors + s
  847. indices[i + 5] = (r + 1) * sectors + (s + 1)
  848. i += 6
  849. vbo_vertices = np.array(vertices_l, dtype=np.float32)
  850. vbo_elements = np.array(indices, dtype=np.uint32)
  851. return (vbo_vertices, vbo_elements)
  852. def draw_text_with_bg(
  853. text: str,
  854. text_pos: imgui.ImVec2Like,
  855. text_size: imgui.ImVec2Like,
  856. bg_color: int,
  857. ):
  858. imgui.get_window_draw_list().add_rect_filled(
  859. (text_pos[0] - 5, text_pos[1] - 5),
  860. (text_pos[0] + text_size[0] + 5, text_pos[1] + text_size[1] + 5),
  861. bg_color,
  862. 3.0,
  863. )
  864. imgui.set_cursor_pos(text_pos)
  865. imgui.text(text)
  866. def main():
  867. """Main application entry point"""
  868. viewer = GraphViewer()
  869. show_fps = True
  870. text_bg_color = imgui.IM_COL32(0, 0, 0, 100)
  871. def gui():
  872. if not viewer.initialized:
  873. viewer.setup()
  874. # # Change the theme
  875. # tweaked_theme = hello_imgui.get_runner_params().imgui_window_params.tweaked_theme
  876. # tweaked_theme.theme = hello_imgui.ImGuiTheme_.darcula_darker
  877. # hello_imgui.apply_tweaked_theme(tweaked_theme)
  878. viewer.window_width = int(imgui.get_window_width())
  879. viewer.window_height = int(imgui.get_window_height())
  880. # Handle keyboard and mouse input
  881. viewer.handle_keyboard_input()
  882. viewer.handle_mouse_interaction()
  883. style = imgui.get_style()
  884. window_bg_color = style.color_(imgui.Col_.window_bg.value)
  885. window_bg_color.w = 0.8
  886. style.set_color_(imgui.Col_.window_bg.value, window_bg_color)
  887. # Main control window
  888. imgui.begin("Graph Controls")
  889. if imgui.button("Load GraphML"):
  890. filepath = show_file_dialog()
  891. if filepath:
  892. viewer.load_file(filepath)
  893. # Show error message if loading failed
  894. if viewer.show_load_error:
  895. imgui.push_style_color(imgui.Col_.text, (1.0, 0.0, 0.0, 1.0))
  896. imgui.text(f"Error loading file: {viewer.error_message}")
  897. imgui.pop_style_color()
  898. imgui.separator()
  899. # Camera controls help
  900. imgui.text("Camera Controls:")
  901. imgui.bullet_text("Hold Right Mouse - Look around")
  902. imgui.bullet_text("W/S - Move forward/backward")
  903. imgui.bullet_text("A/D - Move left/right")
  904. imgui.bullet_text("Q/E - Move up/down")
  905. imgui.bullet_text("Left Mouse - Select node")
  906. imgui.bullet_text("Wheel - Change the movement speed")
  907. imgui.separator()
  908. # Camera settings
  909. _, viewer.move_speed = imgui.slider_float(
  910. "Movement Speed", viewer.move_speed, 0.01, 2.0
  911. )
  912. _, viewer.mouse_sensitivity = imgui.slider_float(
  913. "Mouse Sensitivity", viewer.mouse_sensitivity, 0.01, 0.5
  914. )
  915. imgui.separator()
  916. imgui.begin_horizontal("buttons")
  917. if imgui.button("Reset Camera"):
  918. viewer.reset_view()
  919. if imgui.button("Update Layout") and viewer.graph:
  920. viewer.update_layout()
  921. # if imgui.button("Save Node ID Texture"):
  922. # viewer.save_node_id_texture_to_png("node_id_texture.png")
  923. imgui.end_horizontal()
  924. imgui.end()
  925. # Render node details window if a node is selected
  926. viewer.render_node_details()
  927. # Render graph settings window
  928. viewer.render_settings()
  929. # Render FPS
  930. if show_fps:
  931. imgui.set_window_font_scale(1)
  932. fps_text = f"FPS: {hello_imgui.frame_rate():.1f}"
  933. text_size = imgui.calc_text_size(fps_text)
  934. cursor_pos = (10, viewer.window_height - text_size.y - 10)
  935. draw_text_with_bg(fps_text, cursor_pos, text_size, text_bg_color)
  936. # Render highlighted node ID
  937. if viewer.highlighted_node:
  938. imgui.set_window_font_scale(1)
  939. node_text = f"Node ID: {viewer.highlighted_node.label}"
  940. text_size = imgui.calc_text_size(node_text)
  941. cursor_pos = (
  942. viewer.window_width - text_size.x - 10,
  943. viewer.window_height - text_size.y - 10,
  944. )
  945. draw_text_with_bg(node_text, cursor_pos, text_size, text_bg_color)
  946. window_bg_color.w = 0
  947. style.set_color_(imgui.Col_.window_bg.value, window_bg_color)
  948. # Render labels
  949. viewer.render_labels()
  950. def custom_background():
  951. if viewer.initialized:
  952. viewer.render()
  953. runner_params = hello_imgui.RunnerParams()
  954. runner_params.app_window_params.window_geometry.size = (
  955. viewer.window_width,
  956. viewer.window_height,
  957. )
  958. runner_params.app_window_params.window_title = "3D GraphML Viewer"
  959. runner_params.callbacks.show_gui = gui
  960. runner_params.callbacks.custom_background = custom_background
  961. def load_font():
  962. # You will need to provide it yourself, or use another font.
  963. font_filename = CUSTOM_FONT
  964. io = imgui.get_io()
  965. io.fonts.tex_desired_width = 4096 # Larger texture for better CJK font quality
  966. font_size_pixels = 14
  967. asset_dir = os.path.join(os.path.dirname(__file__), "assets")
  968. # Try to load custom font
  969. if not os.path.isfile(font_filename):
  970. font_filename = os.path.join(asset_dir, font_filename)
  971. if os.path.isfile(font_filename):
  972. custom_font = io.fonts.add_font_from_file_ttf(
  973. filename=font_filename,
  974. size_pixels=font_size_pixels,
  975. glyph_ranges_as_int_list=io.fonts.get_glyph_ranges_chinese_full(),
  976. )
  977. io.font_default = custom_font
  978. return
  979. # Load default fonts
  980. io.fonts.add_font_from_file_ttf(
  981. filename=os.path.join(asset_dir, DEFAULT_FONT_ENG),
  982. size_pixels=font_size_pixels,
  983. )
  984. font_config = imgui.ImFontConfig()
  985. font_config.merge_mode = True
  986. io.font_default = io.fonts.add_font_from_file_ttf(
  987. filename=os.path.join(asset_dir, DEFAULT_FONT_CHI),
  988. size_pixels=font_size_pixels,
  989. font_cfg=font_config,
  990. glyph_ranges_as_int_list=io.fonts.get_glyph_ranges_chinese_full(),
  991. )
  992. runner_params.callbacks.load_additional_fonts = load_font
  993. tk_root = tk.Tk()
  994. tk_root.withdraw() # Hide the main window
  995. immapp.run(runner_params)
  996. tk_root.destroy() # Destroy the main window
  997. if __name__ == "__main__":
  998. main()