config_loader.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import os
  2. import yaml
  3. import json
  4. def load_config(config_path=None):
  5. """
  6. 加载系统配置
  7. 参数:
  8. config_path (str, optional): 配置文件路径,默认为None,使用默认配置文件
  9. 返回:
  10. dict: 配置字典
  11. """
  12. # 如果未指定配置文件,使用默认配置文件
  13. if config_path is None:
  14. config_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'config.yaml')
  15. # 检查文件是否存在
  16. if not os.path.exists(config_path):
  17. print(f"警告:配置文件 {config_path} 不存在,使用默认配置")
  18. return get_default_config()
  19. # 根据文件扩展名选择加载方式
  20. _, ext = os.path.splitext(config_path)
  21. try:
  22. if ext.lower() in ['.yaml', '.yml']:
  23. with open(config_path, 'r', encoding='utf-8') as f:
  24. config = yaml.safe_load(f)
  25. elif ext.lower() == '.json':
  26. with open(config_path, 'r', encoding='utf-8') as f:
  27. config = json.load(f)
  28. else:
  29. print(f"不支持的配置文件格式:{ext},使用默认配置")
  30. return get_default_config()
  31. except Exception as e:
  32. print(f"加载配置文件失败:{e},使用默认配置")
  33. return get_default_config()
  34. # 合并默认配置和加载的配置
  35. default_config = get_default_config()
  36. merged_config = {**default_config, **config}
  37. return merged_config
  38. def get_default_config():
  39. """
  40. 获取默认配置
  41. 返回:
  42. dict: 默认配置字典
  43. """
  44. return {
  45. # 模型配置
  46. 'model_size': 's', # YOLOv5模型大小:n, s, m, l, x
  47. 'num_classes_fire': 2, # 火灾类别数(火焰、烟雾)
  48. 'num_classes_animal': 5, # 动物类别数(可根据保护区内具体动物调整)
  49. 'num_classes_landslide': 3, # 地质灾害类别数(滑坡、泥石流、山体崩塌)
  50. 'conf_threshold': 0.25, # 检测置信度阈值
  51. 'iou_threshold': 0.45, # NMS IOU阈值
  52. # 数据配置
  53. 'image_size': 640, # 输入图像大小
  54. 'batch_size': 16, # 批次大小
  55. 'data_augmentation': True, # 是否使用数据增强
  56. # 训练配置
  57. 'learning_rate': 0.01, # 学习率
  58. 'weight_decay': 0.0005, # 权重衰减
  59. 'epochs': 100, # 训练轮数
  60. 'save_interval': 10, # 模型保存间隔
  61. # 系统配置
  62. 'device': 'cuda:0', # 设备,cuda:0或cpu
  63. 'num_workers': 4, # 数据加载线程数
  64. 'weights_path': 'weights', # 权重保存路径
  65. 'logs_path': 'logs', # 日志保存路径
  66. # 监测区域配置
  67. 'monitor_regions': [
  68. {
  69. 'name': '北部山区',
  70. 'latitude': 40.123,
  71. 'longitude': 116.456,
  72. 'radius': 50, # 监测半径(公里)
  73. 'priority': 'high'
  74. },
  75. {
  76. 'name': '南部林区',
  77. 'latitude': 39.876,
  78. 'longitude': 115.789,
  79. 'radius': 40,
  80. 'priority': 'medium'
  81. }
  82. ],
  83. # GIS集成配置
  84. 'gis_api_key': '', # GIS API密钥
  85. 'map_center': [39.9, 116.3], # 地图中心点
  86. 'map_zoom': 8, # 地图缩放级别
  87. # UI配置
  88. 'dark_mode': True, # 是否使用暗色模式
  89. 'language': 'zh_CN', # 语言设置
  90. 'auto_refresh': 60, # 自动刷新间隔(秒)
  91. # 告警配置
  92. 'alert_threshold': 0.75, # 告警阈值
  93. 'alert_methods': ['ui', 'sound'], # 告警方式:ui界面、声音
  94. 'alert_interval': 30 # 告警间隔(秒)
  95. }