training_route.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import os
  2. import time
  3. from flask import request, jsonify, current_app
  4. from flask_jwt_extended import jwt_required, get_jwt_identity
  5. from app.constants import OperationType, UserRole
  6. from app.decorators import login_required
  7. from app.models import Operation, User
  8. from app.routes import training_routes
  9. from app.services import training_manager as tm
  10. from app.utils import handle_operation_failure, handle_operation_success
  11. def _jwt_user():
  12. uid = get_jwt_identity()
  13. try:
  14. pk = int(uid) if uid is not None else None
  15. except (TypeError, ValueError):
  16. pk = None
  17. return User.query.get(pk) if pk is not None else None
  18. def _role_is_trainer(role) -> bool:
  19. if role is None:
  20. return False
  21. if isinstance(role, UserRole):
  22. return role in (UserRole.DEVELOPER, UserRole.ADMIN)
  23. name = str(getattr(role, 'name', role)).upper()
  24. value = str(getattr(role, 'value', role)).lower()
  25. return name in ('ADMIN', 'DEVELOPER') or value in ('admin', 'developer')
  26. def _require_trainer(user):
  27. if not _role_is_trainer(getattr(user, 'role', None)):
  28. return jsonify({'message': '仅管理员或开发人员可访问模型训练功能'}), 403
  29. return None
  30. @training_routes.route('/base-models', methods=['GET'])
  31. @jwt_required()
  32. @login_required
  33. def base_models():
  34. user = _jwt_user()
  35. denied = _require_trainer(user)
  36. if denied:
  37. return denied
  38. return jsonify({'base_models': tm.list_base_models()}), 200
  39. @training_routes.route('/datasets', methods=['GET'])
  40. @jwt_required()
  41. @login_required
  42. def list_datasets():
  43. user = _jwt_user()
  44. denied = _require_trainer(user)
  45. if denied:
  46. return denied
  47. return jsonify({'datasets': tm.list_datasets(current_app)}), 200
  48. @training_routes.route('/datasets/upload', methods=['POST'])
  49. @jwt_required()
  50. @login_required
  51. def upload_dataset():
  52. start_time = time.time()
  53. current_user = _jwt_user()
  54. current_user_id = current_user.user_id if current_user else None
  55. denied = _require_trainer(current_user)
  56. if denied:
  57. return denied
  58. new_operation = Operation(
  59. operation_type=OperationType.CREATE,
  60. description='上传训练数据集',
  61. ip_address=request.remote_addr,
  62. device_info=request.user_agent.string,
  63. )
  64. zip_file = request.files.get('dataset_zip')
  65. display_name = (request.form.get('name') or '').strip()
  66. if not zip_file:
  67. msg = '【上传数据集失败】未选择 ZIP 文件'
  68. new_operation = handle_operation_failure(new_operation, start_time, msg, current_user_id)
  69. return jsonify({'operation': new_operation.to_dict(), 'message': msg}), 400
  70. max_size = current_app.config.get('MAX_DATASET_ZIP_SIZE', 500 * 1024 * 1024)
  71. zip_file.seek(0, os.SEEK_END)
  72. size = zip_file.tell()
  73. zip_file.seek(0)
  74. if size > max_size:
  75. msg = '【上传数据集失败】文件超过 500MB 限制'
  76. new_operation = handle_operation_failure(new_operation, start_time, msg, current_user_id)
  77. return jsonify({'operation': new_operation.to_dict(), 'message': msg}), 400
  78. try:
  79. item = tm.upload_dataset(
  80. current_app,
  81. zip_file,
  82. display_name,
  83. current_user_id,
  84. current_user.username,
  85. )
  86. except ValueError as exc:
  87. msg = f'【上传数据集失败】{exc}'
  88. new_operation = handle_operation_failure(new_operation, start_time, msg, current_user_id)
  89. return jsonify({'operation': new_operation.to_dict(), 'message': msg}), 400
  90. new_operation = handle_operation_success(new_operation, start_time, current_user_id)
  91. return jsonify({
  92. 'operation': new_operation.to_dict(),
  93. 'dataset': item,
  94. }), 201
  95. @training_routes.route('/jobs', methods=['GET'])
  96. @jwt_required()
  97. @login_required
  98. def list_jobs():
  99. user = _jwt_user()
  100. denied = _require_trainer(user)
  101. if denied:
  102. return denied
  103. return jsonify({'jobs': tm.list_jobs(current_app)}), 200
  104. @training_routes.route('/jobs/<job_id>/resume', methods=['POST'])
  105. @jwt_required()
  106. @login_required
  107. def resume_job_route(job_id):
  108. user = _jwt_user()
  109. denied = _require_trainer(user)
  110. if denied:
  111. return denied
  112. try:
  113. job = tm.resume_job(current_app, job_id)
  114. except ValueError as exc:
  115. return jsonify({'message': str(exc)}), 400
  116. return jsonify({'job': job}), 200
  117. @training_routes.route('/jobs/<job_id>', methods=['GET'])
  118. @jwt_required()
  119. @login_required
  120. def job_detail(job_id):
  121. user = _jwt_user()
  122. denied = _require_trainer(user)
  123. if denied:
  124. return denied
  125. job = tm.get_job(current_app, job_id)
  126. if not job:
  127. return jsonify({'message': '训练任务不存在'}), 404
  128. return jsonify({'job': job}), 200
  129. @training_routes.route('/jobs', methods=['POST'])
  130. @jwt_required()
  131. @login_required
  132. def start_job():
  133. start_time = time.time()
  134. current_user = _jwt_user()
  135. current_user_id = current_user.user_id if current_user else None
  136. denied = _require_trainer(current_user)
  137. if denied:
  138. return denied
  139. data = request.get_json() or {}
  140. new_operation = Operation(
  141. operation_type=OperationType.EXECUTE,
  142. description='启动 YOLO 模型训练',
  143. ip_address=request.remote_addr,
  144. device_info=request.user_agent.string,
  145. )
  146. try:
  147. job = tm.start_job(
  148. current_app,
  149. owner_id=current_user_id,
  150. owner_username=current_user.username,
  151. dataset_id=data.get('dataset_id'),
  152. base_model=data.get('base_model', 'yolov8n-seg.pt'),
  153. disease_category=data.get('disease_category', ''),
  154. output_model_name=data.get('output_model_name', ''),
  155. epochs=data.get('epochs', 50),
  156. imgsz=data.get('imgsz', 640),
  157. batch=data.get('batch', 8),
  158. augmentation=data.get('augmentation', '随机点+颜色扭曲+高斯模糊'),
  159. register_to_library=data.get('register_to_library', True),
  160. )
  161. except ValueError as exc:
  162. msg = f'【启动训练失败】{exc}'
  163. new_operation = handle_operation_failure(new_operation, start_time, msg, current_user_id)
  164. return jsonify({'operation': new_operation.to_dict(), 'message': msg}), 400
  165. new_operation = handle_operation_success(new_operation, start_time, current_user_id)
  166. return jsonify({
  167. 'operation': new_operation.to_dict(),
  168. 'job': job,
  169. }), 201