瀏覽代碼

初始化项目

eric.w 1 周之前
父節點
當前提交
f3147f96e0

+ 6 - 0
data/datasets/dataset_1/dataset.yaml

@@ -0,0 +1,6 @@
+path: C:\pro\islandfox\data\datasets\dataset_1
+train: train/images
+val: val/images
+nc: 1
+
+names: ["text"]

二進制
data/datasets/dataset_1/train/images/2.jpg


二進制
data/datasets/dataset_1/train/labels.cache


+ 0 - 0
data/datasets/dataset_1/train/labels/2.txt


二進制
data/datasets/dataset_1/val/images/2.jpg


二進制
data/datasets/dataset_1/val/labels.cache


+ 0 - 0
data/datasets/dataset_1/val/labels/2.txt


二進制
data/datasets/dataset_2.zip


二進制
rtdetr-l.pt


+ 110 - 0
runs/detect/models/model_1/args.yaml

@@ -0,0 +1,110 @@
+task: detect
+mode: train
+model: yolov8n.pt
+data: C:\pro\islandfox\data\datasets\dataset_1\dataset.yaml
+epochs: 50
+time: null
+patience: 100
+batch: 16
+imgsz: 640
+save: true
+save_period: -1
+cache: false
+device: cpu
+workers: 8
+project: models
+name: model_1
+exist_ok: true
+pretrained: true
+optimizer: auto
+verbose: true
+seed: 0
+deterministic: true
+single_cls: false
+rect: false
+cos_lr: false
+close_mosaic: 10
+resume: false
+amp: true
+fraction: 1.0
+profile: false
+freeze: null
+multi_scale: 0.0
+compile: false
+overlap_mask: true
+mask_ratio: 4
+dropout: 0.0
+val: true
+split: val
+save_json: false
+conf: null
+iou: 0.7
+max_det: 300
+half: false
+dnn: false
+plots: true
+end2end: null
+source: null
+vid_stride: 1
+stream_buffer: false
+visualize: false
+augment: false
+agnostic_nms: false
+classes: null
+retina_masks: false
+embed: null
+show: false
+save_frames: false
+save_txt: false
+save_conf: false
+save_crop: false
+show_labels: true
+show_conf: true
+show_boxes: true
+line_width: null
+format: torchscript
+keras: false
+optimize: false
+int8: false
+dynamic: false
+simplify: true
+opset: null
+workspace: null
+nms: false
+lr0: 0.01
+lrf: 0.01
+momentum: 0.937
+weight_decay: 0.0005
+warmup_epochs: 3.0
+warmup_momentum: 0.8
+warmup_bias_lr: 0.1
+box: 7.5
+cls: 0.5
+cls_pw: 0.0
+dfl: 1.5
+pose: 12.0
+kobj: 1.0
+rle: 1.0
+angle: 1.0
+nbs: 64
+hsv_h: 0.015
+hsv_s: 0.7
+hsv_v: 0.4
+degrees: 0.0
+translate: 0.1
+scale: 0.5
+shear: 0.0
+perspective: 0.0
+flipud: 0.0
+fliplr: 0.5
+bgr: 0.0
+mosaic: 1.0
+mixup: 0.0
+cutmix: 0.0
+copy_paste: 0.0
+copy_paste_mode: flip
+auto_augment: randaugment
+erasing: 0.4
+cfg: null
+tracker: botsort.yaml
+save_dir: C:\pro\islandfox\runs\detect\models\model_1

+ 7 - 0
runs/detect/models/model_1/results.csv

@@ -0,0 +1,7 @@
+epoch,time,train/box_loss,train/cls_loss,train/dfl_loss,metrics/precision(B),metrics/recall(B),metrics/mAP50(B),metrics/mAP50-95(B),val/box_loss,val/cls_loss,val/dfl_loss,lr/pg0,lr/pg1,lr/pg2
+1,0.821639,0,7.51482,0,0,0,0,0,0,8.08585,0,0,0,0
+2,1.60716,0,7.53179,0,0,0,0,0,0,8.09174,0,1.9604e-05,1.9604e-05,1.9604e-05
+3,2.34892,0,7.51607,0,0,0,0,0,0,8.09925,0,3.8416e-05,3.8416e-05,3.8416e-05
+4,3.63407,0,7.47746,0,0,0,0,0,0,8.10813,0,5.6436e-05,5.6436e-05,5.6436e-05
+5,4.3287,0,7.44589,0,0,0,0,0,0,8.11619,0,7.3664e-05,7.3664e-05,7.3664e-05
+6,5.08103,0,7.37537,0,0,0,0,0,0,8.12396,0,9.01e-05,9.01e-05,9.01e-05

二進制
runs/detect/models/model_1/train_batch0.jpg


二進制
runs/detect/models/model_1/train_batch1.jpg


二進制
runs/detect/models/model_1/train_batch2.jpg


二進制
runs/detect/models/model_1/weights/best.pt


二進制
runs/detect/models/model_1/weights/last.pt


+ 5 - 0
src/backend/models/__init__.py

@@ -103,8 +103,13 @@ class ModelStatus(PyEnum):
     
 class ModelArchitecture(PyEnum):
     YOLO_V5 = 'yolo_v5'
+    YOLO_V6 = 'yolo_v6'
     YOLO_V8 = 'yolo_v8'
     YOLO_V9 = 'yolo_v9'
