| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202 |
- import os
- import time
- from flask import request, jsonify, current_app
- from flask_jwt_extended import jwt_required, get_jwt_identity
- from app.constants import OperationType, UserRole
- from app.decorators import login_required
- from app.models import Operation, User
- from app.routes import training_routes
- from app.services import training_manager as tm
- from app.utils import handle_operation_failure, handle_operation_success
- def _jwt_user():
- uid = get_jwt_identity()
- try:
- pk = int(uid) if uid is not None else None
- except (TypeError, ValueError):
- pk = None
- return User.query.get(pk) if pk is not None else None
- def _role_is_trainer(role) -> bool:
- if role is None:
- return False
- if isinstance(role, UserRole):
- return role in (UserRole.DEVELOPER, UserRole.ADMIN)
- name = str(getattr(role, 'name', role)).upper()
- value = str(getattr(role, 'value', role)).lower()
- return name in ('ADMIN', 'DEVELOPER') or value in ('admin', 'developer')
- def _require_trainer(user):
- if not _role_is_trainer(getattr(user, 'role', None)):
- return jsonify({'message': '仅管理员或开发人员可访问模型训练功能'}), 403
- return None
- @training_routes.route('/base-models', methods=['GET'])
- @jwt_required()
- @login_required
- def base_models():
- user = _jwt_user()
- denied = _require_trainer(user)
- if denied:
- return denied
- return jsonify({'base_models': tm.list_base_models()}), 200
- @training_routes.route('/datasets', methods=['GET'])
- @jwt_required()
- @login_required
- def list_datasets():
- user = _jwt_user()
- denied = _require_trainer(user)
- if denied:
- return denied
- return jsonify({'datasets': tm.list_datasets(current_app)}), 200
- @training_routes.route('/datasets/upload', methods=['POST'])
- @jwt_required()
- @login_required
- def upload_dataset():
- start_time = time.time()
- current_user = _jwt_user()
- current_user_id = current_user.user_id if current_user else None
- denied = _require_trainer(current_user)
- if denied:
- return denied
- new_operation = Operation(
- operation_type=OperationType.CREATE,
- description='上传训练数据集',
- ip_address=request.remote_addr,
- device_info=request.user_agent.string,
- )
- zip_file = request.files.get('dataset_zip')
- display_name = (request.form.get('name') or '').strip()
- if not zip_file:
- msg = '【上传数据集失败】未选择 ZIP 文件'
- new_operation = handle_operation_failure(new_operation, start_time, msg, current_user_id)
- return jsonify({'operation': new_operation.to_dict(), 'message': msg}), 400
- max_size = current_app.config.get('MAX_DATASET_ZIP_SIZE', 500 * 1024 * 1024)
- zip_file.seek(0, os.SEEK_END)
- size = zip_file.tell()
- zip_file.seek(0)
- if size > max_size:
- msg = '【上传数据集失败】文件超过 500MB 限制'
- new_operation = handle_operation_failure(new_operation, start_time, msg, current_user_id)
- return jsonify({'operation': new_operation.to_dict(), 'message': msg}), 400
- try:
- item = tm.upload_dataset(
- current_app,
- zip_file,
- display_name,
- current_user_id,
- current_user.username,
- )
- except ValueError as exc:
- msg = f'【上传数据集失败】{exc}'
- new_operation = handle_operation_failure(new_operation, start_time, msg, current_user_id)
- return jsonify({'operation': new_operation.to_dict(), 'message': msg}), 400
- new_operation = handle_operation_success(new_operation, start_time, current_user_id)
- return jsonify({
- 'operation': new_operation.to_dict(),
- 'dataset': item,
- }), 201
- @training_routes.route('/jobs', methods=['GET'])
- @jwt_required()
- @login_required
- def list_jobs():
- user = _jwt_user()
- denied = _require_trainer(user)
- if denied:
- return denied
- return jsonify({'jobs': tm.list_jobs(current_app)}), 200
- @training_routes.route('/jobs/<job_id>/resume', methods=['POST'])
- @jwt_required()
- @login_required
- def resume_job_route(job_id):
- user = _jwt_user()
- denied = _require_trainer(user)
- if denied:
- return denied
- try:
- job = tm.resume_job(current_app, job_id)
- except ValueError as exc:
- return jsonify({'message': str(exc)}), 400
- return jsonify({'job': job}), 200
- @training_routes.route('/jobs/<job_id>', methods=['GET'])
- @jwt_required()
- @login_required
- def job_detail(job_id):
- user = _jwt_user()
- denied = _require_trainer(user)
- if denied:
- return denied
- job = tm.get_job(current_app, job_id)
- if not job:
- return jsonify({'message': '训练任务不存在'}), 404
- return jsonify({'job': job}), 200
- @training_routes.route('/jobs', methods=['POST'])
- @jwt_required()
- @login_required
- def start_job():
- start_time = time.time()
- current_user = _jwt_user()
- current_user_id = current_user.user_id if current_user else None
- denied = _require_trainer(current_user)
- if denied:
- return denied
- data = request.get_json() or {}
- new_operation = Operation(
- operation_type=OperationType.EXECUTE,
- description='启动 YOLO 模型训练',
- ip_address=request.remote_addr,
- device_info=request.user_agent.string,
- )
- try:
- job = tm.start_job(
- current_app,
- owner_id=current_user_id,
- owner_username=current_user.username,
- dataset_id=data.get('dataset_id'),
- base_model=data.get('base_model', 'yolov8n-seg.pt'),
- disease_category=data.get('disease_category', ''),
- output_model_name=data.get('output_model_name', ''),
- epochs=data.get('epochs', 50),
- imgsz=data.get('imgsz', 640),
- batch=data.get('batch', 8),
- augmentation=data.get('augmentation', '随机点+颜色扭曲+高斯模糊'),
- register_to_library=data.get('register_to_library', True),
- )
- except ValueError as exc:
- msg = f'【启动训练失败】{exc}'
- new_operation = handle_operation_failure(new_operation, start_time, msg, current_user_id)
- return jsonify({'operation': new_operation.to_dict(), 'message': msg}), 400
- new_operation = handle_operation_success(new_operation, start_time, current_user_id)
- return jsonify({
- 'operation': new_operation.to_dict(),
- 'job': job,
- }), 201
|