flows.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. from typing import Any, Dict, List, Optional
  2. from uuid import UUID
  3. from fastapi import APIRouter, Depends, HTTPException, Query, status
  4. from flowsint_core.core.celery import celery
  5. from flowsint_core.core.graph import create_graph_service
  6. from flowsint_core.core.models import Profile
  7. from flowsint_core.core.postgre_db import get_db
  8. from flowsint_core.core.services import (
  9. NotFoundError,
  10. PermissionDeniedError,
  11. create_flow_service,
  12. )
  13. from flowsint_core.core.services.type_registry_service import (
  14. create_type_registry_service,
  15. )
  16. from flowsint_core.core.types import FlowBranch, FlowEdge, FlowNode, FlowStep
  17. from flowsint_core.utils import extract_input_schema_flow
  18. from flowsint_enrichers import ENRICHER_REGISTRY, load_all_enrichers
  19. from flowsint_types import (
  20. ASN,
  21. CIDR,
  22. CryptoNFT,
  23. CryptoWallet,
  24. CryptoWalletTransaction,
  25. DNSRecord,
  26. Domain,
  27. Email,
  28. Individual,
  29. Ip,
  30. Organization,
  31. Phone,
  32. Phrase,
  33. Port,
  34. SocialAccount,
  35. Username,
  36. Website,
  37. )
  38. from pydantic import BaseModel
  39. from sqlalchemy.orm import Session
  40. from app.api.deps import get_current_user
  41. from app.api.schemas.flow import FlowCreate, FlowRead, FlowUpdate
  42. load_all_enrichers()
  43. class FlowComputationRequest(BaseModel):
  44. nodes: List[FlowNode]
  45. edges: List[FlowEdge]
  46. inputType: Optional[str] = None
  47. class FlowComputationResponse(BaseModel):
  48. flowBranches: List[FlowBranch]
  49. initialData: Any
  50. class StepSimulationRequest(BaseModel):
  51. flowBranches: List[FlowBranch]
  52. currentStepIndex: int
  53. class launchFlowPayload(BaseModel):
  54. node_ids: List[str]
  55. sketch_id: str
  56. router = APIRouter()
  57. @router.get("", response_model=List[FlowRead])
  58. def get_flows(
  59. category: Optional[str] = Query(None),
  60. db: Session = Depends(get_db),
  61. current_user: Profile = Depends(get_current_user),
  62. ):
  63. service = create_flow_service(db)
  64. return service.get_all_flows(category, current_user.id)
  65. @router.get("/raw_materials")
  66. async def get_material_list():
  67. enrichers = ENRICHER_REGISTRY.list_by_categories()
  68. enricher_categories = {
  69. category: [
  70. {
  71. "class_name": enricher.get("class_name"),
  72. "category": enricher.get("category"),
  73. "name": enricher.get("name"),
  74. "module": enricher.get("module"),
  75. "documentation": enricher.get("documentation"),
  76. "description": enricher.get("description"),
  77. "inputs": enricher.get("inputs"),
  78. "outputs": enricher.get("outputs"),
  79. "type": "enricher",
  80. "params": enricher.get("params"),
  81. "params_schema": enricher.get("params_schema"),
  82. "required_params": enricher.get("required_params"),
  83. "icon": enricher.get("icon"),
  84. }
  85. for enricher in enricher_list
  86. ]
  87. for category, enricher_list in enrichers.items()
  88. }
  89. object_inputs = [
  90. extract_input_schema_flow(Phrase),
  91. extract_input_schema_flow(Organization),
  92. extract_input_schema_flow(Individual),
  93. extract_input_schema_flow(Domain),
  94. extract_input_schema_flow(Website),
  95. extract_input_schema_flow(Ip),
  96. extract_input_schema_flow(DNSRecord),
  97. extract_input_schema_flow(Port),
  98. extract_input_schema_flow(Phone),
  99. extract_input_schema_flow(ASN),
  100. extract_input_schema_flow(CIDR),
  101. extract_input_schema_flow(Username),
  102. extract_input_schema_flow(SocialAccount),
  103. extract_input_schema_flow(Email),
  104. extract_input_schema_flow(CryptoWallet),
  105. extract_input_schema_flow(CryptoWalletTransaction),
  106. extract_input_schema_flow(CryptoNFT),
  107. ]
  108. flattened_enrichers = {"types": object_inputs}
  109. flattened_enrichers.update(enricher_categories)
  110. return {"items": flattened_enrichers}
  111. @router.get("/input_type/{input_type}")
  112. async def get_material_by_input_type(input_type: str):
  113. enrichers = ENRICHER_REGISTRY.list_by_input_type(input_type)
  114. return {"items": enrichers}
  115. @router.post("/create", response_model=FlowRead, status_code=status.HTTP_201_CREATED)
  116. def create_flow(
  117. payload: FlowCreate,
  118. db: Session = Depends(get_db),
  119. current_user: Profile = Depends(get_current_user),
  120. ):
  121. service = create_flow_service(db)
  122. return service.create(
  123. name=payload.name,
  124. description=payload.description,
  125. category=payload.category,
  126. flow_schema=payload.flow_schema,
  127. )
  128. @router.get("/{flow_id}", response_model=FlowRead)
  129. def get_flow_by_id(
  130. flow_id: UUID,
  131. db: Session = Depends(get_db),
  132. current_user: Profile = Depends(get_current_user),
  133. ):
  134. service = create_flow_service(db)
  135. try:
  136. return service.get_by_id(flow_id)
  137. except NotFoundError:
  138. raise HTTPException(status_code=404, detail="Flow not found")
  139. @router.put("/{flow_id}", response_model=FlowRead)
  140. def update_flow(
  141. flow_id: UUID,
  142. payload: FlowUpdate,
  143. db: Session = Depends(get_db),
  144. current_user: Profile = Depends(get_current_user),
  145. ):
  146. service = create_flow_service(db)
  147. try:
  148. return service.update(flow_id, payload.model_dump(exclude_unset=True))
  149. except NotFoundError:
  150. raise HTTPException(status_code=404, detail="Flow not found")
  151. @router.delete("/{flow_id}", status_code=status.HTTP_204_NO_CONTENT)
  152. def delete_flow(
  153. flow_id: UUID,
  154. db: Session = Depends(get_db),
  155. current_user: Profile = Depends(get_current_user),
  156. ):
  157. service = create_flow_service(db)
  158. try:
  159. service.delete(flow_id)
  160. return None
  161. except NotFoundError:
  162. raise HTTPException(status_code=404, detail="Flow not found")
  163. @router.post("/{flow_id}/launch")
  164. async def launch_flow(
  165. flow_id: str,
  166. payload: launchFlowPayload,
  167. db: Session = Depends(get_db),
  168. current_user: Profile = Depends(get_current_user),
  169. ):
  170. service = create_flow_service(db)
  171. try:
  172. flow = service.get_by_id(UUID(flow_id))
  173. service.get_sketch_for_launch(payload.sketch_id, current_user.id)
  174. # Retrieve entities from Neo4J by their element IDs
  175. type_registry = create_type_registry_service(db)
  176. resolver = type_registry.build_type_resolver(current_user.id)
  177. graph_service = create_graph_service(
  178. sketch_id=payload.sketch_id, type_resolver=resolver
  179. )
  180. entities = graph_service.get_nodes_by_ids_for_task(payload.node_ids)
  181. # Compute flow branches
  182. nodes = [FlowNode(**node) for node in flow.flow_schema["nodes"]]
  183. edges = [FlowEdge(**edge) for edge in flow.flow_schema["edges"]]
  184. entities = [
  185. entity.model_dump(mode="json", serialize_as_any=True) for entity in entities
  186. ]
  187. sample_value = (
  188. entities[0].get("nodeLabel", "sample_value")
  189. if len(entities)
  190. else "sample_value"
  191. )
  192. flow_branches = compute_flow_branches(sample_value, nodes, edges)
  193. serializable_branches = [branch.model_dump() for branch in flow_branches]
  194. task = celery.send_task(
  195. "run_flow",
  196. args=[
  197. serializable_branches,
  198. entities,
  199. payload.sketch_id,
  200. str(current_user.id),
  201. ],
  202. )
  203. return {"id": task.id}
  204. except NotFoundError as e:
  205. raise HTTPException(status_code=404, detail=str(e))
  206. except PermissionDeniedError:
  207. raise HTTPException(status_code=403, detail="Forbidden")
  208. except Exception as e:
  209. print(e)
  210. raise HTTPException(status_code=500, detail=f"Error launching flow: {str(e)}")
  211. @router.post("/{flow_id}/compute", response_model=FlowComputationResponse)
  212. def compute_flows(
  213. request: FlowComputationRequest, current_user: Profile = Depends(get_current_user)
  214. ):
  215. initial_data = generate_sample_data(request.inputType or "string")
  216. flow_branches = compute_flow_branches(initial_data, request.nodes, request.edges)
  217. return FlowComputationResponse(flowBranches=flow_branches, initialData=initial_data)
  218. def generate_sample_data(type_str: str) -> Any:
  219. type_str = type_str.lower() if type_str else "string"
  220. if type_str == "string":
  221. return "sample_text"
  222. elif type_str == "number":
  223. return 42
  224. elif type_str == "boolean":
  225. return True
  226. elif type_str == "array":
  227. return [1, 2, 3]
  228. elif type_str == "object":
  229. return {"key": "value"}
  230. elif type_str == "url":
  231. return "https://example.com"
  232. elif type_str == "email":
  233. return "user@example.com"
  234. elif type_str == "domain":
  235. return "example.com"
  236. elif type_str == "ip":
  237. return "192.168.1.1"
  238. else:
  239. return f"sample_{type_str}"
  240. def compute_flow_branches(
  241. initial_value: Any, nodes: List[FlowNode], edges: List[FlowEdge]
  242. ) -> List[FlowBranch]:
  243. """Computes flow branches based on nodes and edges with proper DFS traversal"""
  244. input_nodes = [node for node in nodes if node.data.get("type") == "type"]
  245. if not input_nodes:
  246. return [
  247. FlowBranch(
  248. id="error",
  249. name="Error",
  250. steps=[
  251. FlowStep(
  252. nodeId="error",
  253. inputs={},
  254. params={},
  255. type="error",
  256. outputs={},
  257. status="error",
  258. branchId="error",
  259. depth=0,
  260. )
  261. ],
  262. )
  263. ]
  264. node_map = {node.id: node for node in nodes}
  265. branches = []
  266. branch_counter = 0
  267. enricher_outputs = {}
  268. def calculate_path_length(start_node: str, visited: set = None) -> int:
  269. if visited is None:
  270. visited = set()
  271. if start_node in visited:
  272. return float("inf")
  273. visited.add(start_node)
  274. out_edges = [edge for edge in edges if edge.source == start_node]
  275. if not out_edges:
  276. return 1
  277. min_length = float("inf")
  278. for edge in out_edges:
  279. length = calculate_path_length(edge.target, visited.copy())
  280. min_length = min(min_length, length)
  281. return 1 + min_length
  282. def get_outgoing_edges(node_id: str) -> List[FlowEdge]:
  283. out_edges = [edge for edge in edges if edge.source == node_id]
  284. return sorted(out_edges, key=lambda e: calculate_path_length(e.target))
  285. def create_step(
  286. node_id: str,
  287. branch_id: str,
  288. depth: int,
  289. input_data: Dict[str, Any],
  290. is_input_node: bool,
  291. outputs: Dict[str, Any],
  292. node_params: Optional[Dict[str, Any]] = None,
  293. ) -> FlowStep:
  294. return FlowStep(
  295. nodeId=node_id,
  296. params=node_params,
  297. inputs={} if is_input_node else input_data,
  298. outputs=outputs,
  299. type="type" if is_input_node else "enricher",
  300. status="pending",
  301. branchId=branch_id,
  302. depth=depth,
  303. )
  304. def explore_branch(
  305. current_node_id: str,
  306. branch_id: str,
  307. branch_name: str,
  308. depth: int,
  309. input_data: Dict[str, Any],
  310. path: List[str],
  311. branch_visited: set,
  312. steps: List[FlowStep],
  313. parent_outputs: Dict[str, Any] = None,
  314. ) -> None:
  315. nonlocal branch_counter
  316. if current_node_id in path:
  317. return
  318. current_node = node_map.get(current_node_id)
  319. if not current_node:
  320. return
  321. is_input_node = current_node.data.get("type") == "type"
  322. if is_input_node:
  323. outputs_array = current_node.data["outputs"].get("properties", [])
  324. first_output_name = (
  325. outputs_array[0].get("name", "output") if outputs_array else "output"
  326. )
  327. current_outputs = {first_output_name: initial_value}
  328. else:
  329. if current_node_id in enricher_outputs:
  330. current_outputs = enricher_outputs[current_node_id]
  331. else:
  332. current_outputs = process_node_data(current_node, input_data)
  333. enricher_outputs[current_node_id] = current_outputs
  334. node_params = current_node.data.get("params", {})
  335. current_step = create_step(
  336. current_node_id,
  337. branch_id,
  338. depth,
  339. input_data,
  340. is_input_node,
  341. current_outputs,
  342. node_params,
  343. )
  344. steps.append(current_step)
  345. path.append(current_node_id)
  346. branch_visited.add(current_node_id)
  347. out_edges = get_outgoing_edges(current_node_id)
  348. if not out_edges:
  349. branches.append(FlowBranch(id=branch_id, name=branch_name, steps=steps[:]))
  350. else:
  351. for i, edge in enumerate(out_edges):
  352. if edge.target in path:
  353. continue
  354. output_key = edge.sourceHandle
  355. if not output_key and current_outputs:
  356. output_key = list(current_outputs.keys())[0]
  357. output_value = current_outputs.get(output_key) if output_key else None
  358. if output_value is None and parent_outputs:
  359. output_value = (
  360. parent_outputs.get(output_key) if output_key else None
  361. )
  362. next_input = {edge.targetHandle or "input": output_value}
  363. if i == 0:
  364. explore_branch(
  365. edge.target,
  366. branch_id,
  367. branch_name,
  368. depth + 1,
  369. next_input,
  370. path,
  371. branch_visited,
  372. steps,
  373. current_outputs,
  374. )
  375. else:
  376. branch_counter += 1
  377. new_branch_id = f"{branch_id}-{branch_counter}"
  378. new_branch_name = f"{branch_name} (Branch {branch_counter})"
  379. new_steps = steps[: len(steps)]
  380. new_branch_visited = branch_visited.copy()
  381. explore_branch(
  382. edge.target,
  383. new_branch_id,
  384. new_branch_name,
  385. depth + 1,
  386. next_input,
  387. path[:],
  388. new_branch_visited,
  389. new_steps,
  390. current_outputs,
  391. )
  392. path.pop()
  393. steps.pop()
  394. for index, input_node in enumerate(input_nodes):
  395. branch_id = f"branch-{index}"
  396. branch_name = f"Flow {index + 1}" if len(input_nodes) > 1 else "Main Flow"
  397. explore_branch(
  398. input_node.id,
  399. branch_id,
  400. branch_name,
  401. 0,
  402. {},
  403. [],
  404. set(),
  405. [],
  406. None,
  407. )
  408. branches.sort(key=lambda branch: len(branch.steps))
  409. return branches
  410. def process_node_data(node: FlowNode, inputs: Dict[str, Any]) -> Dict[str, Any]:
  411. """Process node data based on node type and inputs"""
  412. outputs = {}
  413. output_types = node.data["outputs"].get("properties", [])
  414. for output in output_types:
  415. output_name = output.get("name", "output")
  416. class_name = node.data.get("class_name", "")
  417. if class_name in ["ReverseResolveEnricher", "ResolveEnricher"]:
  418. outputs[output_name] = (
  419. "192.168.1.1" if "ip" in output_name.lower() else "example.com"
  420. )
  421. elif class_name == "SubdomainEnricher":
  422. outputs[output_name] = f"sub.{inputs.get('input', 'example.com')}"
  423. elif class_name == "WhoisEnricher":
  424. outputs[output_name] = {
  425. "domain": inputs.get("input", "example.com"),
  426. "registrar": "Example Registrar",
  427. "creation_date": "2020-01-01",
  428. }
  429. elif class_name == "IpToInfosEnricher":
  430. outputs[output_name] = {
  431. "country": "France",
  432. "city": "Paris",
  433. "coordinates": {"lat": 48.8566, "lon": 2.3522},
  434. }
  435. elif class_name == "MaigretEnricher":
  436. outputs[output_name] = {
  437. "username": inputs.get("input", "user123"),
  438. "platforms": ["twitter", "github", "linkedin"],
  439. }
  440. elif class_name == "HoleheEnricher":
  441. outputs[output_name] = {
  442. "email": inputs.get("input", "user@example.com"),
  443. "exists": True,
  444. "platforms": ["gmail", "github"],
  445. }
  446. elif class_name == "SireneEnricher":
  447. outputs[output_name] = {
  448. "name": inputs.get("input", "Example Corp"),
  449. "siret": "12345678901234",
  450. "address": "1 Example Street",
  451. }
  452. else:
  453. outputs[output_name] = inputs.get("input") or f"flowed_{output_name}"
  454. return outputs