+    YOLO_V10 = 'yolo_v10'
+    YOLO_V11 = 'yolo_v11'
+    RT_DETR = 'rt_detr'
+    YOLO_WORLD = 'yolo_world'
 
 class Model(Base):
     __tablename__ = "models"

+ 156 - 53
src/backend/services/model.py

@@ -25,6 +25,70 @@ import sys
 # 设置日志
 logger = logging.getLogger(__name__)
 
+# Ultralytics 架构与预训练权重映射
+ULTRALYTICS_ARCHITECTURES = {
+    ModelArchitecture.YOLO_V6: "yolov3u.pt",
+    ModelArchitecture.YOLO_V8: "yolov8n.pt",
+    ModelArchitecture.YOLO_V9: "yolov9t.pt",
+    ModelArchitecture.YOLO_V10: "yolov10n.pt",
+    ModelArchitecture.YOLO_V11: "yolo11n.pt",
+    ModelArchitecture.RT_DETR: "rtdetr-l.pt",
+    ModelArchitecture.YOLO_WORLD: "yolov8s-worldv2.pt",
+}
+
+DEFAULT_TRAINING_PARAMS = {
+    'epochs': 50,
+    'batch_size': 16,
+    'img_size': 640,
+    'conf_thres': 0.25,
+    'iou_thres': 0.45,
+    'lr0': 0.01,
+    'lrf': 0.01,
+    'patience': 50,
+    'optimizer': 'SGD',
+    'weight_decay': 0.0005,
+    'momentum': 0.937,
+    'warmup_epochs': 3,
+    'workers': 4,
+    'device': 'cpu',
+    'mosaic': 1.0,
+    'cache': False,
+    'save_period': -1,
+}
+
+
+def merge_training_params(parameters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
+    merged = DEFAULT_TRAINING_PARAMS.copy()
+    if parameters:
+        merged.update(parameters)
+    return merged
+
+
+def parameters_need_upgrade(parameters: Optional[Dict[str, Any]]) -> bool:
+    if not parameters:
+        return True
+    return any(key not in parameters for key in DEFAULT_TRAINING_PARAMS)
+
+
+def build_ultralytics_train_args(parameters: dict) -> List[str]:
+    params = merge_training_params(parameters)
+    args = [
+        f"lr0={params['lr0']}",
+        f"lrf={params['lrf']}",
+        f"patience={params['patience']}",
+        f"optimizer={params['optimizer']}",
+        f"weight_decay={params['weight_decay']}",
+        f"momentum={params['momentum']}",
+        f"warmup_epochs={params['warmup_epochs']}",
+        f"workers={params['workers']}",
+        f"mosaic={params['mosaic']}",
+        f"save_period={params['save_period']}",
+        f"device={params['device']}",
+    ]
+    if params.get('cache'):
+        args.append("cache=True")
+    return args
+
 class YOLOModel:
     """YOLO模型包装类,用于在cv_operation中使用"""
     
@@ -54,15 +118,10 @@ class YOLOModel:
                 print(yolov5_path)
                 # 使用YOLOv5加载模型
                 self.model = torch.hub.load(yolov5_path, 'custom', path=self.model_path, source="local")
-            elif self.architecture == ModelArchitecture.YOLO_V8:
-                # 使用YOLOv8加载模型
+            elif self.architecture in ULTRALYTICS_ARCHITECTURES:
                 from ultralytics import YOLO
                 print(self.model_path)
                 self.model = YOLO(self.model_path)
-            elif self.architecture == ModelArchitecture.YOLO_V9:
-                # 使用YOLOv9加载模型
-                from ultralytics import YOLO
-                self.model = YOLO(self.model_path)
             else:
                 raise ValueError(f"不支持的模型架构: {self.architecture}")
                 
@@ -186,17 +245,47 @@ class ModelService:
             logger.error(f"模型加载失败: {str(e)}")
             raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}")
             
+    def _apply_normalized_parameters(self, model: Model, persist: bool = False) -> Model:
+        """将旧模型参数与最新默认参数合并,必要时写回数据库"""
+        merged = merge_training_params(model.parameters)
+        should_persist = persist and parameters_need_upgrade(model.parameters)
+        model.parameters = merged
+        if should_persist:
+            model.updated_at = datetime.utcnow()
+            self.db.commit()
+            self.db.refresh(model)
+        return model
+
     def get_models(self) -> List[Model]:
         """获取所有模型"""
-        return self.db.query(Model).all()
+        return self._normalize_models_list(self.db.query(Model).all())
         
     def get_models_by_dataset(self, dataset_id: int) -> List[Model]:
         """获取特定数据集的所有模型"""
-        return self.db.query(Model).filter(Model.dataset_id == dataset_id).all()
+        return self._normalize_models_list(
+            self.db.query(Model).filter(Model.dataset_id == dataset_id).all()
+        )
+
+    def _normalize_models_list(self, models: List[Model]) -> List[Model]:
+        """批量兼容旧模型参数,缺失字段时一次性写回数据库"""
+        upgraded = False
+        for model in models:
+            if parameters_need_upgrade(model.parameters):
+                model.parameters = merge_training_params(model.parameters)
+                model.updated_at = datetime.utcnow()
+                upgraded = True
+            else:
+                model.parameters = merge_training_params(model.parameters)
+        if upgraded:
+            self.db.commit()
+        return models
 
     def get_model(self, model_id: int) -> Optional[Model]:
         """获取指定模型"""
