enrichers.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from typing import List, Optional
  2. from fastapi import APIRouter, Depends, HTTPException, Query
  3. from flowsint_core.core.celery import celery
  4. from flowsint_core.core.graph import create_graph_service
  5. from flowsint_core.core.models import Profile
  6. from flowsint_core.core.postgre_db import get_db
  7. from flowsint_core.core.services import (
  8. create_enricher_service,
  9. create_enricher_template_service,
  10. )
  11. from flowsint_core.core.services.type_registry_service import create_type_registry_service
  12. from flowsint_enrichers import ENRICHER_REGISTRY, load_all_enrichers
  13. from pydantic import BaseModel
  14. from sqlalchemy.orm import Session
  15. from app.api.deps import get_current_user
  16. load_all_enrichers()
  17. class launchEnricherPayload(BaseModel):
  18. node_ids: List[str]
  19. sketch_id: str
  20. router = APIRouter()
  21. @router.get("")
  22. def get_enrichers(
  23. category: Optional[str] = Query(None),
  24. db: Session = Depends(get_db),
  25. current_user: Profile = Depends(get_current_user),
  26. ):
  27. """Get all enrichers, optionally filtered by category."""
  28. enricher_service = create_enricher_service(db)
  29. return enricher_service.get_all_enrichers(
  30. category, current_user.id, ENRICHER_REGISTRY
  31. )
  32. @router.post("/{enricher_name}/launch")
  33. async def launch_enricher(
  34. enricher_name: str,
  35. payload: launchEnricherPayload,
  36. current_user: Profile = Depends(get_current_user),
  37. db: Session = Depends(get_db),
  38. ):
  39. try:
  40. # Retrieve nodes from Neo4J by their element IDs
  41. type_registry = create_type_registry_service(db)
  42. resolver = type_registry.build_type_resolver(current_user.id)
  43. graph_service = create_graph_service(sketch_id=payload.sketch_id, type_resolver=resolver)
  44. entities = graph_service.get_nodes_by_ids_for_task(payload.node_ids)
  45. # Send deserialized nodes
  46. entities = [
  47. entity.model_dump(mode="json", serialize_as_any=True) for entity in entities
  48. ]
  49. if not entities:
  50. raise HTTPException(
  51. status_code=404, detail="No entities found with provided IDs"
  52. )
  53. is_template = False
  54. enricher_in_registry = ENRICHER_REGISTRY.enricher_exists(enricher_name)
  55. if not enricher_in_registry:
  56. template_service = create_enricher_template_service(db)
  57. template = template_service.find_by_name(enricher_name, current_user.id)
  58. if not template:
  59. raise HTTPException(
  60. status_code=404,
  61. detail=f"Enricher '{enricher_name}' not found",
  62. )
  63. is_template = True
  64. task_name = "run_template_enricher" if is_template else "run_enricher"
  65. task = celery.send_task(
  66. task_name,
  67. args=[
  68. enricher_name,
  69. entities,
  70. payload.sketch_id,
  71. str(current_user.id),
  72. ],
  73. )
  74. return {"id": task.id}
  75. except HTTPException:
  76. raise
  77. except Exception as e:
  78. print(e)
  79. raise HTTPException(
  80. status_code=500, detail=f"Error launching enricher: {str(e)}"
  81. )