create_demo_dataset.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # -*- coding: utf-8 -*-
  2. """
  3. 从 seed_assets/medias 生成迷你 YOLO 分割演示数据集(占位标注,仅用于试跑训练流程)。
  4. 输出:
  5. training_templates/bridge_hazard_demo/
  6. training_templates/bridge_hazard_demo.zip
  7. """
  8. from __future__ import annotations
  9. import shutil
  10. import zipfile
  11. from pathlib import Path
  12. from PIL import Image
  13. ROOT = Path(__file__).resolve().parents[1]
  14. ASSETS = ROOT / 'seed_assets' / 'medias'
  15. OUT_ROOT = ROOT / 'training_templates' / 'bridge_hazard_demo'
  16. ZIP_PATH = ROOT / 'training_templates' / 'bridge_hazard_demo.zip'
  17. # (源图, 划分, 类别 id, 类别说明)
  18. SAMPLES = [
  19. ('01_concrete_crack_bridge.jpg', 'train', 0, 'concrete_crack'),
  20. ('02_bridge_concrete_cracks.jpg', 'train', 0, 'concrete_crack'),
  21. ('04_concrete_bending_cracks.jpg', 'train', 0, 'concrete_crack'),
  22. ('06_shrinkage_cracks_concrete.jpg', 'train', 0, 'concrete_crack'),
  23. ('03_steel_bridge_corrosion.jpg', 'val', 1, 'steel_corrosion'),
  24. ('08_concrete_rebar_corrosion.jpg', 'val', 1, 'steel_corrosion'),
  25. ]
  26. def rect_polygon(class_id: int, cx=0.5, cy=0.5, w=0.35, h=0.35) -> str:
  27. """生成归一化矩形四顶点分割标注行。"""
  28. x1, y1 = cx - w / 2, cy - h / 2
  29. x2, y2 = cx + w / 2, cy - h / 2
  30. x3, y3 = cx + w / 2, cy + h / 2
  31. x4, y4 = cx - w / 2, cy + h / 2
  32. pts = [x1, y1, x2, y2, x3, y3, x4, y4]
  33. pts = [max(0.0, min(1.0, p)) for p in pts]
  34. return f"{class_id} " + " ".join(f"{p:.6f}" for p in pts)
  35. def write_data_yaml(root: Path):
  36. content = """# 检澜演示数据集 — 占位标注,仅用于试跑 YOLOv8 分割训练
  37. path: .
  38. train: train/images
  39. val: val/images
  40. nc: 2
  41. names:
  42. 0: concrete_crack
  43. 1: steel_corrosion
  44. """
  45. (root / 'data.yaml').write_text(content, encoding='utf-8')
  46. def main():
  47. if not ASSETS.is_dir():
  48. print(f'缺少 {ASSETS},请先运行 scripts/download_real_medias.py')
  49. return 1
  50. if OUT_ROOT.exists():
  51. shutil.rmtree(OUT_ROOT)
  52. OUT_ROOT.mkdir(parents=True)
  53. for split in ('train', 'val'):
  54. (OUT_ROOT / split / 'images').mkdir(parents=True)
  55. (OUT_ROOT / split / 'labels').mkdir(parents=True)
  56. idx = 0
  57. for src_name, split, class_id, _ in SAMPLES:
  58. src = ASSETS / src_name
  59. if not src.is_file():
  60. print(f'[skip] 缺少 {src_name}')
  61. continue
  62. idx += 1
  63. stem = f'demo_{idx:03d}'
  64. img_name = f'{stem}.jpg'
  65. dest_img = OUT_ROOT / split / 'images' / img_name
  66. with Image.open(src) as im:
  67. im.convert('RGB').save(dest_img, format='JPEG', quality=90)
  68. label_line = rect_polygon(class_id)
  69. (OUT_ROOT / split / 'labels' / f'{stem}.txt').write_text(
  70. label_line + '\n', encoding='utf-8'
  71. )
  72. print(f'[ok] {split} {img_name} class={class_id}')
  73. if idx < 2:
  74. print('有效图片不足,无法生成数据集')
  75. return 1
  76. write_data_yaml(OUT_ROOT)
  77. if ZIP_PATH.is_file():
  78. ZIP_PATH.unlink()
  79. with zipfile.ZipFile(ZIP_PATH, 'w', zipfile.ZIP_DEFLATED) as zf:
  80. for file in OUT_ROOT.rglob('*'):
  81. if file.is_file():
  82. arc = file.relative_to(OUT_ROOT.parent)
  83. zf.write(file, arc.as_posix())
  84. print()
  85. print('已生成:')
  86. print(f' 目录 {OUT_ROOT}')
  87. print(f' ZIP {ZIP_PATH}')
  88. print()
  89. print('下一步:登录 developer → 模型训练 → 上传 bridge_hazard_demo.zip')
  90. print('建议试跑:yolov8n-seg,epochs=5,batch=2')
  91. return 0
  92. if __name__ == '__main__':
  93. raise SystemExit(main())