-        return self.db.query(Model).filter(Model.id == model_id).first()
+        model = self.db.query(Model).filter(Model.id == model_id).first()
+        if model:
+            return self._apply_normalized_parameters(model, persist=True)
+        return None
 
     def create_model(self, 
                     name: str, 
@@ -210,21 +299,7 @@ class ModelService:
             if not dataset:
                 raise HTTPException(status_code=404, detail="Dataset not found")
                 
-            # 验证参数
-            if parameters is None:
-                parameters = {}
-                
-            # 根据架构类型设置默认参数
-            if 'epochs' not in parameters:
-                parameters['epochs'] = 50
-            if 'batch_size' not in parameters:
-                parameters['batch_size'] = 16
-            if 'img_size' not in parameters:
-                parameters['img_size'] = 640
-            if 'conf_thres' not in parameters:
-                parameters['conf_thres'] = 0.25
-            if 'iou_thres' not in parameters:
-                parameters['iou_thres'] = 0.45
+            parameters = merge_training_params(parameters)
                 
             model = Model(
                 name=name,
@@ -246,9 +321,15 @@ class ModelService:
 
     def update_model(self, model_id: int, **kwargs) -> Model:
         """更新模型信息"""
-        model = self.get_model(model_id)
+        model = self.db.query(Model).filter(Model.id == model_id).first()
         if not model:
             raise HTTPException(status_code=404, detail="Model not found")
+
+        if kwargs.get('parameters') is not None:
+            kwargs['parameters'] = merge_training_params({
+                **merge_training_params(model.parameters),
+                **kwargs['parameters'],
+            })
             
         # 更新提供的字段
         for key, value in kwargs.items():
@@ -296,7 +377,8 @@ class ModelService:
         if model.status == ModelStatus.TRAINING:
             raise HTTPException(status_code=400, detail="Model is already training")
             
-        # 更新模型状态为训练中
+        # 兼容旧模型:训练前补齐并持久化完整参数
+        model.parameters = merge_training_params(model.parameters)
         model.status = ModelStatus.TRAINING
         model.updated_at = datetime.utcnow()
         
@@ -368,10 +450,13 @@ class ModelService:
             try:
                 # 获取模型训练参数
                 logger.info(f"开始训练模型 ID:{model_id}, 架构:{model.architecture}")
-                if model.architecture in [ModelArchitecture.YOLO_V8, ModelArchitecture.YOLO_V9]:
-                    success = self._run_yolo_training(model_id, dataset_dir, model.parameters, "YOLOv8")
+                if model.architecture in ULTRALYTICS_ARCHITECTURES:
+                    success = self._run_yolo_training(
+                        model_id, dataset_dir, model.parameters,
+                        model.architecture, ULTRALYTICS_ARCHITECTURES[model.architecture]
+                    )
                 elif model.architecture == ModelArchitecture.YOLO_V5:
-                    success = self._run_yolo_training(model_id, dataset_dir, model.parameters, "YOLOv5")
+                    success = self._run_yolo_training(model_id, dataset_dir, model.parameters, model.architecture, "yolov5s.pt")
                 else:
                     logger.error(f"不支持的模型架构: {model.architecture}")
                     self._update_training_status(model_id, ModelStatus.FAILED, 
@@ -397,14 +482,15 @@ class ModelService:
                 except:
                     pass
 
-    def _run_yolo_training(self, model_id: int, dataset_dir: str, parameters: dict, model_type: str) -> bool:
-        """运行YOLO训练任务,统一处理YOLOv5和YOLOv8的训练逻辑
+    def _run_yolo_training(self, model_id: int, dataset_dir: str, parameters: dict, architecture: ModelArchitecture, pretrained_weights: str) -> bool:
+        """运行YOLO训练任务,统一处理各架构的训练逻辑
         
         Args:
             model_id: 模型ID
             dataset_dir: 数据集目录
             parameters: 训练参数
-            model_type: "YOLOv5" 或 "YOLOv8"
+            architecture: 模型架构
+            pretrained_weights: 预训练权重文件名
             
         Returns:
             bool: 训练是否成功
@@ -415,15 +501,19 @@ class ModelService:
             
             # 添加初始日志
             timestamp = time.strftime("%H:%M:%S", time.localtime())
-            TRAINING_LOGS[model_id].append(f"[{timestamp}] [系统] 开始{model_type}训练...")
+            TRAINING_LOGS[model_id].append(f"[{timestamp}] [系统] 开始{architecture.value}训练...")
             
-            # 提取训练参数
-            epochs = parameters.get('epochs', 20)
-            batch_size = parameters.get('batch_size', 8)
-            img_size = parameters.get('img_size', 640)
+            params = merge_training_params(parameters)
+            epochs = params['epochs']
+            batch_size = params['batch_size']
+            img_size = params['img_size']
+            device = params['device']
             
             # 记录训练参数
-            TRAINING_LOGS[model_id].append(f"[{timestamp}] [系统] 训练参数: epochs={epochs}, batch_size={batch_size}, img_size={img_size}")
+            TRAINING_LOGS[model_id].append(
+                f"[{timestamp}] [系统] 训练参数: epochs={epochs}, batch_size={batch_size}, "
+                f"img_size={img_size}, lr0={params['lr0']}, optimizer={params['optimizer']}, device={device}"
+            )
             
             # 设置训练基本参数
             data_yaml_path = os.path.join(dataset_dir, "dataset.yaml")
@@ -434,22 +524,22 @@ class ModelService:
             os.makedirs(os.path.join(project_dir, name), exist_ok=True)
             
             # 准备训练命令 - 不同模型有不同命令格式
-            if model_type == "YOLOv8":
+            if architecture in ULTRALYTICS_ARCHITECTURES:
                 cmd = [
                     "yolo",
                     "task=detect",
                     "mode=train",
                     f"data={os.path.abspath(data_yaml_path)}",
-                    "model=yolov8n.pt", 
+                    f"model={pretrained_weights}",
                     f"epochs={epochs}",
                     f"batch={batch_size}",
                     f"imgsz={img_size}",
                     f"project={project_dir}",
                     f"name={name}",
                     "exist_ok=True",
-                    "device=cpu"
+                    *build_ultralytics_train_args(params),
                 ]
-            elif model_type == "YOLOv5":
+            elif architecture == ModelArchitecture.YOLO_V5:
                 # 确保YOLOv5代码存在
                 yolov5_dir = "yolov5"
                 if not os.path.exists(yolov5_dir):
@@ -469,15 +559,25 @@ class ModelService:
                     "--project", project_dir,
                     "--name", name,
                     "--exist-ok",
-                    "--device", "cpu"
+                    "--device", str(device),
+                    "--lr0", str(params['lr0']),
+                    "--lrf", str(params['lrf']),
+                    "--patience", str(params['patience']),
+                    "--optimizer", str(params['optimizer']),
+                    "--weight-decay", str(params['weight_decay']),
+                    "--momentum", str(params['momentum']),
+                    "--warmup-epochs", str(params['warmup_epochs']),
+                    "--workers", str(params['workers']),
                 ]
+                if params.get('cache'):
+                    cmd.append("--cache")
             else:
-                logger.error(f"不支持的YOLO模型类型: {model_type}")
+                logger.error(f"不支持的模型架构: {architecture}")
                 return False
             
             # 打印命令行 
             cmd_str = ' '.join(cmd)
-            logger.info(f"执行{model_type}训练命令: {cmd_str}")
+            logger.info(f"执行{architecture.value}训练命令: {cmd_str}")
             TRAINING_LOGS[model_id].append(f"[{timestamp}] [系统] 执行命令: {cmd_str}")
             
             # 记录开始时间
@@ -531,28 +631,28 @@ class ModelService:
                     self.update_model(model_id, 
                                     file_path=best_model_path, 
                                     status=ModelStatus.COMPLETED)
-                    success_msg = f"{model_type}训练完成,耗时: {training_time:.1f}秒,模型保存在: {best_model_path}"
+                    success_msg = f"{architecture.value}训练完成,耗时: {training_time:.1f}秒,模型保存在: {best_model_path}"
                     logger.info(success_msg)
                     TRAINING_LOGS[model_id].append(f"[{timestamp}] [系统] {success_msg}")
                     return True
                 else:
-                    error_msg = f"{model_type}训练完成但模型文件不存在"
+                    error_msg = f"{architecture.value}训练完成但模型文件不存在"
                     logger.error(error_msg)
                     TRAINING_LOGS[model_id].append(f"[{timestamp}] [错误] {error_msg}")
                     self._update_training_status(model_id, ModelStatus.FAILED, 
-                        error_message=f"{model_type} training completed but model file not found")
+                        error_message=f"{architecture.value} training completed but model file not found")
                     return False
             else:
-                error_msg = f"{model_type}训练失败,返回码: {returncode}"
+                error_msg = f"{architecture.value}训练失败,返回码: {returncode}"
                 logger.error(error_msg)
                 TRAINING_LOGS[model_id].append(f"[{timestamp}] [错误] {error_msg}")
                 self._update_training_status(model_id, ModelStatus.FAILED,
-                    error_message=f"{model_type} training failed with return code: {returncode}")
+                    error_message=f"{architecture.value} training failed with return code: {returncode}")
                 return False
                 
         except Exception as e:
             timestamp = time.strftime("%H:%M:%S", time.localtime())
