import os import time from flask import request, jsonify, current_app from flask_jwt_extended import jwt_required, get_jwt_identity from app.constants import OperationType, UserRole from app.decorators import login_required from app.models import Operation, User from app.routes import training_routes from app.services import training_manager as tm from app.utils import handle_operation_failure, handle_operation_success def _jwt_user(): uid = get_jwt_identity() try: pk = int(uid) if uid is not None else None except (TypeError, ValueError): pk = None return User.query.get(pk) if pk is not None else None def _role_is_trainer(role) -> bool: if role is None: return False if isinstance(role, UserRole): return role in (UserRole.DEVELOPER, UserRole.ADMIN) name = str(getattr(role, 'name', role)).upper() value = str(getattr(role, 'value', role)).lower() return name in ('ADMIN', 'DEVELOPER') or value in ('admin', 'developer') def _require_trainer(user): if not _role_is_trainer(getattr(user, 'role', None)): return jsonify({'message': '仅管理员或开发人员可访问模型训练功能'}), 403 return None @training_routes.route('/base-models', methods=['GET']) @jwt_required() @login_required def base_models(): user = _jwt_user() denied = _require_trainer(user) if denied: return denied return jsonify({'base_models': tm.list_base_models()}), 200 @training_routes.route('/datasets', methods=['GET']) @jwt_required() @login_required def list_datasets(): user = _jwt_user() denied = _require_trainer(user) if denied: return denied return jsonify({'datasets': tm.list_datasets(current_app)}), 200 @training_routes.route('/datasets/upload', methods=['POST']) @jwt_required() @login_required def upload_dataset(): start_time = time.time() current_user = _jwt_user() current_user_id = current_user.user_id if current_user else None denied = _require_trainer(current_user) if denied: return denied new_operation = Operation( operation_type=OperationType.CREATE, description='上传训练数据集', ip_address=request.remote_addr, device_info=request.user_agent.string, ) zip_file = request.files.get('dataset_zip') display_name = (request.form.get('name') or '').strip() if not zip_file: msg = '【上传数据集失败】未选择 ZIP 文件' new_operation = handle_operation_failure(new_operation, start_time, msg, current_user_id) return jsonify({'operation': new_operation.to_dict(), 'message': msg}), 400 max_size = current_app.config.get('MAX_DATASET_ZIP_SIZE', 500 * 1024 * 1024) zip_file.seek(0, os.SEEK_END) size = zip_file.tell() zip_file.seek(0) if size > max_size: msg = '【上传数据集失败】文件超过 500MB 限制' new_operation = handle_operation_failure(new_operation, start_time, msg, current_user_id) return jsonify({'operation': new_operation.to_dict(), 'message': msg}), 400 try: item = tm.upload_dataset( current_app, zip_file, display_name, current_user_id, current_user.username, ) except ValueError as exc: msg = f'【上传数据集失败】{exc}' new_operation = handle_operation_failure(new_operation, start_time, msg, current_user_id) return jsonify({'operation': new_operation.to_dict(), 'message': msg}), 400 new_operation = handle_operation_success(new_operation, start_time, current_user_id) return jsonify({ 'operation': new_operation.to_dict(), 'dataset': item, }), 201 @training_routes.route('/jobs', methods=['GET']) @jwt_required() @login_required def list_jobs(): user = _jwt_user() denied = _require_trainer(user) if denied: return denied return jsonify({'jobs': tm.list_jobs(current_app)}), 200 @training_routes.route('/jobs//resume', methods=['POST']) @jwt_required() @login_required def resume_job_route(job_id): user = _jwt_user() denied = _require_trainer(user) if denied: return denied try: job = tm.resume_job(current_app, job_id) except ValueError as exc: return jsonify({'message': str(exc)}), 400 return jsonify({'job': job}), 200 @training_routes.route('/jobs/', methods=['GET']) @jwt_required() @login_required def job_detail(job_id): user = _jwt_user() denied = _require_trainer(user) if denied: return denied job = tm.get_job(current_app, job_id) if not job: return jsonify({'message': '训练任务不存在'}), 404 return jsonify({'job': job}), 200 @training_routes.route('/jobs', methods=['POST']) @jwt_required() @login_required def start_job(): start_time = time.time() current_user = _jwt_user() current_user_id = current_user.user_id if current_user else None denied = _require_trainer(current_user) if denied: return denied data = request.get_json() or {} new_operation = Operation( operation_type=OperationType.EXECUTE, description='启动 YOLO 模型训练', ip_address=request.remote_addr, device_info=request.user_agent.string, ) try: job = tm.start_job( current_app, owner_id=current_user_id, owner_username=current_user.username, dataset_id=data.get('dataset_id'), base_model=data.get('base_model', 'yolov8n-seg.pt'), disease_category=data.get('disease_category', ''), output_model_name=data.get('output_model_name', ''), epochs=data.get('epochs', 50), imgsz=data.get('imgsz', 640), batch=data.get('batch', 8), augmentation=data.get('augmentation', '随机点+颜色扭曲+高斯模糊'), register_to_library=data.get('register_to_library', True), ) except ValueError as exc: msg = f'【启动训练失败】{exc}' new_operation = handle_operation_failure(new_operation, start_time, msg, current_user_id) return jsonify({'operation': new_operation.to_dict(), 'message': msg}), 400 new_operation = handle_operation_success(new_operation, start_time, current_user_id) return jsonify({ 'operation': new_operation.to_dict(), 'job': job, }), 201