| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- # -*- coding: utf-8 -*-
- """下载 YOLOv8 分割基线权重并复制为各业务模型文件名(演示推理用)。"""
- from __future__ import annotations
- import os
- import shutil
- import sys
- from pathlib import Path
- ROOT = Path(__file__).resolve().parents[1]
- sys.path.insert(0, str(ROOT))
- os.environ.setdefault(
- 'SQLALCHEMY_DATABASE_URI',
- 'mysql+pymysql://root:bridgedisease_root@127.0.0.1:3307/bridge_disease?charset=utf8mb4',
- )
- MODELS_DIR = ROOT / 'app' / 'static' / 'models'
- BASE_NAME = 'yolov8n-seg.pt'
- def download_base_weights() -> Path:
- """通过 Ultralytics 拉取 yolov8n-seg.pt。"""
- from ultralytics import YOLO
- MODELS_DIR.mkdir(parents=True, exist_ok=True)
- target = MODELS_DIR / BASE_NAME
- if target.is_file() and target.stat().st_size > 1_000_000:
- print(f'[skip] 基线权重已存在: {target}')
- return target
- print('正在下载 yolov8n-seg.pt(首次约 6MB,需联网)…')
- cwd_before = Path.cwd()
- os.chdir(MODELS_DIR)
- try:
- YOLO(BASE_NAME)
- finally:
- os.chdir(cwd_before)
- if not target.is_file():
- # 部分版本会下载到用户目录 weights/
- candidates = list(MODELS_DIR.glob('*.pt')) + list((Path.home() / '.cache' / 'ultralytics').rglob(BASE_NAME))
- for c in candidates:
- if c.is_file() and c.stat().st_size > 1_000_000:
- if c.resolve() != target.resolve():
- shutil.copy2(c, target)
- break
- if not target.is_file():
- print('下载失败:未找到 yolov8n-seg.pt', file=sys.stderr)
- sys.exit(1)
- print(f'[ok] 基线权重: {target} ({target.stat().st_size // 1024} KB)')
- return target
- def main() -> None:
- from app import create_app
- from app.models import Model, db
- base = download_base_weights()
- app = create_app()
- with app.app_context():
- models = Model.query.order_by(Model.model_id.asc()).all()
- if not models:
- print('model 表为空,请先导入 sql/init_db.sql')
- sys.exit(1)
- rel_prefix = 'static/models'
- copied = 0
- for m in models:
- filename = Path(str(m.model_path).replace('\\', '/')).name
- dest = MODELS_DIR / filename
- if not dest.is_file() or dest.stat().st_size < 1_000_000:
- shutil.copy2(base, dest)
- copied += 1
- print(f'[copy] {filename}')
- rel = f'{rel_prefix}/{filename}'
- if m.model_path.replace('\\', '/') != rel:
- m.model_path = rel
- print(f'[path] {m.model_name} -> {rel}')
- db.session.commit()
- print(f'完成:复制 {copied} 个模型文件,共 {len(models)} 条 model 记录可用。')
- print('说明:演示环境共用 yolov8n-seg 权重;上线请替换为各病害专项训练权重。')
- if __name__ == '__main__':
- main()
|