-            error_msg = f"{model_type}训练过程中发生异常: {str(e)}"
+            error_msg = f"{architecture.value}训练过程中发生异常: {str(e)}"
             logger.error(error_msg, exc_info=True)
             if model_id in TRAINING_LOGS:
                 TRAINING_LOGS[model_id].append(f"[{timestamp}] [错误] {error_msg}")
@@ -1205,10 +1305,13 @@ class ModelService:
         """使用YOLOv8进行训练"""
         model = self.get_model(model_id)
         if model:
-            self._run_yolo_training(model_id, dataset_dir, model.parameters, "YOLOv8")
+            self._run_yolo_training(
+                model_id, dataset_dir, model.parameters,
+                ModelArchitecture.YOLO_V8, ULTRALYTICS_ARCHITECTURES[ModelArchitecture.YOLO_V8]
+            )
 
     def _train_yolov5(self, model_id: int, dataset_dir: str) -> None:
         """使用YOLOv5进行训练"""
         model = self.get_model(model_id)
         if model:
-            self._run_yolo_training(model_id, dataset_dir, model.parameters, "YOLOv5") 
+            self._run_yolo_training(model_id, dataset_dir, model.parameters, ModelArchitecture.YOLO_V5, "yolov5s.pt") 

+ 4 - 2
src/components/model/ModelTester.vue

@@ -138,6 +138,7 @@ import { ref, defineProps, defineEmits, watch } from 'vue';
 import { useQuasar } from 'quasar';
 import { useI18n } from 'vue-i18n';
 import { modelService, Model } from '../../services/model';
+import { mergeTrainingParams } from '../../utils/trainingParams';
 
 const { t } = useI18n();
 const $q = useQuasar();
