detection_route.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. import base64
  2. import json
  3. import os
  4. import time
  5. import traceback
  6. from datetime import datetime
  7. from io import BytesIO
  8. from pathlib import Path
  9. from zoneinfo import ZoneInfo
  10. import numpy as np
  11. import torch
  12. from PIL import Image
  13. from flask import jsonify, request, current_app, Response, stream_with_context
  14. from flask_jwt_extended import jwt_required, get_jwt_identity
  15. from ultralytics import YOLO
  16. from app import Config
  17. from app.constants import TaskStatus, OperationType, UserRole, DiseaseGrade
  18. from app.decorators import login_required
  19. from app.models import Detection, Media, Model, Operation, User, db
  20. from app.routes import detection_routes
  21. from app.utils import handle_operation_success, handle_operation_failure, compute_count, compute_perimeter, \
  22. compute_area, compute_shape_complexity, compute_texture_roughness, compute_crack_width, compute_avg_hue, \
  23. evaluate_disease_severity, get_pagination_params, adjust_page_if_needed, unify_result_media_format, delete_file, \
  24. user_rate_limit, DateTimeEncoder
  25. # 获取文件夹配置,并确保目录存在
  26. MODELS_FOLDER = Config.MODELS_FOLDER
  27. MEDIAS_FOLDER = Config.MEDIAS_FOLDER
  28. RESULTS_FOLDER = Config.RESULTS_FOLDER
  29. os.makedirs(MODELS_FOLDER, exist_ok=True)
  30. os.makedirs(MEDIAS_FOLDER, exist_ok=True)
  31. os.makedirs(RESULTS_FOLDER, exist_ok=True)
  32. def _resolve_static_path(raw_path, base_folder):
  33. """
  34. 兼容数据库中不同格式的路径:
  35. - static\\models\\xxx.pt
  36. - static/models/xxx.pt
  37. - models/xxx.pt
  38. - 仅文件名 xxx.pt
  39. """
  40. if not raw_path:
  41. return None
  42. normalized = str(raw_path).replace('\\', '/').lstrip('/')
  43. if normalized.startswith('static/'):
  44. return Path(current_app.root_path) / normalized
  45. return Path(base_folder) / os.path.basename(normalized)
  46. @detection_routes.route('/detection_segmentation', methods=['POST'])
  47. @jwt_required()
  48. @login_required
  49. def detection_segmentation():
  50. start_time = time.time() # 记录操作开始时间
  51. # 获取请求中的 JSON 参数
  52. data = request.get_json()
  53. model_id = data.get('model_id')
  54. media_id = data.get('media_id')
  55. # 创建一个新的操作记录
  56. new_operation = Operation(
  57. operation_type=OperationType.EXECUTE,
  58. description="执行病害检测分割",
  59. ip_address=request.remote_addr,
  60. device_info=request.user_agent.string,
  61. )
  62. # 获取当前用户身份(使用 access token)
  63. current_user_id = get_jwt_identity()
  64. current_user = User.query.get(current_user_id)
  65. # 获取模型和媒体
  66. model = Model.query.get(model_id)
  67. media = Media.query.get(media_id)
  68. # 校验字段
  69. validation_checks = [
  70. (not model_id or not media_id, "【检测分割失败】媒体/模型 ID 为空", 400),
  71. (not model, f"【检测分割失败】模型 ID={model_id} 不存在", 404),
  72. (not media, f"【检测分割失败】媒体 ID={media_id} 不存在", 404),
  73. ]
  74. for condition, message, code in validation_checks:
  75. if condition:
  76. new_operation = handle_operation_failure(new_operation, start_time, message, current_user_id)
  77. current_app.logger.warning(message)
  78. return jsonify({
  79. 'operation': new_operation.to_dict(),
  80. }), code
  81. # 查找是否已存在相同 owner_id 和 media_id 的检测分割记录
  82. existing_detection = Detection.query.filter_by(owner_id=current_user_id, media_id=media_id).first()
  83. if existing_detection:
  84. # 如果存在则更新检测时间,后续会更新其他字段
  85. new_detection = existing_detection
  86. new_detection.model_id = model_id
  87. new_detection.detection_at = datetime.now(ZoneInfo("Asia/Shanghai"))
  88. current_app.logger.info("【检测分割】找到已有记录,进行更新")
  89. else:
  90. # 如果不存在则创建新的检测分割记录
  91. new_detection = Detection(
  92. detection_at=datetime.now(ZoneInfo("Asia/Shanghai")),
  93. owner_id=current_user_id,
  94. model_id=model_id,
  95. media_id=media_id,
  96. )
  97. db.session.add(new_detection)
  98. db.session.commit()
  99. try:
  100. # 更新任务状态为进行中
  101. new_detection.status = TaskStatus.IN_PROGRESS
  102. db.session.commit()
  103. # 导入模型/媒体并执行预测
  104. model_path = _resolve_static_path(model.model_path, MODELS_FOLDER)
  105. source_path = _resolve_static_path(media.media_path, MEDIAS_FOLDER)
  106. if not model_path or not os.path.isfile(model_path):
  107. failure_message = f"【检测分割失败】模型文件不存在:{model.model_path}"
  108. new_operation = handle_operation_failure(new_operation, start_time, failure_message, current_user_id)
  109. current_app.logger.warning(failure_message)
  110. return jsonify({'operation': new_operation.to_dict()}), 400
  111. if not source_path or not os.path.isfile(source_path):
  112. failure_message = f"【检测分割失败】媒体文件不存在:{media.media_path}"
  113. new_operation = handle_operation_failure(new_operation, start_time, failure_message, current_user_id)
  114. current_app.logger.warning(failure_message)
  115. return jsonify({'operation': new_operation.to_dict()}), 400
  116. yolo_model = YOLO(model_path)
  117. use_cuda = torch.cuda.is_available()
  118. results = yolo_model.predict(
  119. source=source_path,
  120. imgsz=1024,
  121. half=use_cuda,
  122. device='cuda' if use_cuda else 'cpu',
  123. retina_masks=True,
  124. save=True,
  125. project=RESULTS_FOLDER,
  126. name=current_user.username,
  127. stream=True,
  128. exist_ok=True, # 每次都保存在同一文件夹
  129. )
  130. def generate():
  131. # 初始化媒体帧数
  132. frame_count = media.frame_count
  133. disease_frame_count = frame_count
  134. # 初始化病害指标
  135. total_disease_count = 0
  136. total_disease_perimeter = 0.0
  137. total_disease_area = 0.0
  138. total_shape_complexity = 0.0
  139. total_texture_roughness = 0.0
  140. total_crack_width = 0.0
  141. total_avg_hue = 0.0
  142. # 初始化检测分割耗时
  143. total_detection_duration = 0.0
  144. # 初始消息
  145. init_msg = {
  146. 'type': 'START',
  147. 'existing_detection': bool(existing_detection),
  148. }
  149. yield json.dumps(init_msg, cls=DateTimeEncoder) + '\n'
  150. for idx, result in enumerate(results):
  151. # 计算帧检测时长
  152. frame_detection_duration = sum(result.speed.values())
  153. total_detection_duration += frame_detection_duration
  154. # 统计帧级指标
  155. if result.masks is None:
  156. disease_frame_count = max(disease_frame_count - 1, 0)
  157. else:
  158. masks = result.masks
  159. masks_data = masks.data.cpu().numpy() # 确保在 CPU 上
  160. combined_masks = np.any(masks_data, axis=0).astype(np.uint8) # 确保重复区域只计算一次
  161. frame_disease_count = compute_count(masks) # 病害数量
  162. frame_disease_perimeter = compute_perimeter(combined_masks) # 病害周长(像素)
  163. frame_disease_area = compute_area(combined_masks) # 病害面积(像素)
  164. frame_shape_complexity = compute_shape_complexity(frame_disease_perimeter,
  165. frame_disease_area) # 形状复杂度
  166. frame_texture_roughness = compute_texture_roughness(combined_masks) # 纹理粗糙度
  167. frame_crack_width = compute_crack_width(
  168. combined_masks) if "裂缝" in model.disease_category else 0.0 # 裂缝宽度
  169. frame_avg_hue = compute_avg_hue(combined_masks,
  170. result.orig_img) if "锈蚀" in model.disease_category else 0.0 # 平均色调
  171. total_disease_count += frame_disease_count
  172. total_disease_perimeter += frame_disease_perimeter
  173. total_disease_area += frame_disease_area
  174. total_shape_complexity += frame_shape_complexity
  175. total_texture_roughness += frame_texture_roughness
  176. total_crack_width += frame_crack_width
  177. total_avg_hue += frame_avg_hue
  178. # 绘制并编码图像
  179. annotated_img = result.plot() # ndarray
  180. annotated_pil = Image.fromarray(annotated_img) # 转为 PIL.Image
  181. buf = BytesIO()
  182. annotated_pil.save(buf, format='JPEG') # 正确保存
  183. b64 = base64.b64encode(buf.getvalue()).decode('ascii') # 编码为 base64 字符串
  184. frame_msg = {
  185. 'type': 'FRAME',
  186. 'frame_index': idx,
  187. 'frame_image': b64,
  188. 'frame_detection_duration': frame_detection_duration,
  189. }
  190. yield json.dumps(frame_msg, cls=DateTimeEncoder) + '\n'
  191. current_app.logger.info(
  192. f"【检测分割】total_disease_count: {total_disease_count}, disease_frame_count: {disease_frame_count}")
  193. # 计算平均病害指标
  194. average_disease_count = total_disease_count // disease_frame_count if disease_frame_count != 0 else 0
  195. average_disease_perimeter = total_disease_perimeter / disease_frame_count if disease_frame_count != 0 else 0.0
  196. average_disease_area = total_disease_area / disease_frame_count if disease_frame_count != 0 else 0.0
  197. average_shape_complexity = total_shape_complexity / disease_frame_count if disease_frame_count != 0 else 0.0
  198. average_texture_roughness = total_texture_roughness / disease_frame_count if disease_frame_count != 0 else 0.0
  199. average_crack_width = total_crack_width / disease_frame_count if disease_frame_count != 0 else 0.0
  200. average_avg_hue = total_avg_hue / disease_frame_count if disease_frame_count != 0 else 0.0
  201. # 计算帧平均检测分割耗时
  202. avg_frame_detection_duration = total_detection_duration / frame_count if frame_count != 0 else 0.0
  203. # 根据检测结果计算病害严重性得分、病害等级、病害描述
  204. disease_severity_score, disease_grade, disease_description = evaluate_disease_severity(
  205. average_disease_count,
  206. average_disease_perimeter,
  207. average_disease_area,
  208. average_shape_complexity,
  209. average_texture_roughness,
  210. average_crack_width,
  211. average_avg_hue, media,
  212. )
  213. if isinstance(disease_grade, str):
  214. disease_grade = DiseaseGrade(disease_grade.lower())
  215. # 检测分割结果路径
  216. result_path = unify_result_media_format(media, current_user)
  217. # 更新检测分割信息
  218. new_detection.status = TaskStatus.COMPLETED
  219. new_detection.result_path = result_path
  220. new_detection.disease_count = average_disease_count
  221. new_detection.disease_perimeter = average_disease_perimeter
  222. new_detection.disease_area = average_disease_area
  223. new_detection.shape_complexity = average_shape_complexity
  224. new_detection.texture_roughness = average_texture_roughness
  225. new_detection.crack_width = average_crack_width
  226. new_detection.avg_hue = average_avg_hue
  227. new_detection.disease_severity_score = disease_severity_score
  228. new_detection.disease_grade = disease_grade
  229. new_detection.disease_description = disease_description
  230. new_detection.detection_duration = total_detection_duration
  231. new_detection.avg_frame_detection_duration = avg_frame_detection_duration
  232. db.session.commit()
  233. # 记录操作
  234. success_op = handle_operation_success(new_operation, start_time, current_user_id)
  235. end_msg = {
  236. 'type': 'END',
  237. 'existing_detection': bool(existing_detection),
  238. 'new_detection': new_detection.to_dict(),
  239. 'operation': success_op.to_dict(),
  240. }
  241. current_app.logger.info(f"【检测分割成功】new_detection: {success_op}")
  242. yield json.dumps(end_msg, cls=DateTimeEncoder) + '\n'
  243. return Response(stream_with_context(generate()), mimetype='application/json')
  244. except Exception as error:
  245. # 发生错误,更新任务状态为失败
  246. new_detection.status = TaskStatus.FAILED
  247. db.session.commit()
  248. # 记录操作失败
  249. failure_message = f"【检测分割错误】服务器内部发生错误,请联系管理员"
  250. new_operation = handle_operation_failure(new_operation, start_time, failure_message, current_user_id)
  251. # 获取详细的堆栈追踪信息
  252. stack_trace = traceback.format_exc()
  253. # 获取请求的相关信息
  254. request_method = request.method
  255. request_url = request.url
  256. request_data = request.get_data(as_text=True)
  257. # 记录日志,帮助排查问题
  258. current_app.logger.error(f"Error: {str(error)}\nStack Trace: {stack_trace}\n"
  259. f"Request Method: {request_method}\n"
  260. f"Request URL: {request_url}\n"
  261. f"Request Data: {request_data}")
  262. return jsonify({
  263. 'operation': new_operation.to_dict(),
  264. }), 500
  265. @detection_routes.route('/detail/<int:detection_id>', methods=['GET'])
  266. @jwt_required()
  267. @login_required
  268. def detail(detection_id):
  269. # 获取当前用户身份(使用 access token)
  270. current_user_id = get_jwt_identity()
  271. current_user = User.query.get(current_user_id)
  272. # 获取指定检测分割记录
  273. detection = Detection.query.get(detection_id)
  274. # 校验字段
  275. validation_checks = [
  276. (not detection, f"【获取检测分割 ID={detection_id} 详情失败】该检测分割记录不存在", 404),
  277. (detection and detection.owner_id != current_user_id and current_user.role != UserRole.ADMIN
  278. and current_user.role != UserRole.DEVELOPER,
  279. f"【获取检测分割 ID={detection_id} 详情失败】您非管理员/开发人员,无法查看他人的检测分割记录详情", 403),
  280. ]
  281. for condition, message, code in validation_checks:
  282. if condition:
  283. current_app.logger.warning(message + f', operator: {current_user}')
  284. return jsonify({
  285. 'failure_message': message,
  286. }), code
  287. return jsonify({
  288. 'detection': detection.to_dict(),
  289. }), 200
  290. @detection_routes.route('/delete/<int:detection_id>', methods=['GET'])
  291. @jwt_required()
  292. @login_required
  293. def delete_detection(detection_id):
  294. start_time = time.time() # 记录操作开始时间
  295. # 创建一个新的操作记录
  296. new_operation = Operation(
  297. operation_type=OperationType.DELETE,
  298. description=f"删除检测分割 ID={detection_id} 记录",
  299. ip_address=request.remote_addr,
  300. device_info=request.user_agent.string,
  301. )
  302. # 获取当前用户身份(使用 access token)
  303. current_user_id = get_jwt_identity()
  304. current_user = User.query.get(current_user_id)
  305. # 获取指定媒体
  306. deleted_detection = Detection.query.get(detection_id)
  307. # 校验字段
  308. validation_checks = [
  309. (not deleted_detection, f"【删除检测分割 ID={detection_id} 记录失败】该检测分割记录不存在", 404),
  310. (deleted_detection and deleted_detection.owner_id != current_user_id and current_user.role != UserRole.ADMIN
  311. and current_user.role != UserRole.DEVELOPER,
  312. f"【删除检测分割 ID={detection_id} 记录失败】您非管理员/开发人员,无法删除他人的检测分割记录", 403),
  313. ]
  314. for condition, message, code in validation_checks:
  315. if condition:
  316. new_operation = handle_operation_failure(new_operation, start_time, message, current_user_id)
  317. current_app.logger.warning(message + f', operator: {current_user}')
  318. return jsonify({
  319. 'operation': new_operation.to_dict(),
  320. }), code
  321. # 删除实际文件
  322. file_abs_path = _resolve_static_path(deleted_detection.result_path, RESULTS_FOLDER)
  323. delete_file(file_abs_path)
  324. # 删除数据库记录
  325. db.session.delete(deleted_detection)
  326. db.session.commit()
  327. # 记录操作
  328. new_operation = handle_operation_success(new_operation, start_time, current_user_id)
  329. current_app.logger.info(
  330. f"【删除检测分割 ID={detection_id} 记录成功】deleted_detection: {deleted_detection}, operator: {current_user}")
  331. return jsonify({
  332. 'operation': new_operation.to_dict(),
  333. 'deleted_detection': deleted_detection.to_dict(),
  334. }), 200
  335. @detection_routes.route('/detections/<int:user_id>', methods=['GET'])
  336. @jwt_required()
  337. @login_required
  338. def user_detections(user_id):
  339. # 获取分页参数(从请求中获取,默认为第 1 页,每页 5 条记录)
  340. default_page = request.args.get('page', 1, type=int)
  341. default_per_page = request.args.get('per_page', 5, type=int)
  342. page, per_page = get_pagination_params(default_page, default_per_page)
  343. # 获取当前用户身份(使用 access token)
  344. current_user_id = get_jwt_identity()
  345. current_user = User.query.get(current_user_id)
  346. # 获取指定用户身份
  347. user = User.query.get(user_id)
  348. # 校验字段
  349. validation_checks = [
  350. (not user, f"【获取用户 ID={user_id} 检测分割记录失败】该用户不存在", 404),
  351. (current_user_id != user_id and current_user.role != UserRole.ADMIN and current_user.role != UserRole.DEVELOPER,
  352. f"【获取用户 ID={user_id} 检测分割记录失败】您非管理员/开发人员,无法查看他人的检测分割记录", 403),
  353. ]
  354. for condition, message, code in validation_checks:
  355. if condition:
  356. current_app.logger.warning(message)
  357. return jsonify({
  358. 'failure_message': message,
  359. }), code
  360. # 获取指定用户检测分割记录
  361. query = (
  362. Detection.query
  363. .filter(Detection.owner_id == user_id)
  364. .join(Model, Detection.model_id == Model.model_id)
  365. .join(Media, Detection.media_id == Media.media_id)
  366. .join(User, Detection.owner_id == User.user_id)
  367. .add_columns(
  368. Model.model_name.label('model_name'),
  369. Media.media_name.label('media_name'),
  370. Media.file_type.label('media_type'),
  371. User.username.label('owner_username'),
  372. )
  373. )
  374. page, detections_total, pages = adjust_page_if_needed(query, page, per_page)
  375. paginated = query.paginate(page=page, per_page=per_page, error_out=False)
  376. detections = []
  377. for detection, model_name, media_name, media_type, owner_username in paginated.items:
  378. detection_dict = detection.to_dict()
  379. detection_dict.update({
  380. 'model_name': model_name,
  381. 'media_name': media_name,
  382. 'media_type': media_type,
  383. 'owner_username': owner_username,
  384. })
  385. detections.append(detection_dict)
  386. current_app.logger.info(
  387. f"【获取用户 ID={user_id} 检测分割记录成功】total: {detections_total}, per_page: {per_page}, page: {page}, pages: {pages}, detections: {detections}, operator: {current_user}")
  388. return jsonify({
  389. 'detections': detections,
  390. 'total': detections_total,
  391. 'per_page': per_page,
  392. 'page': page,
  393. 'pages': pages,
  394. }), 200
  395. @detection_routes.route('/detections/all', methods=['GET'])
  396. @jwt_required()
  397. @login_required
  398. def all_detections():
  399. # 获取分页参数(从请求中获取,默认为第 1 页,每页 5 条记录)
  400. default_page = request.args.get('page', 1, type=int)
  401. default_per_page = request.args.get('per_page', 5, type=int)
  402. page, per_page = get_pagination_params(default_page, default_per_page)
  403. # 获取当前用户身份(使用 access token)
  404. current_user_id = get_jwt_identity()
  405. current_user = User.query.get(current_user_id)
  406. if current_user.role != UserRole.ADMIN and current_user.role != UserRole.DEVELOPER:
  407. failure_message = f"【获取所有检测分割记录失败】您非管理员/开发人员,权限不足"
  408. current_app.logger.warning(failure_message)
  409. return jsonify({
  410. 'failure_message': failure_message,
  411. }), 403
  412. # 获取所有媒体
  413. query = (
  414. Detection.query
  415. .join(Model, Detection.model_id == Model.model_id)
  416. .join(Media, Detection.media_id == Media.media_id)
  417. .join(User, Detection.owner_id == User.user_id)
  418. .add_columns(
  419. Model.model_name.label('model_name'),
  420. Media.media_name.label('media_name'),
  421. Media.file_type.label('media_type'),
  422. User.username.label('owner_username'),
  423. )
  424. .order_by(Detection.detection_id.asc())
  425. )
  426. page, detections_total, pages = adjust_page_if_needed(query, page, per_page)
  427. paginated = query.paginate(page=page, per_page=per_page, error_out=False)
  428. detections = []
  429. for detection, model_name, media_name, media_type, owner_username in paginated.items:
  430. detection_dict = detection.to_dict()
  431. detection_dict.update({
  432. 'model_name': model_name,
  433. 'media_name': media_name,
  434. 'media_type': media_type,
  435. 'owner_username': owner_username,
  436. })
  437. detections.append(detection_dict)
  438. current_app.logger.info(
  439. f"【获取所有检测分割记录成功】total: {detections_total}, per_page: {per_page}, page: {page}, pages: {pages}, detections: {detections}, operator: {current_user}")
  440. return jsonify({
  441. 'detections': detections,
  442. 'total': detections_total,
  443. 'per_page': per_page,
  444. 'page': page,
  445. 'pages': pages,
  446. }), 200
  447. @detection_routes.route('/statistics', methods=['GET'])
  448. @jwt_required()
  449. @login_required
  450. def statistics():
  451. # 查询检测记录总数
  452. total_detections = Detection.query.count()
  453. # 查询不同状态的检测记录数量
  454. pending_detections = Detection.query.filter(Detection.status == TaskStatus.PENDING).count()
  455. in_progress_detections = Detection.query.filter(Detection.status == TaskStatus.IN_PROGRESS).count()
  456. completed_detections = Detection.query.filter(Detection.status == TaskStatus.COMPLETED).count()
  457. failed_detections = Detection.query.filter(Detection.status == TaskStatus.FAILED).count()
  458. # 构建统计数据
  459. detections_statistics = {
  460. 'total': total_detections,
  461. 'pending': pending_detections,
  462. 'in_progress': in_progress_detections,
  463. 'completed': completed_detections,
  464. 'failed': failed_detections,
  465. }
  466. return jsonify({
  467. "detections_statistics": detections_statistics,
  468. }), 200