seed_models.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # -*- coding: utf-8 -*-
  2. """下载 YOLOv8 分割基线权重并复制为各业务模型文件名(演示推理用)。"""
  3. from __future__ import annotations
  4. import os
  5. import shutil
  6. import sys
  7. from pathlib import Path
  8. ROOT = Path(__file__).resolve().parents[1]
  9. sys.path.insert(0, str(ROOT))
  10. os.environ.setdefault(
  11. 'SQLALCHEMY_DATABASE_URI',
  12. 'mysql+pymysql://root:bridgedisease_root@127.0.0.1:3307/bridge_disease?charset=utf8mb4',
  13. )
  14. MODELS_DIR = ROOT / 'app' / 'static' / 'models'
  15. BASE_NAME = 'yolov8n-seg.pt'
  16. def download_base_weights() -> Path:
  17. """通过 Ultralytics 拉取 yolov8n-seg.pt。"""
  18. from ultralytics import YOLO
  19. MODELS_DIR.mkdir(parents=True, exist_ok=True)
  20. target = MODELS_DIR / BASE_NAME
  21. if target.is_file() and target.stat().st_size > 1_000_000:
  22. print(f'[skip] 基线权重已存在: {target}')
  23. return target
  24. print('正在下载 yolov8n-seg.pt(首次约 6MB,需联网)…')
  25. cwd_before = Path.cwd()
  26. os.chdir(MODELS_DIR)
  27. try:
  28. YOLO(BASE_NAME)
  29. finally:
  30. os.chdir(cwd_before)
  31. if not target.is_file():
  32. # 部分版本会下载到用户目录 weights/
  33. candidates = list(MODELS_DIR.glob('*.pt')) + list((Path.home() / '.cache' / 'ultralytics').rglob(BASE_NAME))
  34. for c in candidates:
  35. if c.is_file() and c.stat().st_size > 1_000_000:
  36. if c.resolve() != target.resolve():
  37. shutil.copy2(c, target)
  38. break
  39. if not target.is_file():
  40. print('下载失败:未找到 yolov8n-seg.pt', file=sys.stderr)
  41. sys.exit(1)
  42. print(f'[ok] 基线权重: {target} ({target.stat().st_size // 1024} KB)')
  43. return target
  44. def main() -> None:
  45. from app import create_app
  46. from app.models import Model, db
  47. base = download_base_weights()
  48. app = create_app()
  49. with app.app_context():
  50. models = Model.query.order_by(Model.model_id.asc()).all()
  51. if not models:
  52. print('model 表为空,请先导入 sql/init_db.sql')
  53. sys.exit(1)
  54. rel_prefix = 'static/models'
  55. copied = 0
  56. for m in models:
  57. filename = Path(str(m.model_path).replace('\\', '/')).name
  58. dest = MODELS_DIR / filename
  59. if not dest.is_file() or dest.stat().st_size < 1_000_000:
  60. shutil.copy2(base, dest)
  61. copied += 1
  62. print(f'[copy] {filename}')
  63. rel = f'{rel_prefix}/{filename}'
  64. if m.model_path.replace('\\', '/') != rel:
  65. m.model_path = rel
  66. print(f'[path] {m.model_name} -> {rel}')
  67. db.session.commit()
  68. print(f'完成:复制 {copied} 个模型文件,共 {len(models)} 条 model 记录可用。')
  69. print('说明:演示环境共用 yolov8n-seg 权重;上线请替换为各病害专项训练权重。')
  70. if __name__ == '__main__':
  71. main()