@@ -248,8 +249,9 @@ watch(() => props.selectedModel, () => {
   
   // 如果有模型,从模型参数中获取默认阈值
   if (props.selectedModel?.parameters) {
-    confThreshold.value = props.selectedModel.parameters.conf_thres || 0.25;
-    iouThreshold.value = props.selectedModel.parameters.iou_thres || 0.45;
+    const params = mergeTrainingParams(props.selectedModel.parameters);
+    confThreshold.value = params.conf_thres;
+    iouThreshold.value = params.iou_thres;
   }
 });
 </script>

+ 225 - 0
src/components/model/TrainingParamsForm.vue

@@ -0,0 +1,225 @@
+<template>
+  <div class="row q-col-gutter-md">
+    <div class="col-12 col-md-6">
+      <q-input
+        outlined
+        v-model.number="params.epochs"
+        type="number"
+        :label="t('model.training.params.epochs')"
+        class="full-width"
+        :disable="disabled"
+        min="1"
+      />
+    </div>
+    <div class="col-12 col-md-6">
+      <q-input
+        outlined
+        v-model.number="params.batch_size"
+        type="number"
+        :label="t('model.training.params.batchSize')"
+        class="full-width"
+        :disable="disabled"
+        min="1"
+      />
+    </div>
+    <div class="col-12 col-md-6">
+      <q-input
+        outlined
+        v-model.number="params.img_size"
+        type="number"
+        :label="t('model.training.params.imgSize')"
+        class="full-width"
+        :disable="disabled"
+        min="32"
+        step="32"
+      />
+    </div>
+    <div class="col-12 col-md-6">
+      <q-select
+        outlined
+        v-model="params.device"
+        :options="deviceOptions"
+        emit-value
+        map-options
+        :label="t('model.training.params.device')"
+        class="full-width"
+        :disable="disabled"
+      />
+    </div>
+    <div class="col-12 col-md-6">
+      <q-input
+        outlined
+        v-model.number="params.lr0"
+        type="number"
+        :label="t('model.training.params.lr0')"
+        class="full-width"
+        :disable="disabled"
+        step="0.001"
+        min="0"
+      />
+    </div>
+    <div class="col-12 col-md-6">
+      <q-input
+        outlined
+        v-model.number="params.lrf"
+        type="number"
+        :label="t('model.training.params.lrf')"
+        class="full-width"
+        :disable="disabled"
+        step="0.01"
+        min="0"
+        max="1"
+      />
+    </div>
+    <div class="col-12 col-md-6">
+      <q-select
+        outlined
+        v-model="params.optimizer"
+        :options="optimizerOptions"
+        :label="t('model.training.params.optimizer')"
+        class="full-width"
+        :disable="disabled"
+      />
+    </div>
+    <div class="col-12 col-md-6">
+      <q-input
+        outlined
+        v-model.number="params.patience"
+        type="number"
+        :label="t('model.training.params.patience')"
+        class="full-width"
+        :disable="disabled"
+        min="0"
+      />
+    </div>
+    <div class="col-12 col-md-6">
+      <q-input
+        outlined
+        v-model.number="params.weight_decay"
+        type="number"
+        :label="t('model.training.params.weightDecay')"
+        class="full-width"
+        :disable="disabled"
+        step="0.0001"
+        min="0"
+      />
+    </div>
+    <div class="col-12 col-md-6">
+      <q-input
+        outlined
+        v-model.number="params.momentum"
+        type="number"
+        :label="t('model.training.params.momentum')"
+        class="full-width"
+        :disable="disabled"
+        step="0.001"
+        min="0"
+        max="1"
+      />
+    </div>
+    <div class="col-12 col-md-6">
+      <q-input
+        outlined
+        v-model.number="params.warmup_epochs"
+        type="number"
+        :label="t('model.training.params.warmupEpochs')"
+        class="full-width"
+        :disable="disabled"
+        min="0"
+      />
+    </div>
+    <div class="col-12 col-md-6">
+      <q-input
+        outlined
+        v-model.number="params.workers"
+        type="number"
+        :label="t('model.training.params.workers')"
+        class="full-width"
+        :disable="disabled"
+        min="0"
+      />
+    </div>
+    <div class="col-12 col-md-6">
+      <q-input
+        outlined
+        v-model.number="params.mosaic"
+        type="number"
+        :label="t('model.training.params.mosaic')"
+        class="full-width"
+        :disable="disabled"
+        step="0.1"
+        min="0"
+        max="1"
+      />
+    </div>
+    <div class="col-12 col-md-6">
+      <q-input
+        outlined
+        v-model.number="params.save_period"
+        type="number"
+        :label="t('model.training.params.savePeriod')"
+        class="full-width"
+        :disable="disabled"
+        min="-1"
+      />
+    </div>
+    <div class="col-12 col-md-6">
+      <q-input
+        outlined
+        v-model.number="params.conf_thres"
+        type="number"
+        :label="t('model.training.params.confThres')"
+        step="0.05"
+        min="0"
+        max="1"
+        class="full-width"
+        :disable="disabled"
+      />
+    </div>
+    <div class="col-12 col-md-6">
+      <q-input
+        outlined
+        v-model.number="params.iou_thres"
+        type="number"
+        :label="t('model.training.params.iouThres')"
+        step="0.05"
+        min="0"
+        max="1"
+        class="full-width"
+        :disable="disabled"
+      />
+    </div>
+    <div class="col-12 col-md-6">
+      <q-toggle
+        v-model="params.cache"
+        :label="t('model.training.params.cache')"
+        :disable="disabled"
+      />
+    </div>
+  </div>
+</template>
+
+<script setup lang="ts">
+import { computed } from 'vue'
+import { useI18n } from 'vue-i18n'
+import {
+  TrainingParameters,
+  OPTIMIZER_OPTIONS,
+  DEVICE_OPTIONS
+} from '../../utils/trainingParams'
+
+const { t } = useI18n()
+
+defineProps<{
+  params: TrainingParameters
+  disabled?: boolean
+}>()
+
+const optimizerOptions = OPTIMIZER_OPTIONS
+const deviceOptions = computed(() =>
+  DEVICE_OPTIONS.map(opt => ({
+    value: opt.value,
+    label: t(`model.training.params.deviceOptions.${opt.value}`)
+  }))
+)
+</script>

