detection_route.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. import base64
  2. import json
  3. import os
  4. import time
  5. import traceback
  6. from datetime import datetime
  7. from io import BytesIO
  8. from pathlib import Path
  9. from zoneinfo import ZoneInfo
  10. import numpy as np
  11. from PIL import Image
  12. from flask import jsonify, request, current_app, Response, stream_with_context
  13. from flask_jwt_extended import jwt_required, get_jwt_identity
  14. from ultralytics import YOLO
  15. from app import Config
  16. from app.constants import TaskStatus, OperationType, UserRole
  17. from app.decorators import login_required
  18. from app.models import Detection, Media, Model, Operation, User, db
  19. from app.routes import detection_routes
  20. from app.utils import handle_operation_success, handle_operation_failure, compute_count, compute_perimeter, \
  21. compute_area, compute_shape_complexity, compute_texture_roughness, compute_crack_width, compute_avg_hue, \
  22. evaluate_disease_severity, get_pagination_params, adjust_page_if_needed, unify_result_media_format, delete_file, \
  23. user_rate_limit, DateTimeEncoder
  24. # 获取文件夹配置,并确保目录存在
  25. MODELS_FOLDER = Config.MODELS_FOLDER
  26. MEDIAS_FOLDER = Config.MEDIAS_FOLDER
  27. RESULTS_FOLDER = Config.RESULTS_FOLDER
  28. os.makedirs(MODELS_FOLDER, exist_ok=True)
  29. os.makedirs(MEDIAS_FOLDER, exist_ok=True)
  30. os.makedirs(RESULTS_FOLDER, exist_ok=True)
  31. def _resolve_static_path(raw_path, base_folder):
  32. """
  33. 兼容数据库中不同格式的路径:
  34. - static\\models\\xxx.pt
  35. - static/models/xxx.pt
  36. - models/xxx.pt
  37. - 仅文件名 xxx.pt
  38. """
  39. if not raw_path:
  40. return None
  41. normalized = str(raw_path).replace('\\', '/').lstrip('/')
  42. if normalized.startswith('static/'):
  43. return Path(current_app.root_path) / normalized
  44. return Path(base_folder) / os.path.basename(normalized)
  45. @detection_routes.route('/detection_segmentation', methods=['POST'])
  46. @jwt_required()
  47. @login_required
  48. def detection_segmentation():
  49. start_time = time.time() # 记录操作开始时间
  50. # 获取请求中的 JSON 参数
  51. data = request.get_json()
  52. model_id = data.get('model_id')
  53. media_id = data.get('media_id')
  54. # 创建一个新的操作记录
  55. new_operation = Operation(
  56. operation_type=OperationType.EXECUTE,
  57. description="执行病害检测分割",
  58. ip_address=request.remote_addr,
  59. device_info=request.user_agent.string,
  60. )
  61. # 获取当前用户身份(使用 access token)
  62. current_user_id = get_jwt_identity()
  63. current_user = User.query.get(current_user_id)
  64. # 获取模型和媒体
  65. model = Model.query.get(model_id)
  66. media = Media.query.get(media_id)
  67. # 校验字段
  68. validation_checks = [
  69. (not model_id or not media_id, "【检测分割失败】媒体/模型 ID 为空", 400),
  70. (not model, f"【检测分割失败】模型 ID={model_id} 不存在", 404),
  71. (not media, f"【检测分割失败】媒体 ID={media_id} 不存在", 404),
  72. ]
  73. for condition, message, code in validation_checks:
  74. if condition:
  75. new_operation = handle_operation_failure(new_operation, start_time, message, current_user_id)
  76. current_app.logger.warning(message)
  77. return jsonify({
  78. 'operation': new_operation.to_dict(),
  79. }), code
  80. # 查找是否已存在相同 owner_id 和 media_id 的检测分割记录
  81. existing_detection = Detection.query.filter_by(owner_id=current_user_id, media_id=media_id).first()
  82. if existing_detection:
  83. # 如果存在则更新检测时间,后续会更新其他字段
  84. new_detection = existing_detection
  85. new_detection.model_id = model_id
  86. new_detection.detection_at = datetime.now(ZoneInfo("Asia/Shanghai"))
  87. current_app.logger.info("【检测分割】找到已有记录,进行更新")
  88. else:
  89. # 如果不存在则创建新的检测分割记录
  90. new_detection = Detection(
  91. detection_at=datetime.now(ZoneInfo("Asia/Shanghai")),
  92. owner_id=current_user_id,
  93. model_id=model_id,
  94. media_id=media_id,
  95. )
  96. db.session.add(new_detection)
  97. db.session.commit()
  98. try:
  99. # 更新任务状态为进行中
  100. new_detection.status = TaskStatus.IN_PROGRESS
  101. db.session.commit()
  102. # 导入模型/媒体并执行预测
  103. model_path = _resolve_static_path(model.model_path, MODELS_FOLDER)
  104. source_path = _resolve_static_path(media.media_path, MEDIAS_FOLDER)
  105. if not model_path or not os.path.isfile(model_path):
  106. failure_message = f"【检测分割失败】模型文件不存在:{model.model_path}"
  107. new_operation = handle_operation_failure(new_operation, start_time, failure_message, current_user_id)
  108. current_app.logger.warning(failure_message)
  109. return jsonify({'operation': new_operation.to_dict()}), 400
  110. if not source_path or not os.path.isfile(source_path):
  111. failure_message = f"【检测分割失败】媒体文件不存在:{media.media_path}"
  112. new_operation = handle_operation_failure(new_operation, start_time, failure_message, current_user_id)
  113. current_app.logger.warning(failure_message)
  114. return jsonify({'operation': new_operation.to_dict()}), 400
  115. yolo_model = YOLO(model_path)
  116. results = yolo_model.predict(
  117. source=source_path,
  118. imgsz=1024,
  119. half=True,
  120. retina_masks=True,
  121. save=True,
  122. project=RESULTS_FOLDER,
  123. name=current_user.username,
  124. stream=True,
  125. exist_ok=True, # 每次都保存在同一文件夹
  126. )
  127. def generate():
  128. # 初始化媒体帧数
  129. frame_count = media.frame_count
  130. disease_frame_count = frame_count
  131. # 初始化病害指标
  132. total_disease_count = 0
  133. total_disease_perimeter = 0.0
  134. total_disease_area = 0.0
  135. total_shape_complexity = 0.0
  136. total_texture_roughness = 0.0
  137. total_crack_width = 0.0
  138. total_avg_hue = 0.0
  139. # 初始化检测分割耗时
  140. total_detection_duration = 0.0
  141. # 初始消息
  142. init_msg = {
  143. 'type': 'START',
  144. 'existing_detection': bool(existing_detection),
  145. }
  146. yield json.dumps(init_msg, cls=DateTimeEncoder) + '\n'
  147. for idx, result in enumerate(results):
  148. # 计算帧检测时长
  149. frame_detection_duration = sum(result.speed.values())
  150. total_detection_duration += frame_detection_duration
  151. # 统计帧级指标
  152. if result.masks is None:
  153. disease_frame_count = max(disease_frame_count - 1, 0)
  154. else:
  155. masks = result.masks
  156. masks_data = masks.data.cpu().numpy() # 确保在 CPU 上
  157. combined_masks = np.any(masks_data, axis=0).astype(np.uint8) # 确保重复区域只计算一次
  158. frame_disease_count = compute_count(masks) # 病害数量
  159. frame_disease_perimeter = compute_perimeter(combined_masks) # 病害周长(像素)
  160. frame_disease_area = compute_area(combined_masks) # 病害面积(像素)
  161. frame_shape_complexity = compute_shape_complexity(frame_disease_perimeter,
  162. frame_disease_area) # 形状复杂度
  163. frame_texture_roughness = compute_texture_roughness(combined_masks) # 纹理粗糙度
  164. frame_crack_width = compute_crack_width(
  165. combined_masks) if "裂缝" in model.disease_category else 0.0 # 裂缝宽度
  166. frame_avg_hue = compute_avg_hue(combined_masks,
  167. result.orig_img) if "锈蚀" in model.disease_category else 0.0 # 平均色调
  168. total_disease_count += frame_disease_count
  169. total_disease_perimeter += frame_disease_perimeter
  170. total_disease_area += frame_disease_area
  171. total_shape_complexity += frame_shape_complexity
  172. total_texture_roughness += frame_texture_roughness
  173. total_crack_width += frame_crack_width
  174. total_avg_hue += frame_avg_hue
  175. # 绘制并编码图像
  176. annotated_img = result.plot() # ndarray
  177. annotated_pil = Image.fromarray(annotated_img) # 转为 PIL.Image
  178. buf = BytesIO()
  179. annotated_pil.save(buf, format='JPEG') # 正确保存
  180. b64 = base64.b64encode(buf.getvalue()).decode('ascii') # 编码为 base64 字符串
  181. frame_msg = {
  182. 'type': 'FRAME',
  183. 'frame_index': idx,
  184. 'frame_image': b64,
  185. 'frame_detection_duration': frame_detection_duration,
  186. }
  187. yield json.dumps(frame_msg, cls=DateTimeEncoder) + '\n'
  188. current_app.logger.info(
  189. f"【检测分割】total_disease_count: {total_disease_count}, disease_frame_count: {disease_frame_count}")
  190. # 计算平均病害指标
  191. average_disease_count = total_disease_count // disease_frame_count if disease_frame_count != 0 else 0
  192. average_disease_perimeter = total_disease_perimeter / disease_frame_count if disease_frame_count != 0 else 0.0
  193. average_disease_area = total_disease_area / disease_frame_count if disease_frame_count != 0 else 0.0
  194. average_shape_complexity = total_shape_complexity / disease_frame_count if disease_frame_count != 0 else 0.0
  195. average_texture_roughness = total_texture_roughness / disease_frame_count if disease_frame_count != 0 else 0.0
  196. average_crack_width = total_crack_width / disease_frame_count if disease_frame_count != 0 else 0.0
  197. average_avg_hue = total_avg_hue / disease_frame_count if disease_frame_count != 0 else 0.0
  198. # 计算帧平均检测分割耗时
  199. avg_frame_detection_duration = total_detection_duration / frame_count if frame_count != 0 else 0.0
  200. # 根据检测结果计算病害严重性得分、病害等级、病害描述
  201. disease_severity_score, disease_grade, disease_description = evaluate_disease_severity(
  202. average_disease_count,
  203. average_disease_perimeter,
  204. average_disease_area,
  205. average_shape_complexity,
  206. average_texture_roughness,
  207. average_crack_width,
  208. average_avg_hue, media,
  209. )
  210. # 检测分割结果路径
  211. result_path = unify_result_media_format(media, current_user)
  212. # 更新检测分割信息
  213. new_detection.status = TaskStatus.COMPLETED
  214. new_detection.result_path = result_path
  215. new_detection.disease_count = average_disease_count
  216. new_detection.disease_perimeter = average_disease_perimeter
  217. new_detection.disease_area = average_disease_area
  218. new_detection.shape_complexity = average_shape_complexity
  219. new_detection.texture_roughness = average_texture_roughness
  220. new_detection.crack_width = average_crack_width
  221. new_detection.avg_hue = average_avg_hue
  222. new_detection.disease_severity_score = disease_severity_score
  223. new_detection.disease_grade = disease_grade
  224. new_detection.disease_description = disease_description
  225. new_detection.detection_duration = total_detection_duration
  226. new_detection.avg_frame_detection_duration = avg_frame_detection_duration
  227. db.session.commit()
  228. # 记录操作
  229. success_op = handle_operation_success(new_operation, start_time, current_user_id)
  230. end_msg = {
  231. 'type': 'END',
  232. 'existing_detection': bool(existing_detection),
  233. 'new_detection': new_detection.to_dict(),
  234. 'operation': success_op.to_dict(),
  235. }
  236. current_app.logger.info(f"【检测分割成功】new_detection: {success_op}")
  237. yield json.dumps(end_msg, cls=DateTimeEncoder) + '\n'
  238. return Response(stream_with_context(generate()), mimetype='application/json')
  239. except Exception as error:
  240. # 发生错误,更新任务状态为失败
  241. new_detection.status = TaskStatus.FAILED
  242. db.session.commit()
  243. # 记录操作失败
  244. failure_message = f"【检测分割错误】服务器内部发生错误,请联系管理员"
  245. new_operation = handle_operation_failure(new_operation, start_time, failure_message, current_user_id)
  246. # 获取详细的堆栈追踪信息
  247. stack_trace = traceback.format_exc()
  248. # 获取请求的相关信息
  249. request_method = request.method
  250. request_url = request.url
  251. request_data = request.get_data(as_text=True)
  252. # 记录日志,帮助排查问题
  253. current_app.logger.error(f"Error: {str(error)}\nStack Trace: {stack_trace}\n"
  254. f"Request Method: {request_method}\n"
  255. f"Request URL: {request_url}\n"
  256. f"Request Data: {request_data}")
  257. return jsonify({
  258. 'operation': new_operation.to_dict(),
  259. }), 500
  260. @detection_routes.route('/detail/<int:detection_id>', methods=['GET'])
  261. @jwt_required()
  262. @login_required
  263. def detail(detection_id):
  264. # 获取当前用户身份(使用 access token)
  265. current_user_id = get_jwt_identity()
  266. current_user = User.query.get(current_user_id)
  267. # 获取指定检测分割记录
  268. detection = Detection.query.get(detection_id)
  269. # 校验字段
  270. validation_checks = [
  271. (not detection, f"【获取检测分割 ID={detection_id} 详情失败】该检测分割记录不存在", 404),
  272. (detection and detection.owner_id != current_user_id and current_user.role != UserRole.ADMIN
  273. and current_user.role != UserRole.DEVELOPER,
  274. f"【获取检测分割 ID={detection_id} 详情失败】您非管理员/开发人员,无法查看他人的检测分割记录详情", 403),
  275. ]
  276. for condition, message, code in validation_checks:
  277. if condition:
  278. current_app.logger.warning(message + f', operator: {current_user}')
  279. return jsonify({
  280. 'failure_message': message,
  281. }), code
  282. return jsonify({
  283. 'detection': detection.to_dict(),
  284. }), 200
  285. @detection_routes.route('/delete/<int:detection_id>', methods=['GET'])
  286. @jwt_required()
  287. @login_required
  288. def delete_detection(detection_id):
  289. start_time = time.time() # 记录操作开始时间
  290. # 创建一个新的操作记录
  291. new_operation = Operation(
  292. operation_type=OperationType.DELETE,
  293. description=f"删除检测分割 ID={detection_id} 记录",
  294. ip_address=request.remote_addr,
  295. device_info=request.user_agent.string,
  296. )
  297. # 获取当前用户身份(使用 access token)
  298. current_user_id = get_jwt_identity()
  299. current_user = User.query.get(current_user_id)
  300. # 获取指定媒体
  301. deleted_detection = Detection.query.get(detection_id)
  302. # 校验字段
  303. validation_checks = [
  304. (not deleted_detection, f"【删除检测分割 ID={detection_id} 记录失败】该检测分割记录不存在", 404),
  305. (deleted_detection and deleted_detection.owner_id != current_user_id and current_user.role != UserRole.ADMIN
  306. and current_user.role != UserRole.DEVELOPER,
  307. f"【删除检测分割 ID={detection_id} 记录失败】您非管理员/开发人员,无法删除他人的检测分割记录", 403),
  308. ]
  309. for condition, message, code in validation_checks:
  310. if condition:
  311. new_operation = handle_operation_failure(new_operation, start_time, message, current_user_id)
  312. current_app.logger.warning(message + f', operator: {current_user}')
  313. return jsonify({
  314. 'operation': new_operation.to_dict(),
  315. }), code
  316. # 删除实际文件
  317. file_abs_path = _resolve_static_path(deleted_detection.result_path, RESULTS_FOLDER)
  318. delete_file(file_abs_path)
  319. # 删除数据库记录
  320. db.session.delete(deleted_detection)
  321. db.session.commit()
  322. # 记录操作
  323. new_operation = handle_operation_success(new_operation, start_time, current_user_id)
  324. current_app.logger.info(
  325. f"【删除检测分割 ID={detection_id} 记录成功】deleted_detection: {deleted_detection}, operator: {current_user}")
  326. return jsonify({
  327. 'operation': new_operation.to_dict(),
  328. 'deleted_detection': deleted_detection.to_dict(),
  329. }), 200
  330. @detection_routes.route('/detections/<int:user_id>', methods=['GET'])
  331. @jwt_required()
  332. @login_required
  333. def user_detections(user_id):
  334. # 获取分页参数(从请求中获取,默认为第 1 页,每页 5 条记录)
  335. default_page = request.args.get('page', 1, type=int)
  336. default_per_page = request.args.get('per_page', 5, type=int)
  337. page, per_page = get_pagination_params(default_page, default_per_page)
  338. # 获取当前用户身份(使用 access token)
  339. current_user_id = get_jwt_identity()
  340. current_user = User.query.get(current_user_id)
  341. # 获取指定用户身份
  342. user = User.query.get(user_id)
  343. # 校验字段
  344. validation_checks = [
  345. (not user, f"【获取用户 ID={user_id} 检测分割记录失败】该用户不存在", 404),
  346. (current_user_id != user_id and current_user.role != UserRole.ADMIN and current_user.role != UserRole.DEVELOPER,
  347. f"【获取用户 ID={user_id} 检测分割记录失败】您非管理员/开发人员,无法查看他人的检测分割记录", 403),
  348. ]
  349. for condition, message, code in validation_checks:
  350. if condition:
  351. current_app.logger.warning(message)
  352. return jsonify({
  353. 'failure_message': message,
  354. }), code
  355. # 获取指定用户检测分割记录
  356. query = (
  357. Detection.query
  358. .filter(Detection.owner_id == user_id)
  359. .join(Model, Detection.model_id == Model.model_id)
  360. .join(Media, Detection.media_id == Media.media_id)
  361. .join(User, Detection.owner_id == User.user_id)
  362. .add_columns(
  363. Model.model_name.label('model_name'),
  364. Media.media_name.label('media_name'),
  365. Media.file_type.label('media_type'),
  366. User.username.label('owner_username'),
  367. )
  368. )
  369. page, detections_total, pages = adjust_page_if_needed(query, page, per_page)
  370. paginated = query.paginate(page=page, per_page=per_page, error_out=False)
  371. detections = []
  372. for detection, model_name, media_name, media_type, owner_username in paginated.items:
  373. detection_dict = detection.to_dict()
  374. detection_dict.update({
  375. 'model_name': model_name,
  376. 'media_name': media_name,
  377. 'media_type': media_type,
  378. 'owner_username': owner_username,
  379. })
  380. detections.append(detection_dict)
  381. current_app.logger.info(
  382. f"【获取用户 ID={user_id} 检测分割记录成功】total: {detections_total}, per_page: {per_page}, page: {page}, pages: {pages}, detections: {detections}, operator: {current_user}")
  383. return jsonify({
  384. 'detections': detections,
  385. 'total': detections_total,
  386. 'per_page': per_page,
  387. 'page': page,
  388. 'pages': pages,
  389. }), 200
  390. @detection_routes.route('/detections/all', methods=['GET'])
  391. @jwt_required()
  392. @login_required
  393. def all_detections():
  394. # 获取分页参数(从请求中获取,默认为第 1 页,每页 5 条记录)
  395. default_page = request.args.get('page', 1, type=int)
  396. default_per_page = request.args.get('per_page', 5, type=int)
  397. page, per_page = get_pagination_params(default_page, default_per_page)
  398. # 获取当前用户身份(使用 access token)
  399. current_user_id = get_jwt_identity()
  400. current_user = User.query.get(current_user_id)
  401. if current_user.role != UserRole.ADMIN and current_user.role != UserRole.DEVELOPER:
  402. failure_message = f"【获取所有检测分割记录失败】您非管理员/开发人员,权限不足"
  403. current_app.logger.warning(failure_message)
  404. return jsonify({
  405. 'failure_message': failure_message,
  406. }), 403
  407. # 获取所有媒体
  408. query = (
  409. Detection.query
  410. .join(Model, Detection.model_id == Model.model_id)
  411. .join(Media, Detection.media_id == Media.media_id)
  412. .join(User, Detection.owner_id == User.user_id)
  413. .add_columns(
  414. Model.model_name.label('model_name'),
  415. Media.media_name.label('media_name'),
  416. Media.file_type.label('media_type'),
  417. User.username.label('owner_username'),
  418. )
  419. .order_by(Detection.detection_id.asc())
  420. )
  421. page, detections_total, pages = adjust_page_if_needed(query, page, per_page)
  422. paginated = query.paginate(page=page, per_page=per_page, error_out=False)
  423. detections = []
  424. for detection, model_name, media_name, media_type, owner_username in paginated.items:
  425. detection_dict = detection.to_dict()
  426. detection_dict.update({
  427. 'model_name': model_name,
  428. 'media_name': media_name,
  429. 'media_type': media_type,
  430. 'owner_username': owner_username,
  431. })
  432. detections.append(detection_dict)
  433. current_app.logger.info(
  434. f"【获取所有检测分割记录成功】total: {detections_total}, per_page: {per_page}, page: {page}, pages: {pages}, detections: {detections}, operator: {current_user}")
  435. return jsonify({
  436. 'detections': detections,
  437. 'total': detections_total,
  438. 'per_page': per_page,
  439. 'page': page,
  440. 'pages': pages,
  441. }), 200
  442. @detection_routes.route('/statistics', methods=['GET'])
  443. @jwt_required()
  444. @login_required
  445. def statistics():
  446. # 查询检测记录总数
  447. total_detections = Detection.query.count()
  448. # 查询不同状态的检测记录数量
  449. pending_detections = Detection.query.filter(Detection.status == TaskStatus.PENDING).count()
  450. in_progress_detections = Detection.query.filter(Detection.status == TaskStatus.IN_PROGRESS).count()
  451. completed_detections = Detection.query.filter(Detection.status == TaskStatus.COMPLETED).count()
  452. failed_detections = Detection.query.filter(Detection.status == TaskStatus.FAILED).count()
  453. # 构建统计数据
  454. detections_statistics = {
  455. 'total': total_detections,
  456. 'pending': pending_detections,
  457. 'in_progress': in_progress_detections,
  458. 'completed': completed_detections,
  459. 'failed': failed_detections,
  460. }
  461. return jsonify({
  462. "detections_statistics": detections_statistics,
  463. }), 200