training_manager.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. """YOLOv8 分割模型训练任务管理(后台线程 + 本地 JSON 持久化)。"""
  2. from __future__ import annotations
  3. import json
  4. import os
  5. import shutil
  6. import threading
  7. import traceback
  8. import uuid
  9. import zipfile
  10. from datetime import datetime
  11. from pathlib import Path
  12. from zoneinfo import ZoneInfo
  13. import torch
  14. import yaml
  15. from ultralytics import YOLO
  16. from werkzeug.utils import secure_filename
  17. from app.models import Model, db
  18. TZ = ZoneInfo('Asia/Shanghai')
  19. BASE_MODELS = [
  20. {'id': 'yolov8n-seg.pt', 'label': 'YOLOv8n-seg(轻量,适合试跑)'},
  21. {'id': 'yolov8s-seg.pt', 'label': 'YOLOv8s-seg(均衡)'},
  22. {'id': 'yolov8m-seg.pt', 'label': 'YOLOv8m-seg(精度更高,耗时更长)'},
  23. ]
  24. _lock = threading.Lock()
  25. _app = None
  26. def init_training_manager(app):
  27. global _app
  28. _app = app
  29. for key in ('DATASETS_FOLDER', 'TRAINING_RUNS_FOLDER', 'TRAINING_META_FOLDER'):
  30. path = app.config[key]
  31. os.makedirs(path, exist_ok=True)
  32. os.makedirs(app.config['MODELS_FOLDER'], exist_ok=True)
  33. resume_stale_jobs(app)
  34. def _app_for_thread(app):
  35. """Flask 应用代理需转为真实对象,否则后台训练线程可能无法执行。"""
  36. return app._get_current_object() if hasattr(app, '_get_current_object') else app
  37. def _start_training_thread(app, job_id: str):
  38. real_app = _app_for_thread(app)
  39. threading.Thread(target=_run_training, args=(real_app, job_id), daemon=True).start()
  40. def _watch_pending_job(app, job_id: str):
  41. """若数秒后仍为 pending 且无日志,自动重试拉起训练线程。"""
  42. import time
  43. time.sleep(2)
  44. job = get_job(app, job_id)
  45. if job and job.get('status') == 'pending' and not job.get('logs'):
  46. app.logger.warning('训练任务 %s 未真正启动,正在重试', job_id)
  47. _start_training_thread(app, job_id)
  48. def resume_stale_jobs(app):
  49. """后端启动后恢复 pending 任务(进程重启会导致后台训练线程丢失)。"""
  50. if any(j.get('status') == 'running' for j in list_jobs(app)):
  51. return
  52. for job in list_jobs(app):
  53. if job.get('status') != 'pending':
  54. continue
  55. job_id = job['job_id']
  56. app.logger.info('恢复排队中的训练任务:%s', job_id)
  57. _start_training_thread(app, job_id)
  58. return
  59. def resume_job(app, job_id: str) -> dict:
  60. """手动恢复 pending 任务(供前端刷新或 API 调用)。"""
  61. job = get_job(app, job_id)
  62. if not job:
  63. raise ValueError('训练任务不存在')
  64. if job.get('status') == 'running':
  65. return job
  66. if job.get('status') in ('completed', 'failed'):
  67. raise ValueError(f"任务已结束({job.get('status')}),无法恢复")
  68. if any(j.get('status') == 'running' for j in list_jobs(app)):
  69. raise ValueError('已有训练任务进行中,请等待完成')
  70. _start_training_thread(app, job_id)
  71. threading.Thread(target=_watch_pending_job, args=(_app_for_thread(app), job_id), daemon=True).start()
  72. return get_job(app, job_id) or job
  73. def _meta_path(app, name: str) -> Path:
  74. return Path(app.config['TRAINING_META_FOLDER']) / name
  75. def _json_default(obj):
  76. if isinstance(obj, datetime):
  77. return obj.isoformat()
  78. raise TypeError(f'Object of type {type(obj).__name__} is not JSON serializable')
  79. def _load_json(app, name: str, default):
  80. with _lock:
  81. path = _meta_path(app, name)
  82. if not path.is_file():
  83. return default
  84. for candidate in (path, path.with_suffix('.json.bak')):
  85. if not candidate.is_file():
  86. continue
  87. try:
  88. with open(candidate, encoding='utf-8') as f:
  89. return json.load(f)
  90. except (json.JSONDecodeError, OSError):
  91. continue
  92. return default
  93. def _save_json(app, name: str, data):
  94. with _lock:
  95. path = _meta_path(app, name)
  96. path.parent.mkdir(parents=True, exist_ok=True)
  97. if path.is_file():
  98. bak = path.with_suffix('.json.bak')
  99. try:
  100. shutil.copy2(path, bak)
  101. except OSError:
  102. pass
  103. tmp = path.with_suffix('.json.tmp')
  104. payload = json.dumps(data, ensure_ascii=False, indent=2, default=_json_default)
  105. tmp.write_text(payload, encoding='utf-8')
  106. tmp.replace(path)
  107. def _now_iso():
  108. return datetime.now(TZ).isoformat()
  109. def list_base_models():
  110. return BASE_MODELS
  111. def _resolve_data_yaml(data_yaml: str) -> str:
  112. """将 data.yaml 的 path 固定为数据集根目录绝对路径,避免 path: . 被解析到错误工作目录。"""
  113. src = Path(data_yaml).resolve()
  114. if not src.is_file():
  115. raise FileNotFoundError(f'未找到 data.yaml:{data_yaml}')
  116. root = src.parent
  117. with open(src, encoding='utf-8') as f:
  118. cfg = yaml.safe_load(f) or {}
  119. cfg['path'] = str(root)
  120. resolved = root / '_dockscope_data.yaml'
  121. with open(resolved, 'w', encoding='utf-8') as f:
  122. yaml.dump(cfg, f, allow_unicode=True, default_flow_style=False, sort_keys=False)
  123. return str(resolved)
  124. def _ensure_base_weights(app, base_model: str) -> Path:
  125. models_dir = Path(app.config['MODELS_FOLDER'])
  126. target = models_dir / base_model
  127. if target.is_file():
  128. return target
  129. YOLO(base_model)
  130. if not target.is_file():
  131. downloaded = list(models_dir.glob(base_model))
  132. if downloaded:
  133. return downloaded[0]
  134. cwd_candidate = Path.cwd() / base_model
  135. if cwd_candidate.is_file():
  136. shutil.copy2(cwd_candidate, target)
  137. if not target.is_file():
  138. raise FileNotFoundError(f'无法获取基线权重 {base_model},请检查网络或手动放入 models 目录')
  139. return target
  140. def _find_data_yaml(root: Path) -> Path | None:
  141. for name in ('data.yaml', 'dataset.yaml'):
  142. for p in root.rglob(name):
  143. return p
  144. return None
  145. def list_datasets(app):
  146. return _load_json(app, 'datasets.json', [])
  147. def get_dataset(app, dataset_id: str):
  148. for item in list_datasets(app):
  149. if item['dataset_id'] == dataset_id:
  150. return item
  151. return None
  152. def upload_dataset(app, zip_file, display_name: str, owner_id: int, owner_username: str):
  153. if not zip_file or not zip_file.filename.lower().endswith('.zip'):
  154. raise ValueError('请上传 YOLO 格式数据集的 ZIP 压缩包')
  155. dataset_id = uuid.uuid4().hex[:12]
  156. root = Path(app.config['DATASETS_FOLDER']) / dataset_id
  157. root.mkdir(parents=True, exist_ok=True)
  158. zip_path = root / secure_filename(zip_file.filename)
  159. zip_file.save(zip_path)
  160. try:
  161. with zipfile.ZipFile(zip_path, 'r') as zf:
  162. zf.extractall(root / 'extracted')
  163. except zipfile.BadZipFile as exc:
  164. shutil.rmtree(root, ignore_errors=True)
  165. raise ValueError('ZIP 文件损坏或格式无效') from exc
  166. finally:
  167. if zip_path.is_file():
  168. zip_path.unlink()
  169. extracted = root / 'extracted'
  170. data_yaml = _find_data_yaml(extracted)
  171. if not data_yaml:
  172. shutil.rmtree(root, ignore_errors=True)
  173. raise ValueError('压缩包内未找到 data.yaml / dataset.yaml,请使用 Ultralytics 标准目录结构')
  174. try:
  175. with open(data_yaml, encoding='utf-8') as f:
  176. yaml.safe_load(f)
  177. except Exception as exc:
  178. shutil.rmtree(root, ignore_errors=True)
  179. raise ValueError(f'data.yaml 解析失败:{exc}') from exc
  180. item = {
  181. 'dataset_id': dataset_id,
  182. 'name': display_name or data_yaml.parent.name,
  183. 'data_yaml': str(data_yaml.resolve()),
  184. 'root_path': str(extracted.resolve()),
  185. 'owner_id': owner_id,
  186. 'owner_username': owner_username,
  187. 'created_at': _now_iso(),
  188. }
  189. datasets = list_datasets(app)
  190. datasets.insert(0, item)
  191. _save_json(app, 'datasets.json', datasets)
  192. return item
  193. def list_jobs(app):
  194. jobs = _load_json(app, 'jobs.json', [])
  195. return sorted(jobs, key=lambda j: j.get('created_at', ''), reverse=True)
  196. def get_job(app, job_id: str):
  197. for job in list_jobs(app):
  198. if job['job_id'] == job_id:
  199. return job
  200. return None
  201. def _update_job(app, job_id: str, **fields):
  202. jobs = list_jobs(app)
  203. for job in jobs:
  204. if job['job_id'] == job_id:
  205. job.update(fields)
  206. job['updated_at'] = _now_iso()
  207. _save_json(app, 'jobs.json', jobs)
  208. return job
  209. return None
  210. def _append_log(app, job_id: str, line: str):
  211. job = get_job(app, job_id)
  212. if not job:
  213. return
  214. logs = job.get('logs') or []
  215. logs.append(f"[{datetime.now(TZ).strftime('%H:%M:%S')}] {line}")
  216. if len(logs) > 500:
  217. logs = logs[-500:]
  218. _update_job(app, job_id, logs=logs)
  219. def _parse_results_csv(csv_path: Path) -> dict:
  220. if not csv_path.is_file():
  221. return {}
  222. try:
  223. import csv
  224. with open(csv_path, encoding='utf-8') as f:
  225. rows = list(csv.DictReader(f))
  226. if not rows:
  227. return {}
  228. last = rows[-1]
  229. return {
  230. 'box_p': float(last.get('metrics/precision(B)', 0) or 0),
  231. 'box_r': float(last.get('metrics/recall(B)', 0) or 0),
  232. 'box_mAP50': float(last.get('metrics/mAP50(B)', 0) or 0),
  233. 'box_mAP50_95': float(last.get('metrics/mAP50-95(B)', 0) or 0),
  234. 'mask_p': float(last.get('metrics/precision(M)', 0) or 0),
  235. 'mask_r': float(last.get('metrics/recall(M)', 0) or 0),
  236. 'mask_mAP50': float(last.get('metrics/mAP50(M)', 0) or 0),
  237. 'mask_mAP50_95': float(last.get('metrics/mAP50-95(M)', 0) or 0),
  238. }
  239. except Exception:
  240. return {}
  241. def _register_model(app, job: dict, weights_path: Path, metrics: dict):
  242. model_name = job['output_model_name']
  243. if not model_name.endswith('.pt'):
  244. model_name = f'{model_name}.pt'
  245. dest = Path(app.config['MODELS_FOLDER']) / model_name
  246. shutil.copy2(weights_path, dest)
  247. rel_path = f'static/models/{model_name}'.replace('\\', '/')
  248. if Model.query.filter_by(model_name=model_name).first():
  249. model_name = f"{Path(model_name).stem}_{job['job_id'][:6]}.pt"
  250. dest = Path(app.config['MODELS_FOLDER']) / model_name
  251. shutil.copy2(weights_path, dest)
  252. rel_path = f'static/models/{model_name}'.replace('\\', '/')
  253. m = YOLO(str(dest))
  254. layers = 0
  255. params = 0
  256. try:
  257. layers = len(list(m.model.modules()))
  258. params = sum(p.numel() for p in m.model.parameters())
  259. except Exception:
  260. pass
  261. mask_mAP50 = metrics.get('mask_mAP50', 0.0)
  262. fitness = round(mask_mAP50 * 2 + metrics.get('box_mAP50', 0.0), 5)
  263. record = Model(
  264. model_name=model_name,
  265. model_path=rel_path,
  266. disease_category=job['disease_category'],
  267. augmentation=job.get('augmentation') or '随机点+颜色扭曲+高斯模糊',
  268. layers=layers or 113,
  269. parameters=params or 0,
  270. GFLOPs=35.3,
  271. box_p=round(metrics.get('box_p', 0.0), 3),
  272. box_r=round(metrics.get('box_r', 0.0), 3),
  273. box_mAP50=round(metrics.get('box_mAP50', 0.0), 3),
  274. box_mAP50_95=round(metrics.get('box_mAP50_95', 0.0), 3),
  275. mask_p=round(metrics.get('mask_p', 0.0), 3),
  276. mask_r=round(metrics.get('mask_r', 0.0), 3),
  277. mask_mAP50=round(mask_mAP50, 3),
  278. mask_mAP50_95=round(metrics.get('mask_mAP50_95', 0.0), 3),
  279. f1_score=round((metrics.get('mask_p', 0) + metrics.get('mask_r', 0)) or 0, 5),
  280. fitness_score=fitness,
  281. owner_id=job['owner_id'],
  282. )
  283. db.session.add(record)
  284. db.session.commit()
  285. registered = record.to_dict()
  286. for key in ('created_at', 'updated_at'):
  287. val = registered.get(key)
  288. if val is not None and hasattr(val, 'isoformat'):
  289. registered[key] = val.isoformat()
  290. return registered
  291. def _run_training(app, job_id: str):
  292. with app.app_context():
  293. job = get_job(app, job_id)
  294. if not job:
  295. return
  296. try:
  297. _update_job(app, job_id, status='running', progress=0, started_at=_now_iso())
  298. _append_log(app, job_id, '开始加载基线权重…')
  299. base_path = _ensure_base_weights(app, job['base_model'])
  300. data_yaml = _resolve_data_yaml(job['data_yaml'])
  301. _append_log(app, job_id, f'数据集配置:{data_yaml}')
  302. use_cuda = torch.cuda.is_available()
  303. device = 'cuda' if use_cuda else 'cpu'
  304. _append_log(app, job_id, f'训练设备:{device}')
  305. model = YOLO(str(base_path))
  306. runs_root = Path(app.config['TRAINING_RUNS_FOLDER'])
  307. run_name = job['run_name']
  308. _append_log(app, job_id, f"启动 YOLOv8 分割训练:epochs={job['epochs']}, imgsz={job['imgsz']}")
  309. def on_epoch_end(trainer):
  310. epoch = trainer.epoch + 1
  311. total = trainer.epochs
  312. progress = min(99, int(epoch / total * 100))
  313. _update_job(app, job_id, progress=progress, current_epoch=epoch, total_epochs=total)
  314. _append_log(app, job_id, f'Epoch {epoch}/{total} 完成')
  315. model.add_callback('on_train_epoch_end', on_epoch_end)
  316. model.train(
  317. data=data_yaml,
  318. epochs=int(job['epochs']),
  319. imgsz=int(job['imgsz']),
  320. batch=int(job['batch']),
  321. project=str(runs_root),
  322. name=run_name,
  323. exist_ok=True,
  324. device=device,
  325. patience=int(job.get('patience', 20)),
  326. verbose=True,
  327. )
  328. run_dir = runs_root / run_name
  329. best_pt = run_dir / 'weights' / 'best.pt'
  330. if not best_pt.is_file():
  331. best_pt = run_dir / 'weights' / 'last.pt'
  332. if not best_pt.is_file():
  333. raise FileNotFoundError('训练结束但未找到 weights/best.pt')
  334. metrics = _parse_results_csv(run_dir / 'results.csv')
  335. _append_log(app, job_id, f'训练完成,best 权重:{best_pt.name}')
  336. registered = None
  337. if job.get('register_to_library', True):
  338. registered = _register_model(app, job, best_pt, metrics)
  339. _append_log(app, job_id, f"已写入模型库:{registered['model_name']}")
  340. _update_job(
  341. app,
  342. job_id,
  343. status='completed',
  344. progress=100,
  345. finished_at=_now_iso(),
  346. best_weights=str(best_pt.resolve()),
  347. metrics=metrics,
  348. registered_model=registered,
  349. )
  350. except Exception as exc:
  351. tb = traceback.format_exc()
  352. _append_log(app, job_id, f'训练失败:{exc}')
  353. _update_job(
  354. app,
  355. job_id,
  356. status='failed',
  357. finished_at=_now_iso(),
  358. error_message=str(exc),
  359. error_trace=tb[-2000:],
  360. )
  361. def start_job(
  362. app,
  363. *,
  364. owner_id: int,
  365. owner_username: str,
  366. dataset_id: str,
  367. base_model: str,
  368. disease_category: str,
  369. output_model_name: str,
  370. epochs: int = 50,
  371. imgsz: int = 640,
  372. batch: int = 8,
  373. augmentation: str = '随机点+颜色扭曲+高斯模糊',
  374. register_to_library: bool = True,
  375. ):
  376. dataset = get_dataset(app, dataset_id)
  377. if not dataset:
  378. raise ValueError('数据集不存在')
  379. allowed = {m['id'] for m in BASE_MODELS}
  380. if base_model not in allowed:
  381. raise ValueError('不支持的基线模型')
  382. for job in list_jobs(app):
  383. if job.get('status') == 'running':
  384. raise ValueError('已有训练任务进行中,请等待完成后再启动')
  385. if not disease_category.strip():
  386. raise ValueError('请填写隐患类别')
  387. raw_name = output_model_name.strip()
  388. if not raw_name:
  389. raise ValueError('请填写输出模型文件名')
  390. if not raw_name.lower().endswith('.pt'):
  391. raw_name = f'{raw_name}.pt'
  392. safe_name = secure_filename(raw_name)
  393. if not safe_name or not safe_name.endswith('.pt'):
  394. safe_name = f"bridge_seg_{uuid.uuid4().hex[:8]}.pt"
  395. job_id = uuid.uuid4().hex[:12]
  396. run_name = f"train_{job_id}"
  397. job = {
  398. 'job_id': job_id,
  399. 'run_name': run_name,
  400. 'status': 'pending',
  401. 'progress': 0,
  402. 'owner_id': owner_id,
  403. 'owner_username': owner_username,
  404. 'dataset_id': dataset_id,
  405. 'dataset_name': dataset['name'],
  406. 'data_yaml': dataset['data_yaml'],
  407. 'base_model': base_model,
  408. 'disease_category': disease_category.strip(),
  409. 'output_model_name': safe_name,
  410. 'epochs': max(1, min(int(epochs), 500)),
  411. 'imgsz': max(320, min(int(imgsz), 1280)),
  412. 'batch': max(1, min(int(batch), 64)),
  413. 'augmentation': augmentation,
  414. 'register_to_library': bool(register_to_library),
  415. 'logs': [],
  416. 'created_at': _now_iso(),
  417. 'updated_at': _now_iso(),
  418. }
  419. jobs = list_jobs(app)
  420. jobs.insert(0, job)
  421. _save_json(app, 'jobs.json', jobs)
  422. _start_training_thread(app, job_id)
  423. threading.Thread(target=_watch_pending_job, args=(_app_for_thread(app), job_id), daemon=True).start()
  424. return job