model_route.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. import os
  2. import time
  3. from flask import request, jsonify, current_app
  4. from flask_jwt_extended import jwt_required, get_jwt_identity
  5. from werkzeug.utils import secure_filename
  6. from app.constants import OperationType, UserRole
  7. from app.decorators import login_required
  8. from app.models import Operation, Model, db, User, Detection
  9. from app.routes import model_routes
  10. from app.utils import handle_operation_failure, allowed_model_file, handle_file_upload, handle_operation_success, \
  11. adjust_page_if_needed, get_pagination_params, delete_file, user_rate_limit
  12. @model_routes.route('/upload', methods=['POST'])
  13. @jwt_required()
  14. @login_required
  15. def upload():
  16. start_time = time.time() # 记录操作开始时间
  17. # 获取请求中的表单数据
  18. model_file = request.files.get('model_file')
  19. disease_category = request.form.get('disease_category')
  20. augmentation = request.form.get('augmentation', '原图')
  21. layers = int(request.form.get('layers'))
  22. parameters = int(request.form.get('parameters'))
  23. GFLOPs = float(request.form.get('GFLOPs'))
  24. box_p = float(request.form.get('box_p', 0.0))
  25. box_r = float(request.form.get('box_r', 0.0))
  26. box_mAP50 = float(request.form.get('box_mAP50', 0.0))
  27. box_mAP50_95 = float(request.form.get('box_mAP50_95', 0.0))
  28. mask_p = float(request.form.get('mask_p', 0.0))
  29. mask_r = float(request.form.get('mask_r', 0.0))
  30. mask_mAP50 = float(request.form.get('mask_mAP50', 0.0))
  31. mask_mAP50_95 = float(request.form.get('mask_mAP50_95', 0.0))
  32. f1_score = float(request.form.get('f1_score', 0.0))
  33. fitness_score = float(request.form.get('fitness_score', 0.0))
  34. # 创建一个新的操作记录
  35. new_operation = Operation(
  36. operation_type=OperationType.CREATE,
  37. description="上传模型",
  38. ip_address=request.remote_addr,
  39. device_info=request.user_agent.string,
  40. )
  41. # 获取当前用户身份(使用 access token)
  42. current_user_id = get_jwt_identity()
  43. current_user = User.query.get(current_user_id)
  44. # 先获取文件名
  45. file_name = secure_filename(model_file.filename) if model_file else None
  46. # 检查文件名是否已存在
  47. existing_model = Model.query.filter_by(model_name=file_name).first() if file_name else None
  48. # 校验字段
  49. validation_checks = [
  50. (model_file and not allowed_model_file(model_file), "【上传模型失败】模型文件不合规", 400),
  51. (current_user.role != UserRole.DEVELOPER, f"【上传模型失败】您非开发人员,无权上传模型", 403),
  52. (model_file and existing_model, f"【上传模型失败】模型 {file_name} 已存在,请重新上传", 400),
  53. (not disease_category or not layers or not parameters or not GFLOPs,
  54. "【上传模型失败】模型病害类别、层数、参数量或计算量为空", 400),
  55. ]
  56. for condition, message, code in validation_checks:
  57. if condition:
  58. new_operation = handle_operation_failure(new_operation, start_time, message, current_user_id)
  59. current_app.logger.warning(message + f', operator: {current_user}')
  60. return jsonify({
  61. 'operation': new_operation.to_dict(),
  62. }), code
  63. # 保存文件到指定目录(返回相对路径)
  64. file_path = handle_file_upload(model_file, 'models')
  65. new_model = Model(
  66. model_name=file_name,
  67. model_path=file_path,
  68. disease_category=disease_category,
  69. augmentation=augmentation,
  70. layers=layers,
  71. parameters=parameters,
  72. GFLOPs=GFLOPs,
  73. box_p=box_p,
  74. box_r=box_r,
  75. box_mAP50=box_mAP50,
  76. box_mAP50_95=box_mAP50_95,
  77. mask_p=mask_p,
  78. mask_r=mask_r,
  79. mask_mAP50=mask_mAP50,
  80. mask_mAP50_95=mask_mAP50_95,
  81. f1_score=f1_score,
  82. fitness_score=fitness_score,
  83. owner_id=current_user_id,
  84. )
  85. db.session.add(new_model)
  86. db.session.commit()
  87. # 记录操作
  88. new_operation = handle_operation_success(new_operation, start_time, current_user_id)
  89. current_app.logger.info(f"【上传模型成功】new_model: {new_model}, operator: {current_user}")
  90. return jsonify({
  91. 'operation': new_operation.to_dict(),
  92. 'new_model': new_model.to_dict(),
  93. }), 201
  94. @model_routes.route('/detail/<int:model_id>', methods=['GET'])
  95. @jwt_required()
  96. @login_required
  97. def detail(model_id):
  98. # 获取当前用户身份(使用 access token)
  99. current_user_id = get_jwt_identity()
  100. current_user = User.query.get(current_user_id)
  101. # 获取指定模型文件
  102. model = Model.query.get(model_id)
  103. # 校验字段
  104. validation_checks = [
  105. (not model, f"【获取模型 ID={model_id} 详情失败】该模型不存在", 404),
  106. ]
  107. for condition, message, code in validation_checks:
  108. if condition:
  109. current_app.logger.warning(message + f', operator: {current_user}')
  110. return jsonify({
  111. 'failure_message': message,
  112. }), code
  113. return jsonify({
  114. 'model': model.to_dict(),
  115. }), 200
  116. @model_routes.route('/update/<int:model_id>', methods=['PUT'])
  117. @jwt_required()
  118. @login_required
  119. def update(model_id):
  120. start_time = time.time() # 记录操作开始时间
  121. # 获取请求中的表单数据
  122. disease_category = request.form.get('disease_category')
  123. augmentation = request.form.get('augmentation')
  124. layers = int(request.form.get('layers') or 0)
  125. parameters = int(request.form.get('parameters') or 0)
  126. GFLOPs = float(request.form.get('GFLOPs') or 0.0)
  127. box_p = float(request.form.get('box_p') or 0.0)
  128. box_r = float(request.form.get('box_r') or 0.0)
  129. box_mAP50 = float(request.form.get('box_mAP50') or 0.0)
  130. box_mAP50_95 = float(request.form.get('box_mAP50_95') or 0.0)
  131. mask_p = float(request.form.get('mask_p') or 0.0)
  132. mask_r = float(request.form.get('mask_r') or 0.0)
  133. mask_mAP50 = float(request.form.get('mask_mAP50') or 0.0)
  134. mask_mAP50_95 = float(request.form.get('mask_mAP50_95') or 0.0)
  135. f1_score = float(request.form.get('f1_score') or 0.0)
  136. fitness_score = float(request.form.get('fitness_score') or 0.0)
  137. # 创建一个新的操作记录
  138. new_operation = Operation(
  139. operation_type=OperationType.UPDATE,
  140. description=f"更新模型 ID={model_id} 信息",
  141. ip_address=request.remote_addr,
  142. device_info=request.user_agent.string,
  143. )
  144. # 获取当前用户的身份(使用 access token)
  145. current_user_id = get_jwt_identity()
  146. current_user = User.query.get(current_user_id)
  147. # 获取指定模型文件
  148. updated_model = Model.query.get(model_id)
  149. # 校验字段
  150. validation_checks = [
  151. (not disease_category, f"【更新模型 ID={model_id} 信息失败】病害类别为空", 400),
  152. (not updated_model, f"【更新模型 ID={model_id} 信息失败】该模型不存在", 404),
  153. (updated_model and current_user.role != UserRole.DEVELOPER,
  154. f"【更新模型 ID={model_id} 信息失败】您非开发人员,权限不足", 403),
  155. ]
  156. for condition, message, code in validation_checks:
  157. if condition:
  158. new_operation = handle_operation_failure(new_operation, start_time, message, current_user_id)
  159. current_app.logger.warning(message + f', operator: {current_user}')
  160. return jsonify({
  161. 'operation': new_operation.to_dict(),
  162. }), code
  163. # 更新模型文件信息
  164. updated_model.disease_category = disease_category
  165. updated_model.augmentation = augmentation if augmentation else updated_model.augmentation
  166. updated_model.layers = layers
  167. updated_model.parameters = parameters
  168. updated_model.GFLOPs = GFLOPs
  169. updated_model.box_p = box_p if box_p else updated_model.box_p
  170. updated_model.box_r = box_r if box_r else updated_model.box_r
  171. updated_model.box_mAP50 = box_mAP50 if box_mAP50 else updated_model.box_mAP50
  172. updated_model.box_mAP50_95 = box_mAP50_95 if box_mAP50_95 else updated_model.box_mAP50_95
  173. updated_model.mask_p = mask_p if mask_p else updated_model.mask_p
  174. updated_model.mask_r = mask_r if mask_r else updated_model.mask_r
  175. updated_model.mask_mAP50 = mask_mAP50 if mask_mAP50 else updated_model.mask_mAP50
  176. updated_model.mask_mAP50_95 = mask_mAP50_95 if mask_mAP50_95 else updated_model.mask_mAP50_95
  177. updated_model.f1_score = f1_score if f1_score else updated_model.f1_score
  178. updated_model.fitness_score = fitness_score if fitness_score else updated_model.fitness_score
  179. db.session.commit()
  180. # 记录操作
  181. new_operation = handle_operation_success(new_operation, start_time, current_user_id)
  182. current_app.logger.info(
  183. f"【更新模型 ID={model_id} 信息成功】updated_model: {updated_model}, operator: {current_user}")
  184. return jsonify({
  185. 'operation': new_operation.to_dict(),
  186. 'updated_model': updated_model.to_dict(),
  187. }), 200
  188. @model_routes.route('/delete/<int:model_id>', methods=['DELETE'])
  189. @jwt_required()
  190. @login_required
  191. def delete_model(model_id):
  192. start_time = time.time() # 记录操作开始时间
  193. # 创建一个新的操作记录
  194. new_operation = Operation(
  195. operation_type=OperationType.DELETE,
  196. description=f"删除模型 ID={model_id}",
  197. ip_address=request.remote_addr,
  198. device_info=request.user_agent.string,
  199. )
  200. # 获取当前用户身份(使用 access token)
  201. current_user_id = get_jwt_identity()
  202. current_user = User.query.get(current_user_id)
  203. # 获取指定模型
  204. deleted_model = Model.query.get(model_id)
  205. # 校验字段
  206. validation_checks = [
  207. (not deleted_model, f"【删除模型 ID={model_id} 失败】该模型不存在", 404),
  208. (deleted_model and current_user.role != UserRole.DEVELOPER,
  209. f"【删除模型 ID={model_id} 失败】您非开发人员,权限不足", 403),
  210. (deleted_model and Detection.query.filter_by(model_id=model_id).first(),
  211. f"【删除模型 ID={model_id} 失败】该模型存在关联的检测分割记录,无法删除", 400),
  212. ]
  213. for condition, message, code in validation_checks:
  214. if condition:
  215. new_operation = handle_operation_failure(new_operation, start_time, message, current_user_id)
  216. current_app.logger.warning(message + f', operator: {current_user}')
  217. return jsonify({
  218. 'operation': new_operation.to_dict(),
  219. }), code
  220. # 删除实际文件
  221. file_abs_path = os.path.join(current_app.root_path, deleted_model.model_path)
  222. delete_file(file_abs_path)
  223. # 删除数据库记录
  224. db.session.delete(deleted_model)
  225. db.session.commit()
  226. # 记录操作
  227. new_operation = handle_operation_success(new_operation, start_time, current_user_id)
  228. current_app.logger.info(
  229. f"【删除模型 ID={model_id} 成功】deleted_model: {deleted_model}, operator: {current_user}")
  230. return jsonify({
  231. 'operation': new_operation.to_dict(),
  232. 'deleted_model': deleted_model.to_dict(),
  233. }), 200
  234. @model_routes.route('/models/all', methods=['GET'])
  235. @jwt_required()
  236. @login_required
  237. def all_models():
  238. # 获取分页参数(从请求中获取,默认为第 1 页,每页 5 条记录)
  239. default_page = request.args.get('page', 1, type=int)
  240. default_per_page = request.args.get('per_page', 5, type=int)
  241. page, per_page = get_pagination_params(default_page, default_per_page)
  242. # 获取当前用户身份(使用 access token)
  243. current_user_id = get_jwt_identity()
  244. current_user = User.query.get(current_user_id)
  245. # 获取所有模型文件
  246. query = (
  247. Model.query
  248. .join(User, Model.owner_id == User.user_id)
  249. .add_columns(User.username.label('owner_username'))
  250. .order_by(Model.model_id.asc())
  251. )
  252. page, models_total, pages = adjust_page_if_needed(query, page, per_page)
  253. paginated = query.paginate(page=page, per_page=per_page, error_out=False)
  254. models = []
  255. for model, owner_username in paginated.items:
  256. model_dict = model.to_dict()
  257. model_dict.update({'owner_username': owner_username})
  258. models.append(model_dict)
  259. current_app.logger.info(
  260. f"【获取所有模型成功】total: {models_total}, per_page: {per_page}, page: {page}, pages: {pages}, models: {models}, operator: {current_user}")
  261. return jsonify({
  262. 'models': models,
  263. 'total': models_total,
  264. 'per_page': per_page,
  265. 'page': page,
  266. 'pages': pages,
  267. }), 200
  268. @model_routes.route('/statistics', methods=['GET'])
  269. @jwt_required()
  270. @login_required
  271. def statistics():
  272. # 查询模型总数
  273. total_models = Model.query.count()
  274. # 构建统计数据
  275. models_statistics = {
  276. 'total': total_models,
  277. }
  278. return jsonify({
  279. "models_statistics": models_statistics,
  280. }), 200