sketches.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. from typing import Any, Dict, List, Literal, Optional
  2. from uuid import UUID
  3. from fastapi import (
  4. APIRouter,
  5. BackgroundTasks,
  6. Depends,
  7. File,
  8. Form,
  9. HTTPException,
  10. UploadFile,
  11. status,
  12. )
  13. from flowsint_core.core.graph import GraphNode
  14. from flowsint_core.core.models import Profile
  15. from flowsint_core.core.postgre_db import get_db
  16. from flowsint_core.core.services import (
  17. create_sketch_service,
  18. NotFoundError,
  19. PermissionDeniedError,
  20. ValidationError,
  21. DatabaseError,
  22. )
  23. from flowsint_core.core.services.type_registry_service import create_type_registry_service
  24. from flowsint_core.imports import (
  25. EntityMapping,
  26. ImportService,
  27. create_import_service,
  28. FileParseResult,
  29. )
  30. from flowsint_core.core.graph import create_graph_service
  31. from pydantic import BaseModel, Field
  32. from sqlalchemy.orm import Session
  33. from app.api.deps import get_current_user
  34. from app.api.schemas.sketch import SketchCreate, SketchRead, SketchUpdate
  35. from app.api.sketch_utils import update_sketch_timestamp
  36. router = APIRouter()
  37. class NodeData(BaseModel):
  38. label: str = Field(default="Node", description="Label/name of the node")
  39. type: str = Field(default="Node", description="Type of the node")
  40. class Config:
  41. extra = "allow"
  42. class NodeDeleteInput(BaseModel):
  43. nodeIds: List[str]
  44. class RelationshipDeleteInput(BaseModel):
  45. relationshipIds: List[str]
  46. class NodeEditInput(BaseModel):
  47. nodeId: str
  48. updates: Dict[str, Any]
  49. class RelationshipEditInput(BaseModel):
  50. relationshipId: str
  51. data: Dict[str, Any] = Field(
  52. default_factory=dict, description="Updated data for the relationship"
  53. )
  54. class NodeMergeInput(BaseModel):
  55. id: str
  56. data: NodeData = Field(
  57. default_factory=NodeData, description="Updated data for the node"
  58. )
  59. class RelationInput(BaseModel):
  60. source: str
  61. target: str
  62. type: Literal["one-way", "two-way"]
  63. label: str = "RELATED_TO"
  64. class NodePosition(BaseModel):
  65. nodeId: str
  66. x: float
  67. y: float
  68. class UpdatePositionsInput(BaseModel):
  69. positions: List[NodePosition]
  70. class EntityMappingInput(BaseModel):
  71. """Pydantic model for parsing entity mapping input from frontend."""
  72. id: str
  73. entity_type: str
  74. include: bool = True
  75. nodeLabel: str
  76. node_id: Optional[str] = None
  77. data: Dict[str, Any]
  78. class ImportExecuteResponse(BaseModel):
  79. """Response model for import execution."""
  80. status: str
  81. nodes_created: int
  82. nodes_skipped: int
  83. errors: List[str]
  84. @router.post("/create", response_model=SketchRead, status_code=status.HTTP_201_CREATED)
  85. def create_sketch(
  86. data: SketchCreate,
  87. db: Session = Depends(get_db),
  88. current_user: Profile = Depends(get_current_user),
  89. ):
  90. service = create_sketch_service(db)
  91. try:
  92. sketch_data = data.model_dump()
  93. return service.create(
  94. title=sketch_data.get("title"),
  95. description=sketch_data.get("description"),
  96. investigation_id=sketch_data.get("investigation_id"),
  97. owner_id=current_user.id,
  98. )
  99. except ValidationError as e:
  100. raise HTTPException(status_code=404, detail=str(e))
  101. except PermissionDeniedError:
  102. raise HTTPException(status_code=403, detail="Forbidden")
  103. @router.get("", response_model=List[SketchRead])
  104. def list_sketches(
  105. db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)
  106. ):
  107. service = create_sketch_service(db)
  108. return service.list_sketches(current_user.id)
  109. @router.get("/{sketch_id}")
  110. def get_sketch_by_id(
  111. sketch_id: UUID,
  112. db: Session = Depends(get_db),
  113. current_user: Profile = Depends(get_current_user),
  114. ):
  115. service = create_sketch_service(db)
  116. try:
  117. return service.get_by_id(sketch_id, current_user.id)
  118. except NotFoundError:
  119. raise HTTPException(status_code=404, detail="Sketch not found")
  120. except PermissionDeniedError:
  121. raise HTTPException(status_code=403, detail="Forbidden")
  122. @router.put("/{id}", response_model=SketchRead)
  123. def update_sketch(
  124. id: UUID,
  125. payload: SketchUpdate,
  126. db: Session = Depends(get_db),
  127. current_user: Profile = Depends(get_current_user),
  128. ):
  129. service = create_sketch_service(db)
  130. try:
  131. return service.update(id, current_user.id, payload.model_dump(exclude_unset=True))
  132. except NotFoundError:
  133. raise HTTPException(status_code=404, detail="Sketch not found")
  134. except PermissionDeniedError:
  135. raise HTTPException(status_code=403, detail="Forbidden")
  136. @router.delete("/{id}", status_code=204)
  137. def delete_sketch(
  138. id: UUID,
  139. db: Session = Depends(get_db),
  140. current_user: Profile = Depends(get_current_user),
  141. ):
  142. service = create_sketch_service(db)
  143. try:
  144. service.delete(id, current_user.id)
  145. except NotFoundError:
  146. raise HTTPException(status_code=404, detail="Sketch not found")
  147. except PermissionDeniedError:
  148. raise HTTPException(status_code=403, detail="Forbidden")
  149. except DatabaseError:
  150. raise HTTPException(status_code=500, detail="Failed to clean up graph data")
  151. @router.get("/{sketch_id}/graph")
  152. async def get_sketch_nodes(
  153. sketch_id: str,
  154. format: str | None = None,
  155. db: Session = Depends(get_db),
  156. current_user: Profile = Depends(get_current_user),
  157. ):
  158. """Get the nodes and edges for a sketch."""
  159. service = create_sketch_service(db)
  160. try:
  161. return service.get_graph(UUID(sketch_id), current_user.id, format)
  162. except NotFoundError:
  163. raise HTTPException(status_code=404, detail="Graph not found")
  164. except PermissionDeniedError:
  165. raise HTTPException(status_code=403, detail="Forbidden")
  166. @router.post("/{sketch_id}/nodes/add")
  167. @update_sketch_timestamp
  168. def add_node(
  169. sketch_id: str,
  170. node: GraphNode,
  171. background_tasks: BackgroundTasks,
  172. db: Session = Depends(get_db),
  173. current_user: Profile = Depends(get_current_user),
  174. ):
  175. service = create_sketch_service(db)
  176. try:
  177. return service.add_node(UUID(sketch_id), current_user.id, node)
  178. except NotFoundError:
  179. raise HTTPException(status_code=404, detail="Sketch not found")
  180. except PermissionDeniedError:
  181. raise HTTPException(status_code=403, detail="Forbidden")
  182. except ValidationError:
  183. raise HTTPException(status_code=400, detail="Node creation failed")
  184. except DatabaseError as e:
  185. raise HTTPException(status_code=500, detail=str(e))
  186. @router.post("/{sketch_id}/relations/add")
  187. @update_sketch_timestamp
  188. def add_edge(
  189. sketch_id: str,
  190. relation: RelationInput,
  191. background_tasks: BackgroundTasks,
  192. db: Session = Depends(get_db),
  193. current_user: Profile = Depends(get_current_user),
  194. ):
  195. service = create_sketch_service(db)
  196. try:
  197. return service.add_relationship(
  198. UUID(sketch_id), current_user.id, relation.source, relation.target, relation.label
  199. )
  200. except NotFoundError:
  201. raise HTTPException(status_code=404, detail="Sketch not found")
  202. except PermissionDeniedError:
  203. raise HTTPException(status_code=403, detail="Forbidden")
  204. except ValidationError:
  205. raise HTTPException(status_code=400, detail="Edge creation failed")
  206. except DatabaseError:
  207. raise HTTPException(status_code=500, detail="Failed to create edge")
  208. @router.put("/{sketch_id}/nodes/edit")
  209. @update_sketch_timestamp
  210. def edit_node(
  211. sketch_id: str,
  212. node_edit: NodeEditInput,
  213. background_tasks: BackgroundTasks,
  214. db: Session = Depends(get_db),
  215. current_user: Profile = Depends(get_current_user),
  216. ):
  217. service = create_sketch_service(db)
  218. try:
  219. return service.update_node(
  220. UUID(sketch_id), current_user.id, node_edit.nodeId, node_edit.updates
  221. )
  222. except NotFoundError as e:
  223. raise HTTPException(status_code=404, detail=str(e))
  224. except PermissionDeniedError:
  225. raise HTTPException(status_code=403, detail="Forbidden")
  226. except DatabaseError:
  227. raise HTTPException(status_code=500, detail="Failed to update node")
  228. @router.put("/{sketch_id}/nodes/positions")
  229. @update_sketch_timestamp
  230. def update_node_positions(
  231. sketch_id: str,
  232. data: UpdatePositionsInput,
  233. background_tasks: BackgroundTasks,
  234. db: Session = Depends(get_db),
  235. current_user: Profile = Depends(get_current_user),
  236. ):
  237. """Update positions (x, y) for multiple nodes in batch."""
  238. service = create_sketch_service(db)
  239. try:
  240. positions = [pos.model_dump() for pos in data.positions]
  241. return service.update_node_positions(UUID(sketch_id), current_user.id, positions)
  242. except NotFoundError:
  243. raise HTTPException(status_code=404, detail="Sketch not found")
  244. except PermissionDeniedError:
  245. raise HTTPException(status_code=403, detail="Forbidden")
  246. except DatabaseError:
  247. raise HTTPException(status_code=500, detail="Failed to update node positions")
  248. @router.delete("/{sketch_id}/nodes")
  249. @update_sketch_timestamp
  250. def delete_nodes(
  251. sketch_id: str,
  252. nodes: NodeDeleteInput,
  253. background_tasks: BackgroundTasks,
  254. db: Session = Depends(get_db),
  255. current_user: Profile = Depends(get_current_user),
  256. ):
  257. service = create_sketch_service(db)
  258. try:
  259. return service.delete_nodes(UUID(sketch_id), current_user.id, nodes.nodeIds)
  260. except NotFoundError:
  261. raise HTTPException(status_code=404, detail="Sketch not found")
  262. except PermissionDeniedError:
  263. raise HTTPException(status_code=403, detail="Forbidden")
  264. except DatabaseError:
  265. raise HTTPException(status_code=500, detail="Failed to delete nodes")
  266. @router.delete("/{sketch_id}/relationships")
  267. @update_sketch_timestamp
  268. def delete_relationships(
  269. sketch_id: str,
  270. relationships: RelationshipDeleteInput,
  271. background_tasks: BackgroundTasks,
  272. db: Session = Depends(get_db),
  273. current_user: Profile = Depends(get_current_user),
  274. ):
  275. service = create_sketch_service(db)
  276. try:
  277. return service.delete_relationships(
  278. UUID(sketch_id), current_user.id, relationships.relationshipIds
  279. )
  280. except NotFoundError:
  281. raise HTTPException(status_code=404, detail="Sketch not found")
  282. except PermissionDeniedError:
  283. raise HTTPException(status_code=403, detail="Forbidden")
  284. except DatabaseError:
  285. raise HTTPException(status_code=500, detail="Failed to delete relationships")
  286. @router.put("/{sketch_id}/relationships/edit")
  287. @update_sketch_timestamp
  288. def edit_relationship(
  289. sketch_id: str,
  290. relationship_edit: RelationshipEditInput,
  291. background_tasks: BackgroundTasks,
  292. db: Session = Depends(get_db),
  293. current_user: Profile = Depends(get_current_user),
  294. ):
  295. service = create_sketch_service(db)
  296. try:
  297. return service.update_relationship(
  298. UUID(sketch_id),
  299. current_user.id,
  300. relationship_edit.relationshipId,
  301. relationship_edit.data,
  302. )
  303. except NotFoundError as e:
  304. raise HTTPException(status_code=404, detail=str(e))
  305. except PermissionDeniedError:
  306. raise HTTPException(status_code=403, detail="Forbidden")
  307. except DatabaseError:
  308. raise HTTPException(status_code=500, detail="Failed to update relationship")
  309. @router.post("/{sketch_id}/nodes/merge")
  310. @update_sketch_timestamp
  311. def merge_nodes(
  312. sketch_id: str,
  313. oldNodes: List[str],
  314. newNode: NodeMergeInput,
  315. background_tasks: BackgroundTasks,
  316. db: Session = Depends(get_db),
  317. current_user: Profile = Depends(get_current_user),
  318. ):
  319. service = create_sketch_service(db)
  320. try:
  321. node_data = newNode.data.model_dump() if newNode.data else {}
  322. return service.merge_nodes(
  323. UUID(sketch_id), current_user.id, oldNodes, newNode.id, node_data
  324. )
  325. except NotFoundError:
  326. raise HTTPException(status_code=404, detail="Sketch not found")
  327. except PermissionDeniedError:
  328. raise HTTPException(status_code=403, detail="Forbidden")
  329. except ValidationError as e:
  330. raise HTTPException(status_code=400, detail=str(e))
  331. except DatabaseError as e:
  332. raise HTTPException(status_code=500, detail=str(e))
  333. @router.get("/{sketch_id}/nodes/{node_id}")
  334. def get_related_nodes(
  335. sketch_id: str,
  336. node_id: str,
  337. db: Session = Depends(get_db),
  338. current_user: Profile = Depends(get_current_user),
  339. ):
  340. service = create_sketch_service(db)
  341. try:
  342. return service.get_neighbors(UUID(sketch_id), current_user.id, node_id)
  343. except NotFoundError as e:
  344. raise HTTPException(status_code=404, detail=str(e))
  345. except PermissionDeniedError:
  346. raise HTTPException(status_code=403, detail="Forbidden")
  347. except DatabaseError:
  348. raise HTTPException(status_code=500, detail="Failed to retrieve related nodes")
  349. @router.post("/{sketch_id}/import/analyze", response_model=FileParseResult)
  350. async def analyze_import_file(
  351. sketch_id: str,
  352. file: UploadFile = File(...),
  353. db: Session = Depends(get_db),
  354. current_user: Profile = Depends(get_current_user),
  355. ):
  356. """Analyze an uploaded TXT or JSON file for import."""
  357. service = create_sketch_service(db)
  358. try:
  359. service.get_by_id(UUID(sketch_id), current_user.id)
  360. except NotFoundError:
  361. raise HTTPException(status_code=404, detail="Sketch not found")
  362. except PermissionDeniedError:
  363. raise HTTPException(status_code=403, detail="Forbidden")
  364. if not file.filename or not file.filename.lower().endswith((".txt", ".json")):
  365. raise HTTPException(
  366. status_code=400,
  367. detail="Only .txt and .json files are supported. Please upload a correct format.",
  368. )
  369. try:
  370. content = await file.read()
  371. except Exception as e:
  372. raise HTTPException(status_code=400, detail=f"Failed to read file: {str(e)}")
  373. try:
  374. type_registry = create_type_registry_service(db)
  375. resolver = type_registry.build_type_resolver(current_user.id)
  376. graph_service = create_graph_service(sketch_id=sketch_id, enable_batching=False, type_resolver=resolver)
  377. import_service = create_import_service(graph_service)
  378. result = import_service.analyze_file(
  379. file_content=content,
  380. filename=file.filename or "unknown.txt",
  381. )
  382. except ValueError as e:
  383. raise HTTPException(status_code=400, detail=str(e))
  384. except Exception as e:
  385. raise HTTPException(status_code=500, detail=f"Failed to parse file: {str(e)}")
  386. return result
  387. @router.post("/{sketch_id}/import/execute", response_model=ImportExecuteResponse)
  388. @update_sketch_timestamp
  389. async def execute_import(
  390. sketch_id: str,
  391. entity_mappings_json: str = Form(...),
  392. background_tasks: BackgroundTasks = BackgroundTasks(),
  393. db: Session = Depends(get_db),
  394. current_user: Profile = Depends(get_current_user),
  395. ):
  396. """Execute the import of entities into the sketch."""
  397. import json
  398. service = create_sketch_service(db)
  399. try:
  400. service.get_by_id(UUID(sketch_id), current_user.id)
  401. except NotFoundError:
  402. raise HTTPException(status_code=404, detail="Sketch not found")
  403. except PermissionDeniedError:
  404. raise HTTPException(status_code=403, detail="Forbidden")
  405. try:
  406. mappings = json.loads(entity_mappings_json)
  407. nodes = mappings.get("nodes", [])
  408. edges = mappings.get("edges", [])
  409. entity_mapping_inputs = [EntityMappingInput(**m) for m in nodes]
  410. except json.JSONDecodeError:
  411. raise HTTPException(status_code=400, detail="Invalid entity_mappings JSON")
  412. except Exception as e:
  413. raise HTTPException(
  414. status_code=400, detail=f"Failed to parse entity_mappings: {str(e)}"
  415. )
  416. entity_mappings = [
  417. EntityMapping(
  418. id=m.id,
  419. entity_type=m.entity_type,
  420. nodeLabel=m.nodeLabel,
  421. data=m.data,
  422. include=m.include,
  423. node_id=m.node_id,
  424. )
  425. for m in entity_mapping_inputs
  426. ]
  427. type_registry = create_type_registry_service(db)
  428. resolver = type_registry.build_type_resolver(current_user.id)
  429. graph_service = create_graph_service(sketch_id=sketch_id, enable_batching=False, type_resolver=resolver)
  430. import_service = create_import_service(graph_service)
  431. try:
  432. result = import_service.execute_import(
  433. entity_mappings=entity_mappings,
  434. edges=edges,
  435. )
  436. except Exception as e:
  437. raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}")
  438. return ImportExecuteResponse(
  439. status=result.status,
  440. nodes_created=result.nodes_created,
  441. nodes_skipped=result.nodes_skipped,
  442. errors=result.errors,
  443. )
  444. @router.get("/{id}/export")
  445. async def export_sketch(
  446. id: str,
  447. format: str = "json",
  448. db: Session = Depends(get_db),
  449. current_user: Profile = Depends(get_current_user),
  450. ):
  451. """Export the sketch in the specified format."""
  452. service = create_sketch_service(db)
  453. try:
  454. return service.export_sketch(UUID(id), current_user.id, format)
  455. except NotFoundError:
  456. raise HTTPException(status_code=404, detail="Sketch not found")
  457. except PermissionDeniedError:
  458. raise HTTPException(status_code=403, detail="Forbidden")
  459. except ValidationError as e:
  460. raise HTTPException(status_code=400, detail=str(e))