|
@@ -25,6 +25,70 @@ import sys
|
|
|
# 设置日志
|
|
# 设置日志
|
|
|
logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
+# Ultralytics 架构与预训练权重映射
|
|
|
|
|
+ULTRALYTICS_ARCHITECTURES = {
|
|
|
|
|
+ ModelArchitecture.YOLO_V6: "yolov3u.pt",
|
|
|
|
|
+ ModelArchitecture.YOLO_V8: "yolov8n.pt",
|
|
|
|
|
+ ModelArchitecture.YOLO_V9: "yolov9t.pt",
|
|
|
|
|
+ ModelArchitecture.YOLO_V10: "yolov10n.pt",
|
|
|
|
|
+ ModelArchitecture.YOLO_V11: "yolo11n.pt",
|
|
|
|
|
+ ModelArchitecture.RT_DETR: "rtdetr-l.pt",
|
|
|
|
|
+ ModelArchitecture.YOLO_WORLD: "yolov8s-worldv2.pt",
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+DEFAULT_TRAINING_PARAMS = {
|
|
|
|
|
+ 'epochs': 50,
|
|
|
|
|
+ 'batch_size': 16,
|
|
|
|
|
+ 'img_size': 640,
|
|
|
|
|
+ 'conf_thres': 0.25,
|
|
|
|
|
+ 'iou_thres': 0.45,
|
|
|
|
|
+ 'lr0': 0.01,
|
|
|
|
|
+ 'lrf': 0.01,
|
|
|
|
|
+ 'patience': 50,
|
|
|
|
|
+ 'optimizer': 'SGD',
|
|
|
|
|
+ 'weight_decay': 0.0005,
|
|
|
|
|
+ 'momentum': 0.937,
|
|
|
|
|
+ 'warmup_epochs': 3,
|
|
|
|
|
+ 'workers': 4,
|
|
|
|
|
+ 'device': 'cpu',
|
|
|
|
|
+ 'mosaic': 1.0,
|
|
|
|
|
+ 'cache': False,
|
|
|
|
|
+ 'save_period': -1,
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def merge_training_params(parameters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
|
|
|
+ merged = DEFAULT_TRAINING_PARAMS.copy()
|
|
|
|
|
+ if parameters:
|
|
|
|
|
+ merged.update(parameters)
|
|
|
|
|
+ return merged
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def parameters_need_upgrade(parameters: Optional[Dict[str, Any]]) -> bool:
|
|
|
|
|
+ if not parameters:
|
|
|
|
|
+ return True
|
|
|
|
|
+ return any(key not in parameters for key in DEFAULT_TRAINING_PARAMS)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def build_ultralytics_train_args(parameters: dict) -> List[str]:
|
|
|
|
|
+ params = merge_training_params(parameters)
|
|
|
|
|
+ args = [
|
|
|
|
|
+ f"lr0={params['lr0']}",
|
|
|
|
|
+ f"lrf={params['lrf']}",
|
|
|
|
|
+ f"patience={params['patience']}",
|
|
|
|
|
+ f"optimizer={params['optimizer']}",
|
|
|
|
|
+ f"weight_decay={params['weight_decay']}",
|
|
|
|
|
+ f"momentum={params['momentum']}",
|
|
|
|
|
+ f"warmup_epochs={params['warmup_epochs']}",
|
|
|
|
|
+ f"workers={params['workers']}",
|
|
|
|
|
+ f"mosaic={params['mosaic']}",
|
|
|
|
|
+ f"save_period={params['save_period']}",
|
|
|
|
|
+ f"device={params['device']}",
|
|
|
|
|
+ ]
|
|
|
|
|
+ if params.get('cache'):
|
|
|
|
|
+ args.append("cache=True")
|
|
|
|
|
+ return args
|
|
|
|
|
+
|
|
|
class YOLOModel:
|
|
class YOLOModel:
|
|
|
"""YOLO模型包装类,用于在cv_operation中使用"""
|
|
"""YOLO模型包装类,用于在cv_operation中使用"""
|
|
|
|
|
|
|
@@ -54,15 +118,10 @@ class YOLOModel:
|
|
|
print(yolov5_path)
|
|
print(yolov5_path)
|
|
|
# 使用YOLOv5加载模型
|
|
# 使用YOLOv5加载模型
|
|
|
self.model = torch.hub.load(yolov5_path, 'custom', path=self.model_path, source="local")
|
|
self.model = torch.hub.load(yolov5_path, 'custom', path=self.model_path, source="local")
|
|
|
- elif self.architecture == ModelArchitecture.YOLO_V8:
|
|
|
|
|
- # 使用YOLOv8加载模型
|
|
|
|
|
|
|
+ elif self.architecture in ULTRALYTICS_ARCHITECTURES:
|
|
|
from ultralytics import YOLO
|
|
from ultralytics import YOLO
|
|
|
print(self.model_path)
|
|
print(self.model_path)
|
|
|
self.model = YOLO(self.model_path)
|
|
self.model = YOLO(self.model_path)
|
|
|
- elif self.architecture == ModelArchitecture.YOLO_V9:
|
|
|
|
|
- # 使用YOLOv9加载模型
|
|
|
|
|
- from ultralytics import YOLO
|
|
|
|
|
- self.model = YOLO(self.model_path)
|
|
|
|
|
else:
|
|
else:
|
|
|
raise ValueError(f"不支持的模型架构: {self.architecture}")
|
|
raise ValueError(f"不支持的模型架构: {self.architecture}")
|
|
|
|
|
|
|
@@ -186,17 +245,47 @@ class ModelService:
|
|
|
logger.error(f"模型加载失败: {str(e)}")
|
|
logger.error(f"模型加载失败: {str(e)}")
|
|
|
raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}")
|
|
|
|
|
|
|
|
|
|
+ def _apply_normalized_parameters(self, model: Model, persist: bool = False) -> Model:
|
|
|
|
|
+ """将旧模型参数与最新默认参数合并,必要时写回数据库"""
|
|
|
|
|
+ merged = merge_training_params(model.parameters)
|
|
|
|
|
+ should_persist = persist and parameters_need_upgrade(model.parameters)
|
|
|
|
|
+ model.parameters = merged
|
|
|
|
|
+ if should_persist:
|
|
|
|
|
+ model.updated_at = datetime.utcnow()
|
|
|
|
|
+ self.db.commit()
|
|
|
|
|
+ self.db.refresh(model)
|
|
|
|
|
+ return model
|
|
|
|
|
+
|
|
|
def get_models(self) -> List[Model]:
|
|
def get_models(self) -> List[Model]:
|
|
|
"""获取所有模型"""
|
|
"""获取所有模型"""
|
|
|
- return self.db.query(Model).all()
|
|
|
|
|
|
|
+ return self._normalize_models_list(self.db.query(Model).all())
|
|
|
|
|
|
|
|
def get_models_by_dataset(self, dataset_id: int) -> List[Model]:
|
|
def get_models_by_dataset(self, dataset_id: int) -> List[Model]:
|
|
|
"""获取特定数据集的所有模型"""
|
|
"""获取特定数据集的所有模型"""
|
|
|
- return self.db.query(Model).filter(Model.dataset_id == dataset_id).all()
|
|
|
|
|
|
|
+ return self._normalize_models_list(
|
|
|
|
|
+ self.db.query(Model).filter(Model.dataset_id == dataset_id).all()
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def _normalize_models_list(self, models: List[Model]) -> List[Model]:
|
|
|
|
|
+ """批量兼容旧模型参数,缺失字段时一次性写回数据库"""
|
|
|
|
|
+ upgraded = False
|
|
|
|
|
+ for model in models:
|
|
|
|
|
+ if parameters_need_upgrade(model.parameters):
|
|
|
|
|
+ model.parameters = merge_training_params(model.parameters)
|
|
|
|
|
+ model.updated_at = datetime.utcnow()
|
|
|
|
|
+ upgraded = True
|
|
|
|
|
+ else:
|
|
|
|
|
+ model.parameters = merge_training_params(model.parameters)
|
|
|
|
|
+ if upgraded:
|
|
|
|
|
+ self.db.commit()
|
|
|
|
|
+ return models
|
|
|
|
|
|
|
|
def get_model(self, model_id: int) -> Optional[Model]:
|
|
def get_model(self, model_id: int) -> Optional[Model]:
|
|
|
"""获取指定模型"""
|
|
"""获取指定模型"""
|
|
|
- return self.db.query(Model).filter(Model.id == model_id).first()
|
|
|
|
|
|
|
+ model = self.db.query(Model).filter(Model.id == model_id).first()
|
|
|
|
|
+ if model:
|
|
|
|
|
+ return self._apply_normalized_parameters(model, persist=True)
|
|
|
|
|
+ return None
|
|
|
|
|
|
|
|
def create_model(self,
|
|
def create_model(self,
|
|
|
name: str,
|
|
name: str,
|
|
@@ -210,21 +299,7 @@ class ModelService:
|
|
|
if not dataset:
|
|
if not dataset:
|
|
|
raise HTTPException(status_code=404, detail="Dataset not found")
|
|
raise HTTPException(status_code=404, detail="Dataset not found")
|
|
|
|
|
|
|
|
- # 验证参数
|
|
|
|
|
- if parameters is None:
|
|
|
|
|
- parameters = {}
|
|
|
|
|
-
|
|
|
|
|
- # 根据架构类型设置默认参数
|
|
|
|
|
- if 'epochs' not in parameters:
|
|
|
|
|
- parameters['epochs'] = 50
|
|
|
|
|
- if 'batch_size' not in parameters:
|
|
|
|
|
- parameters['batch_size'] = 16
|
|
|
|
|
- if 'img_size' not in parameters:
|
|
|
|
|
- parameters['img_size'] = 640
|
|
|
|
|
- if 'conf_thres' not in parameters:
|
|
|
|
|
- parameters['conf_thres'] = 0.25
|
|
|
|
|
- if 'iou_thres' not in parameters:
|
|
|
|
|
- parameters['iou_thres'] = 0.45
|
|
|
|
|
|
|
+ parameters = merge_training_params(parameters)
|
|
|
|
|
|
|
|
model = Model(
|
|
model = Model(
|
|
|
name=name,
|
|
name=name,
|
|
@@ -246,9 +321,15 @@ class ModelService:
|
|
|
|
|
|
|
|
def update_model(self, model_id: int, **kwargs) -> Model:
|
|
def update_model(self, model_id: int, **kwargs) -> Model:
|
|
|
"""更新模型信息"""
|
|
"""更新模型信息"""
|
|
|
- model = self.get_model(model_id)
|
|
|
|
|
|
|
+ model = self.db.query(Model).filter(Model.id == model_id).first()
|
|
|
if not model:
|
|
if not model:
|
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
|
|
|
+
|
|
|
|
|
+ if kwargs.get('parameters') is not None:
|
|
|
|
|
+ kwargs['parameters'] = merge_training_params({
|
|
|
|
|
+ **merge_training_params(model.parameters),
|
|
|
|
|
+ **kwargs['parameters'],
|
|
|
|
|
+ })
|
|
|
|
|
|
|
|
# 更新提供的字段
|
|
# 更新提供的字段
|
|
|
for key, value in kwargs.items():
|
|
for key, value in kwargs.items():
|
|
@@ -296,7 +377,8 @@ class ModelService:
|
|
|
if model.status == ModelStatus.TRAINING:
|
|
if model.status == ModelStatus.TRAINING:
|
|
|
raise HTTPException(status_code=400, detail="Model is already training")
|
|
raise HTTPException(status_code=400, detail="Model is already training")
|
|
|
|
|
|
|
|
- # 更新模型状态为训练中
|
|
|
|
|
|
|
+ # 兼容旧模型:训练前补齐并持久化完整参数
|
|
|
|
|
+ model.parameters = merge_training_params(model.parameters)
|
|
|
model.status = ModelStatus.TRAINING
|
|
model.status = ModelStatus.TRAINING
|
|
|
model.updated_at = datetime.utcnow()
|
|
model.updated_at = datetime.utcnow()
|
|
|
|
|
|
|
@@ -368,10 +450,13 @@ class ModelService:
|
|
|
try:
|
|
try:
|
|
|
# 获取模型训练参数
|
|
# 获取模型训练参数
|
|
|
logger.info(f"开始训练模型 ID:{model_id}, 架构:{model.architecture}")
|
|
logger.info(f"开始训练模型 ID:{model_id}, 架构:{model.architecture}")
|
|
|
- if model.architecture in [ModelArchitecture.YOLO_V8, ModelArchitecture.YOLO_V9]:
|
|
|
|
|
- success = self._run_yolo_training(model_id, dataset_dir, model.parameters, "YOLOv8")
|
|
|
|
|
|
|
+ if model.architecture in ULTRALYTICS_ARCHITECTURES:
|
|
|
|
|
+ success = self._run_yolo_training(
|
|
|
|
|
+ model_id, dataset_dir, model.parameters,
|
|
|
|
|
+ model.architecture, ULTRALYTICS_ARCHITECTURES[model.architecture]
|
|
|
|
|
+ )
|
|
|
elif model.architecture == ModelArchitecture.YOLO_V5:
|
|
elif model.architecture == ModelArchitecture.YOLO_V5:
|
|
|
- success = self._run_yolo_training(model_id, dataset_dir, model.parameters, "YOLOv5")
|
|
|
|
|
|
|
+ success = self._run_yolo_training(model_id, dataset_dir, model.parameters, model.architecture, "yolov5s.pt")
|
|
|
else:
|
|
else:
|
|
|
logger.error(f"不支持的模型架构: {model.architecture}")
|
|
logger.error(f"不支持的模型架构: {model.architecture}")
|
|
|
self._update_training_status(model_id, ModelStatus.FAILED,
|
|
self._update_training_status(model_id, ModelStatus.FAILED,
|
|
@@ -397,14 +482,15 @@ class ModelService:
|
|
|
except:
|
|
except:
|
|
|
pass
|
|
pass
|
|
|
|
|
|
|
|
- def _run_yolo_training(self, model_id: int, dataset_dir: str, parameters: dict, model_type: str) -> bool:
|
|
|
|
|
- """运行YOLO训练任务,统一处理YOLOv5和YOLOv8的训练逻辑
|
|
|
|
|
|
|
+ def _run_yolo_training(self, model_id: int, dataset_dir: str, parameters: dict, architecture: ModelArchitecture, pretrained_weights: str) -> bool:
|
|
|
|
|
+ """运行YOLO训练任务,统一处理各架构的训练逻辑
|
|
|
|
|
|
|
|
Args:
|
|
Args:
|
|
|
model_id: 模型ID
|
|
model_id: 模型ID
|
|
|
dataset_dir: 数据集目录
|
|
dataset_dir: 数据集目录
|
|
|
parameters: 训练参数
|
|
parameters: 训练参数
|
|
|
- model_type: "YOLOv5" 或 "YOLOv8"
|
|
|
|
|
|
|
+ architecture: 模型架构
|
|
|
|
|
+ pretrained_weights: 预训练权重文件名
|
|
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
|
bool: 训练是否成功
|
|
bool: 训练是否成功
|
|
@@ -415,15 +501,19 @@ class ModelService:
|
|
|
|
|
|
|
|
# 添加初始日志
|
|
# 添加初始日志
|
|
|
timestamp = time.strftime("%H:%M:%S", time.localtime())
|
|
timestamp = time.strftime("%H:%M:%S", time.localtime())
|
|
|
- TRAINING_LOGS[model_id].append(f"[{timestamp}] [系统] 开始{model_type}训练...")
|
|
|
|
|
|
|
+ TRAINING_LOGS[model_id].append(f"[{timestamp}] [系统] 开始{architecture.value}训练...")
|
|
|
|
|
|
|
|
- # 提取训练参数
|
|
|
|
|
- epochs = parameters.get('epochs', 20)
|
|
|
|
|
- batch_size = parameters.get('batch_size', 8)
|
|
|
|
|
- img_size = parameters.get('img_size', 640)
|
|
|
|
|
|
|
+ params = merge_training_params(parameters)
|
|
|
|
|
+ epochs = params['epochs']
|
|
|
|
|
+ batch_size = params['batch_size']
|
|
|
|
|
+ img_size = params['img_size']
|
|
|
|
|
+ device = params['device']
|
|
|
|
|
|
|
|
# 记录训练参数
|
|
# 记录训练参数
|
|
|
- TRAINING_LOGS[model_id].append(f"[{timestamp}] [系统] 训练参数: epochs={epochs}, batch_size={batch_size}, img_size={img_size}")
|
|
|
|
|
|
|
+ TRAINING_LOGS[model_id].append(
|
|
|
|
|
+ f"[{timestamp}] [系统] 训练参数: epochs={epochs}, batch_size={batch_size}, "
|
|
|
|
|
+ f"img_size={img_size}, lr0={params['lr0']}, optimizer={params['optimizer']}, device={device}"
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
# 设置训练基本参数
|
|
# 设置训练基本参数
|
|
|
data_yaml_path = os.path.join(dataset_dir, "dataset.yaml")
|
|
data_yaml_path = os.path.join(dataset_dir, "dataset.yaml")
|
|
@@ -434,22 +524,22 @@ class ModelService:
|
|
|
os.makedirs(os.path.join(project_dir, name), exist_ok=True)
|
|
os.makedirs(os.path.join(project_dir, name), exist_ok=True)
|
|
|
|
|
|
|
|
# 准备训练命令 - 不同模型有不同命令格式
|
|
# 准备训练命令 - 不同模型有不同命令格式
|
|
|
- if model_type == "YOLOv8":
|
|
|
|
|
|
|
+ if architecture in ULTRALYTICS_ARCHITECTURES:
|
|
|
cmd = [
|
|
cmd = [
|
|
|
"yolo",
|
|
"yolo",
|
|
|
"task=detect",
|
|
"task=detect",
|
|
|
"mode=train",
|
|
"mode=train",
|
|
|
f"data={os.path.abspath(data_yaml_path)}",
|
|
f"data={os.path.abspath(data_yaml_path)}",
|
|
|
- "model=yolov8n.pt",
|
|
|
|
|
|
|
+ f"model={pretrained_weights}",
|
|
|
f"epochs={epochs}",
|
|
f"epochs={epochs}",
|
|
|
f"batch={batch_size}",
|
|
f"batch={batch_size}",
|
|
|
f"imgsz={img_size}",
|
|
f"imgsz={img_size}",
|
|
|
f"project={project_dir}",
|
|
f"project={project_dir}",
|
|
|
f"name={name}",
|
|
f"name={name}",
|
|
|
"exist_ok=True",
|
|
"exist_ok=True",
|
|
|
- "device=cpu"
|
|
|
|
|
|
|
+ *build_ultralytics_train_args(params),
|
|
|
]
|
|
]
|
|
|
- elif model_type == "YOLOv5":
|
|
|
|
|
|
|
+ elif architecture == ModelArchitecture.YOLO_V5:
|
|
|
# 确保YOLOv5代码存在
|
|
# 确保YOLOv5代码存在
|
|
|
yolov5_dir = "yolov5"
|
|
yolov5_dir = "yolov5"
|
|
|
if not os.path.exists(yolov5_dir):
|
|
if not os.path.exists(yolov5_dir):
|
|
@@ -469,15 +559,25 @@ class ModelService:
|
|
|
"--project", project_dir,
|
|
"--project", project_dir,
|
|
|
"--name", name,
|
|
"--name", name,
|
|
|
"--exist-ok",
|
|
"--exist-ok",
|
|
|
- "--device", "cpu"
|
|
|
|
|
|
|
+ "--device", str(device),
|
|
|
|
|
+ "--lr0", str(params['lr0']),
|
|
|
|
|
+ "--lrf", str(params['lrf']),
|
|
|
|
|
+ "--patience", str(params['patience']),
|
|
|
|
|
+ "--optimizer", str(params['optimizer']),
|
|
|
|
|
+ "--weight-decay", str(params['weight_decay']),
|
|
|
|
|
+ "--momentum", str(params['momentum']),
|
|
|
|
|
+ "--warmup-epochs", str(params['warmup_epochs']),
|
|
|
|
|
+ "--workers", str(params['workers']),
|
|
|
]
|
|
]
|
|
|
|
|
+ if params.get('cache'):
|
|
|
|
|
+ cmd.append("--cache")
|
|
|
else:
|
|
else:
|
|
|
- logger.error(f"不支持的YOLO模型类型: {model_type}")
|
|
|
|
|
|
|
+ logger.error(f"不支持的模型架构: {architecture}")
|
|
|
return False
|
|
return False
|
|
|
|
|
|
|
|
# 打印命令行
|
|
# 打印命令行
|
|
|
cmd_str = ' '.join(cmd)
|
|
cmd_str = ' '.join(cmd)
|
|
|
- logger.info(f"执行{model_type}训练命令: {cmd_str}")
|
|
|
|
|
|
|
+ logger.info(f"执行{architecture.value}训练命令: {cmd_str}")
|
|
|
TRAINING_LOGS[model_id].append(f"[{timestamp}] [系统] 执行命令: {cmd_str}")
|
|
TRAINING_LOGS[model_id].append(f"[{timestamp}] [系统] 执行命令: {cmd_str}")
|
|
|
|
|
|
|
|
# 记录开始时间
|
|
# 记录开始时间
|
|
@@ -531,28 +631,28 @@ class ModelService:
|
|
|
self.update_model(model_id,
|
|
self.update_model(model_id,
|
|
|
file_path=best_model_path,
|
|
file_path=best_model_path,
|
|
|
status=ModelStatus.COMPLETED)
|
|
status=ModelStatus.COMPLETED)
|
|
|
- success_msg = f"{model_type}训练完成,耗时: {training_time:.1f}秒,模型保存在: {best_model_path}"
|
|
|
|
|
|
|
+ success_msg = f"{architecture.value}训练完成,耗时: {training_time:.1f}秒,模型保存在: {best_model_path}"
|
|
|
logger.info(success_msg)
|
|
logger.info(success_msg)
|
|
|
TRAINING_LOGS[model_id].append(f"[{timestamp}] [系统] {success_msg}")
|
|
TRAINING_LOGS[model_id].append(f"[{timestamp}] [系统] {success_msg}")
|
|
|
return True
|
|
return True
|
|
|
else:
|
|
else:
|
|
|
- error_msg = f"{model_type}训练完成但模型文件不存在"
|
|
|
|
|
|
|
+ error_msg = f"{architecture.value}训练完成但模型文件不存在"
|
|
|
logger.error(error_msg)
|
|
logger.error(error_msg)
|
|
|
TRAINING_LOGS[model_id].append(f"[{timestamp}] [错误] {error_msg}")
|
|
TRAINING_LOGS[model_id].append(f"[{timestamp}] [错误] {error_msg}")
|
|
|
self._update_training_status(model_id, ModelStatus.FAILED,
|
|
self._update_training_status(model_id, ModelStatus.FAILED,
|
|
|
- error_message=f"{model_type} training completed but model file not found")
|
|
|
|
|
|
|
+ error_message=f"{architecture.value} training completed but model file not found")
|
|
|
return False
|
|
return False
|
|
|
else:
|
|
else:
|
|
|
- error_msg = f"{model_type}训练失败,返回码: {returncode}"
|
|
|
|
|
|
|
+ error_msg = f"{architecture.value}训练失败,返回码: {returncode}"
|
|
|
logger.error(error_msg)
|
|
logger.error(error_msg)
|
|
|
TRAINING_LOGS[model_id].append(f"[{timestamp}] [错误] {error_msg}")
|
|
TRAINING_LOGS[model_id].append(f"[{timestamp}] [错误] {error_msg}")
|
|
|
self._update_training_status(model_id, ModelStatus.FAILED,
|
|
self._update_training_status(model_id, ModelStatus.FAILED,
|
|
|
- error_message=f"{model_type} training failed with return code: {returncode}")
|
|
|
|
|
|
|
+ error_message=f"{architecture.value} training failed with return code: {returncode}")
|
|
|
return False
|
|
return False
|
|
|
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
timestamp = time.strftime("%H:%M:%S", time.localtime())
|
|
timestamp = time.strftime("%H:%M:%S", time.localtime())
|
|
|
- error_msg = f"{model_type}训练过程中发生异常: {str(e)}"
|
|
|
|
|
|
|
+ error_msg = f"{architecture.value}训练过程中发生异常: {str(e)}"
|
|
|
logger.error(error_msg, exc_info=True)
|
|
logger.error(error_msg, exc_info=True)
|
|
|
if model_id in TRAINING_LOGS:
|
|
if model_id in TRAINING_LOGS:
|
|
|
TRAINING_LOGS[model_id].append(f"[{timestamp}] [错误] {error_msg}")
|
|
TRAINING_LOGS[model_id].append(f"[{timestamp}] [错误] {error_msg}")
|
|
@@ -1205,10 +1305,13 @@ class ModelService:
|
|
|
"""使用YOLOv8进行训练"""
|
|
"""使用YOLOv8进行训练"""
|
|
|
model = self.get_model(model_id)
|
|
model = self.get_model(model_id)
|
|
|
if model:
|
|
if model:
|
|
|
- self._run_yolo_training(model_id, dataset_dir, model.parameters, "YOLOv8")
|
|
|
|
|
|
|
+ self._run_yolo_training(
|
|
|
|
|
+ model_id, dataset_dir, model.parameters,
|
|
|
|
|
+ ModelArchitecture.YOLO_V8, ULTRALYTICS_ARCHITECTURES[ModelArchitecture.YOLO_V8]
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
def _train_yolov5(self, model_id: int, dataset_dir: str) -> None:
|
|
def _train_yolov5(self, model_id: int, dataset_dir: str) -> None:
|
|
|
"""使用YOLOv5进行训练"""
|
|
"""使用YOLOv5进行训练"""
|
|
|
model = self.get_model(model_id)
|
|
model = self.get_model(model_id)
|
|
|
if model:
|
|
if model:
|
|
|
- self._run_yolo_training(model_id, dataset_dir, model.parameters, "YOLOv5")
|
|
|
|
|
|
|
+ self._run_yolo_training(model_id, dataset_dir, model.parameters, ModelArchitecture.YOLO_V5, "yolov5s.pt")
|