|
|
@@ -0,0 +1,510 @@
|
|
|
+"""YOLOv8 分割模型训练任务管理(后台线程 + 本地 JSON 持久化)。"""
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import json
|
|
|
+import os
|
|
|
+import shutil
|
|
|
+import threading
|
|
|
+import traceback
|
|
|
+import uuid
|
|
|
+import zipfile
|
|
|
+from datetime import datetime
|
|
|
+from pathlib import Path
|
|
|
+from zoneinfo import ZoneInfo
|
|
|
+
|
|
|
+import torch
|
|
|
+import yaml
|
|
|
+from ultralytics import YOLO
|
|
|
+from werkzeug.utils import secure_filename
|
|
|
+
|
|
|
+from app.models import Model, db
|
|
|
+
|
|
|
+TZ = ZoneInfo('Asia/Shanghai')
|
|
|
+
|
|
|
+BASE_MODELS = [
|
|
|
+ {'id': 'yolov8n-seg.pt', 'label': 'YOLOv8n-seg(轻量,适合试跑)'},
|
|
|
+ {'id': 'yolov8s-seg.pt', 'label': 'YOLOv8s-seg(均衡)'},
|
|
|
+ {'id': 'yolov8m-seg.pt', 'label': 'YOLOv8m-seg(精度更高,耗时更长)'},
|
|
|
+]
|
|
|
+
|
|
|
+_lock = threading.Lock()
|
|
|
+_app = None
|
|
|
+
|
|
|
+
|
|
|
+def init_training_manager(app):
|
|
|
+ global _app
|
|
|
+ _app = app
|
|
|
+ for key in ('DATASETS_FOLDER', 'TRAINING_RUNS_FOLDER', 'TRAINING_META_FOLDER'):
|
|
|
+ path = app.config[key]
|
|
|
+ os.makedirs(path, exist_ok=True)
|
|
|
+ os.makedirs(app.config['MODELS_FOLDER'], exist_ok=True)
|
|
|
+ resume_stale_jobs(app)
|
|
|
+
|
|
|
+
|
|
|
+def _app_for_thread(app):
|
|
|
+ """Flask 应用代理需转为真实对象,否则后台训练线程可能无法执行。"""
|
|
|
+ return app._get_current_object() if hasattr(app, '_get_current_object') else app
|
|
|
+
|
|
|
+
|
|
|
+def _start_training_thread(app, job_id: str):
|
|
|
+ real_app = _app_for_thread(app)
|
|
|
+ threading.Thread(target=_run_training, args=(real_app, job_id), daemon=True).start()
|
|
|
+
|
|
|
+
|
|
|
+def _watch_pending_job(app, job_id: str):
|
|
|
+ """若数秒后仍为 pending 且无日志,自动重试拉起训练线程。"""
|
|
|
+ import time
|
|
|
+
|
|
|
+ time.sleep(2)
|
|
|
+ job = get_job(app, job_id)
|
|
|
+ if job and job.get('status') == 'pending' and not job.get('logs'):
|
|
|
+ app.logger.warning('训练任务 %s 未真正启动,正在重试', job_id)
|
|
|
+ _start_training_thread(app, job_id)
|
|
|
+
|
|
|
+
|
|
|
+def resume_stale_jobs(app):
|
|
|
+ """后端启动后恢复 pending 任务(进程重启会导致后台训练线程丢失)。"""
|
|
|
+ if any(j.get('status') == 'running' for j in list_jobs(app)):
|
|
|
+ return
|
|
|
+ for job in list_jobs(app):
|
|
|
+ if job.get('status') != 'pending':
|
|
|
+ continue
|
|
|
+ job_id = job['job_id']
|
|
|
+ app.logger.info('恢复排队中的训练任务:%s', job_id)
|
|
|
+ _start_training_thread(app, job_id)
|
|
|
+ return
|
|
|
+
|
|
|
+
|
|
|
+def resume_job(app, job_id: str) -> dict:
|
|
|
+ """手动恢复 pending 任务(供前端刷新或 API 调用)。"""
|
|
|
+ job = get_job(app, job_id)
|
|
|
+ if not job:
|
|
|
+ raise ValueError('训练任务不存在')
|
|
|
+ if job.get('status') == 'running':
|
|
|
+ return job
|
|
|
+ if job.get('status') in ('completed', 'failed'):
|
|
|
+ raise ValueError(f"任务已结束({job.get('status')}),无法恢复")
|
|
|
+ if any(j.get('status') == 'running' for j in list_jobs(app)):
|
|
|
+ raise ValueError('已有训练任务进行中,请等待完成')
|
|
|
+ _start_training_thread(app, job_id)
|
|
|
+ threading.Thread(target=_watch_pending_job, args=(_app_for_thread(app), job_id), daemon=True).start()
|
|
|
+ return get_job(app, job_id) or job
|
|
|
+
|
|
|
+
|
|
|
+def _meta_path(app, name: str) -> Path:
|
|
|
+ return Path(app.config['TRAINING_META_FOLDER']) / name
|
|
|
+
|
|
|
+
|
|
|
+def _json_default(obj):
|
|
|
+ if isinstance(obj, datetime):
|
|
|
+ return obj.isoformat()
|
|
|
+ raise TypeError(f'Object of type {type(obj).__name__} is not JSON serializable')
|
|
|
+
|
|
|
+
|
|
|
+def _load_json(app, name: str, default):
|
|
|
+ with _lock:
|
|
|
+ path = _meta_path(app, name)
|
|
|
+ if not path.is_file():
|
|
|
+ return default
|
|
|
+ for candidate in (path, path.with_suffix('.json.bak')):
|
|
|
+ if not candidate.is_file():
|
|
|
+ continue
|
|
|
+ try:
|
|
|
+ with open(candidate, encoding='utf-8') as f:
|
|
|
+ return json.load(f)
|
|
|
+ except (json.JSONDecodeError, OSError):
|
|
|
+ continue
|
|
|
+ return default
|
|
|
+
|
|
|
+
|
|
|
+def _save_json(app, name: str, data):
|
|
|
+ with _lock:
|
|
|
+ path = _meta_path(app, name)
|
|
|
+ path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
+ if path.is_file():
|
|
|
+ bak = path.with_suffix('.json.bak')
|
|
|
+ try:
|
|
|
+ shutil.copy2(path, bak)
|
|
|
+ except OSError:
|
|
|
+ pass
|
|
|
+ tmp = path.with_suffix('.json.tmp')
|
|
|
+ payload = json.dumps(data, ensure_ascii=False, indent=2, default=_json_default)
|
|
|
+ tmp.write_text(payload, encoding='utf-8')
|
|
|
+ tmp.replace(path)
|
|
|
+
|
|
|
+
|
|
|
+def _now_iso():
|
|
|
+ return datetime.now(TZ).isoformat()
|
|
|
+
|
|
|
+
|
|
|
+def list_base_models():
|
|
|
+ return BASE_MODELS
|
|
|
+
|
|
|
+
|
|
|
+def _resolve_data_yaml(data_yaml: str) -> str:
|
|
|
+ """将 data.yaml 的 path 固定为数据集根目录绝对路径,避免 path: . 被解析到错误工作目录。"""
|
|
|
+ src = Path(data_yaml).resolve()
|
|
|
+ if not src.is_file():
|
|
|
+ raise FileNotFoundError(f'未找到 data.yaml:{data_yaml}')
|
|
|
+ root = src.parent
|
|
|
+ with open(src, encoding='utf-8') as f:
|
|
|
+ cfg = yaml.safe_load(f) or {}
|
|
|
+ cfg['path'] = str(root)
|
|
|
+ resolved = root / '_dockscope_data.yaml'
|
|
|
+ with open(resolved, 'w', encoding='utf-8') as f:
|
|
|
+ yaml.dump(cfg, f, allow_unicode=True, default_flow_style=False, sort_keys=False)
|
|
|
+ return str(resolved)
|
|
|
+
|
|
|
+
|
|
|
+def _ensure_base_weights(app, base_model: str) -> Path:
|
|
|
+ models_dir = Path(app.config['MODELS_FOLDER'])
|
|
|
+ target = models_dir / base_model
|
|
|
+ if target.is_file():
|
|
|
+ return target
|
|
|
+ YOLO(base_model)
|
|
|
+ if not target.is_file():
|
|
|
+ downloaded = list(models_dir.glob(base_model))
|
|
|
+ if downloaded:
|
|
|
+ return downloaded[0]
|
|
|
+ cwd_candidate = Path.cwd() / base_model
|
|
|
+ if cwd_candidate.is_file():
|
|
|
+ shutil.copy2(cwd_candidate, target)
|
|
|
+ if not target.is_file():
|
|
|
+ raise FileNotFoundError(f'无法获取基线权重 {base_model},请检查网络或手动放入 models 目录')
|
|
|
+ return target
|
|
|
+
|
|
|
+
|
|
|
+def _find_data_yaml(root: Path) -> Path | None:
|
|
|
+ for name in ('data.yaml', 'dataset.yaml'):
|
|
|
+ for p in root.rglob(name):
|
|
|
+ return p
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def list_datasets(app):
|
|
|
+ return _load_json(app, 'datasets.json', [])
|
|
|
+
|
|
|
+
|
|
|
+def get_dataset(app, dataset_id: str):
|
|
|
+ for item in list_datasets(app):
|
|
|
+ if item['dataset_id'] == dataset_id:
|
|
|
+ return item
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def upload_dataset(app, zip_file, display_name: str, owner_id: int, owner_username: str):
|
|
|
+ if not zip_file or not zip_file.filename.lower().endswith('.zip'):
|
|
|
+ raise ValueError('请上传 YOLO 格式数据集的 ZIP 压缩包')
|
|
|
+
|
|
|
+ dataset_id = uuid.uuid4().hex[:12]
|
|
|
+ root = Path(app.config['DATASETS_FOLDER']) / dataset_id
|
|
|
+ root.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ zip_path = root / secure_filename(zip_file.filename)
|
|
|
+ zip_file.save(zip_path)
|
|
|
+
|
|
|
+ try:
|
|
|
+ with zipfile.ZipFile(zip_path, 'r') as zf:
|
|
|
+ zf.extractall(root / 'extracted')
|
|
|
+ except zipfile.BadZipFile as exc:
|
|
|
+ shutil.rmtree(root, ignore_errors=True)
|
|
|
+ raise ValueError('ZIP 文件损坏或格式无效') from exc
|
|
|
+ finally:
|
|
|
+ if zip_path.is_file():
|
|
|
+ zip_path.unlink()
|
|
|
+
|
|
|
+ extracted = root / 'extracted'
|
|
|
+ data_yaml = _find_data_yaml(extracted)
|
|
|
+ if not data_yaml:
|
|
|
+ shutil.rmtree(root, ignore_errors=True)
|
|
|
+ raise ValueError('压缩包内未找到 data.yaml / dataset.yaml,请使用 Ultralytics 标准目录结构')
|
|
|
+
|
|
|
+ try:
|
|
|
+ with open(data_yaml, encoding='utf-8') as f:
|
|
|
+ yaml.safe_load(f)
|
|
|
+ except Exception as exc:
|
|
|
+ shutil.rmtree(root, ignore_errors=True)
|
|
|
+ raise ValueError(f'data.yaml 解析失败:{exc}') from exc
|
|
|
+
|
|
|
+ item = {
|
|
|
+ 'dataset_id': dataset_id,
|
|
|
+ 'name': display_name or data_yaml.parent.name,
|
|
|
+ 'data_yaml': str(data_yaml.resolve()),
|
|
|
+ 'root_path': str(extracted.resolve()),
|
|
|
+ 'owner_id': owner_id,
|
|
|
+ 'owner_username': owner_username,
|
|
|
+ 'created_at': _now_iso(),
|
|
|
+ }
|
|
|
+ datasets = list_datasets(app)
|
|
|
+ datasets.insert(0, item)
|
|
|
+ _save_json(app, 'datasets.json', datasets)
|
|
|
+ return item
|
|
|
+
|
|
|
+
|
|
|
+def list_jobs(app):
|
|
|
+ jobs = _load_json(app, 'jobs.json', [])
|
|
|
+ return sorted(jobs, key=lambda j: j.get('created_at', ''), reverse=True)
|
|
|
+
|
|
|
+
|
|
|
+def get_job(app, job_id: str):
|
|
|
+ for job in list_jobs(app):
|
|
|
+ if job['job_id'] == job_id:
|
|
|
+ return job
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def _update_job(app, job_id: str, **fields):
|
|
|
+ jobs = list_jobs(app)
|
|
|
+ for job in jobs:
|
|
|
+ if job['job_id'] == job_id:
|
|
|
+ job.update(fields)
|
|
|
+ job['updated_at'] = _now_iso()
|
|
|
+ _save_json(app, 'jobs.json', jobs)
|
|
|
+ return job
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def _append_log(app, job_id: str, line: str):
|
|
|
+ job = get_job(app, job_id)
|
|
|
+ if not job:
|
|
|
+ return
|
|
|
+ logs = job.get('logs') or []
|
|
|
+ logs.append(f"[{datetime.now(TZ).strftime('%H:%M:%S')}] {line}")
|
|
|
+ if len(logs) > 500:
|
|
|
+ logs = logs[-500:]
|
|
|
+ _update_job(app, job_id, logs=logs)
|
|
|
+
|
|
|
+
|
|
|
+def _parse_results_csv(csv_path: Path) -> dict:
|
|
|
+ if not csv_path.is_file():
|
|
|
+ return {}
|
|
|
+ try:
|
|
|
+ import csv
|
|
|
+
|
|
|
+ with open(csv_path, encoding='utf-8') as f:
|
|
|
+ rows = list(csv.DictReader(f))
|
|
|
+ if not rows:
|
|
|
+ return {}
|
|
|
+ last = rows[-1]
|
|
|
+ return {
|
|
|
+ 'box_p': float(last.get('metrics/precision(B)', 0) or 0),
|
|
|
+ 'box_r': float(last.get('metrics/recall(B)', 0) or 0),
|
|
|
+ 'box_mAP50': float(last.get('metrics/mAP50(B)', 0) or 0),
|
|
|
+ 'box_mAP50_95': float(last.get('metrics/mAP50-95(B)', 0) or 0),
|
|
|
+ 'mask_p': float(last.get('metrics/precision(M)', 0) or 0),
|
|
|
+ 'mask_r': float(last.get('metrics/recall(M)', 0) or 0),
|
|
|
+ 'mask_mAP50': float(last.get('metrics/mAP50(M)', 0) or 0),
|
|
|
+ 'mask_mAP50_95': float(last.get('metrics/mAP50-95(M)', 0) or 0),
|
|
|
+ }
|
|
|
+ except Exception:
|
|
|
+ return {}
|
|
|
+
|
|
|
+
|
|
|
+def _register_model(app, job: dict, weights_path: Path, metrics: dict):
|
|
|
+ model_name = job['output_model_name']
|
|
|
+ if not model_name.endswith('.pt'):
|
|
|
+ model_name = f'{model_name}.pt'
|
|
|
+ dest = Path(app.config['MODELS_FOLDER']) / model_name
|
|
|
+ shutil.copy2(weights_path, dest)
|
|
|
+ rel_path = f'static/models/{model_name}'.replace('\\', '/')
|
|
|
+
|
|
|
+ if Model.query.filter_by(model_name=model_name).first():
|
|
|
+ model_name = f"{Path(model_name).stem}_{job['job_id'][:6]}.pt"
|
|
|
+ dest = Path(app.config['MODELS_FOLDER']) / model_name
|
|
|
+ shutil.copy2(weights_path, dest)
|
|
|
+ rel_path = f'static/models/{model_name}'.replace('\\', '/')
|
|
|
+
|
|
|
+ m = YOLO(str(dest))
|
|
|
+ layers = 0
|
|
|
+ params = 0
|
|
|
+ try:
|
|
|
+ layers = len(list(m.model.modules()))
|
|
|
+ params = sum(p.numel() for p in m.model.parameters())
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+
|
|
|
+ mask_mAP50 = metrics.get('mask_mAP50', 0.0)
|
|
|
+ fitness = round(mask_mAP50 * 2 + metrics.get('box_mAP50', 0.0), 5)
|
|
|
+
|
|
|
+ record = Model(
|
|
|
+ model_name=model_name,
|
|
|
+ model_path=rel_path,
|
|
|
+ disease_category=job['disease_category'],
|
|
|
+ augmentation=job.get('augmentation') or '随机点+颜色扭曲+高斯模糊',
|
|
|
+ layers=layers or 113,
|
|
|
+ parameters=params or 0,
|
|
|
+ GFLOPs=35.3,
|
|
|
+ box_p=round(metrics.get('box_p', 0.0), 3),
|
|
|
+ box_r=round(metrics.get('box_r', 0.0), 3),
|
|
|
+ box_mAP50=round(metrics.get('box_mAP50', 0.0), 3),
|
|
|
+ box_mAP50_95=round(metrics.get('box_mAP50_95', 0.0), 3),
|
|
|
+ mask_p=round(metrics.get('mask_p', 0.0), 3),
|
|
|
+ mask_r=round(metrics.get('mask_r', 0.0), 3),
|
|
|
+ mask_mAP50=round(mask_mAP50, 3),
|
|
|
+ mask_mAP50_95=round(metrics.get('mask_mAP50_95', 0.0), 3),
|
|
|
+ f1_score=round((metrics.get('mask_p', 0) + metrics.get('mask_r', 0)) or 0, 5),
|
|
|
+ fitness_score=fitness,
|
|
|
+ owner_id=job['owner_id'],
|
|
|
+ )
|
|
|
+ db.session.add(record)
|
|
|
+ db.session.commit()
|
|
|
+ registered = record.to_dict()
|
|
|
+ for key in ('created_at', 'updated_at'):
|
|
|
+ val = registered.get(key)
|
|
|
+ if val is not None and hasattr(val, 'isoformat'):
|
|
|
+ registered[key] = val.isoformat()
|
|
|
+ return registered
|
|
|
+
|
|
|
+
|
|
|
+def _run_training(app, job_id: str):
|
|
|
+ with app.app_context():
|
|
|
+ job = get_job(app, job_id)
|
|
|
+ if not job:
|
|
|
+ return
|
|
|
+ try:
|
|
|
+ _update_job(app, job_id, status='running', progress=0, started_at=_now_iso())
|
|
|
+ _append_log(app, job_id, '开始加载基线权重…')
|
|
|
+
|
|
|
+ base_path = _ensure_base_weights(app, job['base_model'])
|
|
|
+ data_yaml = _resolve_data_yaml(job['data_yaml'])
|
|
|
+ _append_log(app, job_id, f'数据集配置:{data_yaml}')
|
|
|
+ use_cuda = torch.cuda.is_available()
|
|
|
+ device = 'cuda' if use_cuda else 'cpu'
|
|
|
+ _append_log(app, job_id, f'训练设备:{device}')
|
|
|
+
|
|
|
+ model = YOLO(str(base_path))
|
|
|
+ runs_root = Path(app.config['TRAINING_RUNS_FOLDER'])
|
|
|
+ run_name = job['run_name']
|
|
|
+
|
|
|
+ _append_log(app, job_id, f"启动 YOLOv8 分割训练:epochs={job['epochs']}, imgsz={job['imgsz']}")
|
|
|
+
|
|
|
+ def on_epoch_end(trainer):
|
|
|
+ epoch = trainer.epoch + 1
|
|
|
+ total = trainer.epochs
|
|
|
+ progress = min(99, int(epoch / total * 100))
|
|
|
+ _update_job(app, job_id, progress=progress, current_epoch=epoch, total_epochs=total)
|
|
|
+ _append_log(app, job_id, f'Epoch {epoch}/{total} 完成')
|
|
|
+
|
|
|
+ model.add_callback('on_train_epoch_end', on_epoch_end)
|
|
|
+
|
|
|
+ model.train(
|
|
|
+ data=data_yaml,
|
|
|
+ epochs=int(job['epochs']),
|
|
|
+ imgsz=int(job['imgsz']),
|
|
|
+ batch=int(job['batch']),
|
|
|
+ project=str(runs_root),
|
|
|
+ name=run_name,
|
|
|
+ exist_ok=True,
|
|
|
+ device=device,
|
|
|
+ patience=int(job.get('patience', 20)),
|
|
|
+ verbose=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ run_dir = runs_root / run_name
|
|
|
+ best_pt = run_dir / 'weights' / 'best.pt'
|
|
|
+ if not best_pt.is_file():
|
|
|
+ best_pt = run_dir / 'weights' / 'last.pt'
|
|
|
+ if not best_pt.is_file():
|
|
|
+ raise FileNotFoundError('训练结束但未找到 weights/best.pt')
|
|
|
+
|
|
|
+ metrics = _parse_results_csv(run_dir / 'results.csv')
|
|
|
+ _append_log(app, job_id, f'训练完成,best 权重:{best_pt.name}')
|
|
|
+
|
|
|
+ registered = None
|
|
|
+ if job.get('register_to_library', True):
|
|
|
+ registered = _register_model(app, job, best_pt, metrics)
|
|
|
+ _append_log(app, job_id, f"已写入模型库:{registered['model_name']}")
|
|
|
+
|
|
|
+ _update_job(
|
|
|
+ app,
|
|
|
+ job_id,
|
|
|
+ status='completed',
|
|
|
+ progress=100,
|
|
|
+ finished_at=_now_iso(),
|
|
|
+ best_weights=str(best_pt.resolve()),
|
|
|
+ metrics=metrics,
|
|
|
+ registered_model=registered,
|
|
|
+ )
|
|
|
+ except Exception as exc:
|
|
|
+ tb = traceback.format_exc()
|
|
|
+ _append_log(app, job_id, f'训练失败:{exc}')
|
|
|
+ _update_job(
|
|
|
+ app,
|
|
|
+ job_id,
|
|
|
+ status='failed',
|
|
|
+ finished_at=_now_iso(),
|
|
|
+ error_message=str(exc),
|
|
|
+ error_trace=tb[-2000:],
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def start_job(
|
|
|
+ app,
|
|
|
+ *,
|
|
|
+ owner_id: int,
|
|
|
+ owner_username: str,
|
|
|
+ dataset_id: str,
|
|
|
+ base_model: str,
|
|
|
+ disease_category: str,
|
|
|
+ output_model_name: str,
|
|
|
+ epochs: int = 50,
|
|
|
+ imgsz: int = 640,
|
|
|
+ batch: int = 8,
|
|
|
+ augmentation: str = '随机点+颜色扭曲+高斯模糊',
|
|
|
+ register_to_library: bool = True,
|
|
|
+):
|
|
|
+ dataset = get_dataset(app, dataset_id)
|
|
|
+ if not dataset:
|
|
|
+ raise ValueError('数据集不存在')
|
|
|
+
|
|
|
+ allowed = {m['id'] for m in BASE_MODELS}
|
|
|
+ if base_model not in allowed:
|
|
|
+ raise ValueError('不支持的基线模型')
|
|
|
+
|
|
|
+ for job in list_jobs(app):
|
|
|
+ if job.get('status') == 'running':
|
|
|
+ raise ValueError('已有训练任务进行中,请等待完成后再启动')
|
|
|
+
|
|
|
+ if not disease_category.strip():
|
|
|
+ raise ValueError('请填写隐患类别')
|
|
|
+ raw_name = output_model_name.strip()
|
|
|
+ if not raw_name:
|
|
|
+ raise ValueError('请填写输出模型文件名')
|
|
|
+ if not raw_name.lower().endswith('.pt'):
|
|
|
+ raw_name = f'{raw_name}.pt'
|
|
|
+ safe_name = secure_filename(raw_name)
|
|
|
+ if not safe_name or not safe_name.endswith('.pt'):
|
|
|
+ safe_name = f"bridge_seg_{uuid.uuid4().hex[:8]}.pt"
|
|
|
+
|
|
|
+ job_id = uuid.uuid4().hex[:12]
|
|
|
+ run_name = f"train_{job_id}"
|
|
|
+ job = {
|
|
|
+ 'job_id': job_id,
|
|
|
+ 'run_name': run_name,
|
|
|
+ 'status': 'pending',
|
|
|
+ 'progress': 0,
|
|
|
+ 'owner_id': owner_id,
|
|
|
+ 'owner_username': owner_username,
|
|
|
+ 'dataset_id': dataset_id,
|
|
|
+ 'dataset_name': dataset['name'],
|
|
|
+ 'data_yaml': dataset['data_yaml'],
|
|
|
+ 'base_model': base_model,
|
|
|
+ 'disease_category': disease_category.strip(),
|
|
|
+ 'output_model_name': safe_name,
|
|
|
+ 'epochs': max(1, min(int(epochs), 500)),
|
|
|
+ 'imgsz': max(320, min(int(imgsz), 1280)),
|
|
|
+ 'batch': max(1, min(int(batch), 64)),
|
|
|
+ 'augmentation': augmentation,
|
|
|
+ 'register_to_library': bool(register_to_library),
|
|
|
+ 'logs': [],
|
|
|
+ 'created_at': _now_iso(),
|
|
|
+ 'updated_at': _now_iso(),
|
|
|
+ }
|
|
|
+
|
|
|
+ jobs = list_jobs(app)
|
|
|
+ jobs.insert(0, job)
|
|
|
+ _save_json(app, 'jobs.json', jobs)
|
|
|
+
|
|
|
+ _start_training_thread(app, job_id)
|
|
|
+ threading.Thread(target=_watch_pending_job, args=(_app_for_thread(app), job_id), daemon=True).start()
|
|
|
+ return job
|