+ 18 - 1
src/i18n/locales/en.ts

@@ -571,7 +571,24 @@ export default {
         batchSize: 'Batch Size',
         imgSize: 'Image Size',
         confThres: 'Confidence Threshold',
-        iouThres: 'IOU Threshold'
+        iouThres: 'IOU Threshold',
+        lr0: 'Initial Learning Rate',
+        lrf: 'Final LR Factor',
+        patience: 'Early Stop Patience',
+        optimizer: 'Optimizer',
+        weightDecay: 'Weight Decay',
+        momentum: 'Momentum',
+        warmupEpochs: 'Warmup Epochs',
+        workers: 'Data Loader Workers',
+        device: 'Training Device',
+        mosaic: 'Mosaic Augmentation',
+        cache: 'Cache Training Images',
+        savePeriod: 'Save Period',
+        deviceOptions: {
+          cpu: 'CPU',
+          cuda: 'GPU (CUDA)',
+          '0': 'GPU 0'
+        }
       },
       start: 'Start Training',
       stop: 'Stop Training',

+ 18 - 1
src/i18n/locales/zh-CN.ts

@@ -571,7 +571,24 @@ export default {
         batchSize: '批次大小',
         imgSize: '图像尺寸',
         confThres: '置信度阈值',
-        iouThres: 'IOU阈值'
+        iouThres: 'IOU阈值',
+        lr0: '初始学习率',
+        lrf: '最终学习率因子',
+        patience: '早停耐心值',
+        optimizer: '优化器',
+        weightDecay: '权重衰减',
+        momentum: '动量',
+        warmupEpochs: '预热轮数',
+        workers: '数据加载线程',
+        device: '训练设备',
+        mosaic: 'Mosaic 增强',
+        cache: '缓存训练图像',
+        savePeriod: '模型保存周期',
+        deviceOptions: {
+          cpu: 'CPU',
+          cuda: 'GPU (CUDA)',
+          '0': 'GPU 0'
+        }
       },
       start: '开始训练',
       stop: '停止训练',

+ 20 - 7
src/services/model.ts

@@ -1,4 +1,12 @@
 import { apiService } from './api';
