| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510 |
- """YOLOv8 分割模型训练任务管理(后台线程 + 本地 JSON 持久化)。"""
- from __future__ import annotations
- import json
- import os
- import shutil
- import threading
- import traceback
- import uuid
- import zipfile
- from datetime import datetime
- from pathlib import Path
- from zoneinfo import ZoneInfo
- import torch
- import yaml
- from ultralytics import YOLO
- from werkzeug.utils import secure_filename
- from app.models import Model, db
- TZ = ZoneInfo('Asia/Shanghai')
- BASE_MODELS = [
- {'id': 'yolov8n-seg.pt', 'label': 'YOLOv8n-seg(轻量,适合试跑)'},
- {'id': 'yolov8s-seg.pt', 'label': 'YOLOv8s-seg(均衡)'},
- {'id': 'yolov8m-seg.pt', 'label': 'YOLOv8m-seg(精度更高,耗时更长)'},
- ]
- _lock = threading.Lock()
- _app = None
- def init_training_manager(app):
- global _app
- _app = app
- for key in ('DATASETS_FOLDER', 'TRAINING_RUNS_FOLDER', 'TRAINING_META_FOLDER'):
- path = app.config[key]
- os.makedirs(path, exist_ok=True)
- os.makedirs(app.config['MODELS_FOLDER'], exist_ok=True)
- resume_stale_jobs(app)
- def _app_for_thread(app):
- """Flask 应用代理需转为真实对象,否则后台训练线程可能无法执行。"""
- return app._get_current_object() if hasattr(app, '_get_current_object') else app
- def _start_training_thread(app, job_id: str):
- real_app = _app_for_thread(app)
- threading.Thread(target=_run_training, args=(real_app, job_id), daemon=True).start()
- def _watch_pending_job(app, job_id: str):
- """若数秒后仍为 pending 且无日志,自动重试拉起训练线程。"""
- import time
- time.sleep(2)
- job = get_job(app, job_id)
- if job and job.get('status') == 'pending' and not job.get('logs'):
- app.logger.warning('训练任务 %s 未真正启动,正在重试', job_id)
- _start_training_thread(app, job_id)
- def resume_stale_jobs(app):
- """后端启动后恢复 pending 任务(进程重启会导致后台训练线程丢失)。"""
- if any(j.get('status') == 'running' for j in list_jobs(app)):
- return
- for job in list_jobs(app):
- if job.get('status') != 'pending':
- continue
- job_id = job['job_id']
- app.logger.info('恢复排队中的训练任务:%s', job_id)
- _start_training_thread(app, job_id)
- return
- def resume_job(app, job_id: str) -> dict:
- """手动恢复 pending 任务(供前端刷新或 API 调用)。"""
- job = get_job(app, job_id)
- if not job:
- raise ValueError('训练任务不存在')
- if job.get('status') == 'running':
- return job
- if job.get('status') in ('completed', 'failed'):
- raise ValueError(f"任务已结束({job.get('status')}),无法恢复")
- if any(j.get('status') == 'running' for j in list_jobs(app)):
- raise ValueError('已有训练任务进行中,请等待完成')
- _start_training_thread(app, job_id)
- threading.Thread(target=_watch_pending_job, args=(_app_for_thread(app), job_id), daemon=True).start()
- return get_job(app, job_id) or job
- def _meta_path(app, name: str) -> Path:
- return Path(app.config['TRAINING_META_FOLDER']) / name
- def _json_default(obj):
- if isinstance(obj, datetime):
- return obj.isoformat()
- raise TypeError(f'Object of type {type(obj).__name__} is not JSON serializable')
- def _load_json(app, name: str, default):
- with _lock:
- path = _meta_path(app, name)
- if not path.is_file():
- return default
- for candidate in (path, path.with_suffix('.json.bak')):
- if not candidate.is_file():
- continue
- try:
- with open(candidate, encoding='utf-8') as f:
- return json.load(f)
- except (json.JSONDecodeError, OSError):
- continue
- return default
- def _save_json(app, name: str, data):
- with _lock:
- path = _meta_path(app, name)
- path.parent.mkdir(parents=True, exist_ok=True)
- if path.is_file():
- bak = path.with_suffix('.json.bak')
- try:
- shutil.copy2(path, bak)
- except OSError:
- pass
- tmp = path.with_suffix('.json.tmp')
- payload = json.dumps(data, ensure_ascii=False, indent=2, default=_json_default)
- tmp.write_text(payload, encoding='utf-8')
- tmp.replace(path)
- def _now_iso():
- return datetime.now(TZ).isoformat()
- def list_base_models():
- return BASE_MODELS
- def _resolve_data_yaml(data_yaml: str) -> str:
- """将 data.yaml 的 path 固定为数据集根目录绝对路径,避免 path: . 被解析到错误工作目录。"""
- src = Path(data_yaml).resolve()
- if not src.is_file():
- raise FileNotFoundError(f'未找到 data.yaml:{data_yaml}')
- root = src.parent
- with open(src, encoding='utf-8') as f:
- cfg = yaml.safe_load(f) or {}
- cfg['path'] = str(root)
- resolved = root / '_dockscope_data.yaml'
- with open(resolved, 'w', encoding='utf-8') as f:
- yaml.dump(cfg, f, allow_unicode=True, default_flow_style=False, sort_keys=False)
- return str(resolved)
- def _ensure_base_weights(app, base_model: str) -> Path:
- models_dir = Path(app.config['MODELS_FOLDER'])
- target = models_dir / base_model
- if target.is_file():
- return target
- YOLO(base_model)
- if not target.is_file():
- downloaded = list(models_dir.glob(base_model))
- if downloaded:
- return downloaded[0]
- cwd_candidate = Path.cwd() / base_model
- if cwd_candidate.is_file():
- shutil.copy2(cwd_candidate, target)
- if not target.is_file():
- raise FileNotFoundError(f'无法获取基线权重 {base_model},请检查网络或手动放入 models 目录')
- return target
- def _find_data_yaml(root: Path) -> Path | None:
- for name in ('data.yaml', 'dataset.yaml'):
- for p in root.rglob(name):
- return p
- return None
- def list_datasets(app):
- return _load_json(app, 'datasets.json', [])
- def get_dataset(app, dataset_id: str):
- for item in list_datasets(app):
- if item['dataset_id'] == dataset_id:
- return item
- return None
- def upload_dataset(app, zip_file, display_name: str, owner_id: int, owner_username: str):
- if not zip_file or not zip_file.filename.lower().endswith('.zip'):
- raise ValueError('请上传 YOLO 格式数据集的 ZIP 压缩包')
- dataset_id = uuid.uuid4().hex[:12]
- root = Path(app.config['DATASETS_FOLDER']) / dataset_id
- root.mkdir(parents=True, exist_ok=True)
- zip_path = root / secure_filename(zip_file.filename)
- zip_file.save(zip_path)
- try:
- with zipfile.ZipFile(zip_path, 'r') as zf:
- zf.extractall(root / 'extracted')
- except zipfile.BadZipFile as exc:
- shutil.rmtree(root, ignore_errors=True)
- raise ValueError('ZIP 文件损坏或格式无效') from exc
- finally:
- if zip_path.is_file():
- zip_path.unlink()
- extracted = root / 'extracted'
- data_yaml = _find_data_yaml(extracted)
- if not data_yaml:
- shutil.rmtree(root, ignore_errors=True)
- raise ValueError('压缩包内未找到 data.yaml / dataset.yaml,请使用 Ultralytics 标准目录结构')
- try:
- with open(data_yaml, encoding='utf-8') as f:
- yaml.safe_load(f)
- except Exception as exc:
- shutil.rmtree(root, ignore_errors=True)
- raise ValueError(f'data.yaml 解析失败:{exc}') from exc
- item = {
- 'dataset_id': dataset_id,
- 'name': display_name or data_yaml.parent.name,
- 'data_yaml': str(data_yaml.resolve()),
- 'root_path': str(extracted.resolve()),
- 'owner_id': owner_id,
- 'owner_username': owner_username,
- 'created_at': _now_iso(),
- }
- datasets = list_datasets(app)
- datasets.insert(0, item)
- _save_json(app, 'datasets.json', datasets)
- return item
- def list_jobs(app):
- jobs = _load_json(app, 'jobs.json', [])
- return sorted(jobs, key=lambda j: j.get('created_at', ''), reverse=True)
- def get_job(app, job_id: str):
- for job in list_jobs(app):
- if job['job_id'] == job_id:
- return job
- return None
- def _update_job(app, job_id: str, **fields):
- jobs = list_jobs(app)
- for job in jobs:
- if job['job_id'] == job_id:
- job.update(fields)
- job['updated_at'] = _now_iso()
- _save_json(app, 'jobs.json', jobs)
- return job
- return None
- def _append_log(app, job_id: str, line: str):
- job = get_job(app, job_id)
- if not job:
- return
- logs = job.get('logs') or []
- logs.append(f"[{datetime.now(TZ).strftime('%H:%M:%S')}] {line}")
- if len(logs) > 500:
- logs = logs[-500:]
- _update_job(app, job_id, logs=logs)
- def _parse_results_csv(csv_path: Path) -> dict:
- if not csv_path.is_file():
- return {}
- try:
- import csv
- with open(csv_path, encoding='utf-8') as f:
- rows = list(csv.DictReader(f))
- if not rows:
- return {}
- last = rows[-1]
- return {
- 'box_p': float(last.get('metrics/precision(B)', 0) or 0),
- 'box_r': float(last.get('metrics/recall(B)', 0) or 0),
- 'box_mAP50': float(last.get('metrics/mAP50(B)', 0) or 0),
- 'box_mAP50_95': float(last.get('metrics/mAP50-95(B)', 0) or 0),
- 'mask_p': float(last.get('metrics/precision(M)', 0) or 0),
- 'mask_r': float(last.get('metrics/recall(M)', 0) or 0),
- 'mask_mAP50': float(last.get('metrics/mAP50(M)', 0) or 0),
- 'mask_mAP50_95': float(last.get('metrics/mAP50-95(M)', 0) or 0),
- }
- except Exception:
- return {}
- def _register_model(app, job: dict, weights_path: Path, metrics: dict):
- model_name = job['output_model_name']
- if not model_name.endswith('.pt'):
- model_name = f'{model_name}.pt'
- dest = Path(app.config['MODELS_FOLDER']) / model_name
- shutil.copy2(weights_path, dest)
- rel_path = f'static/models/{model_name}'.replace('\\', '/')
- if Model.query.filter_by(model_name=model_name).first():
- model_name = f"{Path(model_name).stem}_{job['job_id'][:6]}.pt"
- dest = Path(app.config['MODELS_FOLDER']) / model_name
- shutil.copy2(weights_path, dest)
- rel_path = f'static/models/{model_name}'.replace('\\', '/')
- m = YOLO(str(dest))
- layers = 0
- params = 0
- try:
- layers = len(list(m.model.modules()))
- params = sum(p.numel() for p in m.model.parameters())
- except Exception:
- pass
- mask_mAP50 = metrics.get('mask_mAP50', 0.0)
- fitness = round(mask_mAP50 * 2 + metrics.get('box_mAP50', 0.0), 5)
- record = Model(
- model_name=model_name,
- model_path=rel_path,
- disease_category=job['disease_category'],
- augmentation=job.get('augmentation') or '随机点+颜色扭曲+高斯模糊',
- layers=layers or 113,
- parameters=params or 0,
- GFLOPs=35.3,
- box_p=round(metrics.get('box_p', 0.0), 3),
- box_r=round(metrics.get('box_r', 0.0), 3),
- box_mAP50=round(metrics.get('box_mAP50', 0.0), 3),
- box_mAP50_95=round(metrics.get('box_mAP50_95', 0.0), 3),
- mask_p=round(metrics.get('mask_p', 0.0), 3),
- mask_r=round(metrics.get('mask_r', 0.0), 3),
- mask_mAP50=round(mask_mAP50, 3),
- mask_mAP50_95=round(metrics.get('mask_mAP50_95', 0.0), 3),
- f1_score=round((metrics.get('mask_p', 0) + metrics.get('mask_r', 0)) or 0, 5),
- fitness_score=fitness,
- owner_id=job['owner_id'],
- )
- db.session.add(record)
- db.session.commit()
- registered = record.to_dict()
- for key in ('created_at', 'updated_at'):
- val = registered.get(key)
- if val is not None and hasattr(val, 'isoformat'):
- registered[key] = val.isoformat()
- return registered
- def _run_training(app, job_id: str):
- with app.app_context():
- job = get_job(app, job_id)
- if not job:
- return
- try:
- _update_job(app, job_id, status='running', progress=0, started_at=_now_iso())
- _append_log(app, job_id, '开始加载基线权重…')
- base_path = _ensure_base_weights(app, job['base_model'])
- data_yaml = _resolve_data_yaml(job['data_yaml'])
- _append_log(app, job_id, f'数据集配置:{data_yaml}')
- use_cuda = torch.cuda.is_available()
- device = 'cuda' if use_cuda else 'cpu'
- _append_log(app, job_id, f'训练设备:{device}')
- model = YOLO(str(base_path))
- runs_root = Path(app.config['TRAINING_RUNS_FOLDER'])
- run_name = job['run_name']
- _append_log(app, job_id, f"启动 YOLOv8 分割训练:epochs={job['epochs']}, imgsz={job['imgsz']}")
- def on_epoch_end(trainer):
- epoch = trainer.epoch + 1
- total = trainer.epochs
- progress = min(99, int(epoch / total * 100))
- _update_job(app, job_id, progress=progress, current_epoch=epoch, total_epochs=total)
- _append_log(app, job_id, f'Epoch {epoch}/{total} 完成')
- model.add_callback('on_train_epoch_end', on_epoch_end)
- model.train(
- data=data_yaml,
- epochs=int(job['epochs']),
- imgsz=int(job['imgsz']),
- batch=int(job['batch']),
- project=str(runs_root),
- name=run_name,
- exist_ok=True,
- device=device,
- patience=int(job.get('patience', 20)),
- verbose=True,
- )
- run_dir = runs_root / run_name
- best_pt = run_dir / 'weights' / 'best.pt'
- if not best_pt.is_file():
- best_pt = run_dir / 'weights' / 'last.pt'
- if not best_pt.is_file():
- raise FileNotFoundError('训练结束但未找到 weights/best.pt')
- metrics = _parse_results_csv(run_dir / 'results.csv')
- _append_log(app, job_id, f'训练完成,best 权重:{best_pt.name}')
- registered = None
- if job.get('register_to_library', True):
- registered = _register_model(app, job, best_pt, metrics)
- _append_log(app, job_id, f"已写入模型库:{registered['model_name']}")
- _update_job(
- app,
- job_id,
- status='completed',
- progress=100,
- finished_at=_now_iso(),
- best_weights=str(best_pt.resolve()),
- metrics=metrics,
- registered_model=registered,
- )
- except Exception as exc:
- tb = traceback.format_exc()
- _append_log(app, job_id, f'训练失败:{exc}')
- _update_job(
- app,
- job_id,
- status='failed',
- finished_at=_now_iso(),
- error_message=str(exc),
- error_trace=tb[-2000:],
- )
- def start_job(
- app,
- *,
- owner_id: int,
- owner_username: str,
- dataset_id: str,
- base_model: str,
- disease_category: str,
- output_model_name: str,
- epochs: int = 50,
- imgsz: int = 640,
- batch: int = 8,
- augmentation: str = '随机点+颜色扭曲+高斯模糊',
- register_to_library: bool = True,
- ):
- dataset = get_dataset(app, dataset_id)
- if not dataset:
- raise ValueError('数据集不存在')
- allowed = {m['id'] for m in BASE_MODELS}
- if base_model not in allowed:
- raise ValueError('不支持的基线模型')
- for job in list_jobs(app):
- if job.get('status') == 'running':
- raise ValueError('已有训练任务进行中,请等待完成后再启动')
- if not disease_category.strip():
- raise ValueError('请填写隐患类别')
- raw_name = output_model_name.strip()
- if not raw_name:
- raise ValueError('请填写输出模型文件名')
- if not raw_name.lower().endswith('.pt'):
- raw_name = f'{raw_name}.pt'
- safe_name = secure_filename(raw_name)
- if not safe_name or not safe_name.endswith('.pt'):
- safe_name = f"bridge_seg_{uuid.uuid4().hex[:8]}.pt"
- job_id = uuid.uuid4().hex[:12]
- run_name = f"train_{job_id}"
- job = {
- 'job_id': job_id,
- 'run_name': run_name,
- 'status': 'pending',
- 'progress': 0,
- 'owner_id': owner_id,
- 'owner_username': owner_username,
- 'dataset_id': dataset_id,
- 'dataset_name': dataset['name'],
- 'data_yaml': dataset['data_yaml'],
- 'base_model': base_model,
- 'disease_category': disease_category.strip(),
- 'output_model_name': safe_name,
- 'epochs': max(1, min(int(epochs), 500)),
- 'imgsz': max(320, min(int(imgsz), 1280)),
- 'batch': max(1, min(int(batch), 64)),
- 'augmentation': augmentation,
- 'register_to_library': bool(register_to_library),
- 'logs': [],
- 'created_at': _now_iso(),
- 'updated_at': _now_iso(),
- }
- jobs = list_jobs(app)
- jobs.insert(0, job)
- _save_json(app, 'jobs.json', jobs)
- _start_training_thread(app, job_id)
- threading.Thread(target=_watch_pending_job, args=(_app_for_thread(app), job_id), daemon=True).start()
- return job
|