+import { mergeTrainingParams, TrainingParameters } from '../utils/trainingParams';
+
+function normalizeModel(model: Model): Model {
+  return {
+    ...model,
+    parameters: mergeTrainingParams(model.parameters as Partial<TrainingParameters>)
+  };
+}
 
 // Enums to match backend
 export enum ModelStatus {
@@ -10,8 +18,13 @@ export enum ModelStatus {
 
 export enum ModelArchitecture {
   YOLO_V5 = 'yolo_v5',
+  YOLO_V6 = 'yolo_v6',
   YOLO_V8 = 'yolo_v8',
-  YOLO_V9 = 'yolo_v9'
+  YOLO_V9 = 'yolo_v9',
+  YOLO_V10 = 'yolo_v10',
+  YOLO_V11 = 'yolo_v11',
+  RT_DETR = 'rt_detr',
+  YOLO_WORLD = 'yolo_world'
 }
 
 // Model interfaces
@@ -81,7 +94,7 @@ export class ModelService {
    */
   async getModels(): Promise<Model[]> {
     const response = await apiService.get<Model[]>('/models/');
-    return response;
+    return response.map(normalizeModel);
   }
 
   /**
@@ -89,7 +102,7 @@ export class ModelService {
    */
   async getModelsByDataset(datasetId: number): Promise<Model[]> {
     const response = await apiService.get<Model[]>(`/models?dataset_id=${datasetId}`);
-    return response;
+    return response.map(normalizeModel);
   }
 
   /**
@@ -97,7 +110,7 @@ export class ModelService {
    */
   async getModel(id: number): Promise<Model> {
     const response = await apiService.get<Model>(`/models/${id}`);
-    return response;
+    return normalizeModel(response);
   }
 
   /**
@@ -112,7 +125,7 @@ export class ModelService {
     };
     
     const response = await apiService.post<Model>('/models/', requestData);
-    return response;
+    return normalizeModel(response);
   }
 
   /**
@@ -124,7 +137,7 @@ export class ModelService {
     if (data.parameters !== undefined) requestData.parameters = data.parameters;
     
     const response = await apiService.put<Model>(`/models/${id}`, requestData);
-    return response;
+    return normalizeModel(response);
   }
 
   /**
@@ -139,7 +152,7 @@ export class ModelService {
    */
   async startTraining(id: number): Promise<Model> {
     const response = await apiService.post<Model>(`/models/${id}/train`);
-    return response;
+    return normalizeModel(response);
   }
 
   /**

+ 51 - 0
src/utils/trainingParams.ts

@@ -0,0 +1,51 @@
+export interface TrainingParameters {
+  epochs: number
+  batch_size: number
+  img_size: number
+  conf_thres: number
+  iou_thres: number
+  lr0: number
+  lrf: number
+  patience: number
+  optimizer: string
+  weight_decay: number
+  momentum: number
+  warmup_epochs: number
+  workers: number
+  device: string
+  mosaic: number
+  cache: boolean
+  save_period: number
+}
+
+export const DEFAULT_TRAINING_PARAMS: TrainingParameters = {
+  epochs: 50,
+  batch_size: 16,
+  img_size: 640,
+  conf_thres: 0.25,
+  iou_thres: 0.45,
+  lr0: 0.01,
+  lrf: 0.01,
+  patience: 50,
+  optimizer: 'SGD',
+  weight_decay: 0.0005,
+  momentum: 0.937,
+  warmup_epochs: 3,
+  workers: 4,
+  device: 'cpu',
+  mosaic: 1.0,
+  cache: false,
+  save_period: -1
+}
+
+export const OPTIMIZER_OPTIONS = ['SGD', 'Adam', 'AdamW', 'auto'] as const
+
+export const DEVICE_OPTIONS = [
+  { value: 'cpu', label: 'CPU' },
+  { value: 'cuda', label: 'GPU (CUDA)' },
+  { value: '0', label: 'GPU 0' }
+] as const
+
+export function mergeTrainingParams(params?: Partial<TrainingParameters>): TrainingParameters {
+  return { ...DEFAULT_TRAINING_PARAMS, ...params }
+}

+ 37 - 142
src/views/Model.vue

@@ -139,63 +139,12 @@
                   disable
                 />
               </div>
-              <div class="col-12 col-md-6">
-                <q-input
-                  outlined
-                  v-model.number="epochs"
-                  type="number"
-                  :label="t('model.training.params.epochs')"
-                  class="full-width"
-                  :disable="selectedModel.status === ModelStatus.TRAINING"
-                />
-              </div>
-              <div class="col-12 col-md-6">
-                <q-input
-                  outlined
-                  v-model.number="batchSize"
-                  type="number"
-                  :label="t('model.training.params.batchSize')"
-                  class="full-width"
-                  :disable="selectedModel.status === ModelStatus.TRAINING"
-                />
-              </div>
-              <div class="col-12 col-md-6">
-                <q-input
-                  outlined
-                  v-model.number="imgSize"
-                  type="number"
-                  :label="t('model.training.params.imgSize')"
-                  class="full-width"
-                  :disable="selectedModel.status === ModelStatus.TRAINING"
-                />
-              </div>
-              <div class="col-12 col-md-6">
-                <q-input
-                  outlined
-                  v-model.number="confThres"
-                  type="number"
-                  :label="t('model.training.params.confThres')"
-                  step="0.05"
-                  min="0"
-                  max="1"
-                  class="full-width"
-                  :disable="selectedModel.status === ModelStatus.TRAINING"
-                />
-              </div>
-              <div class="col-12 col-md-6">
-                <q-input
-                  outlined
-                  v-model.number="iouThres"
-                  type="number"
-                  :label="t('model.training.params.iouThres')"
-                  step="0.05"
-                  min="0"
-                  max="1"
-                  class="full-width"
-                  :disable="selectedModel.status === ModelStatus.TRAINING"
-                />
-              </div>
             </div>
+            <TrainingParamsForm
+              :params="trainingParams"
+              class="q-mt-md"
+              :disabled="selectedModel.status === ModelStatus.TRAINING"
+            />
             <q-btn
               color="primary"
               icon="play_arrow"
@@ -421,12 +370,12 @@
     
     <!-- 新建模型对话框 -->
     <q-dialog v-model="createModelDialog" persistent>
-      <q-card style="min-width: 350px">
+      <q-card style="min-width: 720px; max-width: 90vw">
         <q-card-section>
           <div class="text-h6">{{ t('model.dialogs.create.title') }}</div>
         </q-card-section>
         
-        <q-card-section class="q-pt-none">
+        <q-card-section class="q-pt-none" style="max-height: 70vh; overflow-y: auto">
           <q-input 
             outlined 
             v-model="newModel.name" 
@@ -456,52 +405,8 @@
             :label="t('model.training.params.architecture')"
             class="q-mt-md"
           />
-          
-          <q-input 
-            outlined 
-            v-model.number="newModel.parameters!.epochs" 
-            type="number"
-            :label="t('model.training.params.epochs')"
-            class="q-mt-md"
-          />
-          
-          <q-input 
-            outlined 
-            v-model.number="newModel.parameters!.batch_size" 
-            type="number"
-            :label="t('model.training.params.batchSize')"
-            class="q-mt-md"
-          />
-          
-          <q-input 
-            outlined 
-            v-model.number="newModel.parameters!.img_size" 
-            type="number"
-            :label="t('model.training.params.imgSize')"
-            class="q-mt-md"
-          />
-          
-          <q-input 
-            outlined 
-            v-model.number="newModel.parameters!.conf_thres" 
-            type="number"
-            :label="t('model.training.params.confThres')"
-            step="0.05"
-            min="0"
-            max="1"
-            class="q-mt-md"
-          />
-          
-          <q-input 
-            outlined 
-            v-model.number="newModel.parameters!.iou_thres" 
-            type="number"
-            :label="t('model.training.params.iouThres')"
-            step="0.05"
-            min="0"
-            max="1"
-            class="q-mt-md"
-          />
+
+          <TrainingParamsForm :params="newModelParameters" class="q-mt-md" />
         </q-card-section>
         
         <q-card-actions align="right">
@@ -536,6 +441,8 @@ import { modelService, Model, ModelStatus, ModelArchitecture, CreateModelRequest
 import { useRouter, useRoute } from 'vue-router'
 import { AnnotationService } from '../services/annotation'
 import ModelTester from '../components/model/ModelTester.vue'
+import TrainingParamsForm from '../components/model/TrainingParamsForm.vue'
+import { DEFAULT_TRAINING_PARAMS, mergeTrainingParams, TrainingParameters } from '../utils/trainingParams'
 import { Terminal } from 'xterm'
 import { FitAddon } from 'xterm-addon-fit'
 import 'xterm/css/xterm.css'
@@ -577,30 +484,27 @@ const props = defineProps({
 // 模型架构选项
 const architectureOptions = [
   { value: ModelArchitecture.YOLO_V5, label: 'YOLO v5' },
+  { value: ModelArchitecture.YOLO_V6, label: 'YOLO v3' },
   { value: ModelArchitecture.YOLO_V8, label: 'YOLO v8' },
-  { value: ModelArchitecture.YOLO_V9, label: 'YOLO v9' }
+  { value: ModelArchitecture.YOLO_V9, label: 'YOLO v9' },
+  { value: ModelArchitecture.YOLO_V10, label: 'YOLO v10' },
+  { value: ModelArchitecture.YOLO_V11, label: 'YOLO v11' },
+  { value: ModelArchitecture.RT_DETR, label: 'RT-DETR' },
+  { value: ModelArchitecture.YOLO_WORLD, label: 'YOLO-World' }
 ]
 
 // 新模型表单
 const newModel = ref<CreateModelRequest>({
   name: '',
-  architecture: ModelArchitecture.YOLO_V8,  // 默认使用YOLOv8
+  architecture: ModelArchitecture.YOLO_V8,
   dataset_id: undefined,
-  parameters: {
-    epochs: 50,
-    batch_size: 16,
-    img_size: 640,
-    conf_thres: 0.25,
-    iou_thres: 0.45
-  }
+  parameters: { ...DEFAULT_TRAINING_PARAMS }
 })
 
+const newModelParameters = ref<TrainingParameters>({ ...DEFAULT_TRAINING_PARAMS })
+
 // 训练参数
-const epochs = ref(50)
-const batchSize = ref(16)
-const imgSize = ref(640)
-const confThres = ref(0.25)
-const iouThres = ref(0.45)
+const trainingParams = ref<TrainingParameters>({ ...DEFAULT_TRAINING_PARAMS })
 
 // 获取选中的数据集
 const selectedDataset = computed(() => {
@@ -722,13 +626,7 @@ const updateModelParameters = async () => {
   
   try {
     await modelService.updateModel(selectedModel.value.id, {
-      parameters: {
-        epochs: epochs.value,
-        batch_size: batchSize.value,
-        img_size: imgSize.value,
-        conf_thres: confThres.value,
-        iou_thres: iouThres.value
-      }
+      parameters: { ...trainingParams.value }
     })
   } catch (error) {
     throw new Error(t('model.notifications.updateParamsFailed', { error: (error as Error).message }))
@@ -737,14 +635,12 @@ const updateModelParameters = async () => {
 
 // 选择模型
 const selectModel = (model: Model) => {
-  selectedModel.value = model
-  
-  // 更新训练参数
-  epochs.value = model.parameters.epochs || 50
-  batchSize.value = model.parameters.batch_size || 16
-  imgSize.value = model.parameters.img_size || 640
-  confThres.value = model.parameters.conf_thres || 0.25
-  iouThres.value = model.parameters.iou_thres || 0.45
+  const normalized = {
+    ...model,
+    parameters: mergeTrainingParams(model.parameters)
+  }
+  selectedModel.value = normalized
+  trainingParams.value = { ...normalized.parameters }
   
   // 如果模型正在训练,获取日志并开始轮询状态
   if (model.status === ModelStatus.TRAINING) {
@@ -774,21 +670,19 @@ const createModel = async () => {
       return
     }
     
-    const model = await modelService.createModel(newModel.value)
+    const model = await modelService.createModel({
+      ...newModel.value,
+      parameters: { ...newModelParameters.value }
+    })
     models.value.push(model)
     
     newModel.value = {
       name: '',
       architecture: ModelArchitecture.YOLO_V8,
       dataset_id: undefined,
-      parameters: {
-        epochs: 50,
-        batch_size: 16,
-        img_size: 640,
-        conf_thres: 0.25,
-        iou_thres: 0.45
-      }
+      parameters: { ...DEFAULT_TRAINING_PARAMS }
     }
+    newModelParameters.value = { ...DEFAULT_TRAINING_PARAMS }
     
     createModelDialog.value = false
     
@@ -887,9 +781,10 @@ const refreshModel = async (modelId: number) => {
       models.value[index] = model
     }
     
-    // 更新选中的模型
+    // 更新选中的模型并同步训练参数表单
     if (selectedModel.value?.id === modelId) {
       selectedModel.value = model
+      trainingParams.value = mergeTrainingParams(model.parameters)
     }
     
     // 如果模型正在训练中,获取训练日志

二進制
yolo11n.pt


二進制
yolov10n.pt


二進制
yolov8s-worldv2.pt


二進制
yolov9t.pt