open-embodied 9 months ago
parent
commit
5f01b3fc99
100 changed files with 9956 additions and 57 deletions
  1. 354 57
      README.md
  2. 81 0
      configs/config.yaml
  3. 44 0
      configs/yolov5s.yaml
  4. 3 0
      main.py
  5. 3 0
      map_temp.html
  6. 3 0
      models/__init__.py
  7. BIN
      models/__pycache__/__init__.cpython-39.pyc
  8. BIN
      models/__pycache__/common.cpython-39.pyc
  9. BIN
      models/__pycache__/experimental.cpython-39.pyc
  10. BIN
      models/__pycache__/yolo.cpython-39.pyc
  11. 662 0
      models/common.py
  12. 120 0
      models/experimental.py
  13. 59 0
      models/hub/anchors.yaml
  14. 51 0
      models/hub/yolov3-spp.yaml
  15. 41 0
      models/hub/yolov3-tiny.yaml
  16. 51 0
      models/hub/yolov3.yaml
  17. 48 0
      models/hub/yolov5-bifpn.yaml
  18. 42 0
      models/hub/yolov5-fpn.yaml
  19. 54 0
      models/hub/yolov5-p2.yaml
  20. 41 0
      models/hub/yolov5-p34.yaml
  21. 56 0
      models/hub/yolov5-p6.yaml
  22. 67 0
      models/hub/yolov5-p7.yaml
  23. 48 0
      models/hub/yolov5-panet.yaml
  24. 60 0
      models/hub/yolov5l6.yaml
  25. 60 0
      models/hub/yolov5m6.yaml
  26. 60 0
      models/hub/yolov5n6.yaml
  27. 48 0
      models/hub/yolov5s-ghost.yaml
  28. 48 0
      models/hub/yolov5s-transformer.yaml
  29. 60 0
      models/hub/yolov5s6.yaml
  30. 60 0
      models/hub/yolov5x6.yaml
  31. 464 0
      models/tf.py
  32. 329 0
      models/yolo.py
  33. 48 0
      models/yolov5l.yaml
  34. 48 0
      models/yolov5m.yaml
  35. 48 0
      models/yolov5n.yaml
  36. 48 0
      models/yolov5s.yaml
  37. 48 0
      models/yolov5x.yaml
  38. 3 0
      requirements.txt
  39. BIN
      resources/videos/f1.mp4
  40. BIN
      ui/__pycache__/splash_screen.cpython-39.pyc
  41. BIN
      ui/assets/__pycache__/icons.cpython-39.pyc
  42. 16 0
      ui/assets/icons.py
  43. 1 0
      ui/assets/loading.gif
  44. 1 0
      ui/assets/loading.png
  45. 1 0
      ui/assets/location.png
  46. 34 0
      ui/assets/map.html
  47. 773 0
      ui/assets/style.qss
  48. BIN
      ui/components/__pycache__/alert_panel.cpython-38.pyc
  49. BIN
      ui/components/__pycache__/alert_panel.cpython-39.pyc
  50. BIN
      ui/components/__pycache__/camera_manager.cpython-39.pyc
  51. BIN
      ui/components/__pycache__/camera_view.cpython-38.pyc
  52. BIN
      ui/components/__pycache__/camera_view.cpython-39.pyc
  53. BIN
      ui/components/__pycache__/camera_widget.cpython-39.pyc
  54. BIN
      ui/components/__pycache__/control_panel.cpython-38.pyc
  55. BIN
      ui/components/__pycache__/control_panel.cpython-39.pyc
  56. BIN
      ui/components/__pycache__/drone_manager.cpython-39.pyc
  57. BIN
      ui/components/__pycache__/fire_detection.cpython-39.pyc
  58. BIN
      ui/components/__pycache__/grid_camera_view.cpython-39.pyc
  59. BIN
      ui/components/__pycache__/map_view.cpython-38.pyc
  60. BIN
      ui/components/__pycache__/map_view.cpython-39.pyc
  61. BIN
      ui/components/__pycache__/statistics_panel.cpython-38.pyc
  62. BIN
      ui/components/__pycache__/statistics_panel.cpython-39.pyc
  63. 651 0
      ui/components/alert_panel.py
  64. 885 0
      ui/components/camera_view.py
  65. 402 0
      ui/components/control_panel.py
  66. 598 0
      ui/components/drone_manager.py
  67. 303 0
      ui/components/grid_camera_view.py
  68. 724 0
      ui/components/map_view.py
  69. 961 0
      ui/components/statistics_panel.py
  70. BIN
      ui/pages/__pycache__/main_window.cpython-38.pyc
  71. BIN
      ui/pages/__pycache__/main_window.cpython-39.pyc
  72. 750 0
      ui/pages/main_window.py
  73. 559 0
      ui/splash_screen.py
  74. 37 0
      utils/__init__.py
  75. BIN
      utils/__pycache__/__init__.cpython-311.pyc
  76. BIN
      utils/__pycache__/__init__.cpython-38.pyc
  77. BIN
      utils/__pycache__/__init__.cpython-39.pyc
  78. BIN
      utils/__pycache__/augmentations.cpython-311.pyc
  79. BIN
      utils/__pycache__/augmentations.cpython-38.pyc
  80. BIN
      utils/__pycache__/augmentations.cpython-39.pyc
  81. BIN
      utils/__pycache__/autoanchor.cpython-38.pyc
  82. BIN
      utils/__pycache__/autoanchor.cpython-39.pyc
  83. BIN
      utils/__pycache__/autobatch.cpython-38.pyc
  84. BIN
      utils/__pycache__/callbacks.cpython-38.pyc
  85. BIN
      utils/__pycache__/config_loader.cpython-38.pyc
  86. BIN
      utils/__pycache__/config_loader.cpython-39.pyc
  87. BIN
      utils/__pycache__/datasets.cpython-311.pyc
  88. BIN
      utils/__pycache__/datasets.cpython-38.pyc
  89. BIN
      utils/__pycache__/datasets.cpython-39.pyc
  90. BIN
      utils/__pycache__/downloads.cpython-311.pyc
  91. BIN
      utils/__pycache__/downloads.cpython-38.pyc
  92. BIN
      utils/__pycache__/downloads.cpython-39.pyc
  93. BIN
      utils/__pycache__/flame_detector.cpython-39.pyc
  94. BIN
      utils/__pycache__/general.cpython-311.pyc
  95. BIN
      utils/__pycache__/general.cpython-38.pyc
  96. BIN
      utils/__pycache__/general.cpython-39.pyc
  97. BIN
      utils/__pycache__/loss.cpython-38.pyc
  98. BIN
      utils/__pycache__/metrics.cpython-311.pyc
  99. BIN
      utils/__pycache__/metrics.cpython-38.pyc
  100. BIN
      utils/__pycache__/metrics.cpython-39.pyc

+ 354 - 57
README.md

@@ -1,92 +1,389 @@
-# project0090-SenTong
+# 森瞳森林多模态灾害监测系统
+
+## 项目概述
+
+森瞳森林多模态灾害监测系统是一个基于深度学习的智能森林监控平台,采用YOLOv5作为核心检测框架,结合多任务学习方法,实现对森林火灾、野生动物活动、地质灾害等多种灾害类型的实时监测和预警。系统集成了多路摄像头实时监控、无人机集群管理、GIS地理信息系统等功能,为森林安全管理提供全方位的技术支持。
+
+## 核心技术特点
+
+### 1. 多任务深度学习框架
+
+- **检测网络**: 基于YOLOv5的多任务检测网络
+- **模型优化**:
+  - 任务特定的检测头设计
+  - 多尺度特征融合
+  - 注意力机制增强
+- **训练策略**:
+  - 多任务联合训练
+  - 渐进式学习
+  - 数据增强技术
+
+### 2. 特征增强模块
+
+- **SPD-Conv空间金字塔扩张卷积**
+  - 增强小目标检测能力
+  - 多尺度特征提取
+- **CBAM通道空间注意力机制**
+  - 提升模糊目标识别效果
+  - 自适应特征权重调整
+- **高分辨率特征保留模块**
+  - 保持细节信息
+  - 提高检测精度
+
+## 系统功能
+
+### 1. 多路摄像头监控 🎥
+
+- **视频源支持**:
+  - IP摄像头 (RTSP/RTMP)
+  - USB摄像头
+  - 本地视频文件
+- **显示模式**:
+  - 9路同屏显示
+  - 单路全屏
+  - 自定义布局
+- **视频处理**:
+  - 实时编解码
+  - 画面增强
+  - 运动检测
+
+### 2. 智能灾害检测 🔥
+
+- **火灾检测**:
+  - 烟雾识别
+  - 火焰检测
+  - 热成像分析
+- **野生动物监测**:
+  - 物种识别
+  - 行为分析
+  - 数量统计
+- **地质灾害预警**:
+  - 地表变形检测
+  - 滑坡预警
+  - 泥石流监测
+
+### 3. 无人机集群管理 🚁
+
+- **飞行控制**:
+  - 一键起飞/降落
+  - 自动返航
+  - 紧急停止
+- **任务规划**:
+  - 航线规划
+  - 区域覆盖
+  - 目标跟踪
+- **数据采集**:
+  - 高清图像
+  - 热成像
+  - 多光谱数据
+
+### 4. GIS地图集成 🗺️
+
+- **地图功能**:
+  - 多图层显示
+  - 实时定位
+  - 区域标记
+- **数据可视化**:
+  - 热力图
+  - 轨迹回放
+  - 统计图表
+
+## 系统要求
+
+### 硬件要求
+
+- **处理器**: Intel Core i7-9700K或更高
+- **内存**: 16GB RAM (推荐32GB)
+- **显卡**: NVIDIA RTX 2060 6GB或更高
+- **存储**: 500GB SSD (系统盘)
+- **网络**: 千兆以太网
+
+### 软件要求
+
+- **操作系统**:
+  - Windows 10/11 专业版
+  - Ubuntu 20.04 LTS或更高
+- **Python环境**:
+  - Python 3.8+
+  - CUDA 11.3+
+  - cuDNN 8.2+
+- **依赖库版本**:
+  - PyTorch 1.10+
+  - OpenCV 4.5+
+  - PyQt5 5.15+
+
+## 快速开始
+
+### 1. 环境配置
+
+```bash
+# 创建虚拟环境
+python -m venv venv
+
+# 激活虚拟环境
+# Windows
+venv\Scripts\activate
+# Linux/Mac
+source venv/bin/activate
+
+# 安装依赖
+pip install -r requirements.txt
+
+# 安装CUDA工具包(如果需要GPU加速)
+# 请访问NVIDIA官网下载对应版本的CUDA和cuDNN
+```
+
+### 2. 配置文件说明
+
+在 `configs/config.yaml` 中配置系统参数:
+
+```yaml
+# 基础配置
+app_name: "森瞳森林多模态灾害监测系统"
+theme: "dark"  # dark/light
+language: "zh_CN"
+debug_mode: false
+
+# 更新配置
+update_interval: 60  # 秒
+auto_save: true
+save_interval: 300  # 秒
+
+# 地图配置
+map_center: [39.916527, 116.397128]  # 北京市中心
+map_zoom: 12
+map_type: "satellite"  # satellite/terrain/roadmap
+
+# 监测区域配置
+monitor_regions:
+  - name: "北京密云"
+    latitude: 40.3764
+    longitude: 116.8301
+    radius: 5  # 公里
+    camera_ids: ["CAM001", "CAM002"]
+  
+# 摄像头配置
+cameras:
+  - id: "CAM001"
+    name: "密云水库东"
+    type: "rtsp"
+    url: "rtsp://admin:admin@192.168.1.100:554"
+    enabled: true
+  
+# 检测模型配置
+models:
+  fire:
+    weights: "weights/fire_detection.pt"
+    conf_thres: 0.25
+    iou_thres: 0.45
+  animal:
+    weights: "weights/animal_detection.pt"
+    conf_thres: 0.3
+    iou_thres: 0.5
+```
+
+### 3. 启动系统
+
+```bash
+# 启动主程序
+python main.py
+
+# 启动带调试信息的程序
+python main.py --debug
+
+# 指定配置文件启动
+python main.py --config configs/custom_config.yaml
+```
+
+## 使用指南
+
+### 1. 界面布局
 
+系统界面分为四个主要区域:
 
+- **顶部**: 主菜单栏和工具栏
+- **左侧**: 控制面板和地图显示
+- **中间**: 摄像头监控界面
+- **右侧**: 统计信息和告警面板
 
-## Getting started
+### 2. 基本操作流程
 
-To make it easy for you to get started with GitLab, here's a list of recommended next steps.
+#### 2.1 系统初始化
 
-Already a pro? Just edit this README.md and make it your own. Want to make it easy? [Use the template at the bottom](#editing-this-readme)!
+1. 启动系统后,检查所有摄像头连接状态
+2. 确认无人机通信正常
+3. 验证地图服务可用性
 
-## Add your files
+#### 2.2 监控操作
 
-- [ ] [Create](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#create-a-file) or [upload](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#upload-a-file) files
-- [ ] [Add files using the command line](https://docs.gitlab.com/ee/gitlab-basics/add-file.html#add-a-file-using-the-command-line) or push an existing Git repository with the following command:
+1. 选择监控区域
+2. 配置检测参数
+3. 启动视频流
+4. 开启智能检测
+
+#### 2.3 告警处理
+
+1. 接收告警信息
+2. 查看告警详情
+3. 确认和处理告警
+4. 生成告警报告
+
+### 3. 快捷键
+
+- **系统操作**:
+  - `Ctrl+S`: 保存当前配置
+  - `Ctrl+R`: 刷新监控画面
+  - `Ctrl+Q`: 退出系统
+  - `F11`: 全屏切换
+- **视图操作**:
+  - `1-9`: 切换摄像头视图
+  - `Space`: 暂停/继续视频流
+  - `Ctrl+C`: 截图
+  - `Ctrl+V`: 粘贴图片
+- **地图操作**:
+  - `+/-`: 缩放地图
+  - `方向键`: 平移地图
+  - `Home`: 返回默认视图
+
+## 开发指南
+
+### 1. 项目结构
 
 ```
-cd existing_repo
-git remote add origin http://47.103.205.229:8099/seafog/project0090-sentong.git
-git branch -M main
-git push -uf origin main
+project_root/
+├── main.py                    # 主程序入口
+├── README.md                  # 项目说明文档
+├── requirements.txt           # 项目依赖
+├── map_temp.html             # 地图临时文件
+│
+├── ui/                       # 用户界面相关
+│   ├── assets/              # 静态资源
+│   ├── components/          # UI组件
+│   │   ├── alert_panel.py      # 告警面板
+│   │   ├── camera_view.py      # 摄像头视图
+│   │   ├── control_panel.py    # 控制面板
+│   │   ├── drone_manager.py    # 无人机管理
+│   │   ├── grid_camera_view.py # 网格摄像头视图
+│   │   ├── map_view.py         # 地图视图
+│   │   └── statistics_panel.py # 统计面板
+│   ├── pages/               # 页面
+│   │   └── main_window.py   # 主窗口
+│   └── splash_screen.py     # 启动画面
+│
+├── configs/                  # 配置文件
+│   ├── config.yaml          # 主配置
+│   └── yolov5s.yaml        # YOLOv5模型配置
+│
+├── utils/                   # 工具函数
+│   ├── loggers/            # 日志工具
+│   ├── aws/                # AWS相关工具
+│   ├── config_loader.py    # 配置加载器
+│   ├── camera_detector.py  # 摄像头检测器
+│   ├── torch_utils.py      # PyTorch工具
+│   ├── metrics.py          # 评估指标
+│   ├── plots.py           # 绘图工具
+│   ├── loss.py            # 损失函数
+│   ├── general.py         # 通用工具
+│   ├── datasets.py        # 数据集处理
+│   └── ...                # 其他工具函数
+│
+├── models/                  # 模型相关
+├── weights/                # 模型权重
+├── results/                # 结果输出
+├── scripts/                # 脚本文件
+├── resources/              # 资源文件
+└── data/                   # 数据文件
 ```
 
-## Integrate with your tools
+### 2. 开发规范
+
+#### 2.1 代码风格
+
+- 遵循PEP 8规范
+- 使用类型注解
+- 编写详细的文档字符串
+- 保持代码模块化
 
-- [ ] [Set up project integrations](http://47.103.205.229:8099/seafog/project0090-sentong/-/settings/integrations)
+#### 2.2 Git提交规范
 
-## Collaborate with your team
+```
+feat: 新功能
+fix: 修复bug
+docs: 文档更新
+style: 代码格式化
+refactor: 代码重构
+test: 测试相关
+chore: 构建过程或辅助工具的变动
+```
 
-- [ ] [Invite team members and collaborators](https://docs.gitlab.com/ee/user/project/members/)
-- [ ] [Create a new merge request](https://docs.gitlab.com/ee/user/project/merge_requests/creating_merge_requests.html)
-- [ ] [Automatically close issues from merge requests](https://docs.gitlab.com/ee/user/project/issues/managing_issues.html#closing-issues-automatically)
-- [ ] [Enable merge request approvals](https://docs.gitlab.com/ee/user/project/merge_requests/approvals/)
-- [ ] [Set auto-merge](https://docs.gitlab.com/ee/user/project/merge_requests/merge_when_pipeline_succeeds.html)
+#### 2.3 测试规范
 
-## Test and Deploy
+- 单元测试覆盖率 > 80%
+- 集成测试覆盖主要功能
+- 提交前本地测试通过
+- 编写测试文档
 
-Use the built-in continuous integration in GitLab.
+## 常见问题
 
-- [ ] [Get started with GitLab CI/CD](https://docs.gitlab.com/ee/ci/quick_start/index.html)
-- [ ] [Analyze your code for known vulnerabilities with Static Application Security Testing(SAST)](https://docs.gitlab.com/ee/user/application_security/sast/)
-- [ ] [Deploy to Kubernetes, Amazon EC2, or Amazon ECS using Auto Deploy](https://docs.gitlab.com/ee/topics/autodevops/requirements.html)
-- [ ] [Use pull-based deployments for improved Kubernetes management](https://docs.gitlab.com/ee/user/clusters/agent/)
-- [ ] [Set up protected environments](https://docs.gitlab.com/ee/ci/environments/protected_environments.html)
+### 1. 系统启动问题
 
-***
+- **问题**: 系统无法启动
+- **解决方案**:
+  1. 检查Python版本
+  2. 验证依赖完整性
+  3. 查看日志文件
+  4. 确认配置文件正确
 
-# Editing this README
+### 2. 摄像头连接问题
 
-When you're ready to make this README your own, just edit this file and use the handy template below (or feel free to structure it however you want - this is just a starting point!). Thank you to [makeareadme.com](https://www.makeareadme.com/) for this template.
+- **问题**: 摄像头画面不显示
+- **解决方案**:
+  1. 检查网络连接
+  2. 验证摄像头地址
+  3. 确认权限设置
+  4. 更新摄像头驱动
 
-## Suggestions for a good README
-Every project is different, so consider which of these sections apply to yours. The sections used in the template are suggestions for most open source projects. Also keep in mind that while a README can be too long and detailed, too long is better than too short. If you think your README is too long, consider utilizing another form of documentation rather than cutting out information.
+### 3. 性能优化
 
-## Name
-Choose a self-explaining name for your project.
+- **问题**: 系统运行缓慢
+- **解决方案**:
+  1. 降低分辨率
+  2. 调整检测频率
+  3. 优化GPU使用
+  4. 清理缓存数据
 
-## Description
-Let people know what your project can do specifically. Provide context and add a link to any reference visitors might be unfamiliar with. A list of Features or a Background subsection can also be added here. If there are alternatives to your project, this is a good place to list differentiating factors.
+## 技术支持
 
-## Badges
-On some READMEs, you may see small images that convey metadata, such as whether or not all the tests are passing for the project. You can use Shields to add some to your README. Many services also have instructions for adding a badge.
+### 问题反馈
 
-## Visuals
-Depending on what you are making, it can be a good idea to include screenshots or even a video (you'll frequently see GIFs rather than actual videos). Tools like ttygif can help, but check out Asciinema for a more sophisticated method.
+- Issues: [GitHub Issues](https://github.com/your-repo/issues)
+- 邮箱: 674137120@qq.com
+- QQ: 674137120
 
-## Installation
-Within a particular ecosystem, there may be a common way of installing things, such as using Yarn, NuGet, or Homebrew. However, consider the possibility that whoever is reading your README is a novice and would like more guidance. Listing specific steps helps remove ambiguity and gets people to using your project as quickly as possible. If it only runs in a specific context like a particular programming language version or operating system or has dependencies that have to be installed manually, also add a Requirements subsection.
+### 文档资源
 
-## Usage
-Use examples liberally, and show the expected output if you can. It's helpful to have inline the smallest example of usage that you can demonstrate, while providing links to more sophisticated examples if they are too long to reasonably include in the README.
+- [在线文档](https://docs.example.com)
+- [API参考](https://api.example.com)
+- [开发Wiki](https://wiki.example.com)
 
-## Support
-Tell people where they can go to for help. It can be any combination of an issue tracker, a chat room, an email address, etc.
+## 贡献指南
 
-## Roadmap
-If you have ideas for releases in the future, it is a good idea to list them in the README.
+欢迎提交 Pull Request 或 Issue。在贡献代码前,请:
 
-## Contributing
-State if you are open to contributions and what your requirements are for accepting them.
+1. Fork 本仓库
+2. 创建新的分支
+3. 提交变更
+4. 创建 Pull Request
 
-For people who want to make changes to your project, it's helpful to have some documentation on how to get started. Perhaps there is a script that they should run or some environment variables that they need to set. Make these steps explicit. These instructions could also be useful to your future self.
+## 致谢
 
-You can also document commands to lint the code or run tests. These steps help to ensure high code quality and reduce the likelihood that the changes inadvertently break something. Having instructions for running tests is especially helpful if it requires external setup, such as starting a Selenium server for testing in a browser.
+感谢以下开源项目的支持:
 
-## Authors and acknowledgment
-Show your appreciation to those who have contributed to the project.
+- YOLOv5
+- PyQt5
+- OpenCV
+- PyTorch
 
-## License
-For open source projects, say how it is licensed.
+---
 
-## Project status
-If you have run out of energy or time for your project, put a note at the top of the README saying that development has slowed down or stopped completely. Someone may choose to fork your project or volunteer to step in as a maintainer or owner, allowing your project to keep going. You can also make an explicit request for maintainers.
+© 2025 森瞳科技。保留所有权利。

+ 81 - 0
configs/config.yaml

@@ -0,0 +1,81 @@
+# 森林多模态灾害监测系统配置文件
+
+# 应用基础配置
+app_name: "森瞳森林多模态灾害监测系统"
+theme: "dark"
+update_interval: 60
+baidu_map_key: "F47f0642o1uGqSGnm3T6JLxHHLOjnx2T"
+baidu_map_domain: "localhost"
+
+# 模型配置
+model_size: s                   # YOLOv5模型大小:n, s, m, l, x
+num_classes_fire: 2             # 火灾类别数(火焰、烟雾)
+num_classes_animal: 5           # 动物类别数(可根据保护区内具体动物调整)
+num_classes_landslide: 3        # 地质灾害类别数(滑坡、泥石流、山体崩塌)
+conf_threshold: 0.25            # 检测置信度阈值
+iou_threshold: 0.45             # NMS IOU阈值
+
+# 数据配置
+image_size: 640                 # 输入图像大小
+batch_size: 16                  # 批次大小
+data_augmentation: true         # 是否使用数据增强
+
+# 训练配置
+learning_rate: 0.01             # 学习率
+weight_decay: 0.0005            # 权重衰减
+epochs: 100                     # 训练轮数
+save_interval: 10               # 模型保存间隔
+
+# 系统配置
+device: cuda:0                  # 设备,cuda:0或cpu
+num_workers: 4                  # 数据加载线程数
+weights_path: weights           # 权重保存路径
+logs_path: logs                 # 日志保存路径
+
+# 监测区域配置
+monitor_regions:
+  - name: "北京密云"
+    latitude: 40.3764
+    longitude: 116.8301
+    radius: 5
+    priority: high
+  - name: "杭州西湖"
+    latitude: 30.2650
+    longitude: 120.1331
+    radius: 3
+    priority: medium
+  - name: "四川卧龙"
+    latitude: 31.0500
+    longitude: 103.1500
+    radius: 10
+    priority: high
+
+# GIS集成配置
+gis_api_key: ""                 # GIS API密钥
+map_center: [39.916527, 116.397128]
+map_zoom: 12
+
+# UI配置
+dark_mode: true                 # 是否使用暗色模式
+language: zh_CN                 # 语言设置
+auto_refresh: 60                # 自动刷新间隔(秒)
+
+# 告警配置
+alert_threshold: 0.75           # 告警阈值
+alert_methods: [ui, sound]      # 告警方式:ui界面、声音
+alert_interval: 30              # 告警间隔(秒)
+random_alert:                   # 随机告警配置
+  enabled: false               # 是否启用随机告警
+  interval: 5                  # 随机告警检查间隔(秒)
+  probability: 0.3             # 生成告警的概率(0-1)
+  types:                       # 可能的告警类型
+    - fire                     # 火灾
+    - animal                   # 野生动物
+    - landslide               # 山体滑坡
+    - pest                    # 病虫害
+  locations:                   # 可能的告警位置
+    - 北部山区
+    - 南部林区
+    - 东部山脊
+    - 西部谷地
+    - 中央林场 

+ 44 - 0
configs/yolov5s.yaml

@@ -0,0 +1,44 @@
+# YOLOv5s模型配置
+# 参考自: https://github.com/ultralytics/yolov5/blob/master/models/yolov5s.yaml
+
+# 参数
+nc: 80  # 类别数量 (例如COCO数据集有80类)
+depth_multiple: 0.33  # 模型深度因子
+width_multiple: 0.50  # 模型宽度因子
+
+# 网络结构定义
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Focus, [64, 3]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 9, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
+   [-1, 1, SPP, [1024, [5, 9, 13]]],
+   [-1, 3, C3, [1024, False]],  # 9
+  ]
+
+head:
+  [[-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 13
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 14], 1, Concat, [1]],  # cat head P4
+   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 10], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
+
+   [[17, 20, 23], 1, Detect, [nc, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]]]],  # Detect(P3, P4, P5)
+  ] 

+ 3 - 0
main.py

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7fcdd9dbd8beb5c6619114e8cf9b1116553992c0a94b60672efe16bb1ada2d16
+size 2375

+ 3 - 0
map_temp.html

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b89b9b57e40b956446a4663719f1a6287ef3522c85046d9a1c35d392f68a3045
+size 17135

+ 3 - 0
models/__init__.py

@@ -0,0 +1,3 @@
+from .yolo import Model, Detect
+
+__all__ = ['Model', 'Detect'] 

BIN
models/__pycache__/__init__.cpython-39.pyc


BIN
models/__pycache__/common.cpython-39.pyc


BIN
models/__pycache__/experimental.cpython-39.pyc


BIN
models/__pycache__/yolo.cpython-39.pyc


+ 662 - 0
models/common.py

@@ -0,0 +1,662 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+"""
+Common modules
+"""
+
+import json
+import math
+import platform
+import warnings
+from collections import OrderedDict, namedtuple
+from copy import copy
+from pathlib import Path
+
+import cv2
+import numpy as np
+import pandas as pd
+import requests
+import torch
+import torch.nn as nn
+import yaml
+from PIL import Image
+from torch.cuda import amp
+
+from utils.datasets import exif_transpose, letterbox
+from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path,
+                           make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
+from utils.plots import Annotator, colors, save_one_box
+from utils.torch_utils import copy_attr, time_sync
+
+
+def autopad(k, p=None):  # kernel, padding
+    # Pad to 'same'
+    if p is None:
+        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
+    return p
+
+
+class Conv(nn.Module):
+    # Standard convolution
+    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
+        super().__init__()
+        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
+        self.bn = nn.BatchNorm2d(c2)
+        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
+
+    def forward(self, x):
+        return self.act(self.bn(self.conv(x)))
+
+    def forward_fuse(self, x):
+        return self.act(self.conv(x))
+
+
+class DWConv(Conv):
+    # Depth-wise convolution class
+    def __init__(self, c1, c2, k=1, s=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
+        super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
+
+
+class TransformerLayer(nn.Module):
+    # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
+    def __init__(self, c, num_heads):
+        super().__init__()
+        self.q = nn.Linear(c, c, bias=False)
+        self.k = nn.Linear(c, c, bias=False)
+        self.v = nn.Linear(c, c, bias=False)
+        self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
+        self.fc1 = nn.Linear(c, c, bias=False)
+        self.fc2 = nn.Linear(c, c, bias=False)
+
+    def forward(self, x):
+        x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
+        x = self.fc2(self.fc1(x)) + x
+        return x
+
+
+class TransformerBlock(nn.Module):
+    # Vision Transformer https://arxiv.org/abs/2010.11929
+    def __init__(self, c1, c2, num_heads, num_layers):
+        super().__init__()
+        self.conv = None
+        if c1 != c2:
+            self.conv = Conv(c1, c2)
+        self.linear = nn.Linear(c2, c2)  # learnable position embedding
+        self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
+        self.c2 = c2
+
+    def forward(self, x):
+        if self.conv is not None:
+            x = self.conv(x)
+        b, _, w, h = x.shape
+        p = x.flatten(2).permute(2, 0, 1)
+        return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
+
+
+class Bottleneck(nn.Module):
+    # Standard bottleneck
+    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
+        super().__init__()
+        c_ = int(c2 * e)  # hidden channels
+        self.cv1 = Conv(c1, c_, 1, 1)
+        self.cv2 = Conv(c_, c2, 3, 1, g=g)
+        self.add = shortcut and c1 == c2
+
+    def forward(self, x):
+        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
+
+
+class BottleneckCSP(nn.Module):
+    # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
+    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
+        super().__init__()
+        c_ = int(c2 * e)  # hidden channels
+        self.cv1 = Conv(c1, c_, 1, 1)
+        self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
+        self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
+        self.cv4 = Conv(2 * c_, c2, 1, 1)
+        self.bn = nn.BatchNorm2d(2 * c_)  # applied to cat(cv2, cv3)
+        self.act = nn.SiLU()
+        self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
+
+    def forward(self, x):
+        y1 = self.cv3(self.m(self.cv1(x)))
+        y2 = self.cv2(x)
+        return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
+
+
+class C3(nn.Module):
+    # CSP Bottleneck with 3 convolutions
+    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
+        super().__init__()
+        c_ = int(c2 * e)  # hidden channels
+        self.cv1 = Conv(c1, c_, 1, 1)
+        self.cv2 = Conv(c1, c_, 1, 1)
+        self.cv3 = Conv(2 * c_, c2, 1)  # act=FReLU(c2)
+        self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
+        # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
+
+    def forward(self, x):
+        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
+
+
+class C3TR(C3):
+    # C3 module with TransformerBlock()
+    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
+        super().__init__(c1, c2, n, shortcut, g, e)
+        c_ = int(c2 * e)
+        self.m = TransformerBlock(c_, c_, 4, n)
+
+
+class C3SPP(C3):
+    # C3 module with SPP()
+    def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5):
+        super().__init__(c1, c2, n, shortcut, g, e)
+        c_ = int(c2 * e)
+        self.m = SPP(c_, c_, k)
+
+
+class C3Ghost(C3):
+    # C3 module with GhostBottleneck()
+    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
+        super().__init__(c1, c2, n, shortcut, g, e)
+        c_ = int(c2 * e)  # hidden channels
+        self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
+
+
+class SPP(nn.Module):
+    # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
+    def __init__(self, c1, c2, k=(5, 9, 13)):
+        super().__init__()
+        c_ = c1 // 2  # hidden channels
+        self.cv1 = Conv(c1, c_, 1, 1)
+        self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
+        self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
+
+    def forward(self, x):
+        x = self.cv1(x)
+        with warnings.catch_warnings():
+            warnings.simplefilter('ignore')  # suppress torch 1.9.0 max_pool2d() warning
+            return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
+
+
+class SPPF(nn.Module):
+    # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
+    def __init__(self, c1, c2, k=5):  # equivalent to SPP(k=(5, 9, 13))
+        super().__init__()
+        c_ = c1 // 2  # hidden channels
+        self.cv1 = Conv(c1, c_, 1, 1)
+        self.cv2 = Conv(c_ * 4, c2, 1, 1)
+        self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
+
+    def forward(self, x):
+        x = self.cv1(x)
+        with warnings.catch_warnings():
+            warnings.simplefilter('ignore')  # suppress torch 1.9.0 max_pool2d() warning
+            y1 = self.m(x)
+            y2 = self.m(y1)
+            return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
+
+
+class Focus(nn.Module):
+    # Focus wh information into c-space
+    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
+        super().__init__()
+        self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
+        # self.contract = Contract(gain=2)
+
+    def forward(self, x):  # x(b,c,w,h) -> y(b,4c,w/2,h/2)
+        return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
+        # return self.conv(self.contract(x))
+
+
+class GhostConv(nn.Module):
+    # Ghost Convolution https://github.com/huawei-noah/ghostnet
+    def __init__(self, c1, c2, k=1, s=1, g=1, act=True):  # ch_in, ch_out, kernel, stride, groups
+        super().__init__()
+        c_ = c2 // 2  # hidden channels
+        self.cv1 = Conv(c1, c_, k, s, None, g, act)
+        self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
+
+    def forward(self, x):
+        y = self.cv1(x)
+        return torch.cat([y, self.cv2(y)], 1)
+
+
+class GhostBottleneck(nn.Module):
+    # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
+    def __init__(self, c1, c2, k=3, s=1):  # ch_in, ch_out, kernel, stride
+        super().__init__()
+        c_ = c2 // 2
+        self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1),  # pw
+                                  DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(),  # dw
+                                  GhostConv(c_, c2, 1, 1, act=False))  # pw-linear
+        self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
+                                      Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
+
+    def forward(self, x):
+        return self.conv(x) + self.shortcut(x)
+
+
+class Contract(nn.Module):
+    # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
+    def __init__(self, gain=2):
+        super().__init__()
+        self.gain = gain
+
+    def forward(self, x):
+        b, c, h, w = x.size()  # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
+        s = self.gain
+        x = x.view(b, c, h // s, s, w // s, s)  # x(1,64,40,2,40,2)
+        x = x.permute(0, 3, 5, 1, 2, 4).contiguous()  # x(1,2,2,64,40,40)
+        return x.view(b, c * s * s, h // s, w // s)  # x(1,256,40,40)
+
+
+class Expand(nn.Module):
+    # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
+    def __init__(self, gain=2):
+        super().__init__()
+        self.gain = gain
+
+    def forward(self, x):
+        b, c, h, w = x.size()  # assert C / s ** 2 == 0, 'Indivisible gain'
+        s = self.gain
+        x = x.view(b, s, s, c // s ** 2, h, w)  # x(1,2,2,16,80,80)
+        x = x.permute(0, 3, 4, 1, 5, 2).contiguous()  # x(1,16,80,2,80,2)
+        return x.view(b, c // s ** 2, h * s, w * s)  # x(1,16,160,160)
+
+
+class Concat(nn.Module):
+    # Concatenate a list of tensors along dimension
+    def __init__(self, dimension=1):
+        super().__init__()
+        self.d = dimension
+
+    def forward(self, x):
+        return torch.cat(x, self.d)
+
+
+class DetectMultiBackend(nn.Module):
+    # YOLOv5 MultiBackend class for python inference on various backends
+    def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
+        # Usage:
+        #   PyTorch:      weights = *.pt
+        #   TorchScript:            *.torchscript
+        #   CoreML:                 *.mlmodel
+        #   OpenVINO:               *.xml
+        #   TensorFlow:             *_saved_model
+        #   TensorFlow:             *.pb
+        #   TensorFlow Lite:        *.tflite
+        #   TensorFlow Edge TPU:    *_edgetpu.tflite
+        #   ONNX Runtime:           *.onnx
+        #   OpenCV DNN:             *.onnx with dnn=True
+        #   TensorRT:               *.engine
+        from models.experimental import attempt_download, attempt_load  # scoped to avoid circular import
+
+        super().__init__()
+        w = str(weights[0] if isinstance(weights, list) else weights)
+        suffix = Path(w).suffix.lower()
+        suffixes = ['.pt', '.torchscript', '.onnx', '.engine', '.tflite', '.pb', '', '.mlmodel', '.xml']
+        check_suffix(w, suffixes)  # check weights have acceptable suffix
+        pt, jit, onnx, engine, tflite, pb, saved_model, coreml, xml = (suffix == x for x in suffixes)  # backends
+        stride, names = 64, [f'class{i}' for i in range(1000)]  # assign defaults
+        w = attempt_download(w)  # download if not local
+        if data:  # data.yaml path (optional)
+            with open(data, errors='ignore') as f:
+                names = yaml.safe_load(f)['names']  # class names
+
+        if pt:  # PyTorch
+            model = attempt_load(weights if isinstance(weights, list) else w, map_location=device)
+            stride = max(int(model.stride.max()), 32)  # model stride
+            names = model.module.names if hasattr(model, 'module') else model.names  # get class names
+            self.model = model  # explicitly assign for to(), cpu(), cuda(), half()
+        elif jit:  # TorchScript
+            LOGGER.info(f'Loading {w} for TorchScript inference...')
+            extra_files = {'config.txt': ''}  # model metadata
+            model = torch.jit.load(w, _extra_files=extra_files)
+            if extra_files['config.txt']:
+                d = json.loads(extra_files['config.txt'])  # extra_files dict
+                stride, names = int(d['stride']), d['names']
+        elif dnn:  # ONNX OpenCV DNN
+            LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
+            check_requirements(('opencv-python>=4.5.4',))
+            net = cv2.dnn.readNetFromONNX(w)
+        elif onnx:  # ONNX Runtime
+            LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
+            cuda = torch.cuda.is_available()
+            check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
+            import onnxruntime
+            providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
+            session = onnxruntime.InferenceSession(w, providers=providers)
+        elif xml:  # OpenVINO
+            LOGGER.info(f'Loading {w} for OpenVINO inference...')
+            check_requirements(('openvino-dev',))  # requires openvino-dev: https://pypi.org/project/openvino-dev/
+            import openvino.inference_engine as ie
+            core = ie.IECore()
+            network = core.read_network(model=w, weights=Path(w).with_suffix('.bin'))  # *.xml, *.bin paths
+            executable_network = core.load_network(network, device_name='CPU', num_requests=1)
+        elif engine:  # TensorRT
+            LOGGER.info(f'Loading {w} for TensorRT inference...')
+            import tensorrt as trt  # https://developer.nvidia.com/nvidia-tensorrt-download
+            check_version(trt.__version__, '7.0.0', hard=True)  # require tensorrt>=7.0.0
+            Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
+            logger = trt.Logger(trt.Logger.INFO)
+            with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
+                model = runtime.deserialize_cuda_engine(f.read())
+            bindings = OrderedDict()
+            for index in range(model.num_bindings):
+                name = model.get_binding_name(index)
+                dtype = trt.nptype(model.get_binding_dtype(index))
+                shape = tuple(model.get_binding_shape(index))
+                data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
+                bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
+            binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
+            context = model.create_execution_context()
+            batch_size = bindings['images'].shape[0]
+        elif coreml:  # CoreML
+            LOGGER.info(f'Loading {w} for CoreML inference...')
+            import coremltools as ct
+            model = ct.models.MLModel(w)
+        else:  # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
+            if saved_model:  # SavedModel
+                LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
+                import tensorflow as tf
+                model = tf.keras.models.load_model(w)
+            elif pb:  # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
+                LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
+                import tensorflow as tf
+
+                def wrap_frozen_graph(gd, inputs, outputs):
+                    x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), [])  # wrapped
+                    return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
+                                   tf.nest.map_structure(x.graph.as_graph_element, outputs))
+
+                graph_def = tf.Graph().as_graph_def()
+                graph_def.ParseFromString(open(w, 'rb').read())
+                frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
+            elif tflite:  # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
+                try:  # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
+                    from tflite_runtime.interpreter import Interpreter, load_delegate
+                except ImportError:
+                    import tensorflow as tf
+                    Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
+                if 'edgetpu' in w.lower():  # Edge TPU https://coral.ai/software/#edgetpu-runtime
+                    LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
+                    delegate = {'Linux': 'libedgetpu.so.1',
+                                'Darwin': 'libedgetpu.1.dylib',
+                                'Windows': 'edgetpu.dll'}[platform.system()]
+                    interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
+                else:  # Lite
+                    LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
+                    interpreter = Interpreter(model_path=w)  # load TFLite model
+                interpreter.allocate_tensors()  # allocate
+                input_details = interpreter.get_input_details()  # inputs
+                output_details = interpreter.get_output_details()  # outputs
+        self.__dict__.update(locals())  # assign all variables to self
+
+    def forward(self, im, augment=False, visualize=False, val=False):
+        # YOLOv5 MultiBackend inference
+        b, ch, h, w = im.shape  # batch, channel, height, width
+        if self.pt or self.jit:  # PyTorch
+            y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize)
+            return y if val else y[0]
+        elif self.dnn:  # ONNX OpenCV DNN
+            im = im.cpu().numpy()  # torch to numpy
+            self.net.setInput(im)
+            y = self.net.forward()
+        elif self.onnx:  # ONNX Runtime
+            im = im.cpu().numpy()  # torch to numpy
+            y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
+        elif self.xml:  # OpenVINO
+            im = im.cpu().numpy()  # FP32
+            desc = self.ie.TensorDesc(precision='FP32', dims=im.shape, layout='NCHW')  # Tensor Description
+            request = self.executable_network.requests[0]  # inference request
+            request.set_blob(blob_name='images', blob=self.ie.Blob(desc, im))  # name=next(iter(request.input_blobs))
+            request.infer()
+            y = request.output_blobs['output'].buffer  # name=next(iter(request.output_blobs))
+        elif self.engine:  # TensorRT
+            assert im.shape == self.bindings['images'].shape, (im.shape, self.bindings['images'].shape)
+            self.binding_addrs['images'] = int(im.data_ptr())
+            self.context.execute_v2(list(self.binding_addrs.values()))
+            y = self.bindings['output'].data
+        elif self.coreml:  # CoreML
+            im = im.permute(0, 2, 3, 1).cpu().numpy()  # torch BCHW to numpy BHWC shape(1,320,192,3)
+            im = Image.fromarray((im[0] * 255).astype('uint8'))
+            # im = im.resize((192, 320), Image.ANTIALIAS)
+            y = self.model.predict({'image': im})  # coordinates are xywh normalized
+            if 'confidence' in y:
+                box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]])  # xyxy pixels
+                conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
+                y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
+            else:
+                y = y[sorted(y)[-1]]  # last output
+        else:  # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
+            im = im.permute(0, 2, 3, 1).cpu().numpy()  # torch BCHW to numpy BHWC shape(1,320,192,3)
+            if self.saved_model:  # SavedModel
+                y = self.model(im, training=False).numpy()
+            elif self.pb:  # GraphDef
+                y = self.frozen_func(x=self.tf.constant(im)).numpy()
+            elif self.tflite:  # Lite
+                input, output = self.input_details[0], self.output_details[0]
+                int8 = input['dtype'] == np.uint8  # is TFLite quantized uint8 model
+                if int8:
+                    scale, zero_point = input['quantization']
+                    im = (im / scale + zero_point).astype(np.uint8)  # de-scale
+                self.interpreter.set_tensor(input['index'], im)
+                self.interpreter.invoke()
+                y = self.interpreter.get_tensor(output['index'])
+                if int8:
+                    scale, zero_point = output['quantization']
+                    y = (y.astype(np.float32) - zero_point) * scale  # re-scale
+            y[..., :4] *= [w, h, w, h]  # xywh normalized to pixels
+
+        y = torch.tensor(y) if isinstance(y, np.ndarray) else y
+        return (y, []) if val else y
+
+    def warmup(self, imgsz=(1, 3, 640, 640), half=False):
+        # Warmup model by running inference once
+        if self.pt or self.jit or self.onnx or self.engine:  # warmup types
+            if isinstance(self.device, torch.device) and self.device.type != 'cpu':  # only warmup GPU models
+                im = torch.zeros(*imgsz).to(self.device).type(torch.half if half else torch.float)  # input image
+                self.forward(im)  # warmup
+
+
+class AutoShape(nn.Module):
+    # YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
+    conf = 0.25  # NMS confidence threshold
+    iou = 0.45  # NMS IoU threshold
+    agnostic = False  # NMS class-agnostic
+    multi_label = False  # NMS multiple labels per box
+    classes = None  # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
+    max_det = 1000  # maximum number of detections per image
+    amp = False  # Automatic Mixed Precision (AMP) inference
+
+    def __init__(self, model):
+        super().__init__()
+        LOGGER.info('Adding AutoShape... ')
+        copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=())  # copy attributes
+        self.dmb = isinstance(model, DetectMultiBackend)  # DetectMultiBackend() instance
+        self.pt = not self.dmb or model.pt  # PyTorch model
+        self.model = model.eval()
+
+    def _apply(self, fn):
+        # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
+        self = super()._apply(fn)
+        if self.pt:
+            m = self.model.model.model[-1] if self.dmb else self.model.model[-1]  # Detect()
+            m.stride = fn(m.stride)
+            m.grid = list(map(fn, m.grid))
+            if isinstance(m.anchor_grid, list):
+                m.anchor_grid = list(map(fn, m.anchor_grid))
+        return self
+
+    @torch.no_grad()
+    def forward(self, imgs, size=640, augment=False, profile=False):
+        # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
+        #   file:       imgs = 'data/images/zidane.jpg'  # str or PosixPath
+        #   URI:             = 'https://ultralytics.com/images/zidane.jpg'
+        #   OpenCV:          = cv2.imread('image.jpg')[:,:,::-1]  # HWC BGR to RGB x(640,1280,3)
+        #   PIL:             = Image.open('image.jpg') or ImageGrab.grab()  # HWC x(640,1280,3)
+        #   numpy:           = np.zeros((640,1280,3))  # HWC
+        #   torch:           = torch.zeros(16,3,320,640)  # BCHW (scaled to size=640, 0-1 values)
+        #   multiple:        = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...]  # list of images
+
+        t = [time_sync()]
+        p = next(self.model.parameters()) if self.pt else torch.zeros(1)  # for device and type
+        autocast = self.amp and (p.device.type != 'cpu')  # Automatic Mixed Precision (AMP) inference
+        if isinstance(imgs, torch.Tensor):  # torch
+            with amp.autocast(enabled=autocast):
+                return self.model(imgs.to(p.device).type_as(p), augment, profile)  # inference
+
+        # Pre-process
+        n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs])  # number of images, list of images
+        shape0, shape1, files = [], [], []  # image and inference shapes, filenames
+        for i, im in enumerate(imgs):
+            f = f'image{i}'  # filename
+            if isinstance(im, (str, Path)):  # filename or uri
+                im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
+                im = np.asarray(exif_transpose(im))
+            elif isinstance(im, Image.Image):  # PIL Image
+                im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
+            files.append(Path(f).with_suffix('.jpg').name)
+            if im.shape[0] < 5:  # image in CHW
+                im = im.transpose((1, 2, 0))  # reverse dataloader .transpose(2, 0, 1)
+            im = im[..., :3] if im.ndim == 3 else np.tile(im[..., None], 3)  # enforce 3ch input
+            s = im.shape[:2]  # HWC
+            shape0.append(s)  # image shape
+            g = (size / max(s))  # gain
+            shape1.append([y * g for y in s])
+            imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im)  # update
+        shape1 = [make_divisible(x, self.stride) for x in np.stack(shape1, 0).max(0)]  # inference shape
+        x = [letterbox(im, new_shape=shape1 if self.pt else size, auto=False)[0] for im in imgs]  # pad
+        x = np.stack(x, 0) if n > 1 else x[0][None]  # stack
+        x = np.ascontiguousarray(x.transpose((0, 3, 1, 2)))  # BHWC to BCHW
+        x = torch.from_numpy(x).to(p.device).type_as(p) / 255  # uint8 to fp16/32
+        t.append(time_sync())
+
+        with amp.autocast(enabled=autocast):
+            # Inference
+            y = self.model(x, augment, profile)  # forward
+            t.append(time_sync())
+
+            # Post-process
+            y = non_max_suppression(y if self.dmb else y[0], self.conf, iou_thres=self.iou, classes=self.classes,
+                                    agnostic=self.agnostic, multi_label=self.multi_label, max_det=self.max_det)  # NMS
+            for i in range(n):
+                scale_coords(shape1, y[i][:, :4], shape0[i])
+
+            t.append(time_sync())
+            return Detections(imgs, y, files, t, self.names, x.shape)
+
+
+class Detections:
+    # YOLOv5 detections class for inference results
+    def __init__(self, imgs, pred, files, times=(0, 0, 0, 0), names=None, shape=None):
+        super().__init__()
+        d = pred[0].device  # device
+        gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in imgs]  # normalizations
+        self.imgs = imgs  # list of images as numpy arrays
+        self.pred = pred  # list of tensors pred[0] = (xyxy, conf, cls)
+        self.names = names  # class names
+        self.files = files  # image filenames
+        self.times = times  # profiling times
+        self.xyxy = pred  # xyxy pixels
+        self.xywh = [xyxy2xywh(x) for x in pred]  # xywh pixels
+        self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)]  # xyxy normalized
+        self.xywhn = [x / g for x, g in zip(self.xywh, gn)]  # xywh normalized
+        self.n = len(self.pred)  # number of images (batch size)
+        self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3))  # timestamps (ms)
+        self.s = shape  # inference BCHW shape
+
+    def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')):
+        crops = []
+        for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
+            s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} '  # string
+            if pred.shape[0]:
+                for c in pred[:, -1].unique():
+                    n = (pred[:, -1] == c).sum()  # detections per class
+                    s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, "  # add to string
+                if show or save or render or crop:
+                    annotator = Annotator(im, example=str(self.names))
+                    for *box, conf, cls in reversed(pred):  # xyxy, confidence, class
+                        label = f'{self.names[int(cls)]} {conf:.2f}'
+                        if crop:
+                            file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
+                            crops.append({'box': box, 'conf': conf, 'cls': cls, 'label': label,
+                                          'im': save_one_box(box, im, file=file, save=save)})
+                        else:  # all others
+                            annotator.box_label(box, label, color=colors(cls))
+                    im = annotator.im
+            else:
+                s += '(no detections)'
+
+            im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im  # from np
+            if pprint:
+                LOGGER.info(s.rstrip(', '))
+            if show:
+                im.show(self.files[i])  # show
+            if save:
+                f = self.files[i]
+                im.save(save_dir / f)  # save
+                if i == self.n - 1:
+                    LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
+            if render:
+                self.imgs[i] = np.asarray(im)
+        if crop:
+            if save:
+                LOGGER.info(f'Saved results to {save_dir}\n')
+            return crops
+
+    def print(self):
+        self.display(pprint=True)  # print results
+        LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' %
+                    self.t)
+
+    def show(self):
+        self.display(show=True)  # show results
+
+    def save(self, save_dir='runs/detect/exp'):
+        save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True)  # increment save_dir
+        self.display(save=True, save_dir=save_dir)  # save results
+
+    def crop(self, save=True, save_dir='runs/detect/exp'):
+        save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) if save else None
+        return self.display(crop=True, save=save, save_dir=save_dir)  # crop results
+
+    def render(self):
+        self.display(render=True)  # render results
+        return self.imgs
+
+    def pandas(self):
+        # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
+        new = copy(self)  # return copy
+        ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name'  # xyxy columns
+        cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name'  # xywh columns
+        for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
+            a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)]  # update
+            setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
+        return new
+
+    def tolist(self):
+        # return a list of Detections objects, i.e. 'for result in results.tolist():'
+        r = range(self.n)  # iterable
+        x = [Detections([self.imgs[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
+        # for d in x:
+        #    for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
+        #        setattr(d, k, getattr(d, k)[0])  # pop out of list
+        return x
+
+    def __len__(self):
+        return self.n
+
+
+class Classify(nn.Module):
+    # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
+    def __init__(self, c1, c2, k=1, s=1, p=None, g=1):  # ch_in, ch_out, kernel, stride, padding, groups
+        super().__init__()
+        self.aap = nn.AdaptiveAvgPool2d(1)  # to x(b,c1,1,1)
+        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g)  # to x(b,c2,1,1)
+        self.flat = nn.Flatten()
+
+    def forward(self, x):
+        z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1)  # cat if list
+        return self.flat(self.conv(z))  # flatten to x(b,c2)

+ 120 - 0
models/experimental.py

@@ -0,0 +1,120 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+"""
+Experimental modules
+"""
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from models.common import Conv
+from utils.downloads import attempt_download
+
+
+class CrossConv(nn.Module):
+    # Cross Convolution Downsample
+    def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
+        # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
+        super().__init__()
+        c_ = int(c2 * e)  # hidden channels
+        self.cv1 = Conv(c1, c_, (1, k), (1, s))
+        self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
+        self.add = shortcut and c1 == c2
+
+    def forward(self, x):
+        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
+
+
+class Sum(nn.Module):
+    # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
+    def __init__(self, n, weight=False):  # n: number of inputs
+        super().__init__()
+        self.weight = weight  # apply weights boolean
+        self.iter = range(n - 1)  # iter object
+        if weight:
+            self.w = nn.Parameter(-torch.arange(1.0, n) / 2, requires_grad=True)  # layer weights
+
+    def forward(self, x):
+        y = x[0]  # no weight
+        if self.weight:
+            w = torch.sigmoid(self.w) * 2
+            for i in self.iter:
+                y = y + x[i + 1] * w[i]
+        else:
+            for i in self.iter:
+                y = y + x[i + 1]
+        return y
+
+
+class MixConv2d(nn.Module):
+    # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
+    def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):  # ch_in, ch_out, kernel, stride, ch_strategy
+        super().__init__()
+        n = len(k)  # number of convolutions
+        if equal_ch:  # equal c_ per group
+            i = torch.linspace(0, n - 1E-6, c2).floor()  # c2 indices
+            c_ = [(i == g).sum() for g in range(n)]  # intermediate channels
+        else:  # equal weight.numel() per group
+            b = [c2] + [0] * n
+            a = np.eye(n + 1, n, k=-1)
+            a -= np.roll(a, 1, axis=1)
+            a *= np.array(k) ** 2
+            a[0] = 1
+            c_ = np.linalg.lstsq(a, b, rcond=None)[0].round()  # solve for equal weight indices, ax = b
+
+        self.m = nn.ModuleList(
+            [nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
+        self.bn = nn.BatchNorm2d(c2)
+        self.act = nn.SiLU()
+
+    def forward(self, x):
+        return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
+
+
+class Ensemble(nn.ModuleList):
+    # Ensemble of models
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x, augment=False, profile=False, visualize=False):
+        y = []
+        for module in self:
+            y.append(module(x, augment, profile, visualize)[0])
+        # y = torch.stack(y).max(0)[0]  # max ensemble
+        # y = torch.stack(y).mean(0)  # mean ensemble
+        y = torch.cat(y, 1)  # nms ensemble
+        return y, None  # inference, train output
+
+
+def attempt_load(weights, map_location=None, inplace=True, fuse=True):
+    from models.yolo import Detect, Model
+
+    # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
+    model = Ensemble()
+    for w in weights if isinstance(weights, list) else [weights]:
+        ckpt = torch.load(attempt_download(w), map_location=map_location)  # load
+        if fuse:
+            model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval())  # FP32 model
+        else:
+            model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval())  # without layer fuse
+
+    # Compatibility updates
+    for m in model.modules():
+        if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
+            m.inplace = inplace  # pytorch 1.7.0 compatibility
+            if type(m) is Detect:
+                if not isinstance(m.anchor_grid, list):  # new Detect Layer compatibility
+                    delattr(m, 'anchor_grid')
+                    setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
+        elif type(m) is Conv:
+            m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibility
+
+    if len(model) == 1:
+        return model[-1]  # return model
+    else:
+        print(f'Ensemble created with {weights}\n')
+        for k in ['names']:
+            setattr(model, k, getattr(model[-1], k))
+        model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride  # max stride
+        return model  # return ensemble

+ 59 - 0
models/hub/anchors.yaml

@@ -0,0 +1,59 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+# Default anchors for COCO data
+
+
+# P5 -------------------------------------------------------------------------------------------------------------------
+# P5-640:
+anchors_p5_640:
+  - [10,13, 16,30, 33,23]  # P3/8
+  - [30,61, 62,45, 59,119]  # P4/16
+  - [116,90, 156,198, 373,326]  # P5/32
+
+
+# P6 -------------------------------------------------------------------------------------------------------------------
+# P6-640:  thr=0.25: 0.9964 BPR, 5.54 anchors past thr, n=12, img_size=640, metric_all=0.281/0.716-mean/best, past_thr=0.469-mean: 9,11,  21,19,  17,41,  43,32,  39,70,  86,64,  65,131,  134,130,  120,265,  282,180,  247,354,  512,387
+anchors_p6_640:
+  - [9,11,  21,19,  17,41]  # P3/8
+  - [43,32,  39,70,  86,64]  # P4/16
+  - [65,131,  134,130,  120,265]  # P5/32
+  - [282,180,  247,354,  512,387]  # P6/64
+
+# P6-1280:  thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1280, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 19,27,  44,40,  38,94,  96,68,  86,152,  180,137,  140,301,  303,264,  238,542,  436,615,  739,380,  925,792
+anchors_p6_1280:
+  - [19,27,  44,40,  38,94]  # P3/8
+  - [96,68,  86,152,  180,137]  # P4/16
+  - [140,301,  303,264,  238,542]  # P5/32
+  - [436,615,  739,380,  925,792]  # P6/64
+
+# P6-1920:  thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1920, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 28,41,  67,59,  57,141,  144,103,  129,227,  270,205,  209,452,  455,396,  358,812,  653,922,  1109,570,  1387,1187
+anchors_p6_1920:
+  - [28,41,  67,59,  57,141]  # P3/8
+  - [144,103,  129,227,  270,205]  # P4/16
+  - [209,452,  455,396,  358,812]  # P5/32
+  - [653,922,  1109,570,  1387,1187]  # P6/64
+
+
+# P7 -------------------------------------------------------------------------------------------------------------------
+# P7-640:  thr=0.25: 0.9962 BPR, 6.76 anchors past thr, n=15, img_size=640, metric_all=0.275/0.733-mean/best, past_thr=0.466-mean: 11,11,  13,30,  29,20,  30,46,  61,38,  39,92,  78,80,  146,66,  79,163,  149,150,  321,143,  157,303,  257,402,  359,290,  524,372
+anchors_p7_640:
+  - [11,11,  13,30,  29,20]  # P3/8
+  - [30,46,  61,38,  39,92]  # P4/16
+  - [78,80,  146,66,  79,163]  # P5/32
+  - [149,150,  321,143,  157,303]  # P6/64
+  - [257,402,  359,290,  524,372]  # P7/128
+
+# P7-1280:  thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1280, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 19,22,  54,36,  32,77,  70,83,  138,71,  75,173,  165,159,  148,334,  375,151,  334,317,  251,626,  499,474,  750,326,  534,814,  1079,818
+anchors_p7_1280:
+  - [19,22,  54,36,  32,77]  # P3/8
+  - [70,83,  138,71,  75,173]  # P4/16
+  - [165,159,  148,334,  375,151]  # P5/32
+  - [334,317,  251,626,  499,474]  # P6/64
+  - [750,326,  534,814,  1079,818]  # P7/128
+
+# P7-1920:  thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1920, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 29,34,  81,55,  47,115,  105,124,  207,107,  113,259,  247,238,  222,500,  563,227,  501,476,  376,939,  749,711,  1126,489,  801,1222,  1618,1227
+anchors_p7_1920:
+  - [29,34,  81,55,  47,115]  # P3/8
+  - [105,124,  207,107,  113,259]  # P4/16
+  - [247,238,  222,500,  563,227]  # P5/32
+  - [501,476,  376,939,  749,711]  # P6/64
+  - [1126,489,  801,1222,  1618,1227]  # P7/128

+ 51 - 0
models/hub/yolov3-spp.yaml

@@ -0,0 +1,51 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 1.0  # model depth multiple
+width_multiple: 1.0  # layer channel multiple
+anchors:
+  - [10,13, 16,30, 33,23]  # P3/8
+  - [30,61, 62,45, 59,119]  # P4/16
+  - [116,90, 156,198, 373,326]  # P5/32
+
+# darknet53 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [32, 3, 1]],  # 0
+   [-1, 1, Conv, [64, 3, 2]],  # 1-P1/2
+   [-1, 1, Bottleneck, [64]],
+   [-1, 1, Conv, [128, 3, 2]],  # 3-P2/4
+   [-1, 2, Bottleneck, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 5-P3/8
+   [-1, 8, Bottleneck, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 7-P4/16
+   [-1, 8, Bottleneck, [512]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 9-P5/32
+   [-1, 4, Bottleneck, [1024]],  # 10
+  ]
+
+# YOLOv3-SPP head
+head:
+  [[-1, 1, Bottleneck, [1024, False]],
+   [-1, 1, SPP, [512, [5, 9, 13]]],
+   [-1, 1, Conv, [1024, 3, 1]],
+   [-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, Conv, [1024, 3, 1]],  # 15 (P5/32-large)
+
+   [-2, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 8], 1, Concat, [1]],  # cat backbone P4
+   [-1, 1, Bottleneck, [512, False]],
+   [-1, 1, Bottleneck, [512, False]],
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, Conv, [512, 3, 1]],  # 22 (P4/16-medium)
+
+   [-2, 1, Conv, [128, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P3
+   [-1, 1, Bottleneck, [256, False]],
+   [-1, 2, Bottleneck, [256, False]],  # 27 (P3/8-small)
+
+   [[27, 22, 15], 1, Detect, [nc, anchors]],   # Detect(P3, P4, P5)
+  ]

+ 41 - 0
models/hub/yolov3-tiny.yaml

@@ -0,0 +1,41 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 1.0  # model depth multiple
+width_multiple: 1.0  # layer channel multiple
+anchors:
+  - [10,14, 23,27, 37,58]  # P4/16
+  - [81,82, 135,169, 344,319]  # P5/32
+
+# YOLOv3-tiny backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [16, 3, 1]],  # 0
+   [-1, 1, nn.MaxPool2d, [2, 2, 0]],  # 1-P1/2
+   [-1, 1, Conv, [32, 3, 1]],
+   [-1, 1, nn.MaxPool2d, [2, 2, 0]],  # 3-P2/4
+   [-1, 1, Conv, [64, 3, 1]],
+   [-1, 1, nn.MaxPool2d, [2, 2, 0]],  # 5-P3/8
+   [-1, 1, Conv, [128, 3, 1]],
+   [-1, 1, nn.MaxPool2d, [2, 2, 0]],  # 7-P4/16
+   [-1, 1, Conv, [256, 3, 1]],
+   [-1, 1, nn.MaxPool2d, [2, 2, 0]],  # 9-P5/32
+   [-1, 1, Conv, [512, 3, 1]],
+   [-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]],  # 11
+   [-1, 1, nn.MaxPool2d, [2, 1, 0]],  # 12
+  ]
+
+# YOLOv3-tiny head
+head:
+  [[-1, 1, Conv, [1024, 3, 1]],
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, Conv, [512, 3, 1]],  # 15 (P5/32-large)
+
+   [-2, 1, Conv, [128, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 8], 1, Concat, [1]],  # cat backbone P4
+   [-1, 1, Conv, [256, 3, 1]],  # 19 (P4/16-medium)
+
+   [[19, 15], 1, Detect, [nc, anchors]],  # Detect(P4, P5)
+  ]

+ 51 - 0
models/hub/yolov3.yaml

@@ -0,0 +1,51 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 1.0  # model depth multiple
+width_multiple: 1.0  # layer channel multiple
+anchors:
+  - [10,13, 16,30, 33,23]  # P3/8
+  - [30,61, 62,45, 59,119]  # P4/16
+  - [116,90, 156,198, 373,326]  # P5/32
+
+# darknet53 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [32, 3, 1]],  # 0
+   [-1, 1, Conv, [64, 3, 2]],  # 1-P1/2
+   [-1, 1, Bottleneck, [64]],
+   [-1, 1, Conv, [128, 3, 2]],  # 3-P2/4
+   [-1, 2, Bottleneck, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 5-P3/8
+   [-1, 8, Bottleneck, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 7-P4/16
+   [-1, 8, Bottleneck, [512]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 9-P5/32
+   [-1, 4, Bottleneck, [1024]],  # 10
+  ]
+
+# YOLOv3 head
+head:
+  [[-1, 1, Bottleneck, [1024, False]],
+   [-1, 1, Conv, [512, [1, 1]]],
+   [-1, 1, Conv, [1024, 3, 1]],
+   [-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, Conv, [1024, 3, 1]],  # 15 (P5/32-large)
+
+   [-2, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 8], 1, Concat, [1]],  # cat backbone P4
+   [-1, 1, Bottleneck, [512, False]],
+   [-1, 1, Bottleneck, [512, False]],
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, Conv, [512, 3, 1]],  # 22 (P4/16-medium)
+
+   [-2, 1, Conv, [128, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P3
+   [-1, 1, Bottleneck, [256, False]],
+   [-1, 2, Bottleneck, [256, False]],  # 27 (P3/8-small)
+
+   [[27, 22, 15], 1, Detect, [nc, anchors]],   # Detect(P3, P4, P5)
+  ]

+ 48 - 0
models/hub/yolov5-bifpn.yaml

@@ -0,0 +1,48 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 1.0  # model depth multiple
+width_multiple: 1.0  # layer channel multiple
+anchors:
+  - [10,13, 16,30, 33,23]  # P3/8
+  - [30,61, 62,45, 59,119]  # P4/16
+  - [116,90, 156,198, 373,326]  # P5/32
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 6, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
+   [-1, 3, C3, [1024]],
+   [-1, 1, SPPF, [1024, 5]],  # 9
+  ]
+
+# YOLOv5 v6.0 BiFPN head
+head:
+  [[-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 13
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 14, 6], 1, Concat, [1]],  # cat P4 <--- BiFPN change
+   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 10], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
+
+   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
+  ]

+ 42 - 0
models/hub/yolov5-fpn.yaml

@@ -0,0 +1,42 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 1.0  # model depth multiple
+width_multiple: 1.0  # layer channel multiple
+anchors:
+  - [10,13, 16,30, 33,23]  # P3/8
+  - [30,61, 62,45, 59,119]  # P4/16
+  - [116,90, 156,198, 373,326]  # P5/32
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 6, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
+   [-1, 3, C3, [1024]],
+   [-1, 1, SPPF, [1024, 5]],  # 9
+  ]
+
+# YOLOv5 v6.0 FPN head
+head:
+  [[-1, 3, C3, [1024, False]],  # 10 (P5/32-large)
+
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 1, Conv, [512, 1, 1]],
+   [-1, 3, C3, [512, False]],  # 14 (P4/16-medium)
+
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 3, C3, [256, False]],  # 18 (P3/8-small)
+
+   [[18, 14, 10], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
+  ]

+ 54 - 0
models/hub/yolov5-p2.yaml

@@ -0,0 +1,54 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 1.0  # model depth multiple
+width_multiple: 1.0  # layer channel multiple
+anchors: 3  # AutoAnchor evolves 3 anchors per P output layer
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 6, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
+   [-1, 3, C3, [1024]],
+   [-1, 1, SPPF, [1024, 5]],  # 9
+  ]
+
+# YOLOv5 v6.0 head with (P2, P3, P4, P5) outputs
+head:
+  [[-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 13
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)
+
+   [-1, 1, Conv, [128, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 2], 1, Concat, [1]],  # cat backbone P2
+   [-1, 1, C3, [128, False]],  # 21 (P2/4-xsmall)
+
+   [-1, 1, Conv, [128, 3, 2]],
+   [[-1, 18], 1, Concat, [1]],  # cat head P3
+   [-1, 3, C3, [256, False]],  # 24 (P3/8-small)
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 14], 1, Concat, [1]],  # cat head P4
+   [-1, 3, C3, [512, False]],  # 27 (P4/16-medium)
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 10], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3, [1024, False]],  # 30 (P5/32-large)
+
+   [[21, 24, 27, 30], 1, Detect, [nc, anchors]],  # Detect(P2, P3, P4, P5)
+  ]

+ 41 - 0
models/hub/yolov5-p34.yaml

@@ -0,0 +1,41 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 0.33  # model depth multiple
+width_multiple: 0.50  # layer channel multiple
+anchors: 3  # AutoAnchor evolves 3 anchors per P output layer
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [ [ -1, 1, Conv, [ 64, 6, 2, 2 ] ],  # 0-P1/2
+    [ -1, 1, Conv, [ 128, 3, 2 ] ],  # 1-P2/4
+    [ -1, 3, C3, [ 128 ] ],
+    [ -1, 1, Conv, [ 256, 3, 2 ] ],  # 3-P3/8
+    [ -1, 6, C3, [ 256 ] ],
+    [ -1, 1, Conv, [ 512, 3, 2 ] ],  # 5-P4/16
+    [ -1, 9, C3, [ 512 ] ],
+    [ -1, 1, Conv, [ 1024, 3, 2 ] ],  # 7-P5/32
+    [ -1, 3, C3, [ 1024 ] ],
+    [ -1, 1, SPPF, [ 1024, 5 ] ],  # 9
+  ]
+
+# YOLOv5 v6.0 head with (P3, P4) outputs
+head:
+  [ [ -1, 1, Conv, [ 512, 1, 1 ] ],
+    [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ],
+    [ [ -1, 6 ], 1, Concat, [ 1 ] ],  # cat backbone P4
+    [ -1, 3, C3, [ 512, False ] ],  # 13
+
+    [ -1, 1, Conv, [ 256, 1, 1 ] ],
+    [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ],
+    [ [ -1, 4 ], 1, Concat, [ 1 ] ],  # cat backbone P3
+    [ -1, 3, C3, [ 256, False ] ],  # 17 (P3/8-small)
+
+    [ -1, 1, Conv, [ 256, 3, 2 ] ],
+    [ [ -1, 14 ], 1, Concat, [ 1 ] ],  # cat head P4
+    [ -1, 3, C3, [ 512, False ] ],  # 20 (P4/16-medium)
+
+    [ [ 17, 20 ], 1, Detect, [ nc, anchors ] ],  # Detect(P3, P4)
+  ]

+ 56 - 0
models/hub/yolov5-p6.yaml

@@ -0,0 +1,56 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 1.0  # model depth multiple
+width_multiple: 1.0  # layer channel multiple
+anchors: 3  # AutoAnchor evolves 3 anchors per P output layer
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 6, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [768, 3, 2]],  # 7-P5/32
+   [-1, 3, C3, [768]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 9-P6/64
+   [-1, 3, C3, [1024]],
+   [-1, 1, SPPF, [1024, 5]],  # 11
+  ]
+
+# YOLOv5 v6.0 head with (P3, P4, P5, P6) outputs
+head:
+  [[-1, 1, Conv, [768, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 8], 1, Concat, [1]],  # cat backbone P5
+   [-1, 3, C3, [768, False]],  # 15
+
+   [-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 19
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [256, False]],  # 23 (P3/8-small)
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 20], 1, Concat, [1]],  # cat head P4
+   [-1, 3, C3, [512, False]],  # 26 (P4/16-medium)
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 16], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3, [768, False]],  # 29 (P5/32-large)
+
+   [-1, 1, Conv, [768, 3, 2]],
+   [[-1, 12], 1, Concat, [1]],  # cat head P6
+   [-1, 3, C3, [1024, False]],  # 32 (P6/64-xlarge)
+
+   [[23, 26, 29, 32], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5, P6)
+  ]

+ 67 - 0
models/hub/yolov5-p7.yaml

@@ -0,0 +1,67 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 1.0  # model depth multiple
+width_multiple: 1.0  # layer channel multiple
+anchors: 3  # AutoAnchor evolves 3 anchors per P output layer
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 6, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [768, 3, 2]],  # 7-P5/32
+   [-1, 3, C3, [768]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 9-P6/64
+   [-1, 3, C3, [1024]],
+   [-1, 1, Conv, [1280, 3, 2]],  # 11-P7/128
+   [-1, 3, C3, [1280]],
+   [-1, 1, SPPF, [1280, 5]],  # 13
+  ]
+
+# YOLOv5 v6.0 head with (P3, P4, P5, P6, P7) outputs
+head:
+  [[-1, 1, Conv, [1024, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 10], 1, Concat, [1]],  # cat backbone P6
+   [-1, 3, C3, [1024, False]],  # 17
+
+   [-1, 1, Conv, [768, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 8], 1, Concat, [1]],  # cat backbone P5
+   [-1, 3, C3, [768, False]],  # 21
+
+   [-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 25
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [256, False]],  # 29 (P3/8-small)
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 26], 1, Concat, [1]],  # cat head P4
+   [-1, 3, C3, [512, False]],  # 32 (P4/16-medium)
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 22], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3, [768, False]],  # 35 (P5/32-large)
+
+   [-1, 1, Conv, [768, 3, 2]],
+   [[-1, 18], 1, Concat, [1]],  # cat head P6
+   [-1, 3, C3, [1024, False]],  # 38 (P6/64-xlarge)
+
+   [-1, 1, Conv, [1024, 3, 2]],
+   [[-1, 14], 1, Concat, [1]],  # cat head P7
+   [-1, 3, C3, [1280, False]],  # 41 (P7/128-xxlarge)
+
+   [[29, 32, 35, 38, 41], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5, P6, P7)
+  ]

+ 48 - 0
models/hub/yolov5-panet.yaml

@@ -0,0 +1,48 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 1.0  # model depth multiple
+width_multiple: 1.0  # layer channel multiple
+anchors:
+  - [10,13, 16,30, 33,23]  # P3/8
+  - [30,61, 62,45, 59,119]  # P4/16
+  - [116,90, 156,198, 373,326]  # P5/32
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 6, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
+   [-1, 3, C3, [1024]],
+   [-1, 1, SPPF, [1024, 5]],  # 9
+  ]
+
+# YOLOv5 v6.0 PANet head
+head:
+  [[-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 13
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 14], 1, Concat, [1]],  # cat head P4
+   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 10], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
+
+   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
+  ]

+ 60 - 0
models/hub/yolov5l6.yaml

@@ -0,0 +1,60 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 1.0  # model depth multiple
+width_multiple: 1.0  # layer channel multiple
+anchors:
+  - [19,27,  44,40,  38,94]  # P3/8
+  - [96,68,  86,152,  180,137]  # P4/16
+  - [140,301,  303,264,  238,542]  # P5/32
+  - [436,615,  739,380,  925,792]  # P6/64
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 6, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [768, 3, 2]],  # 7-P5/32
+   [-1, 3, C3, [768]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 9-P6/64
+   [-1, 3, C3, [1024]],
+   [-1, 1, SPPF, [1024, 5]],  # 11
+  ]
+
+# YOLOv5 v6.0 head
+head:
+  [[-1, 1, Conv, [768, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 8], 1, Concat, [1]],  # cat backbone P5
+   [-1, 3, C3, [768, False]],  # 15
+
+   [-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 19
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [256, False]],  # 23 (P3/8-small)
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 20], 1, Concat, [1]],  # cat head P4
+   [-1, 3, C3, [512, False]],  # 26 (P4/16-medium)
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 16], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3, [768, False]],  # 29 (P5/32-large)
+
+   [-1, 1, Conv, [768, 3, 2]],
+   [[-1, 12], 1, Concat, [1]],  # cat head P6
+   [-1, 3, C3, [1024, False]],  # 32 (P6/64-xlarge)
+
+   [[23, 26, 29, 32], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5, P6)
+  ]

+ 60 - 0
models/hub/yolov5m6.yaml

@@ -0,0 +1,60 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 0.67  # model depth multiple
+width_multiple: 0.75  # layer channel multiple
+anchors:
+  - [19,27,  44,40,  38,94]  # P3/8
+  - [96,68,  86,152,  180,137]  # P4/16
+  - [140,301,  303,264,  238,542]  # P5/32
+  - [436,615,  739,380,  925,792]  # P6/64
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 6, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [768, 3, 2]],  # 7-P5/32
+   [-1, 3, C3, [768]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 9-P6/64
+   [-1, 3, C3, [1024]],
+   [-1, 1, SPPF, [1024, 5]],  # 11
+  ]
+
+# YOLOv5 v6.0 head
+head:
+  [[-1, 1, Conv, [768, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 8], 1, Concat, [1]],  # cat backbone P5
+   [-1, 3, C3, [768, False]],  # 15
+
+   [-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 19
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [256, False]],  # 23 (P3/8-small)
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 20], 1, Concat, [1]],  # cat head P4
+   [-1, 3, C3, [512, False]],  # 26 (P4/16-medium)
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 16], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3, [768, False]],  # 29 (P5/32-large)
+
+   [-1, 1, Conv, [768, 3, 2]],
+   [[-1, 12], 1, Concat, [1]],  # cat head P6
+   [-1, 3, C3, [1024, False]],  # 32 (P6/64-xlarge)
+
+   [[23, 26, 29, 32], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5, P6)
+  ]

+ 60 - 0
models/hub/yolov5n6.yaml

@@ -0,0 +1,60 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 0.33  # model depth multiple
+width_multiple: 0.25  # layer channel multiple
+anchors:
+  - [19,27,  44,40,  38,94]  # P3/8
+  - [96,68,  86,152,  180,137]  # P4/16
+  - [140,301,  303,264,  238,542]  # P5/32
+  - [436,615,  739,380,  925,792]  # P6/64
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 6, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [768, 3, 2]],  # 7-P5/32
+   [-1, 3, C3, [768]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 9-P6/64
+   [-1, 3, C3, [1024]],
+   [-1, 1, SPPF, [1024, 5]],  # 11
+  ]
+
+# YOLOv5 v6.0 head
+head:
+  [[-1, 1, Conv, [768, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 8], 1, Concat, [1]],  # cat backbone P5
+   [-1, 3, C3, [768, False]],  # 15
+
+   [-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 19
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [256, False]],  # 23 (P3/8-small)
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 20], 1, Concat, [1]],  # cat head P4
+   [-1, 3, C3, [512, False]],  # 26 (P4/16-medium)
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 16], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3, [768, False]],  # 29 (P5/32-large)
+
+   [-1, 1, Conv, [768, 3, 2]],
+   [[-1, 12], 1, Concat, [1]],  # cat head P6
+   [-1, 3, C3, [1024, False]],  # 32 (P6/64-xlarge)
+
+   [[23, 26, 29, 32], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5, P6)
+  ]

+ 48 - 0
models/hub/yolov5s-ghost.yaml

@@ -0,0 +1,48 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 0.33  # model depth multiple
+width_multiple: 0.50  # layer channel multiple
+anchors:
+  - [10,13, 16,30, 33,23]  # P3/8
+  - [30,61, 62,45, 59,119]  # P4/16
+  - [116,90, 156,198, 373,326]  # P5/32
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, GhostConv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3Ghost, [128]],
+   [-1, 1, GhostConv, [256, 3, 2]],  # 3-P3/8
+   [-1, 6, C3Ghost, [256]],
+   [-1, 1, GhostConv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3Ghost, [512]],
+   [-1, 1, GhostConv, [1024, 3, 2]],  # 7-P5/32
+   [-1, 3, C3Ghost, [1024]],
+   [-1, 1, SPPF, [1024, 5]],  # 9
+  ]
+
+# YOLOv5 v6.0 head
+head:
+  [[-1, 1, GhostConv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3Ghost, [512, False]],  # 13
+
+   [-1, 1, GhostConv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3Ghost, [256, False]],  # 17 (P3/8-small)
+
+   [-1, 1, GhostConv, [256, 3, 2]],
+   [[-1, 14], 1, Concat, [1]],  # cat head P4
+   [-1, 3, C3Ghost, [512, False]],  # 20 (P4/16-medium)
+
+   [-1, 1, GhostConv, [512, 3, 2]],
+   [[-1, 10], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3Ghost, [1024, False]],  # 23 (P5/32-large)
+
+   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
+  ]

+ 48 - 0
models/hub/yolov5s-transformer.yaml

@@ -0,0 +1,48 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 0.33  # model depth multiple
+width_multiple: 0.50  # layer channel multiple
+anchors:
+  - [10,13, 16,30, 33,23]  # P3/8
+  - [30,61, 62,45, 59,119]  # P4/16
+  - [116,90, 156,198, 373,326]  # P5/32
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 6, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
+   [-1, 3, C3TR, [1024]],  # 9 <--- C3TR() Transformer module
+   [-1, 1, SPPF, [1024, 5]],  # 9
+  ]
+
+# YOLOv5 v6.0 head
+head:
+  [[-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 13
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 14], 1, Concat, [1]],  # cat head P4
+   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 10], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
+
+   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
+  ]

+ 60 - 0
models/hub/yolov5s6.yaml

@@ -0,0 +1,60 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 0.33  # model depth multiple
+width_multiple: 0.50  # layer channel multiple
+anchors:
+  - [19,27,  44,40,  38,94]  # P3/8
+  - [96,68,  86,152,  180,137]  # P4/16
+  - [140,301,  303,264,  238,542]  # P5/32
+  - [436,615,  739,380,  925,792]  # P6/64
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 6, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [768, 3, 2]],  # 7-P5/32
+   [-1, 3, C3, [768]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 9-P6/64
+   [-1, 3, C3, [1024]],
+   [-1, 1, SPPF, [1024, 5]],  # 11
+  ]
+
+# YOLOv5 v6.0 head
+head:
+  [[-1, 1, Conv, [768, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 8], 1, Concat, [1]],  # cat backbone P5
+   [-1, 3, C3, [768, False]],  # 15
+
+   [-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 19
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [256, False]],  # 23 (P3/8-small)
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 20], 1, Concat, [1]],  # cat head P4
+   [-1, 3, C3, [512, False]],  # 26 (P4/16-medium)
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 16], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3, [768, False]],  # 29 (P5/32-large)
+
+   [-1, 1, Conv, [768, 3, 2]],
+   [[-1, 12], 1, Concat, [1]],  # cat head P6
+   [-1, 3, C3, [1024, False]],  # 32 (P6/64-xlarge)
+
+   [[23, 26, 29, 32], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5, P6)
+  ]

+ 60 - 0
models/hub/yolov5x6.yaml

@@ -0,0 +1,60 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 1.33  # model depth multiple
+width_multiple: 1.25  # layer channel multiple
+anchors:
+  - [19,27,  44,40,  38,94]  # P3/8
+  - [96,68,  86,152,  180,137]  # P4/16
+  - [140,301,  303,264,  238,542]  # P5/32
+  - [436,615,  739,380,  925,792]  # P6/64
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 6, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [768, 3, 2]],  # 7-P5/32
+   [-1, 3, C3, [768]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 9-P6/64
+   [-1, 3, C3, [1024]],
+   [-1, 1, SPPF, [1024, 5]],  # 11
+  ]
+
+# YOLOv5 v6.0 head
+head:
+  [[-1, 1, Conv, [768, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 8], 1, Concat, [1]],  # cat backbone P5
+   [-1, 3, C3, [768, False]],  # 15
+
+   [-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 19
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [256, False]],  # 23 (P3/8-small)
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 20], 1, Concat, [1]],  # cat head P4
+   [-1, 3, C3, [512, False]],  # 26 (P4/16-medium)
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 16], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3, [768, False]],  # 29 (P5/32-large)
+
+   [-1, 1, Conv, [768, 3, 2]],
+   [[-1, 12], 1, Concat, [1]],  # cat head P6
+   [-1, 3, C3, [1024, False]],  # 32 (P6/64-xlarge)
+
+   [[23, 26, 29, 32], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5, P6)
+  ]

+ 464 - 0
models/tf.py

@@ -0,0 +1,464 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+"""
+TensorFlow, Keras and TFLite versions of YOLOv5
+Authored by https://github.com/zldrobit in PR https://github.com/ultralytics/yolov5/pull/1127
+
+Usage:
+    $ python models/tf.py --weights yolov5s.pt
+
+Export:
+    $ python path/to/export.py --weights yolov5s.pt --include saved_model pb tflite tfjs
+"""
+
+import argparse
+import sys
+from copy import deepcopy
+from pathlib import Path
+
+FILE = Path(__file__).resolve()
+ROOT = FILE.parents[1]  # YOLOv5 root directory
+if str(ROOT) not in sys.path:
+    sys.path.append(str(ROOT))  # add ROOT to PATH
+# ROOT = ROOT.relative_to(Path.cwd())  # relative
+
+import numpy as np
+import tensorflow as tf
+import torch
+import torch.nn as nn
+from tensorflow import keras
+
+from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, Concat, Conv, DWConv, Focus, autopad
+from models.experimental import CrossConv, MixConv2d, attempt_load
+from models.yolo import Detect
+from utils.activations import SiLU
+from utils.general import LOGGER, make_divisible, print_args
+
+
+class TFBN(keras.layers.Layer):
+    # TensorFlow BatchNormalization wrapper
+    def __init__(self, w=None):
+        super().__init__()
+        self.bn = keras.layers.BatchNormalization(
+            beta_initializer=keras.initializers.Constant(w.bias.numpy()),
+            gamma_initializer=keras.initializers.Constant(w.weight.numpy()),
+            moving_mean_initializer=keras.initializers.Constant(w.running_mean.numpy()),
+            moving_variance_initializer=keras.initializers.Constant(w.running_var.numpy()),
+            epsilon=w.eps)
+
+    def call(self, inputs):
+        return self.bn(inputs)
+
+
+class TFPad(keras.layers.Layer):
+    def __init__(self, pad):
+        super().__init__()
+        self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
+
+    def call(self, inputs):
+        return tf.pad(inputs, self.pad, mode='constant', constant_values=0)
+
+
+class TFConv(keras.layers.Layer):
+    # Standard convolution
+    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
+        # ch_in, ch_out, weights, kernel, stride, padding, groups
+        super().__init__()
+        assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
+        assert isinstance(k, int), "Convolution with multiple kernels are not allowed."
+        # TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
+        # see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch
+
+        conv = keras.layers.Conv2D(
+            c2, k, s, 'SAME' if s == 1 else 'VALID', use_bias=False if hasattr(w, 'bn') else True,
+            kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
+            bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
+        self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
+        self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
+
+        # YOLOv5 activations
+        if isinstance(w.act, nn.LeakyReLU):
+            self.act = (lambda x: keras.activations.relu(x, alpha=0.1)) if act else tf.identity
+        elif isinstance(w.act, nn.Hardswish):
+            self.act = (lambda x: x * tf.nn.relu6(x + 3) * 0.166666667) if act else tf.identity
+        elif isinstance(w.act, (nn.SiLU, SiLU)):
+            self.act = (lambda x: keras.activations.swish(x)) if act else tf.identity
+        else:
+            raise Exception(f'no matching TensorFlow activation found for {w.act}')
+
+    def call(self, inputs):
+        return self.act(self.bn(self.conv(inputs)))
+
+
+class TFFocus(keras.layers.Layer):
+    # Focus wh information into c-space
+    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
+        # ch_in, ch_out, kernel, stride, padding, groups
+        super().__init__()
+        self.conv = TFConv(c1 * 4, c2, k, s, p, g, act, w.conv)
+
+    def call(self, inputs):  # x(b,w,h,c) -> y(b,w/2,h/2,4c)
+        # inputs = inputs / 255  # normalize 0-255 to 0-1
+        return self.conv(tf.concat([inputs[:, ::2, ::2, :],
+                                    inputs[:, 1::2, ::2, :],
+                                    inputs[:, ::2, 1::2, :],
+                                    inputs[:, 1::2, 1::2, :]], 3))
+
+
+class TFBottleneck(keras.layers.Layer):
+    # Standard bottleneck
+    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None):  # ch_in, ch_out, shortcut, groups, expansion
+        super().__init__()
+        c_ = int(c2 * e)  # hidden channels
+        self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
+        self.cv2 = TFConv(c_, c2, 3, 1, g=g, w=w.cv2)
+        self.add = shortcut and c1 == c2
+
+    def call(self, inputs):
+        return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
+
+
+class TFConv2d(keras.layers.Layer):
+    # Substitution for PyTorch nn.Conv2D
+    def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
+        super().__init__()
+        assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
+        self.conv = keras.layers.Conv2D(
+            c2, k, s, 'VALID', use_bias=bias,
+            kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()),
+            bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None, )
+
+    def call(self, inputs):
+        return self.conv(inputs)
+
+
+class TFBottleneckCSP(keras.layers.Layer):
+    # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
+    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
+        # ch_in, ch_out, number, shortcut, groups, expansion
+        super().__init__()
+        c_ = int(c2 * e)  # hidden channels
+        self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
+        self.cv2 = TFConv2d(c1, c_, 1, 1, bias=False, w=w.cv2)
+        self.cv3 = TFConv2d(c_, c_, 1, 1, bias=False, w=w.cv3)
+        self.cv4 = TFConv(2 * c_, c2, 1, 1, w=w.cv4)
+        self.bn = TFBN(w.bn)
+        self.act = lambda x: keras.activations.relu(x, alpha=0.1)
+        self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
+
+    def call(self, inputs):
+        y1 = self.cv3(self.m(self.cv1(inputs)))
+        y2 = self.cv2(inputs)
+        return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))
+
+
+class TFC3(keras.layers.Layer):
+    # CSP Bottleneck with 3 convolutions
+    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
+        # ch_in, ch_out, number, shortcut, groups, expansion
+        super().__init__()
+        c_ = int(c2 * e)  # hidden channels
+        self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
+        self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
+        self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
+        self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
+
+    def call(self, inputs):
+        return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
+
+
+class TFSPP(keras.layers.Layer):
+    # Spatial pyramid pooling layer used in YOLOv3-SPP
+    def __init__(self, c1, c2, k=(5, 9, 13), w=None):
+        super().__init__()
+        c_ = c1 // 2  # hidden channels
+        self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
+        self.cv2 = TFConv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)
+        self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]
+
+    def call(self, inputs):
+        x = self.cv1(inputs)
+        return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
+
+
+class TFSPPF(keras.layers.Layer):
+    # Spatial pyramid pooling-Fast layer
+    def __init__(self, c1, c2, k=5, w=None):
+        super().__init__()
+        c_ = c1 // 2  # hidden channels
+        self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
+        self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
+        self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME')
+
+    def call(self, inputs):
+        x = self.cv1(inputs)
+        y1 = self.m(x)
+        y2 = self.m(y1)
+        return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))
+
+
+class TFDetect(keras.layers.Layer):
+    def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None):  # detection layer
+        super().__init__()
+        self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
+        self.nc = nc  # number of classes
+        self.no = nc + 5  # number of outputs per anchor
+        self.nl = len(anchors)  # number of detection layers
+        self.na = len(anchors[0]) // 2  # number of anchors
+        self.grid = [tf.zeros(1)] * self.nl  # init grid
+        self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
+        self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]),
+                                      [self.nl, 1, -1, 1, 2])
+        self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
+        self.training = False  # set to False after building model
+        self.imgsz = imgsz
+        for i in range(self.nl):
+            ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
+            self.grid[i] = self._make_grid(nx, ny)
+
+    def call(self, inputs):
+        z = []  # inference output
+        x = []
+        for i in range(self.nl):
+            x.append(self.m[i](inputs[i]))
+            # x(bs,20,20,255) to x(bs,3,20,20,85)
+            ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
+            x[i] = tf.transpose(tf.reshape(x[i], [-1, ny * nx, self.na, self.no]), [0, 2, 1, 3])
+
+            if not self.training:  # inference
+                y = tf.sigmoid(x[i])
+                xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xy
+                wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]
+                # Normalize xywh to 0-1 to reduce calibration error
+                xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
+                wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
+                y = tf.concat([xy, wh, y[..., 4:]], -1)
+                z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no]))
+
+        return x if self.training else (tf.concat(z, 1), x)
+
+    @staticmethod
+    def _make_grid(nx=20, ny=20):
+        # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
+        # return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
+        xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
+        return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)
+
+
+class TFUpsample(keras.layers.Layer):
+    def __init__(self, size, scale_factor, mode, w=None):  # warning: all arguments needed including 'w'
+        super().__init__()
+        assert scale_factor == 2, "scale_factor must be 2"
+        self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * 2, x.shape[2] * 2), method=mode)
+        # self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
+        # with default arguments: align_corners=False, half_pixel_centers=False
+        # self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,
+        #                                                            size=(x.shape[1] * 2, x.shape[2] * 2))
+
+    def call(self, inputs):
+        return self.upsample(inputs)
+
+
+class TFConcat(keras.layers.Layer):
+    def __init__(self, dimension=1, w=None):
+        super().__init__()
+        assert dimension == 1, "convert only NCHW to NHWC concat"
+        self.d = 3
+
+    def call(self, inputs):
+        return tf.concat(inputs, self.d)
+
+
+def parse_model(d, ch, model, imgsz):  # model_dict, input_channels(3)
+    LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10}  {'module':<40}{'arguments':<30}")
+    anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
+    na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchors
+    no = na * (nc + 5)  # number of outputs = anchors * (classes + 5)
+
+    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
+    for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args
+        m_str = m
+        m = eval(m) if isinstance(m, str) else m  # eval strings
+        for j, a in enumerate(args):
+            try:
+                args[j] = eval(a) if isinstance(a, str) else a  # eval strings
+            except NameError:
+                pass
+
+        n = max(round(n * gd), 1) if n > 1 else n  # depth gain
+        if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
+            c1, c2 = ch[f], args[0]
+            c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
+
+            args = [c1, c2, *args[1:]]
+            if m in [BottleneckCSP, C3]:
+                args.insert(2, n)
+                n = 1
+        elif m is nn.BatchNorm2d:
+            args = [ch[f]]
+        elif m is Concat:
+            c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)
+        elif m is Detect:
+            args.append([ch[x + 1] for x in f])
+            if isinstance(args[1], int):  # number of anchors
+                args[1] = [list(range(args[1] * 2))] * len(f)
+            args.append(imgsz)
+        else:
+            c2 = ch[f]
+
+        tf_m = eval('TF' + m_str.replace('nn.', ''))
+        m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \
+            else tf_m(*args, w=model.model[i])  # module
+
+        torch_m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
+        t = str(m)[8:-2].replace('__main__.', '')  # module type
+        np = sum(x.numel() for x in torch_m_.parameters())  # number params
+        m_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number params
+        LOGGER.info(f'{i:>3}{str(f):>18}{str(n):>3}{np:>10}  {t:<40}{str(args):<30}')  # print
+        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
+        layers.append(m_)
+        ch.append(c2)
+    return keras.Sequential(layers), sorted(save)
+
+
+class TFModel:
+    def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, model=None, imgsz=(640, 640)):  # model, channels, classes
+        super().__init__()
+        if isinstance(cfg, dict):
+            self.yaml = cfg  # model dict
+        else:  # is *.yaml
+            import yaml  # for torch hub
+            self.yaml_file = Path(cfg).name
+            with open(cfg) as f:
+                self.yaml = yaml.load(f, Loader=yaml.FullLoader)  # model dict
+
+        # Define model
+        if nc and nc != self.yaml['nc']:
+            LOGGER.info(f"Overriding {cfg} nc={self.yaml['nc']} with nc={nc}")
+            self.yaml['nc'] = nc  # override yaml value
+        self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model, imgsz=imgsz)
+
+    def predict(self, inputs, tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,
+                conf_thres=0.25):
+        y = []  # outputs
+        x = inputs
+        for i, m in enumerate(self.model.layers):
+            if m.f != -1:  # if not from previous layer
+                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
+
+            x = m(x)  # run
+            y.append(x if m.i in self.savelist else None)  # save output
+
+        # Add TensorFlow NMS
+        if tf_nms:
+            boxes = self._xywh2xyxy(x[0][..., :4])
+            probs = x[0][:, :, 4:5]
+            classes = x[0][:, :, 5:]
+            scores = probs * classes
+            if agnostic_nms:
+                nms = AgnosticNMS()((boxes, classes, scores), topk_all, iou_thres, conf_thres)
+                return nms, x[1]
+            else:
+                boxes = tf.expand_dims(boxes, 2)
+                nms = tf.image.combined_non_max_suppression(
+                    boxes, scores, topk_per_class, topk_all, iou_thres, conf_thres, clip_boxes=False)
+                return nms, x[1]
+
+        return x[0]  # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...]
+        # x = x[0][0]  # [x(1,6300,85), ...] to x(6300,85)
+        # xywh = x[..., :4]  # x(6300,4) boxes
+        # conf = x[..., 4:5]  # x(6300,1) confidences
+        # cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1))  # x(6300,1)  classes
+        # return tf.concat([conf, cls, xywh], 1)
+
+    @staticmethod
+    def _xywh2xyxy(xywh):
+        # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
+        x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
+        return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)
+
+
+class AgnosticNMS(keras.layers.Layer):
+    # TF Agnostic NMS
+    def call(self, input, topk_all, iou_thres, conf_thres):
+        # wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
+        return tf.map_fn(lambda x: self._nms(x, topk_all, iou_thres, conf_thres), input,
+                         fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
+                         name='agnostic_nms')
+
+    @staticmethod
+    def _nms(x, topk_all=100, iou_thres=0.45, conf_thres=0.25):  # agnostic NMS
+        boxes, classes, scores = x
+        class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
+        scores_inp = tf.reduce_max(scores, -1)
+        selected_inds = tf.image.non_max_suppression(
+            boxes, scores_inp, max_output_size=topk_all, iou_threshold=iou_thres, score_threshold=conf_thres)
+        selected_boxes = tf.gather(boxes, selected_inds)
+        padded_boxes = tf.pad(selected_boxes,
+                              paddings=[[0, topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
+                              mode="CONSTANT", constant_values=0.0)
+        selected_scores = tf.gather(scores_inp, selected_inds)
+        padded_scores = tf.pad(selected_scores,
+                               paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
+                               mode="CONSTANT", constant_values=-1.0)
+        selected_classes = tf.gather(class_inds, selected_inds)
+        padded_classes = tf.pad(selected_classes,
+                                paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
+                                mode="CONSTANT", constant_values=-1.0)
+        valid_detections = tf.shape(selected_inds)[0]
+        return padded_boxes, padded_scores, padded_classes, valid_detections
+
+
+def representative_dataset_gen(dataset, ncalib=100):
+    # Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays
+    for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
+        input = np.transpose(img, [1, 2, 0])
+        input = np.expand_dims(input, axis=0).astype(np.float32)
+        input /= 255
+        yield [input]
+        if n >= ncalib:
+            break
+
+
+def run(weights=ROOT / 'yolov5s.pt',  # weights path
+        imgsz=(640, 640),  # inference size h,w
+        batch_size=1,  # batch size
+        dynamic=False,  # dynamic batch size
+        ):
+    # PyTorch model
+    im = torch.zeros((batch_size, 3, *imgsz))  # BCHW image
+    model = attempt_load(weights, map_location=torch.device('cpu'), inplace=True, fuse=False)
+    _ = model(im)  # inference
+    model.info()
+
+    # TensorFlow model
+    im = tf.zeros((batch_size, *imgsz, 3))  # BHWC image
+    tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
+    _ = tf_model.predict(im)  # inference
+
+    # Keras model
+    im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
+    keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
+    keras_model.summary()
+
+    LOGGER.info('PyTorch, TensorFlow and Keras models successfully verified.\nUse export.py for TF model export.')
+
+
+def parse_opt():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='weights path')
+    parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
+    parser.add_argument('--batch-size', type=int, default=1, help='batch size')
+    parser.add_argument('--dynamic', action='store_true', help='dynamic batch size')
+    opt = parser.parse_args()
+    opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1  # expand
+    print_args(FILE.stem, opt)
+    return opt
+
+
+def main(opt):
+    run(**vars(opt))
+
+
+if __name__ == "__main__":
+    opt = parse_opt()
+    main(opt)

+ 329 - 0
models/yolo.py

@@ -0,0 +1,329 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+"""
+YOLO-specific modules
+
+Usage:
+    $ python path/to/models/yolo.py --cfg yolov5s.yaml
+"""
+
+import argparse
+import sys
+from copy import deepcopy
+from pathlib import Path
+
+FILE = Path(__file__).resolve()
+ROOT = FILE.parents[3]  # YOLOv5 root directory
+if str(ROOT) not in sys.path:
+    sys.path.append(str(ROOT))  # add ROOT to PATH
+# ROOT = ROOT.relative_to(Path.cwd())  # relative
+
+from models.common import *
+from models.experimental import *
+from utils.autoanchor import check_anchor_order
+from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
+from utils.plots import feature_visualization
+from utils.torch_utils import fuse_conv_and_bn, initialize_weights, model_info, scale_img, select_device, time_sync
+
+try:
+    import thop  # for FLOPs computation
+except ImportError:
+    thop = None
+
+
+class Detect(nn.Module):
+    stride = None  # strides computed during build
+    onnx_dynamic = False  # ONNX export parameter
+
+    def __init__(self, nc=80, anchors=(), ch=(), inplace=True):  # detection layer
+        super().__init__()
+        self.inplace = inplace  # 添加inplace属性
+        self.nc = nc  # number of classes
+        self.no = nc + 5  # number of outputs per anchor
+        self.nl = len(anchors)  # number of detection layers
+        self.na = len(anchors[0]) // 2  # number of anchors
+        self.grid = [torch.zeros(1)] * self.nl  # init grid
+        self.anchor_grid = [torch.zeros(1)] * self.nl  # init anchor grid
+        self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2))  # shape(nl,na,2)
+        self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # output conv
+
+    def forward(self, x):
+        z = []  # inference output
+        for i in range(self.nl):
+            x[i] = self.m[i](x[i])  # conv
+            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
+            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
+
+            if not self.training:  # inference
+                if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
+                    self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
+
+                y = x[i].sigmoid()
+                if self.inplace:
+                    y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xy
+                    y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
+                else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
+                    xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xy
+                    wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
+                    y = torch.cat((xy, wh, y[..., 4:]), -1)
+                z.append(y.view(bs, -1, self.no))
+
+        return x if self.training else (torch.cat(z, 1), x)
+
+    def _make_grid(self, nx=20, ny=20, i=0):
+        d = self.anchors[i].device
+        if check_version(torch.__version__, '1.10.0'):  # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
+            yv, xv = torch.meshgrid([torch.arange(ny, device=d), torch.arange(nx, device=d)], indexing='ij')
+        else:
+            yv, xv = torch.meshgrid([torch.arange(ny, device=d), torch.arange(nx, device=d)])
+        grid = torch.stack((xv, yv), 2).expand((1, self.na, ny, nx, 2)).float()
+        anchor_grid = (self.anchors[i].clone() * self.stride[i]) \
+            .view((1, self.na, 1, 1, 2)).expand((1, self.na, ny, nx, 2)).float()
+        return grid, anchor_grid
+
+
+class Model(nn.Module):
+    def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None):  # model, input channels, number of classes
+        super().__init__()
+        if isinstance(cfg, dict):
+            self.yaml = cfg  # model dict
+        else:  # is *.yaml
+            import yaml  # for torch hub
+            self.yaml_file = Path(cfg).name
+            with open(cfg, encoding='ascii', errors='ignore') as f:
+                self.yaml = yaml.safe_load(f)  # model dict
+
+        # Define model
+        ch = self.yaml['ch'] = self.yaml.get('ch', ch)  # input channels
+        if nc and nc != self.yaml['nc']:
+            LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
+            self.yaml['nc'] = nc  # override yaml value
+        if anchors:
+            LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
+            self.yaml['anchors'] = round(anchors)  # override yaml value
+        self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch])  # model, savelist
+        self.names = [str(i) for i in range(self.yaml['nc'])]  # default names
+        self.inplace = self.yaml.get('inplace', True)
+
+        # Build strides, anchors
+        m = self.model[-1]  # Detect()
+        if isinstance(m, Detect):
+            s = 256  # 2x min stride
+            m.inplace = self.inplace
+            m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))])  # forward
+            m.anchors /= m.stride.view(-1, 1, 1)
+            check_anchor_order(m)
+            self.stride = m.stride
+            self._initialize_biases()  # only run once
+
+        # Init weights, biases
+        initialize_weights(self)
+        self.info()
+        LOGGER.info('')
+
+    def forward(self, x, augment=False, profile=False, visualize=False):
+        if augment:
+            return self._forward_augment(x)  # augmented inference, None
+        return self._forward_once(x, profile, visualize)  # single-scale inference, train
+
+    def _forward_augment(self, x):
+        img_size = x.shape[-2:]  # height, width
+        s = [1, 0.83, 0.67]  # scales
+        f = [None, 3, None]  # flips (2-ud, 3-lr)
+        y = []  # outputs
+        for si, fi in zip(s, f):
+            xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
+            yi = self._forward_once(xi)[0]  # forward
+            # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1])  # save
+            yi = self._descale_pred(yi, fi, si, img_size)
+            y.append(yi)
+        y = self._clip_augmented(y)  # clip augmented tails
+        return torch.cat(y, 1), None  # augmented inference, train
+
+    def _forward_once(self, x, profile=False, visualize=False):
+        y, dt = [], []  # outputs
+        for m in self.model:
+            if m.f != -1:  # if not from previous layer
+                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
+            if profile:
+                self._profile_one_layer(m, x, dt)
+            x = m(x)  # run
+            y.append(x if m.i in self.save else None)  # save output
+            if visualize:
+                feature_visualization(x, m.type, m.i, save_dir=visualize)
+        return x
+
+    def _descale_pred(self, p, flips, scale, img_size):
+        # de-scale predictions following augmented inference (inverse operation)
+        if self.inplace:
+            p[..., :4] /= scale  # de-scale
+            if flips == 2:
+                p[..., 1] = img_size[0] - p[..., 1]  # de-flip ud
+            elif flips == 3:
+                p[..., 0] = img_size[1] - p[..., 0]  # de-flip lr
+        else:
+            x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale  # de-scale
+            if flips == 2:
+                y = img_size[0] - y  # de-flip ud
+            elif flips == 3:
+                x = img_size[1] - x  # de-flip lr
+            p = torch.cat((x, y, wh, p[..., 4:]), -1)
+        return p
+
+    def _clip_augmented(self, y):
+        # Clip YOLOv5 augmented inference tails
+        nl = self.model[-1].nl  # number of detection layers (P3-P5)
+        g = sum(4 ** x for x in range(nl))  # grid points
+        e = 1  # exclude layer count
+        i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e))  # indices
+        y[0] = y[0][:, :-i]  # large
+        i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e))  # indices
+        y[-1] = y[-1][:, i:]  # small
+        return y
+
+    def _profile_one_layer(self, m, x, dt):
+        c = isinstance(m, Detect)  # is final layer, copy input as inplace fix
+        o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0  # FLOPs
+        t = time_sync()
+        for _ in range(10):
+            m(x.copy() if c else x)
+        dt.append((time_sync() - t) * 100)
+        if m == self.model[0]:
+            LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s}  {'module'}")
+        LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f}  {m.type}')
+        if c:
+            LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s}  Total")
+
+    def _initialize_biases(self, cf=None):  # initialize biases into Detect(), cf is class frequency
+        # https://arxiv.org/abs/1708.02002 section 3.3
+        # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
+        m = self.model[-1]  # Detect() module
+        for mi, s in zip(m.m, m.stride):  # from
+            b = mi.bias.view(m.na, -1)  # conv.bias(255) to (3,85)
+            b.data[:, 4] += math.log(8 / (640 / s) ** 2)  # obj (8 objects per 640 image)
+            b.data[:, 5:] += math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum())  # cls
+            mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+
+    def _print_biases(self):
+        m = self.model[-1]  # Detect() module
+        for mi in m.m:  # from
+            b = mi.bias.detach().view(m.na, -1).T  # conv.bias(255) to (3,85)
+            LOGGER.info(
+                ('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
+
+    # def _print_weights(self):
+    #     for m in self.model.modules():
+    #         if type(m) is Bottleneck:
+    #             LOGGER.info('%10.3g' % (m.w.detach().sigmoid() * 2))  # shortcut weights
+
+    def fuse(self):  # fuse model Conv2d() + BatchNorm2d() layers
+        LOGGER.info('Fusing layers... ')
+        for m in self.model.modules():
+            if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
+                m.conv = fuse_conv_and_bn(m.conv, m.bn)  # update conv
+                delattr(m, 'bn')  # remove batchnorm
+                m.forward = m.forward_fuse  # update forward
+        self.info()
+        return self
+
+    def info(self, verbose=False, img_size=640):  # print model information
+        model_info(self, verbose, img_size)
+
+    def _apply(self, fn):
+        # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
+        self = super()._apply(fn)
+        m = self.model[-1]  # Detect()
+        if isinstance(m, Detect):
+            m.stride = fn(m.stride)
+            m.grid = list(map(fn, m.grid))
+            if isinstance(m.anchor_grid, list):
+                m.anchor_grid = list(map(fn, m.anchor_grid))
+        return self
+
+
+def parse_model(d, ch):  # model_dict, input_channels(3)
+    LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10}  {'module':<40}{'arguments':<30}")
+    anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
+    na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchors
+    no = na * (nc + 5)  # number of outputs = anchors * (classes + 5)
+
+    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
+    for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args
+        m = eval(m) if isinstance(m, str) else m  # eval strings
+        for j, a in enumerate(args):
+            try:
+                args[j] = eval(a) if isinstance(a, str) else a  # eval strings
+            except NameError:
+                pass
+
+        n = n_ = max(round(n * gd), 1) if n > 1 else n  # depth gain
+        if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
+                 BottleneckCSP, C3, C3TR, C3SPP, C3Ghost]:
+            c1, c2 = ch[f], args[0]
+            if c2 != no:  # if not output
+                c2 = make_divisible(c2 * gw, 8)
+
+            args = [c1, c2, *args[1:]]
+            if m in [BottleneckCSP, C3, C3TR, C3Ghost]:
+                args.insert(2, n)  # number of repeats
+                n = 1
+        elif m is nn.BatchNorm2d:
+            args = [ch[f]]
+        elif m is Concat:
+            c2 = sum(ch[x] for x in f)
+        elif m is Detect:
+            args.append([ch[x] for x in f])
+            if isinstance(args[1], int):  # number of anchors
+                args[1] = [list(range(args[1] * 2))] * len(f)
+        elif m is Contract:
+            c2 = ch[f] * args[0] ** 2
+        elif m is Expand:
+            c2 = ch[f] // args[0] ** 2
+        else:
+            c2 = ch[f]
+
+        m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
+        t = str(m)[8:-2].replace('__main__.', '')  # module type
+        np = sum(x.numel() for x in m_.parameters())  # number params
+        m_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number params
+        LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f}  {t:<40}{str(args):<30}')  # print
+        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
+        layers.append(m_)
+        if i == 0:
+            ch = []
+        ch.append(c2)
+    return nn.Sequential(*layers), sorted(save)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
+    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
+    parser.add_argument('--profile', action='store_true', help='profile model speed')
+    parser.add_argument('--test', action='store_true', help='test all yolo*.yaml')
+    opt = parser.parse_args()
+    opt.cfg = check_yaml(opt.cfg)  # check YAML
+    print_args(FILE.stem, opt)
+    device = select_device(opt.device)
+
+    # Create model
+    model = Model(opt.cfg).to(device)
+    model.train()
+
+    # Profile
+    if opt.profile:
+        img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
+        y = model(img, profile=True)
+
+    # Test all models
+    if opt.test:
+        for cfg in Path(ROOT / 'models').rglob('yolo*.yaml'):
+            try:
+                _ = Model(cfg)
+            except Exception as e:
+                print(f'Error in {cfg}: {e}')
+
+    # Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898)
+    # from torch.utils.tensorboard import SummaryWriter
+    # tb_writer = SummaryWriter('.')
+    # LOGGER.info("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/")
+    # tb_writer.add_graph(torch.jit.trace(model, img, strict=False), [])  # add model graph

+ 48 - 0
models/yolov5l.yaml

@@ -0,0 +1,48 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 1.0  # model depth multiple
+width_multiple: 1.0  # layer channel multiple
+anchors:
+  - [10,13, 16,30, 33,23]  # P3/8
+  - [30,61, 62,45, 59,119]  # P4/16
+  - [116,90, 156,198, 373,326]  # P5/32
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 6, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
+   [-1, 3, C3, [1024]],
+   [-1, 1, SPPF, [1024, 5]],  # 9
+  ]
+
+# YOLOv5 v6.0 head
+head:
+  [[-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 13
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 14], 1, Concat, [1]],  # cat head P4
+   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 10], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
+
+   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
+  ]

+ 48 - 0
models/yolov5m.yaml

@@ -0,0 +1,48 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 0.67  # model depth multiple
+width_multiple: 0.75  # layer channel multiple
+anchors:
+  - [10,13, 16,30, 33,23]  # P3/8
+  - [30,61, 62,45, 59,119]  # P4/16
+  - [116,90, 156,198, 373,326]  # P5/32
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 6, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
+   [-1, 3, C3, [1024]],
+   [-1, 1, SPPF, [1024, 5]],  # 9
+  ]
+
+# YOLOv5 v6.0 head
+head:
+  [[-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 13
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 14], 1, Concat, [1]],  # cat head P4
+   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 10], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
+
+   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
+  ]

+ 48 - 0
models/yolov5n.yaml

@@ -0,0 +1,48 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 0.33  # model depth multiple
+width_multiple: 0.25  # layer channel multiple
+anchors:
+  - [10,13, 16,30, 33,23]  # P3/8
+  - [30,61, 62,45, 59,119]  # P4/16
+  - [116,90, 156,198, 373,326]  # P5/32
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 6, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
+   [-1, 3, C3, [1024]],
+   [-1, 1, SPPF, [1024, 5]],  # 9
+  ]
+
+# YOLOv5 v6.0 head
+head:
+  [[-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 13
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 14], 1, Concat, [1]],  # cat head P4
+   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 10], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
+
+   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
+  ]

+ 48 - 0
models/yolov5s.yaml

@@ -0,0 +1,48 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 2  # number of classes
+depth_multiple: 0.33  # model depth multiple
+width_multiple: 0.50  # layer channel multiple
+anchors:
+  - [10,13, 16,30, 33,23]  # P3/8
+  - [30,61, 62,45, 59,119]  # P4/16
+  - [116,90, 156,198, 373,326]  # P5/32
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 9, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
+   [-1, 1, SPP, [1024, [5, 9, 13]]],
+   [-1, 3, C3, [1024, False]],  # 9
+  ]
+
+# YOLOv5 v6.0 head
+head:
+  [[-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 13
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 14], 1, Concat, [1]],  # cat head P4
+   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 10], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
+
+   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
+  ]

+ 48 - 0
models/yolov5x.yaml

@@ -0,0 +1,48 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 1.33  # model depth multiple
+width_multiple: 1.25  # layer channel multiple
+anchors:
+  - [10,13, 16,30, 33,23]  # P3/8
+  - [30,61, 62,45, 59,119]  # P4/16
+  - [116,90, 156,198, 373,326]  # P5/32
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 6, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
+   [-1, 3, C3, [1024]],
+   [-1, 1, SPPF, [1024, 5]],  # 9
+  ]
+
+# YOLOv5 v6.0 head
+head:
+  [[-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 13
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 14], 1, Concat, [1]],  # cat head P4
+   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 10], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
+
+   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
+  ]

+ 3 - 0
requirements.txt

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8d6ea6ff9149c5bbc6ba49163fb27b258dc9fe5ae6256d87b488e2e571fc49c4
+size 329

BIN
resources/videos/f1.mp4


BIN
ui/__pycache__/splash_screen.cpython-39.pyc


BIN
ui/assets/__pycache__/icons.cpython-39.pyc


+ 16 - 0
ui/assets/icons.py

@@ -0,0 +1,16 @@
+"""图标资源文件"""
+
+# 网格视图图标
+GRID_ICON = """
+iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAAB7SURBVEiJ7ZXBCcAgDEUfbsIo3aSbdBNHcRMXscEWKZgmtVQo9EFAPx+/RhPgBxTgAHbNm/yAPbLkO6ADDWiRJc/dfUbkyfV6WvLcnc7IkicnMCNLnvwCI7LkyTPQI0uePAE1suTJI1AiS548ACmy5MkdkCNLntyAFP7HBZEjGBl5wPb2AAAAAElFTkSuQmCC
+"""
+
+# 单视图图标
+SINGLE_ICON = """
+iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAABhSURBVEiJ7ZRBCsAgDATH0v9/OT2VQqHYaLLxUmYhkGQcE0MBEkABLLxjZh4vklSBE+gPJN2lA2jACZwrklpRMhA9Lg++7Ymkw0eGR48k7T4yPHokadOR4dEjSauPDP/jYgIvVBgZcxqU4QAAAABJRU5ErkJggg==
+"""
+
+# 刷新图标
+REFRESH_ICON = """
+iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAAGxSURBVEiJ3ZW/S1tRFMc/5+Y1EZGaQQQhQ0h4YIYsDv4onUJwKljI4OAg/gUdOjg5CEI3cexQsDgIhQ6C4FKhU0EUQQQJGPNDm/fejc17uTkOeQ9jY0zSvIwFv8u953LP53LOveceuCFJwBPgGfAQmAAGgR7gHDgG9oHvwFfgt+/7f+8KPg+sAW3+ZwPYBKrAX2AEWARmgFY/5hSoAO+Bn+FEQRDcB94CL4EbQAf4AnwCfgBnQB/wGHgOzPo5deA18MH3/WYQBPYm+CqgwE/gDdCM8zxvyvO8ac/zJqLj0Tm+77eBt8APoAmsxJIHg8FQkiS7SqkDpVRVKXWklNoJw3A5DMPelB4GYdgbhuFyGIZHSqmqUupAKbUbBMFQkjxN0zEReWeMObTWHhtjDowxO0qpmVartZAkyVCSJENa6wVjzI4x5sBae2it3U/TdEzSNB1T1tpdEWkAkyLSEJFdY0wVmBWROWPMnIjMiUhFRPZFpCEikyJSt9buKmvtgYhcAP0icoGIHAC9wJWIXACISEtELkXkSkSMMeZSRFoi0gYQkRZwfp3yN/0DwGjcl8m8hbIAAAAASUVORK5CYII=
+""" 

+ 1 - 0
ui/assets/loading.gif


+ 1 - 0
ui/assets/loading.png

@@ -0,0 +1 @@
+PNG... 

+ 1 - 0
ui/assets/location.png

@@ -0,0 +1 @@
+iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAOxAAADsQBlSsOGwAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAAGUSURBVDiNldO/S1tRFMfxz72+GBXULg4FoS4OLhYEESxYEMTi4NLBwUFwKv4BDg5OQqFDBzsWOnRw6B8gCA6CDiIO/qhDBykWOzQQgz+i0ZjkvQzFGH0x+YbveM7hfM/33nPuNf0GY1VgO4zhFZ7hEU7QwCF2sI5VnN0XNu4BX8AiPuI5nuA+6viNn/iBL/gdAW3cwQIWMRwBXdzCJeooRXkdP/AJn7GHVgS0cBvvMYeRXOEQu/gW5Q2cR/kMDXzFN+wPAHOYxnhU7OFHzKuRXsNpzE9wgq84wkUEtHEHbzGLx1F5gFWsYSPmTVxFeQvf8QV72AyTFl5jHpN4GJUHWMYyNmN+hcsoP8UmvmMVv1q4i7eYwVRUHmIJn7Ad8xauY36GrQBXsNfCPbzBa0xH5TEWsYTdmF/hT8wvsB3gdoDPAryHGUxE5V8BbuMy5g1cxPwSOwFuBPgWrzCJB1F5GsltHMf8HH9jfoHdANcDnMULjEflSSTXcRTzM/yL+Sl+BvgxwOe4F5X/ARVjvr4UwrDKAAAAAElFTkSuQmCC 

+ 34 - 0
ui/assets/map.html

@@ -0,0 +1,34 @@
+
+        <!DOCTYPE html>
+        <html>
+        <head>
+            <meta charset="utf-8" />
+            <title>森林监控地图</title>
+            <script type="text/javascript" src="https://api.map.baidu.com/api?v=3.0&ak=您的百度地图密钥"></script>
+            <style>
+                html, body, #map {
+                    height: 100%;
+                    margin: 0;
+                    padding: 0;
+                }
+            </style>
+        </head>
+        <body>
+            <div id="map"></div>
+            <script>
+                var map = new BMap.Map("map");
+                var point = new BMap.Point(116.397128, 39.916527);
+                map.centerAndZoom(point, 12);
+                map.enableScrollWheelZoom();
+                map.addControl(new BMap.NavigationControl());
+                map.addControl(new BMap.ScaleControl());
+                map.addControl(new BMap.OverviewMapControl());
+                map.addControl(new BMap.MapTypeControl());
+                
+                // 添加定位控件
+                var geolocationControl = new BMap.GeolocationControl();
+                map.addControl(geolocationControl);
+            </script>
+        </body>
+        </html>
+        

+ 773 - 0
ui/assets/style.qss

@@ -0,0 +1,773 @@
+/* 森林多模态灾害监测系统样式表 */
+
+/* 全局变量 */
+* {
+    font-family: "Microsoft YaHei", "SimHei", "Arial", sans-serif;
+    outline: none;
+}
+
+/* 主窗口 */
+QMainWindow {
+    background-color: #102030;
+    color: #d0d0d0;
+}
+
+/* 状态栏 */
+QStatusBar {
+    background-color: #0a1520;
+    color: #d0d0d0;
+    border-top: 1px solid #2a3a4a;
+    min-height: 22px;
+}
+
+QStatusBar::item {
+    border: none;
+}
+
+/* 菜单栏 */
+QMenuBar {
+    background-color: #0a1520;
+    color: #d0d0d0;
+    border-bottom: 1px solid #2a3a4a;
+}
+
+QMenuBar::item {
+    background: transparent;
+    padding: 4px 8px;
+}
+
+QMenuBar::item:selected {
+    background: #2a4a6a;
+    border-radius: 2px;
+}
+
+QMenu {
+    background-color: #152535;
+    color: #d0d0d0;
+    border: 1px solid #2a3a4a;
+}
+
+QMenu::item {
+    padding: 5px 20px 5px 20px;
+    border-radius: 2px;
+}
+
+QMenu::item:selected {
+    background-color: #2a4a6a;
+}
+
+QMenu::separator {
+    height: 1px;
+    background-color: #2a3a4a;
+    margin: 4px 8px;
+}
+
+/* 工具栏 */
+QToolBar {
+    background-color: #152535;
+    border-bottom: 1px solid #2a3a4a;
+    spacing: 2px;
+    padding: 2px;
+}
+
+QToolBar::separator {
+    width: 1px;
+    background-color: #2a3a4a;
+    margin: 0 4px;
+}
+
+QToolButton {
+    background-color: transparent;
+    border-radius: 2px;
+    padding: 3px;
+    margin: 1px;
+}
+
+QToolButton:hover {
+    background-color: #2a4a6a;
+}
+
+QToolButton:pressed {
+    background-color: #1a3a5a;
+}
+
+/* 标签页 */
+QTabWidget::pane {
+    border: 1px solid #2a3a4a;
+    background-color: #102030;
+}
+
+QTabWidget::tab-bar {
+    left: 5px;
+}
+
+QTabBar::tab {
+    background-color: #152535;
+    color: #b0b0b0;
+    border: 1px solid #2a3a4a;
+    border-bottom: none;
+    border-top-left-radius: 4px;
+    border-top-right-radius: 4px;
+    padding: 5px 10px;
+    min-width: 80px;
+}
+
+QTabBar::tab:selected {
+    background-color: #203040;
+    color: #00ccff;
+    border-bottom: none;
+}
+
+QTabBar::tab:!selected {
+    margin-top: 2px;
+}
+
+/* 滚动条 */
+QScrollBar:vertical {
+    border: none;
+    background: #152535;
+    width: 10px;
+    margin: 0px;
+}
+
+QScrollBar::handle:vertical {
+    background: #2a4a6a;
+    min-height: 20px;
+    border-radius: 5px;
+}
+
+QScrollBar::add-line:vertical, QScrollBar::sub-line:vertical {
+    border: none;
+    background: none;
+    height: 0px;
+}
+
+QScrollBar:horizontal {
+    border: none;
+    background: #152535;
+    height: 10px;
+    margin: 0px;
+}
+
+QScrollBar::handle:horizontal {
+    background: #2a4a6a;
+    min-width: 20px;
+    border-radius: 5px;
+}
+
+QScrollBar::add-line:horizontal, QScrollBar::sub-line:horizontal {
+    border: none;
+    background: none;
+    width: 0px;
+}
+
+/* 按钮 */
+QPushButton {
+    background-color: #2a4a6a;
+    color: #d0d0d0;
+    border: 1px solid #3a5a7a;
+    border-radius: 3px;
+    padding: 5px 10px;
+    min-width: 60px;
+}
+
+QPushButton:hover {
+    background-color: #3a5a7a;
+    border: 1px solid #4a6a8a;
+}
+
+QPushButton:pressed {
+    background-color: #1a3a5a;
+}
+
+QPushButton:disabled {
+    background-color: #1a2a3a;
+    color: #707070;
+    border: 1px solid #2a3a4a;
+}
+
+/* 输入框 */
+QLineEdit, QTextEdit, QPlainTextEdit {
+    background-color: #0a1520;
+    color: #d0d0d0;
+    border: 1px solid #2a3a4a;
+    border-radius: 3px;
+    padding: 3px;
+    selection-background-color: #2a4a6a;
+}
+
+QLineEdit:focus, QTextEdit:focus, QPlainTextEdit:focus {
+    border: 1px solid #4a6a8a;
+}
+
+/* 下拉框 */
+QComboBox {
+    background-color: #152535;
+    color: #d0d0d0;
+    border: 1px solid #2a3a4a;
+    border-radius: 3px;
+    padding: 3px 18px 3px 3px;
+    min-width: 6em;
+}
+
+QComboBox:on {
+    background-color: #1a3a5a;
+}
+
+QComboBox::drop-down {
+    subcontrol-origin: padding;
+    subcontrol-position: top right;
+    width: 15px;
+    border-left: 1px solid #2a3a4a;
+}
+
+QComboBox::down-arrow {
+    image: url(ui/assets/dropdown_arrow.png);
+    width: 10px;
+    height: 10px;
+}
+
+QComboBox QAbstractItemView {
+    background-color: #152535;
+    color: #d0d0d0;
+    border: 1px solid #2a3a4a;
+    selection-background-color: #2a4a6a;
+}
+
+/* 滑块 */
+QSlider::groove:horizontal {
+    border: 1px solid #2a3a4a;
+    height: 6px;
+    background: #0a1520;
+    margin: 2px 0;
+    border-radius: 3px;
+}
+
+QSlider::handle:horizontal {
+    background: #3a5a7a;
+    border: 1px solid #4a6a8a;
+    width: 14px;
+    height: 14px;
+    margin: -4px 0;
+    border-radius: 7px;
+}
+
+QSlider::sub-page:horizontal {
+    background: qlineargradient(x1: 0, y1: 0.5, x2: 1, y2: 0.5, stop: 0 #00a0e0, stop: 1 #00e0a0);
+    border: 1px solid #2a3a4a;
+    height: 6px;
+    border-radius: 3px;
+}
+
+/* 进度条 */
+QProgressBar {
+    border: 1px solid #2a3a4a;
+    border-radius: 3px;
+    background-color: #0a1520;
+    text-align: center;
+    color: #d0d0d0;
+}
+
+QProgressBar::chunk {
+    background-color: qlineargradient(x1: 0, y1: 0.5, x2: 1, y2: 0.5, stop: 0 #00a0e0, stop: 1 #00e0a0);
+    border-radius: 2px;
+}
+
+/* 复选框 */
+QCheckBox {
+    color: #d0d0d0;
+    spacing: 5px;
+}
+
+QCheckBox::indicator {
+    width: 15px;
+    height: 15px;
+}
+
+QCheckBox::indicator:unchecked {
+    background-color: #0a1520;
+    border: 1px solid #2a3a4a;
+    border-radius: 2px;
+}
+
+QCheckBox::indicator:checked {
+    background-color: #2a4a6a;
+    border: 1px solid #3a5a7a;
+    border-radius: 2px;
+    image: url(ui/assets/checkbox_checked.png);
+}
+
+/* 单选框 */
+QRadioButton {
+    color: #d0d0d0;
+    spacing: 5px;
+}
+
+QRadioButton::indicator {
+    width: 15px;
+    height: 15px;
+}
+
+QRadioButton::indicator:unchecked {
+    background-color: #0a1520;
+    border: 1px solid #2a3a4a;
+    border-radius: 7px;
+}
+
+QRadioButton::indicator:checked {
+    background-color: #2a4a6a;
+    border: 1px solid #3a5a7a;
+    border-radius: 7px;
+    image: url(ui/assets/radio_checked.png);
+}
+
+/* 分组框 */
+QGroupBox {
+    background-color: #102030;
+    color: #00ccff;
+    border: 1px solid #2a3a4a;
+    border-radius: 5px;
+    margin-top: 20px;
+    font-weight: bold;
+}
+
+QGroupBox::title {
+    subcontrol-origin: margin;
+    subcontrol-position: top left;
+    left: 10px;
+    padding: 0 5px;
+}
+
+/* 表格 */
+QTableView, QTableWidget {
+    background-color: #102030;
+    color: #d0d0d0;
+    gridline-color: #2a3a4a;
+    selection-background-color: #2a4a6a;
+    selection-color: #ffffff;
+    alternate-background-color: #152535;
+}
+
+QTableView QHeaderView::section, QTableWidget QHeaderView::section {
+    background-color: #1a2a3a;
+    color: #00ccff;
+    border: 1px solid #2a3a4a;
+    padding: 4px;
+}
+
+/* 列表 */
+QListView, QListWidget {
+    background-color: #102030;
+    color: #d0d0d0;
+    border: 1px solid #2a3a4a;
+    border-radius: 3px;
+    selection-background-color: #2a4a6a;
+}
+
+QListView::item, QListWidget::item {
+    padding: 5px;
+}
+
+QListView::item:selected, QListWidget::item:selected {
+    background-color: #2a4a6a;
+}
+
+/* 树视图 */
+QTreeView, QTreeWidget {
+    background-color: #102030;
+    color: #d0d0d0;
+    border: 1px solid #2a3a4a;
+    selection-background-color: #2a4a6a;
+}
+
+QTreeView::branch:has-siblings:!adjoins-item {
+    border-image: url(ui/assets/branch_line.png) 0;
+}
+
+QTreeView::branch:has-siblings:adjoins-item {
+    border-image: url(ui/assets/branch_more.png) 0;
+}
+
+QTreeView::branch:!has-children:!has-siblings:adjoins-item {
+    border-image: url(ui/assets/branch_end.png) 0;
+}
+
+QTreeView::branch:has-children:!has-siblings:closed,
+QTreeView::branch:closed:has-children:has-siblings {
+    border-image: none;
+    image: url(ui/assets/branch_closed.png);
+}
+
+QTreeView::branch:open:has-children:!has-siblings,
+QTreeView::branch:open:has-children:has-siblings {
+    border-image: none;
+    image: url(ui/assets/branch_open.png);
+}
+
+/* 日期选择器 */
+QDateEdit, QTimeEdit, QDateTimeEdit {
+    background-color: #152535;
+    color: #d0d0d0;
+    border: 1px solid #2a3a4a;
+    border-radius: 3px;
+    padding: 3px;
+}
+
+QDateEdit::drop-down, QTimeEdit::drop-down, QDateTimeEdit::drop-down {
+    subcontrol-origin: padding;
+    subcontrol-position: top right;
+    width: 15px;
+    border-left: 1px solid #2a3a4a;
+}
+
+/* 控制面板特殊样式 */
+#controlPanel {
+    background-color: #152535;
+    border-radius: 5px;
+    border: 1px solid #2a3a4a;
+}
+
+#controlPanel QLabel {
+    color: #00ccff;
+    font-weight: bold;
+}
+
+/* 告警面板特殊样式 */
+#alertPanel {
+    background-color: #152535;
+    border-radius: 5px;
+    border: 1px solid #2a3a4a;
+}
+
+#alertPanel QTableWidget {
+    selection-background-color: #2a4a6a;
+}
+
+/* 数据分析面板样式 */
+#analysisPanel {
+    background-color: #152535;
+    border-radius: 5px;
+    border: 1px solid #2a3a4a;
+}
+
+/* 摄像头视图样式 */
+#cameraView {
+    border: 2px solid #2a4a6a;
+    border-radius: 5px;
+}
+
+/* 地图视图样式 */
+#mapView {
+    border: 2px solid #2a4a6a;
+    border-radius: 5px;
+}
+
+/* 数据标签特殊样式 */
+.dataLabel {
+    color: #00e0a0;
+    font-size: 14px;
+    font-weight: bold;
+}
+
+/* 告警标签特殊样式 */
+.alertLabel {
+    color: #ff5050;
+    font-size: 14px;
+    font-weight: bold;
+}
+
+/* 为特定告警类型设置颜色 */
+.fireAlert {
+    color: #ff5050;
+}
+
+.animalAlert {
+    color: #ffaa00;
+}
+
+.landslideAlert {
+    color: #aa5500;
+}
+
+.forestAlert {
+    color: #00cc00;
+}
+
+/* 亮色主题变量 */
+.light {
+    /* 主窗口 */
+    QMainWindow {
+        background-color: #f0f0f0;
+        color: #303030;
+    }
+
+    /* 状态栏 */
+    QStatusBar {
+        background-color: #e0e0e0;
+        color: #303030;
+        border-top: 1px solid #c0c0c0;
+    }
+
+    /* 菜单栏 */
+    QMenuBar {
+        background-color: #e0e0e0;
+        color: #303030;
+        border-bottom: 1px solid #c0c0c0;
+    }
+
+    QMenuBar::item:selected {
+        background: #c0d0e0;
+    }
+
+    QMenu {
+        background-color: #f0f0f0;
+        color: #303030;
+        border: 1px solid #c0c0c0;
+    }
+
+    QMenu::item:selected {
+        background-color: #c0d0e0;
+    }
+
+    QMenu::separator {
+        background-color: #c0c0c0;
+    }
+
+    /* 工具栏 */
+    QToolBar {
+        background-color: #e0e0e0;
+        border-bottom: 1px solid #c0c0c0;
+    }
+
+    QToolBar::separator {
+        background-color: #c0c0c0;
+    }
+
+    QToolButton:hover {
+        background-color: #c0d0e0;
+    }
+
+    QToolButton:pressed {
+        background-color: #a0b0c0;
+    }
+
+    /* 标签页 */
+    QTabWidget::pane {
+        border: 1px solid #c0c0c0;
+        background-color: #f0f0f0;
+    }
+
+    QTabBar::tab {
+        background-color: #e0e0e0;
+        color: #505050;
+        border: 1px solid #c0c0c0;
+    }
+
+    QTabBar::tab:selected {
+        background-color: #f0f0f0;
+        color: #0080c0;
+    }
+
+    /* 滚动条 */
+    QScrollBar:vertical, QScrollBar:horizontal {
+        background: #e0e0e0;
+    }
+
+    QScrollBar::handle:vertical, QScrollBar::handle:horizontal {
+        background: #b0b0b0;
+    }
+
+    /* 按钮 */
+    QPushButton {
+        background-color: #d0d0d0;
+        color: #303030;
+        border: 1px solid #b0b0b0;
+    }
+
+    QPushButton:hover {
+        background-color: #c0d0e0;
+        border: 1px solid #a0b0c0;
+    }
+
+    QPushButton:pressed {
+        background-color: #a0b0c0;
+    }
+
+    QPushButton:disabled {
+        background-color: #e0e0e0;
+        color: #a0a0a0;
+        border: 1px solid #c0c0c0;
+    }
+
+    /* 输入框 */
+    QLineEdit, QTextEdit, QPlainTextEdit {
+        background-color: #ffffff;
+        color: #303030;
+        border: 1px solid #c0c0c0;
+        selection-background-color: #c0d0e0;
+    }
+
+    QLineEdit:focus, QTextEdit:focus, QPlainTextEdit:focus {
+        border: 1px solid #a0b0c0;
+    }
+
+    /* 下拉框 */
+    QComboBox {
+        background-color: #ffffff;
+        color: #303030;
+        border: 1px solid #c0c0c0;
+    }
+
+    QComboBox:on {
+        background-color: #e0e0e0;
+    }
+
+    QComboBox::drop-down {
+        border-left: 1px solid #c0c0c0;
+    }
+
+    QComboBox QAbstractItemView {
+        background-color: #ffffff;
+        color: #303030;
+        border: 1px solid #c0c0c0;
+        selection-background-color: #c0d0e0;
+    }
+
+    /* 滑块 */
+    QSlider::groove:horizontal {
+        border: 1px solid #c0c0c0;
+        background: #e0e0e0;
+    }
+
+    QSlider::handle:horizontal {
+        background: #b0b0b0;
+        border: 1px solid #909090;
+    }
+
+    QSlider::sub-page:horizontal {
+        background: qlineargradient(x1: 0, y1: 0, x2: 1, y2: 0, stop: 0 #0080c0, stop: 1 #00c080);
+        border: 1px solid #c0c0c0;
+    }
+
+    /* 进度条 */
+    QProgressBar {
+        border: 1px solid #c0c0c0;
+        background-color: #e0e0e0;
+        color: #303030;
+    }
+
+    QProgressBar::chunk {
+        background-color: qlineargradient(x1: 0, y1: 0.5, x2: 1, y2: 0.5, stop: 0 #0080c0, stop: 1 #00c080);
+    }
+
+    /* 复选框 */
+    QCheckBox {
+        color: #303030;
+    }
+
+    QCheckBox::indicator:unchecked {
+        background-color: #ffffff;
+        border: 1px solid #c0c0c0;
+    }
+
+    QCheckBox::indicator:checked {
+        background-color: #c0d0e0;
+        border: 1px solid #a0b0c0;
+    }
+
+    /* 单选框 */
+    QRadioButton {
+        color: #303030;
+    }
+
+    QRadioButton::indicator:unchecked {
+        background-color: #ffffff;
+        border: 1px solid #c0c0c0;
+    }
+
+    QRadioButton::indicator:checked {
+        background-color: #c0d0e0;
+        border: 1px solid #a0b0c0;
+    }
+
+    /* 分组框 */
+    QGroupBox {
+        background-color: #f0f0f0;
+        color: #0080c0;
+        border: 1px solid #c0c0c0;
+    }
+
+    /* 表格 */
+    QTableView, QTableWidget {
+        background-color: #ffffff;
+        color: #303030;
+        gridline-color: #c0c0c0;
+        selection-background-color: #c0d0e0;
+        selection-color: #000000;
+        alternate-background-color: #f5f5f5;
+    }
+
+    QTableView QHeaderView::section, QTableWidget QHeaderView::section {
+        background-color: #e0e0e0;
+        color: #0080c0;
+        border: 1px solid #c0c0c0;
+    }
+
+    /* 列表 */
+    QListView, QListWidget {
+        background-color: #ffffff;
+        color: #303030;
+        border: 1px solid #c0c0c0;
+        selection-background-color: #c0d0e0;
+    }
+
+    /* 树视图 */
+    QTreeView, QTreeWidget {
+        background-color: #ffffff;
+        color: #303030;
+        border: 1px solid #c0c0c0;
+        selection-background-color: #c0d0e0;
+    }
+
+    /* 控制面板特殊样式 */
+    #controlPanel {
+        background-color: #f5f5f5;
+        border: 1px solid #c0c0c0;
+    }
+
+    #controlPanel QLabel {
+        color: #0080c0;
+    }
+
+    /* 告警面板特殊样式 */
+    #alertPanel {
+        background-color: #f5f5f5;
+        border: 1px solid #c0c0c0;
+    }
+
+    /* 数据分析面板样式 */
+    #analysisPanel {
+        background-color: #f5f5f5;
+        border: 1px solid #c0c0c0;
+    }
+
+    /* 摄像头视图样式 */
+    #cameraView {
+        border: 2px solid #a0b0c0;
+    }
+
+    /* 地图视图样式 */
+    #mapView {
+        border: 2px solid #a0b0c0;
+    }
+
+    /* 数据标签特殊样式 */
+    .dataLabel {
+        color: #00a080;
+    }
+
+    /* 告警标签特殊样式 */
+    .alertLabel {
+        color: #ff0000;
+    }
+} 

BIN
ui/components/__pycache__/alert_panel.cpython-38.pyc


BIN
ui/components/__pycache__/alert_panel.cpython-39.pyc


BIN
ui/components/__pycache__/camera_manager.cpython-39.pyc


BIN
ui/components/__pycache__/camera_view.cpython-38.pyc


BIN
ui/components/__pycache__/camera_view.cpython-39.pyc


BIN
ui/components/__pycache__/camera_widget.cpython-39.pyc


BIN
ui/components/__pycache__/control_panel.cpython-38.pyc


BIN
ui/components/__pycache__/control_panel.cpython-39.pyc


BIN
ui/components/__pycache__/drone_manager.cpython-39.pyc


BIN
ui/components/__pycache__/fire_detection.cpython-39.pyc


BIN
ui/components/__pycache__/grid_camera_view.cpython-39.pyc


BIN
ui/components/__pycache__/map_view.cpython-38.pyc


BIN
ui/components/__pycache__/map_view.cpython-39.pyc


BIN
ui/components/__pycache__/statistics_panel.cpython-38.pyc


BIN
ui/components/__pycache__/statistics_panel.cpython-39.pyc


+ 651 - 0
ui/components/alert_panel.py

@@ -0,0 +1,651 @@
+import os
+import time
+from datetime import datetime
+from PyQt5.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton,
+                            QTableWidget, QTableWidgetItem, QHeaderView, QComboBox,
+                            QGroupBox, QToolBar, QAction, QMenu, QAbstractItemView, QMessageBox, QDialog, QProgressBar)
+from PyQt5.QtCore import Qt, pyqtSlot, QSize, QTimer, QDateTime, pyqtSignal
+from PyQt5.QtGui import QIcon, QColor, QBrush, QFont, QPixmap
+import random
+
+class AlertPanel(QWidget):
+    """告警面板组件,显示各类灾害告警信息"""
+    
+    # 添加信号用于通知统计面板
+    alert_added = pyqtSignal(str, str)  # 参数: alert_type, region
+    alert_processed = pyqtSignal(str, str)  # 参数: alert_type, region - 新增信号用于通知告警已处理
+    
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.alerts = []  # 保存所有告警
+        self.current_filter = "all"  # 当前过滤类型
+        self.current_severity = "all"  # 当前严重程度过滤
+        
+        # 从配置文件获取随机告警设置
+        self.random_alert_config = config.get('random_alert', {
+            'enabled': False,
+            'interval': 5,
+            'probability': 0.3,
+            'types': ['fire', 'animal', 'landslide', 'forest_degradation', 'pest'],
+            'locations': ['北部山区', '南部林区', '东部山脊', '西部谷地', '中央林场']
+        })
+        
+        self.init_ui()
+        
+        # 设置自动更新定时器
+        self.update_timer = QTimer(self)
+        self.update_timer.timeout.connect(self.update_alerts)
+        self.update_timer.start(self.random_alert_config['interval'] * 1000)  # 转换为毫秒
+        
+    def init_ui(self):
+        """初始化UI"""
+        # 创建主布局
+        layout = QVBoxLayout(self)
+        layout.setContentsMargins(0, 0, 0, 0)
+        
+        # 创建工具栏
+        toolbar = QHBoxLayout()
+        
+        # 添加标题
+        title_label = QLabel("告警信息")
+        title_label.setFont(QFont("Microsoft YaHei", 12, QFont.Bold))
+        title_label.setStyleSheet("color: #00e6e6; margin: 5px;")
+        toolbar.addWidget(title_label)
+        
+        # 添加分隔符
+        toolbar.addStretch(1)
+        
+        # 添加随机告警开关
+        self.random_alert_btn = QPushButton(f"随机告警: {'开' if self.random_alert_config['enabled'] else '关'}")
+        self.random_alert_btn.setCheckable(True)  # 使按钮可切换
+        self.random_alert_btn.setChecked(self.random_alert_config['enabled'])  # 设置初始状态
+        self.random_alert_btn.setFixedWidth(100)
+        self.random_alert_btn.clicked.connect(self.toggle_random_alerts)
+        self.random_alert_btn.setStyleSheet("""
+            QPushButton {
+                background-color: #666666;
+                color: white;
+                border: none;
+                border-radius: 4px;
+                padding: 5px;
+            }
+            QPushButton:checked {
+                background-color: #4CAF50;
+            }
+        """)
+        toolbar.addWidget(self.random_alert_btn)
+        
+        # 添加过滤下拉框
+        self.filter_combo = QComboBox()
+        self.filter_combo.addItem("全部告警", "all")
+        self.filter_combo.addItem("火灾告警", "fire")
+        self.filter_combo.addItem("动物告警", "animal")
+        self.filter_combo.addItem("滑坡告警", "landslide")
+        self.filter_combo.addItem("森林退化告警", "forest")
+        self.filter_combo.addItem("病虫害告警", "pest")
+        self.filter_combo.currentIndexChanged.connect(self.filter_alerts)
+        toolbar.addWidget(QLabel("过滤: "))
+        toolbar.addWidget(self.filter_combo)
+        
+        # 添加严重程度过滤下拉框
+        self.severity_combo = QComboBox()
+        self.severity_combo.addItem("所有等级", "all")
+        self.severity_combo.addItem("高", "high")
+        self.severity_combo.addItem("中", "medium")
+        self.severity_combo.addItem("低", "low")
+        self.severity_combo.currentIndexChanged.connect(self.filter_alerts)
+        toolbar.addWidget(QLabel("严重程度: "))
+        toolbar.addWidget(self.severity_combo)
+        
+        # 添加按钮
+        self.clear_btn = QPushButton("清空")
+        self.clear_btn.setIcon(QIcon(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'clear.png')))
+        self.clear_btn.clicked.connect(self.clear_alerts)
+        toolbar.addWidget(self.clear_btn)
+        
+        # 添加工具栏到布局
+        layout.addLayout(toolbar)
+        
+        # 创建表格
+        self.alert_table = QTableWidget()
+        self.alert_table.setColumnCount(6)
+        self.alert_table.setHorizontalHeaderLabels(["时间", "类型", "位置", "详情", "等级", "操作"])
+        self.alert_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
+        self.alert_table.setSelectionBehavior(QAbstractItemView.SelectRows)
+        self.alert_table.setAlternatingRowColors(True)
+        self.alert_table.setStyleSheet("alternate-background-color: #0c1e32; background-color: #081a2e; color: white; "
+                                      "QHeaderView::section { background-color: #15253a; color: white; padding: 4px; "
+                                      "border: 1px solid #1e3a5a; font-weight: bold; }"
+                                      "QTableView { gridline-color: #1e3a5a; border: 1px solid #1e3a5a; }"
+                                      "QTableView::item:selected { background-color: #2a4a6a; }")
+        
+        # 设置列宽
+        self.alert_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeToContents)
+        self.alert_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeToContents)
+        self.alert_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeToContents)
+        self.alert_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.Stretch)
+        self.alert_table.horizontalHeader().setSectionResizeMode(4, QHeaderView.ResizeToContents)
+        self.alert_table.horizontalHeader().setSectionResizeMode(5, QHeaderView.ResizeToContents)
+        
+        # 双击行处理
+        self.alert_table.cellDoubleClicked.connect(self.show_alert_detail)
+        
+        # 添加表格到布局
+        layout.addWidget(self.alert_table)
+        
+        # 设置最小高度
+        self.setMinimumHeight(150)
+        
+    def add_alert(self, alert):
+        """添加告警到列表和表格"""
+        # 添加到告警列表
+        self.alerts.append(alert)
+        
+        # 发出信号通知统计面板
+        print(f"发送告警信号: 类型={alert['type']}, 区域={alert['location']}")
+        self.alert_added.emit(alert['type'], alert['location'])
+        
+        # 检查是否符合当前过滤条件
+        type_match = self.current_filter == "all" or alert['type'] == self.current_filter
+        severity_match = self.current_severity == "all" or alert['level'] == self.current_severity
+        
+        if type_match and severity_match:
+            self.add_filtered_alert(alert)
+        
+    def get_alert_type_name(self, alert_type):
+        """获取告警类型的中文名称
+        
+        Args:
+            alert_type (str): 告警类型
+            
+        Returns:
+            str: 告警类型的中文名称
+        """
+        type_names = {
+            "fire": "森林火灾",
+            "animal": "野生动物异常",
+            "landslide": "山体滑坡",
+            "forest_degradation": "森林退化",
+            "pest": "病虫害"
+        }
+        return type_names.get(alert_type, "未知类型")
+        
+    def get_alert_type_color(self, alert_type):
+        """获取告警类型显示颜色"""
+        type_colors = {
+            'fire': QColor(255, 200, 200),  # 红色
+            'animal': QColor(200, 255, 200),  # 绿色
+            'landslide': QColor(200, 200, 255),  # 蓝色
+            'pest': QColor(230, 190, 255)  # 紫色
+        }
+        return type_colors.get(alert_type, QColor(240, 240, 240))
+        
+    def get_alert_level_name(self, level):
+        """获取告警等级的中文名称
+        
+        Args:
+            level (int): 告警等级
+            
+        Returns:
+            str: 告警等级的中文名称
+        """
+        level_names = {
+            1: "一级 (紧急)",
+            2: "二级 (高危)",
+            3: "三级 (中危)",
+            4: "四级 (低危)",
+            5: "五级 (提示)"
+        }
+        return level_names.get(level, "未知等级")
+        
+    def get_alert_level_color(self, level):
+        """获取告警等级显示颜色"""
+        level_colors = {
+            'high': QColor(255, 100, 100),  # 红色
+            'medium': QColor(255, 200, 100),  # 橙色
+            'low': QColor(200, 200, 200),  # 灰色
+            'processed': QColor(100, 220, 100)  # 绿色,表示已处理
+        }
+        return level_colors.get(level, QColor(240, 240, 240))
+        
+    @pyqtSlot(int)
+    def filter_alerts(self, index):
+        """过滤告警"""
+        # 获取当前过滤类型
+        self.current_filter = self.filter_combo.currentData()
+        self.current_severity = self.severity_combo.currentData()
+        
+        # 清空表格
+        self.alert_table.setRowCount(0)
+        
+        # 重新添加符合条件的告警
+        for alert in self.alerts:
+            type_match = self.current_filter == "all" or alert['type'] == self.current_filter
+            severity_match = self.current_severity == "all" or alert['level'] == self.current_severity
+            
+            if type_match and severity_match:
+                self.add_filtered_alert(alert)
+                
+    def add_filtered_alert(self, alert):
+        """添加过滤后的告警到表格(不添加到告警列表)"""
+        row = self.alert_table.rowCount()
+        self.alert_table.insertRow(row)
+        
+        # 设置单元格内容
+        self.alert_table.setItem(row, 0, QTableWidgetItem(alert['time']))
+        
+        # 根据类型设置显示名称和颜色
+        type_item = QTableWidgetItem(self.get_alert_type_name(alert['type']))
+        type_item.setData(Qt.UserRole, alert['type'])
+        type_item.setBackground(self.get_alert_type_color(alert['type']))
+        self.alert_table.setItem(row, 1, type_item)
+        
+        self.alert_table.setItem(row, 2, QTableWidgetItem(alert['location']))
+        self.alert_table.setItem(row, 3, QTableWidgetItem(alert['detail']))
+        
+        # 根据等级设置显示名称和颜色
+        level_item = QTableWidgetItem(self.get_alert_level_name(alert['level']))
+        level_item.setBackground(self.get_alert_level_color(alert['level']))
+        self.alert_table.setItem(row, 4, level_item)
+        
+        # 操作按钮
+        btn_cell = QWidget()
+        btn_layout = QHBoxLayout(btn_cell)
+        btn_layout.setContentsMargins(2, 2, 2, 2)
+        
+        details_btn = QPushButton("详情")
+        details_btn.setFixedWidth(60)
+        details_btn.clicked.connect(lambda _, a=alert: self.show_alert_details(a))
+        
+        handle_btn = QPushButton("处理")
+        handle_btn.setFixedWidth(60)
+        if alert['level'] == 'high':
+            handle_btn.setEnabled(False)
+        handle_btn.clicked.connect(lambda _, r=row: self.handle_alert(r))
+        
+        btn_layout.addWidget(details_btn)
+        btn_layout.addWidget(handle_btn)
+        btn_cell.setLayout(btn_layout)
+        
+        self.alert_table.setCellWidget(row, 5, btn_cell)
+        
+        # 自动滚动到最新行
+        self.alert_table.scrollToItem(self.alert_table.item(row, 0))
+        
+    @pyqtSlot()
+    def clear_alerts(self):
+        """清空告警"""
+        self.alerts = []
+        self.alert_table.setRowCount(0)
+        
+    @pyqtSlot(int, int)
+    def show_alert_detail(self, row, column):
+        """显示告警详情(双击行时触发)"""
+        # 实际项目中可以打开告警详情对话框
+        print(f"显示第 {row} 行告警详情")
+        
+    def show_alert_details(self, alert):
+        """显示告警详情"""
+        # 防止重复触发
+        current_time = datetime.now()
+        if hasattr(self, 'last_detail_time'):
+            if (current_time - self.last_detail_time).total_seconds() < 1:
+                return
+        self.last_detail_time = current_time
+        
+        msg = QMessageBox()
+        msg.setWindowTitle("预警详情")
+        
+        details = f"""
+        时间:{alert['time']}
+        类型:{self.get_alert_type_name(alert['type'])}
+        位置:{alert['location']}
+        严重程度:{alert['level']}
+        详细信息:{alert['detail']}
+        """
+        
+        msg.setText(details)
+        msg.setIcon(QMessageBox.Information)
+        msg.exec_()
+        
+    def handle_alert(self, row):
+        """处理告警"""
+        # 防止重复处理
+        current_time = datetime.now()
+        if hasattr(self, 'last_handle_time'):
+            if (current_time - self.last_handle_time).total_seconds() < 1:
+                return
+        self.last_handle_time = current_time
+        
+        # 获取告警数据
+        alert = self.alerts[row]
+        
+        # 向护林员发送通知
+        self.send_notification_to_ranger(alert)
+        
+        # 降低告警等级 
+        original_level = alert['level']
+        if alert['level'] == 'high':
+            alert['level'] = 'medium'
+        elif alert['level'] == 'medium':
+            alert['level'] = 'low'
+        elif alert['level'] == 'low':
+            alert['level'] = 'processed'  # 添加'processed'状态表示已完全处理
+        
+        # 更新UI显示
+        level_name = self.get_alert_level_name(alert['level'])
+        level_color = self.get_alert_level_color(alert['level'])
+        
+        level_item = QTableWidgetItem(level_name)
+        level_item.setBackground(level_color)
+        self.alert_table.setItem(row, 4, level_item)
+        
+        # 如果已完全处理,禁用处理按钮
+        if alert['level'] == 'processed':
+            cell_widget = self.alert_table.cellWidget(row, 5)
+            if cell_widget:
+                for child in cell_widget.children():
+                    if isinstance(child, QPushButton) and child.text() == "处理":
+                        child.setEnabled(False)
+                        child.setText("已处理")
+                        break
+        
+        # 发送告警已处理信号
+        print(f"发送告警处理信号: 类型={alert['type']}, 区域={alert['location']}")
+        self.alert_processed.emit(alert['type'], alert['location'])
+        
+        # 滚动到当前行
+        self.alert_table.scrollToItem(self.alert_table.item(row, 0))
+        
+    def send_notification_to_ranger(self, alert):
+        """向护林员发送通知
+        
+        Args:
+            alert (dict): 告警信息
+        """
+        from PyQt5.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, 
+                                   QFrame, QSpacerItem, QSizePolicy, QProgressBar)
+        from PyQt5.QtGui import QFont, QPixmap, QIcon
+        from PyQt5.QtCore import Qt, QTimer
+        import os
+        
+        # 根据告警类型确定应联系的组织
+        alert_type = alert['type']
+        organization = self.get_responsible_organization(alert_type)
+        
+        # 创建自定义对话框
+        dialog = QDialog(self)
+        dialog.setWindowTitle(f"通知{organization}")
+        dialog.setMinimumWidth(500)
+        dialog.setStyleSheet("""
+            QDialog {
+                background-color: #f5f5f5;
+                border: 1px solid #e0e0e0;
+                border-radius: 5px;
+            }
+            QLabel {
+                color: #333333;
+            }
+            QPushButton {
+                background-color: #2196f3;
+                color: white;
+                border: none;
+                padding: 8px 16px;
+                border-radius: 4px;
+            }
+            QPushButton:hover {
+                background-color: #0d8bf2;
+            }
+            QPushButton:pressed {
+                background-color: #0a75cf;
+            }
+            QFrame {
+                border: 1px solid #e0e0e0;
+                border-radius: 4px;
+                background-color: white;
+            }
+            QProgressBar {
+                border: 1px solid #e0e0e0;
+                border-radius: 4px;
+                background-color: #f0f0f0;
+                text-align: center;
+            }
+            QProgressBar::chunk {
+                background-color: #4caf50;
+                border-radius: 3px;
+            }
+        """)
+        
+        # 创建布局
+        layout = QVBoxLayout(dialog)
+        
+        # 标题区域
+        title_layout = QHBoxLayout()
+        
+        # 尝试根据告警类型添加图标
+        icon_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', f"{alert_type}.png")
+        icon_label = QLabel()
+        if os.path.exists(icon_path):
+            pixmap = QPixmap(icon_path).scaled(48, 48, Qt.KeepAspectRatio, Qt.SmoothTransformation)
+            icon_label.setPixmap(pixmap)
+        else:
+            # 使用默认警告图标
+            icon_label.setText("⚠️")
+            icon_label.setFont(QFont("Arial", 24))
+        
+        icon_label.setFixedSize(60, 60)
+        icon_label.setAlignment(Qt.AlignCenter)
+        title_layout.addWidget(icon_label)
+        
+        # 标题文本
+        title_text = QLabel("正在发送告警通知...")
+        title_text.setFont(QFont("Microsoft YaHei", 14, QFont.Bold))
+        title_layout.addWidget(title_text, 1)
+        
+        layout.addLayout(title_layout)
+        
+        # 进度条
+        progress_bar = QProgressBar()
+        progress_bar.setRange(0, 100)
+        progress_bar.setValue(0)
+        progress_bar.setFixedHeight(20)
+        layout.addWidget(progress_bar)
+        
+        # 状态标签
+        status_label = QLabel("正在连接通信系统...")
+        status_label.setAlignment(Qt.AlignCenter)
+        layout.addWidget(status_label)
+        
+        # 分隔线
+        separator = QFrame()
+        separator.setFrameShape(QFrame.HLine)
+        separator.setFrameShadow(QFrame.Sunken)
+        layout.addWidget(separator)
+        
+        # 信息面板
+        info_frame = QFrame()
+        info_layout = QVBoxLayout(info_frame)
+        
+        # 添加告警信息
+        info_layout.addWidget(self.create_info_row("告警类型", self.get_alert_type_name(alert_type)))
+        info_layout.addWidget(self.create_info_row("告警区域", alert['location']))
+        info_layout.addWidget(self.create_info_row("告警等级", self.get_alert_level_name(alert['level'])))
+        info_layout.addWidget(self.create_info_row("详细信息", alert['detail']))
+        info_layout.addWidget(self.create_info_row("时间", alert['time']))
+        info_layout.addWidget(self.create_info_row("接收组织", organization))
+        
+        layout.addWidget(info_frame)
+        
+        # 添加消息
+        message_label = QLabel("")
+        message_label.setWordWrap(True)
+        message_label.setAlignment(Qt.AlignCenter)
+        message_label.setFont(QFont("Microsoft YaHei", 10))
+        message_label.setStyleSheet("color: #666666; margin: 10px;")
+        layout.addWidget(message_label)
+        
+        # 按钮区域
+        button_layout = QHBoxLayout()
+        button_layout.addStretch()
+        
+        # 禁用关闭按钮,等待通知发送完成
+        close_button = QPushButton("关闭")
+        close_button.setFixedWidth(120)
+        close_button.setEnabled(False)  # 先禁用按钮
+        close_button.clicked.connect(dialog.accept)
+        button_layout.addWidget(close_button)
+        
+        layout.addLayout(button_layout)
+        
+        # 设置进度条更新的定时器
+        progress = 0
+        timer = QTimer(dialog)
+        
+        # 进度模拟阶段
+        stages = [
+            (10, "正在连接通信系统..."),
+            (30, f"正在向{organization}发送告警信息..."),
+            (60, f"等待{organization}确认接收..."),
+            (90, f"{organization}已确认接收告警信息"),
+            (100, f"通知流程完成,{organization}将立即处理")
+        ]
+        current_stage = 0
+        
+        def update_progress():
+            nonlocal progress, current_stage
+            
+            # 更新进度条
+            progress += 5
+            if progress > 100:
+                progress = 100
+                
+            progress_bar.setValue(progress)
+            
+            # 检查是否需要更新阶段
+            if current_stage < len(stages) and progress >= stages[current_stage][0]:
+                status_label.setText(stages[current_stage][1])
+                current_stage += 1
+                
+            # 通知完成时
+            if progress >= 100:
+                timer.stop()
+                title_text.setText(f"已向{organization}发送告警通知")
+                message_label.setText(f"相关{organization}已接收通知,并将前往现场处理。\n系统会持续跟踪处理进度,直至告警解除。")
+                close_button.setEnabled(True)  # 启用关闭按钮
+                
+                # 在控制台打印日志
+                print(f"已向{organization}发送告警通知: {alert_type} - {alert['location']}")
+                
+        # 启动定时器
+        timer.timeout.connect(update_progress)
+        timer.start(150)  # 每150毫秒更新一次
+        
+        # 显示对话框
+        dialog.exec_()
+        
+    def create_info_row(self, label_text, value_text):
+        """创建信息行
+        
+        Args:
+            label_text (str): 标签文本
+            value_text (str): 值文本
+            
+        Returns:
+            QWidget: 包含标签和值的行
+        """
+        from PyQt5.QtWidgets import QWidget, QHBoxLayout, QLabel
+        from PyQt5.QtGui import QFont
+        from PyQt5.QtCore import Qt
+        
+        row = QWidget()
+        layout = QHBoxLayout(row)
+        layout.setContentsMargins(0, 5, 0, 5)
+        
+        label = QLabel(label_text + ":")
+        label.setFixedWidth(80)
+        label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
+        label.setFont(QFont("Microsoft YaHei", 10, QFont.Bold))
+        
+        value = QLabel(value_text)
+        value.setWordWrap(True)
+        value.setFont(QFont("Microsoft YaHei", 10))
+        value.setStyleSheet("color: #333333;")
+        
+        layout.addWidget(label)
+        layout.addWidget(value, 1)
+        
+        return row
+        
+    def get_responsible_organization(self, alert_type):
+        """根据告警类型获取负责组织
+        
+        Args:
+            alert_type (str): 告警类型
+            
+        Returns:
+            str: 负责组织名称
+        """
+        organizations = {
+            "fire": "森林消防队",
+            "animal": "野生动物保护中心",
+            "landslide": "地质灾害应急中心",
+            "forest_degradation": "森林修复小组",
+            "pest": "病虫害防治站"
+        }
+        return organizations.get(alert_type, "森林管理部门")
+        
+    def toggle_random_alerts(self):
+        """切换随机告警功能"""
+        self.random_alert_config['enabled'] = not self.random_alert_config['enabled']
+        self.random_alert_btn.setText(f"随机告警: {'开' if self.random_alert_config['enabled'] else '关'}")
+        
+        # 立即生成一个告警作为反馈
+        if self.random_alert_config['enabled']:
+            self.generate_random_alert()
+            
+    def generate_random_alert(self):
+        """生成一个随机告警"""
+        # 根据当前过滤类型选择告警类型
+        if self.current_filter != "all":
+            alert_type = self.current_filter
+        else:
+            alert_type = random.choice(self.random_alert_config['types'])
+            
+        # 根据当前严重程度过滤选择告警等级
+        if self.current_severity != "all":
+            alert_level = self.current_severity
+        else:
+            alert_level = random.choice(['high', 'medium', 'low'])
+            
+        alert_location = random.choice(self.random_alert_config['locations'])
+        
+        print(f"告警面板生成新告警: 类型={alert_type}, 区域={alert_location}")
+        
+        # 为病虫害生成特殊的详情
+        if alert_type == 'pest':
+            pest_types = ['松毛虫', '美国白蛾', '落叶松毛虫', '杨树食叶害虫', '松材线虫病']
+            pest_type = random.choice(pest_types)
+            area = random.randint(10, 200)
+            detail = f'检测到{pest_type}病虫害,受灾面积约{area}平方米'
+        else:
+            detail = f'新检测到的告警 #{len(self.alerts) + 1}'
+        
+        alert = {
+            'time': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
+            'type': alert_type,
+            'location': alert_location,
+            'detail': detail,
+            'level': alert_level
+        }
+        
+        self.add_alert(alert)
+        
+    def update_alerts(self):
+        """更新告警(由定时器调用)"""
+        # 如果启用了随机告警,则有机会生成新告警
+        if self.random_alert_config['enabled'] and random.random() < self.random_alert_config['probability']:
+            self.generate_random_alert()
+        
+        # 更新UI显示
+        self.alert_table.viewport().update() 

+ 885 - 0
ui/components/camera_view.py

@@ -0,0 +1,885 @@
+import os
+import cv2
+import numpy as np
+from datetime import datetime
+from PyQt5.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QLabel, 
+                            QPushButton, QComboBox, QToolBar, QAction, 
+                            QGridLayout, QFrame, QSplitter, QFileDialog,
+                            QMessageBox)
+from PyQt5.QtCore import Qt, QTimer, pyqtSlot, pyqtSignal, QSize, QRect, QThread
+from PyQt5.QtGui import QImage, QPixmap, QIcon, QPainter, QPen, QColor, QFont
+import random
+import time
+import torch
+import torch.backends.cudnn as cudnn
+from pathlib import Path
+import sys
+
+# 添加项目根目录到系统路径
+FILE = Path(__file__).resolve()
+ROOT = FILE.parents[2]  # YOLOv5 root directory
+if str(ROOT) not in sys.path:
+    sys.path.append(str(ROOT))
+
+from models.common import DetectMultiBackend
+from models.yolo import Detect
+from utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
+from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr,
+                         increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
+from utils.plots import Annotator, colors, save_one_box
+from utils.torch_utils import select_device, time_sync
+
+class VideoThread(QThread):
+    """视频处理线程,用于读取摄像头或视频文件"""
+    update_frame = pyqtSignal(np.ndarray)
+    update_detections = pyqtSignal(list)
+    fire_detected = pyqtSignal(str)  # 修改为发送区域信息的信号
+    error_signal = pyqtSignal(str)
+    
+    def __init__(self, source=0):
+        super().__init__()
+        self.source = source
+        self.running = False
+        self.model = None
+        self.device = select_device('0' if torch.cuda.is_available() else 'cpu')
+        self.half = False
+        self.stride = 32
+        self.imgsz = [640, 640]
+        self.current_region = "中央林场"  # 添加默认区域
+        self.init_model()
+        
+    def init_model(self):
+        """初始化YOLOv5火焰检测模型"""
+        try:
+            # 获取项目根目录路径
+            weights = str(ROOT / 'weights/best.pt')
+            
+            if not os.path.exists(weights):
+                print(f"错误:模型文件 {weights} 不存在")
+                self.model = None
+                return
+                
+            # 加载模型
+            self.model = DetectMultiBackend(weights, device=self.device)
+            self.stride = self.model.stride
+            self.imgsz = check_img_size(self.imgsz, s=self.stride)  # 检查图片尺寸
+            
+            # 打印模型支持的类别
+            print("模型支持的类别:", self.model.names)
+            
+            # 设置半精度
+            self.half = self.device.type != 'cpu'  # 仅在GPU上使用半精度
+            if self.half:
+                self.model.half()  # 转换模型为半精度
+            else:
+                self.model.float()  # 使用单精度
+
+            # 确保Detect层有inplace属性
+            for m in self.model.model.modules():
+                if isinstance(m, Detect):
+                    if not hasattr(m, 'inplace'):
+                        m.inplace = True
+            
+            # 预热模型
+            if self.device.type != 'cpu':
+                dummy = torch.zeros(1, 3, *self.imgsz).to(self.device)
+                dummy = dummy.half() if self.half else dummy.float()
+                for _ in range(3):  # 预热3次
+                    with torch.no_grad():
+                        self.model(dummy)  # 预热推理
+                torch.cuda.empty_cache()  # 清理显存
+                
+            print(f"火焰检测模型加载成功:{weights}")
+            
+            # 设置CUDA性能优化
+            if self.device.type != 'cpu':
+                cudnn.benchmark = True  # 加速固定大小图像的推理
+                cudnn.deterministic = False  # 提高速度
+                
+        except Exception as e:
+            print(f"火焰检测模型加载失败: {str(e)}")
+            import traceback
+            traceback.print_exc()
+            self.model = None
+            self.error_signal.emit(f"模型加载失败: {str(e)}")
+        
+    def set_source(self, source):
+        """设置视频源"""
+        self.source = source
+        
+    def run(self):
+        """运行线程,读取视频并进行检测"""
+        if self.model is None:
+            self.error_signal.emit("错误:模型未加载")
+            return
+            
+        try:
+            self.running = True
+            
+            # 初始化参数
+            conf_thres = 0.25  # 置信度阈值
+            iou_thres = 0.45  # NMS IOU阈值
+            max_det = 1000  # 每张图片最大检测数量
+            line_thickness = 2  # 减小边框线条粗细以提高性能
+            hide_labels = False  # 是否隐藏标签
+            hide_conf = False  # 是否隐藏置信度
+            
+            # 设置数据源
+            source = str(self.source)
+            webcam = source.isnumeric()
+            
+            # 检查视频源
+            if webcam:
+                try:
+                    source = int(source)
+                    cap = cv2.VideoCapture(source)
+                    if not cap.isOpened():
+                        self.error_signal.emit(f"无法打开摄像头 {source}")
+                        return
+                    cap.release()
+                except Exception as e:
+                    self.error_signal.emit(f"摄像头初始化失败: {str(e)}")
+                    return
+            else:
+                if not os.path.exists(source):
+                    self.error_signal.emit(f"视频文件不存在: {source}")
+                    return
+                try:
+                    cap = cv2.VideoCapture(source)
+                    if not cap.isOpened():
+                        self.error_signal.emit(f"无法打开视频文件: {source}")
+                        return
+                    # 检查视频是否可读
+                    ret, frame = cap.read()
+                    if not ret:
+                        self.error_signal.emit(f"无法读取视频帧: {source}")
+                        return
+                    cap.release()
+                except Exception as e:
+                    self.error_signal.emit(f"视频文件打开失败: {str(e)}")
+                    return
+            
+            # 获取模型信息
+            stride = self.model.stride
+            names = self.model.names
+            pt = getattr(self.model, 'pt', True)
+            
+            # 检查图像尺寸
+            self.imgsz = check_img_size(self.imgsz, s=stride)
+            
+            try:
+                # 设置数据加载器
+                if webcam:
+                    cudnn.benchmark = True
+                    dataset = LoadStreams(str(source), img_size=self.imgsz[0], stride=stride, auto=pt)
+                else:
+                    dataset = LoadImages(source, img_size=self.imgsz[0], stride=stride, auto=pt)
+            except Exception as e:
+                self.error_signal.emit(f"数据加载失败: {str(e)}")
+                return
+                
+            # 处理每一帧
+            for path, im, im0s, vid_cap, s in dataset:
+                if not self.running:
+                    break
+                    
+                try:
+                    # 预处理图像
+                    im = torch.from_numpy(im).to(self.device)
+                    im = im.half() if self.half else im.float()
+                    im /= 255
+                    if len(im.shape) == 3:
+                        im = im[None]
+                        
+                    # 推理
+                    with torch.no_grad():
+                        pred = self.model(im, augment=False)  # 禁用数据增强以提高速度
+                        if isinstance(pred, (list, tuple)):
+                            pred = pred[0]
+                    
+                    # NMS
+                    pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=max_det)
+                    
+                    # 处理检测结果
+                    for i, det in enumerate(pred):
+                        if webcam:
+                            im0 = im0s[i].copy()
+                        else:
+                            im0 = im0s.copy()
+                            
+                        s += '%gx%g ' % im.shape[2:]
+                        annotator = Annotator(im0, line_width=line_thickness, example=str(names))
+                        
+                        if len(det):
+                            # 将边界框从img_size缩放到im0大小
+                            det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
+                            
+                            detections = []
+                            fire_detected = False
+                            
+                            for c in det[:, -1].unique():
+                                n = (det[:, -1] == c).sum()
+                                s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "
+                            
+                            for *xyxy, conf, cls in reversed(det):
+                                c = int(cls)
+                                if conf > conf_thres:  # 如果置信度大于阈值
+                                    label = None if hide_labels else (
+                                        names[c] if hide_conf else f'{names[c]} {conf:.2f}'
+                                    )
+                                    annotator.box_label(xyxy, label, color=colors(c, True))
+                                    
+                                    # 如果检测到火焰,发送信号
+                                    if names[c].lower() == 'fire' and conf > 0.5:  # 提高火焰检测的置信度阈值
+                                        print(f"检测到火焰!类别:{names[c]},置信度:{conf:.2f}")  # 添加调试输出
+                                        self.fire_detected.emit(self.current_region)  # 发送当前区域信息
+                            
+                            self.update_detections.emit(detections)
+                        
+                        im0 = annotator.result()
+                        self.update_frame.emit(im0)
+                        
+                except Exception as e:
+                    print(f"处理帧时出错: {str(e)}")
+                    import traceback
+                    traceback.print_exc()
+                    continue
+                
+                self.msleep(10)  # 减小延迟时间,提高帧率
+                
+        except Exception as e:
+            print(f"视频处理线程出错: {str(e)}")
+            import traceback
+            traceback.print_exc()
+            self.error_signal.emit(f"视频处理出错: {str(e)}")
+            
+        finally:
+            self.running = False
+        
+    def draw_box(self, img, xyxy, label):
+        """在图像上绘制边界框和标签"""
+        x1, y1, x2, y2 = [int(x) for x in xyxy]
+        color = (0, 0, 255)  # 红色
+        
+        # 绘制边界框
+        cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
+        
+        # 绘制标签背景
+        text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
+        cv2.rectangle(img, (x1, y1 - text_size[1] - 5), (x1 + text_size[0], y1), color, -1)
+        
+        # 绘制标签
+        cv2.putText(img, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
+        
+        return img
+        
+    def stop(self):
+        """停止线程"""
+        self.running = False
+        self.wait()
+        
+    def preprocess_frame(self, frame):
+        """预处理帧,用于模型输入"""
+        # 缩放到模型所需尺寸
+        resized = cv2.resize(frame, (640, 640))
+        # 转换为RGB
+        rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
+        # 归一化
+        normalized = rgb / 255.0
+        # 转换为tensor
+        img = torch.from_numpy(normalized).float()
+        # 调整维度顺序 (H, W, C) -> (C, H, W)
+        img = img.permute(2, 0, 1)
+        # 添加批次维度
+        img = img.unsqueeze(0)
+        return img
+        
+    def mock_detections(self, frame):
+        """模拟检测结果,用于开发阶段"""
+        height, width = frame.shape[:2]
+        
+        # 模拟一些检测结果
+        detections = [
+            {
+                'task': 'fire',
+                'class': 0,  # 火焰
+                'label': '火灾',
+                'confidence': 0.85,
+                'bbox': [width * 0.1, height * 0.2, width * 0.2, height * 0.3]  # [x1, y1, x2, y2]
+            },
+            {
+                'task': 'animal',
+                'class': 2,  # 某种动物
+                'label': '野生动物',
+                'confidence': 0.76,
+                'bbox': [width * 0.6, height * 0.5, width * 0.8, height * 0.7]
+            },
+            {
+                'task': 'pest',
+                'class': 3,  # 病虫害
+                'label': '病虫害-松毛虫',
+                'confidence': 0.82,
+                'bbox': [width * 0.3, height * 0.4, width * 0.5, height * 0.6],
+                'subtype': '松毛虫',
+                'severity': '中度'
+            }
+        ]
+        
+        # 每秒随机变化一下位置,模拟运动
+        random.seed(int(time.time()))
+        
+        for det in detections:
+            # 随机移动边界框
+            x1, y1, x2, y2 = det['bbox']
+            dx = random.uniform(-10, 10)
+            dy = random.uniform(-10, 10)
+            
+            # 确保边界框在图像内
+            x1 = max(0, min(width - 10, x1 + dx))
+            y1 = max(0, min(height - 10, y1 + dy))
+            x2 = max(x1 + 10, min(width, x2 + dx))
+            y2 = max(y1 + 10, min(height, y2 + dy))
+            
+            det['bbox'] = [x1, y1, x2, y2]
+        
+        return detections
+        
+    def draw_detections(self, frame, detections):
+        """在帧上绘制检测结果"""
+        for det in detections:
+            # 获取边界框和标签
+            x1, y1, x2, y2 = [int(c) for c in det['bbox']]
+            label = f"{det['label']} {det['confidence']:.2f}"
+            
+            # 根据任务选择颜色
+            if det['task'] == 'fire':
+                color = (0, 0, 255)  # 红色
+            elif det['task'] == 'animal':
+                color = (0, 255, 0)  # 绿色
+            elif det['task'] == 'landslide':
+                color = (255, 0, 0)  # 蓝色
+            elif det['task'] == 'pest':
+                color = (128, 0, 128)  # 紫色
+            else:
+                color = (255, 255, 0)  # 青色
+                
+            # 绘制边界框
+            cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
+            
+            # 绘制标签背景
+            text_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
+            cv2.rectangle(frame, (x1, y1 - text_size[1] - 5), (x1 + text_size[0], y1), color, -1)
+            
+            # 绘制标签
+            cv2.putText(frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
+            
+            # 如果是病虫害,添加额外信息
+            if det['task'] == 'pest' and 'subtype' in det:
+                severity_text = f"类型: {det['subtype']}"
+                if 'severity' in det:
+                    severity_text += f" | 严重程度: {det['severity']}"
+                cv2.putText(frame, severity_text, (x1, y1 - 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
+        
+        return frame
+
+class CameraView(QWidget):
+    """摄像头视图组件,用于显示实时视频流和检测结果"""
+    
+    fire_detected = pyqtSignal(str)  # 修改为发送区域信息的信号
+    
+    def __init__(self, config=None):
+        super().__init__()
+        self.config = config
+        self.detection_results = []
+        self.current_source = '0'  # 默认使用摄像头
+        self.output_size = 400
+        self.current_region = "中央林场"  # 默认区域
+        
+        # 修改设备选择逻辑
+        try:
+            if torch.cuda.is_available():
+                self.device = select_device('0')
+            else:
+                self.device = select_device('cpu')
+        except Exception as e:
+            print(f"GPU初始化失败,使用CPU: {str(e)}")
+            self.device = select_device('cpu')
+        
+        # 创建视频处理线程
+        self.video_thread = VideoThread(self.current_source)
+        self.video_thread.update_frame.connect(self.update_frame)
+        self.video_thread.update_detections.connect(self.update_detections)
+        self.video_thread.fire_detected.connect(self.on_fire_detected)
+        self.video_thread.error_signal.connect(self.on_error)
+        
+        # 创建结果保存目录
+        self.results_dir = str(ROOT / 'results')
+        os.makedirs(self.results_dir, exist_ok=True)
+        
+        self.init_ui()
+        
+    def init_ui(self):
+        """初始化UI"""
+        # 创建主布局
+        layout = QVBoxLayout(self)
+        layout.setContentsMargins(0, 0, 0, 0)
+        
+        # 创建顶部工具栏布局
+        toolbar_layout = QHBoxLayout()
+        
+        # 添加摄像头选择下拉框
+        self.camera_combo = QComboBox()
+        self.camera_combo.addItem("摄像头0", '0')
+        self.camera_combo.addItem("摄像头1", '1')
+        self.camera_combo.addItem("视频文件", '-1')
+        self.camera_combo.currentIndexChanged.connect(self.change_camera)
+        toolbar_layout.addWidget(QLabel("视频源:"))
+        toolbar_layout.addWidget(self.camera_combo)
+        
+        # 添加空白占位
+        toolbar_layout.addStretch(1)
+        
+        # 添加截图按钮
+        self.snapshot_btn = QPushButton("截图")
+        self.snapshot_btn.setIcon(QIcon(str(ROOT / 'ui/assets/snapshot.png')))
+        self.snapshot_btn.clicked.connect(self.take_snapshot)
+        toolbar_layout.addWidget(self.snapshot_btn)
+        
+        # 添加工具栏到布局
+        layout.addLayout(toolbar_layout)
+        
+        # 创建视频显示区域
+        self.video_frame = QLabel()
+        self.video_frame.setAlignment(Qt.AlignCenter)
+        self.video_frame.setMinimumSize(640, 480)
+        self.video_frame.setStyleSheet("background-color: black;")
+        
+        # 添加视频帧到布局
+        layout.addWidget(self.video_frame, 1)
+        
+        # 创建状态栏
+        status_bar = QHBoxLayout()
+        
+        # 添加状态信息
+        self.status_label = QLabel("状态: 未启动")
+        status_bar.addWidget(self.status_label)
+        
+        # 添加FPS信息
+        self.fps_label = QLabel("FPS: 0")
+        status_bar.addWidget(self.fps_label)
+        
+        # 添加检测结果信息
+        self.detection_label = QLabel("检测结果: 0")
+        status_bar.addWidget(self.detection_label)
+        
+        # 添加时间戳
+        self.timestamp_label = QLabel(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
+        status_bar.addWidget(self.timestamp_label, 1, Qt.AlignRight)
+        
+        # 添加状态栏到布局
+        layout.addLayout(status_bar)
+        
+        # 创建定时器更新时间戳
+        self.timer = QTimer(self)
+        self.timer.timeout.connect(self.update_timestamp)
+        self.timer.start(1000)  # 每秒更新一次
+        
+        # 计算FPS的变量
+        self.frame_count = 0
+        self.fps = 0
+        self.fps_timer = QTimer(self)
+        self.fps_timer.timeout.connect(self.calculate_fps)
+        self.fps_timer.start(1000)  # 每秒计算一次FPS
+        
+    def update_frame(self, frame):
+        """更新视频帧"""
+        # 计算帧数
+        self.frame_count += 1
+        
+        # 转换OpenCV的BGR格式为RGB
+        rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+        
+        # 转换为QImage
+        h, w, ch = rgb_frame.shape
+        bytes_per_line = ch * w
+        qt_image = QImage(rgb_frame.data, w, h, bytes_per_line, QImage.Format_RGB888)
+        
+        # 缩放到显示区域大小
+        pixmap = QPixmap.fromImage(qt_image)
+        pixmap = pixmap.scaled(self.video_frame.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation)
+        
+        # 设置图像
+        self.video_frame.setPixmap(pixmap)
+        
+    def update_detections(self, detections):
+        """更新检测结果"""
+        self.detection_results = detections
+        self.detection_label.setText(f"检测结果: {len(detections)}")
+        
+    def update_timestamp(self):
+        """更新时间戳"""
+        self.timestamp_label.setText(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
+        
+    def calculate_fps(self):
+        """计算FPS"""
+        self.fps = self.frame_count
+        self.fps_label.setText(f"FPS: {self.fps}")
+        self.frame_count = 0
+        
+    def update_view(self):
+        """更新视图(由外部定时器调用)"""
+        # 由于视频帧由线程自动更新,这里主要更新其他UI元素
+        if self.video_thread.isRunning():
+            self.status_label.setText("状态: 监控中")
+        else:
+            self.status_label.setText("状态: 已停止")
+            
+    def start_monitoring(self):
+        """开始监控"""
+        if not self.video_thread.isRunning():
+            self.video_thread.start()
+            self.status_label.setText("状态: 监控中")
+            
+    def stop_monitoring(self):
+        """停止监控"""
+        if self.video_thread.isRunning():
+            self.video_thread.stop()
+            self.status_label.setText("状态: 已停止")
+            
+    @pyqtSlot(int)
+    def change_camera(self, index):
+        """改变摄像头"""
+        # 获取选择的摄像头
+        source = self.camera_combo.currentData()
+        
+        # 如果选择了视频文件
+        if source == '-1':
+            fileName, _ = QFileDialog.getOpenFileName(
+                self,
+                "选择视频文件",
+                "",
+                "视频文件 (*.mp4 *.avi *.mkv *.mov);;所有文件 (*.*)"
+            )
+            if not fileName:
+                # 如果用户取消选择,恢复到之前的选项
+                self.camera_combo.setCurrentIndex(0)
+                return
+            source = fileName
+        
+        # 停止当前视频
+        self.stop_monitoring()
+        
+        # 设置新的视频源
+        self.current_source = source
+        self.video_thread.set_source(source)
+        
+        # 重新开始监控
+        self.start_monitoring()
+        
+    @pyqtSlot()
+    def take_snapshot(self):
+        """截取当前帧"""
+        # 获取当前显示的图像
+        pixmap = self.video_frame.pixmap()
+        
+        if pixmap and not pixmap.isNull():
+            # 创建保存目录
+            snapshots_dir = os.path.join(self.results_dir, 'snapshots')
+            os.makedirs(snapshots_dir, exist_ok=True)
+            
+            # 保存图像
+            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+            file_path = os.path.join(snapshots_dir, f"snapshot_{timestamp}.jpg")
+            
+            if pixmap.save(file_path, "JPG"):
+                self.status_label.setText(f"状态: 已保存截图 {file_path}")
+            else:
+                self.status_label.setText("状态: 截图保存失败")
+                
+    def resizeEvent(self, event):
+        """窗口大小改变事件"""
+        # 如果有图像,重新缩放
+        if self.video_frame.pixmap() and not self.video_frame.pixmap().isNull():
+            pixmap = self.video_frame.pixmap().scaled(
+                self.video_frame.size(), 
+                Qt.KeepAspectRatio, 
+                Qt.SmoothTransformation
+            )
+            self.video_frame.setPixmap(pixmap)
+            
+    def on_fire_detected(self, region):
+        """处理火焰检测信号"""
+        # 转发信号到主窗口,包含区域信息
+        self.fire_detected.emit(region)
+
+    def on_error(self, error_msg):
+        """处理错误信息"""
+        self.status_label.setText(f"状态: {error_msg}")
+        QMessageBox.warning(self, "错误", error_msg)
+        self.reset_camera()
+        
+    def reset_camera(self):
+        """重置摄像头状态"""
+        self.stop_monitoring()
+        self.camera_combo.setCurrentIndex(0)  # 切换回默认摄像头
+        self.video_frame.clear()  # 清空视频显示
+        self.video_frame.setStyleSheet("background-color: black;")
+
+    def detect_vid(self):
+        """视频检测函数"""
+        try:
+            # 初始化模型参数
+            model = self.model
+            output_size = self.output_size
+            imgsz = [640, 640]  # 推理时的输入图像尺寸(像素)
+            conf_thres = 0.25  # 置信度阈值
+            iou_thres = 0.45  # NMS(非极大值抑制)IOU阈值
+            max_det = 1000  # 每张图片最大检测数
+            
+            # 设备选择逻辑
+            try:
+                if torch.cuda.is_available():
+                    device = select_device('0')
+                    half = True  # 在GPU上使用半精度
+                else:
+                    device = select_device('cpu')
+                    half = False  # CPU不使用半精度
+            except Exception as e:
+                print(f"GPU初始化失败,使用CPU: {str(e)}")
+                device = select_device('cpu')
+                half = False
+            
+            view_img = False  # 是否显示检测结果
+            save_txt = False  # 是否保存结果到*.txt文件
+            save_conf = False  # 是否保存置信度到标签文件
+            save_crop = False  # 是否保存裁剪后的预测框图像
+            nosave = False  # 是否禁用保存图像/视频
+            classes = None  # 按类别过滤:--class 0,或者--class 0 2 3
+            agnostic_nms = False  # 是否使用类别无关的NMS
+            augment = False  # 是否进行增强推理
+            visualize = False  # 是否可视化特征
+            line_thickness = 3  # 边框的厚度(像素)
+            hide_labels = False  # 是否隐藏标签
+            hide_conf = False  # 是否隐藏置信度
+            dnn = False  # 是否使用OpenCV DNN进行ONNX推理
+
+            source = str(self.vid_source)  # 设置视频源(文件/目录)
+            webcam = self.webcam  # 是否使用摄像头
+            device = select_device(self.device)  # 选择推理设备
+            
+            # 获取模型信息
+            stride = model.stride
+            names = model.names
+            pt = getattr(model, 'pt', True)
+            
+            # 检查图像尺寸
+            imgsz = check_img_size(imgsz, s=stride)  # 检查图像尺寸是否合适
+            
+            # 检查视频源
+            if webcam:
+                try:
+                    source = int(source)
+                    cap = cv2.VideoCapture(source)
+                    if not cap.isOpened():
+                        self.error_signal.emit(f"无法打开摄像头 {source}")
+                        return
+                    cap.release()
+                except Exception as e:
+                    self.error_signal.emit(f"摄像头初始化失败: {str(e)}")
+                    return
+            else:
+                if not os.path.exists(source):
+                    self.error_signal.emit(f"视频文件不存在: {source}")
+                    return
+                try:
+                    cap = cv2.VideoCapture(source)
+                    if not cap.isOpened():
+                        self.error_signal.emit(f"无法打开视频文件: {source}")
+                        return
+                    ret, frame = cap.read()
+                    if not ret:
+                        self.error_signal.emit(f"无法读取视频帧: {source}")
+                        return
+                    cap.release()
+                except Exception as e:
+                    self.error_signal.emit(f"视频文件打开失败: {str(e)}")
+                    return
+
+            try:
+                # 设置数据加载器
+                if webcam:
+                    cudnn.benchmark = True  # 设置为True以加速恒定图像大小的推理
+                    dataset = LoadStreams(str(source), img_size=imgsz[0], stride=stride, auto=pt)
+                else:
+                    dataset = LoadImages(source, img_size=imgsz[0], stride=stride, auto=pt)
+            except Exception as e:
+                self.error_signal.emit(f"数据加载失败: {str(e)}")
+                return
+
+            # 预热模型
+            if pt and device.type != 'cpu':
+                model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters())))
+
+            # 处理每一帧
+            for path, im, im0s, vid_cap, s in dataset:
+                try:
+                    # 预处理
+                    im = torch.from_numpy(im).to(device)
+                    im = im.half() if half else im.float()  # uint8 转为 fp16/32
+                    im /= 255  # 将像素值从0-255归一化到0.0-1.0
+                    if len(im.shape) == 3:
+                        im = im[None]  # 扩展维度以符合batch size要求
+
+                    # 推理过程
+                    with torch.no_grad():
+                        pred = model(im, augment=augment, visualize=visualize)
+                        if isinstance(pred, (list, tuple)):
+                            pred = pred[0]
+
+                    # NMS非极大值抑制
+                    pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
+
+                    # 处理检测结果
+                    for i, det in enumerate(pred):
+                        if webcam:  # 如果是摄像头流
+                            im0 = im0s[i].copy()
+                        else:
+                            im0 = im0s.copy()
+
+                        # 初始化标注器
+                        annotator = Annotator(im0, line_width=line_thickness, example=str(names))
+
+                        if len(det):
+                            # 将边框坐标从img_size映射到原始图像尺寸
+                            det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
+
+                            # 处理每个检测结果
+                            for *xyxy, conf, cls in reversed(det):
+                                c = int(cls)
+                                if conf > conf_thres:  # 如果置信度大于阈值
+                                    label = None if hide_labels else (
+                                        names[c] if hide_conf else f'{names[c]} {conf:.2f}'
+                                    )
+                                    annotator.box_label(xyxy, label, color=colors(c, True))
+                                    
+                                    # 如果检测到火焰,发送信号
+                                    if names[c].lower() == 'fire' and conf > 0.5:  # 提高火焰检测的置信度阈值
+                                        print(f"检测到火焰!类别:{names[c]},置信度:{conf:.2f}")  # 添加调试输出
+                                        self.fire_detected.emit(self.current_region)  # 发送当前区域信息
+
+                        # 获取标注后的图像
+                        im0 = annotator.result()
+                        
+                        # 调整图像大小并显示
+                        resize_scale = output_size / im0.shape[0]
+                        im0 = cv2.resize(im0, (0, 0), fx=resize_scale, fy=resize_scale)
+                        
+                        # 转换为Qt图像并显示
+                        rgb_image = cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)
+                        h, w, ch = rgb_image.shape
+                        qt_image = QImage(rgb_image.data, w, h, w * ch, QImage.Format_RGB888)
+                        self.vid_img.setPixmap(QPixmap.fromImage(qt_image))
+
+                except Exception as e:
+                    print(f"处理帧时出错: {str(e)}")
+                    import traceback
+                    traceback.print_exc()
+                    continue
+
+                # 检查是否需要停止
+                if self.stopEvent.is_set():
+                    break
+
+                self.msleep(10)  # 控制帧率
+
+        except Exception as e:
+            print(f"视频处理出错: {str(e)}")
+            import traceback
+            traceback.print_exc()
+            self.error_signal.emit(f"视频处理出错: {str(e)}")
+        finally:
+            self.reset_vid()
+
+class MockDetector:
+    """模拟检测器类,用于生成模拟检测结果"""
+    
+    def __init__(self):
+        """初始化检测器"""
+        self.detection_mode = "all"  # 默认检测所有类型
+        
+    def set_mode(self, mode):
+        """设置检测模式"""
+        self.detection_mode = mode
+        
+    def detect(self, frame):
+        """在图像上执行检测"""
+        # 实际项目中,这里应该调用真实的检测模型
+        # 这里仅做模拟,随机生成检测结果
+        
+        height, width = frame.shape[:2]
+        detections = []
+        
+        # 随机决定是否生成检测结果
+        if random.random() < 0.2:  # 20%概率生成检测
+            # 可检测的类型
+            types = {
+                'all': ['fire', 'animal', 'landslide', 'forest', 'pest'],
+                'fire': ['fire'],
+                'animal': ['animal'],
+                'landslide': ['landslide'],
+                'forest': ['forest'],
+                'pest': ['pest']
+            }
+            
+            # 根据检测模式选择可能的检测类型
+            possible_types = types.get(self.detection_mode, ['fire'])
+            
+            # 随机选择一种类型
+            detection_type = random.choice(possible_types)
+            
+            # 类型对应的标签和颜色
+            labels = {
+                'fire': '火灾',
+                'animal': '野生动物',
+                'landslide': '滑坡',
+                'forest': '森林退化',
+                'pest': '病虫害'
+            }
+            
+            colors = {
+                'fire': (0, 0, 255),      # 红色
+                'animal': (0, 255, 0),    # 绿色
+                'landslide': (255, 0, 0), # 蓝色
+                'forest': (255, 255, 0),  # 青色
+                'pest': (128, 0, 128)     # 紫色
+            }
+            
+            # 对于病虫害类型,生成更详细的标签
+            if detection_type == 'pest':
+                pest_types = ['松毛虫', '美国白蛾', '落叶松毛虫', '杨树食叶害虫', '松材线虫病']
+                pest_subtype = random.choice(pest_types)
+                label = f"{labels[detection_type]}-{pest_subtype}"
+            else:
+                label = labels[detection_type]
+            
+            # 随机生成边界框
+            x = random.randint(10, width - 100)
+            y = random.randint(10, height - 100)
+            w = random.randint(50, 150)
+            h = random.randint(50, 150)
+            
+            # 随机生成置信度
+            confidence = random.uniform(0.65, 0.95)
+            
+            # 创建检测结果
+            detection = {
+                'type': detection_type,
+                'bbox': [x, y, x+w, y+h],
+                'confidence': confidence,
+                'label': label,
+                'color': colors[detection_type]
+            }
+            
+            detections.append(detection)
+            
+        return detections 

+ 402 - 0
ui/components/control_panel.py

@@ -0,0 +1,402 @@
+import os
+from PyQt5.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QPushButton, 
+                            QLabel, QComboBox, QGroupBox, QSlider, QCheckBox,
+                            QSpinBox, QDoubleSpinBox, QFormLayout, QTabWidget)
+from PyQt5.QtCore import Qt, pyqtSlot, QSize, QSettings
+from PyQt5.QtGui import QIcon, QFont
+
+class ControlPanel(QWidget):
+    """控制面板组件,用于系统参数调整和控制"""
+    
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.init_ui()
+        
+    def init_ui(self):
+        """初始化UI"""
+        # 创建主布局
+        layout = QVBoxLayout(self)
+        layout.setContentsMargins(5, 5, 5, 5)
+        
+        # 创建标签页
+        tab_widget = QTabWidget()
+        
+        # 添加检测控制页
+        detection_tab = self.create_detection_tab()
+        tab_widget.addTab(detection_tab, "检测控制")
+        
+        # 添加摄像头控制页
+        camera_tab = self.create_camera_tab()
+        tab_widget.addTab(camera_tab, "摄像头控制")
+        
+        # 添加告警控制页
+        alert_tab = self.create_alert_tab()
+        tab_widget.addTab(alert_tab, "告警控制")
+        
+        # 添加标签页到布局
+        layout.addWidget(tab_widget)
+        
+        # 底部按钮区域
+        button_layout = QHBoxLayout()
+        
+        # 启动按钮
+        self.start_btn = QPushButton("启动监测")
+        self.start_btn.setIcon(QIcon(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'start.png')))
+        self.start_btn.clicked.connect(self.start_monitoring)
+        button_layout.addWidget(self.start_btn)
+        
+        # 停止按钮
+        self.stop_btn = QPushButton("停止监测")
+        self.stop_btn.setIcon(QIcon(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'stop.png')))
+        self.stop_btn.clicked.connect(self.stop_monitoring)
+        self.stop_btn.setEnabled(False)
+        button_layout.addWidget(self.stop_btn)
+        
+        # 添加按钮区域到布局
+        layout.addLayout(button_layout)
+        
+    def create_detection_tab(self):
+        """创建检测控制标签页"""
+        widget = QWidget()
+        layout = QVBoxLayout(widget)
+        
+        # 检测任务分组
+        task_group = QGroupBox("检测任务")
+        task_layout = QVBoxLayout(task_group)
+        
+        # 添加复选框
+        self.fire_check = QCheckBox("火灾检测")
+        self.fire_check.setChecked(True)
+        task_layout.addWidget(self.fire_check)
+        
+        self.animal_check = QCheckBox("动物检测")
+        self.animal_check.setChecked(True)
+        task_layout.addWidget(self.animal_check)
+        
+        self.landslide_check = QCheckBox("滑坡检测")
+        self.landslide_check.setChecked(True)
+        task_layout.addWidget(self.landslide_check)
+        
+        self.pest_check = QCheckBox("病虫害检测")
+        self.pest_check.setChecked(True)
+        task_layout.addWidget(self.pest_check)
+        
+        layout.addWidget(task_group)
+        
+        # 检测参数分组
+        param_group = QGroupBox("检测参数")
+        param_layout = QFormLayout(param_group)
+        
+        # 置信度阈值
+        self.conf_slider = QSlider(Qt.Horizontal)
+        self.conf_slider.setRange(1, 100)
+        self.conf_slider.setValue(int(self.config.get('conf_threshold', 0.25) * 100))
+        self.conf_slider.setTickPosition(QSlider.TicksBelow)
+        self.conf_slider.setTickInterval(10)
+        self.conf_slider.valueChanged.connect(self.update_conf_threshold)
+        
+        self.conf_label = QLabel(f"{self.conf_slider.value() / 100:.2f}")
+        conf_layout = QHBoxLayout()
+        conf_layout.addWidget(self.conf_slider)
+        conf_layout.addWidget(self.conf_label)
+        
+        param_layout.addRow("置信度阈值:", conf_layout)
+        
+        # IOU阈值
+        self.iou_slider = QSlider(Qt.Horizontal)
+        self.iou_slider.setRange(1, 100)
+        self.iou_slider.setValue(int(self.config.get('iou_threshold', 0.45) * 100))
+        self.iou_slider.setTickPosition(QSlider.TicksBelow)
+        self.iou_slider.setTickInterval(10)
+        self.iou_slider.valueChanged.connect(self.update_iou_threshold)
+        
+        self.iou_label = QLabel(f"{self.iou_slider.value() / 100:.2f}")
+        iou_layout = QHBoxLayout()
+        iou_layout.addWidget(self.iou_slider)
+        iou_layout.addWidget(self.iou_label)
+        
+        param_layout.addRow("IOU阈值:", iou_layout)
+        
+        # 输入尺寸
+        self.size_combo = QComboBox()
+        self.size_combo.addItem("320x320", 320)
+        self.size_combo.addItem("416x416", 416)
+        self.size_combo.addItem("512x512", 512)
+        self.size_combo.addItem("640x640", 640)
+        self.size_combo.addItem("1280x1280", 1280)
+        
+        # 设置默认值
+        default_size = self.config.get('image_size', 640)
+        index = self.size_combo.findData(default_size)
+        if index >= 0:
+            self.size_combo.setCurrentIndex(index)
+            
+        param_layout.addRow("输入尺寸:", self.size_combo)
+        
+        # 批处理大小
+        self.batch_spin = QSpinBox()
+        self.batch_spin.setRange(1, 64)
+        self.batch_spin.setValue(self.config.get('batch_size', 16))
+        param_layout.addRow("批处理大小:", self.batch_spin)
+        
+        layout.addWidget(param_group)
+        
+        # 添加保存按钮
+        self.save_params_btn = QPushButton("保存参数")
+        self.save_params_btn.clicked.connect(self.save_detection_params)
+        layout.addWidget(self.save_params_btn)
+        
+        return widget
+        
+    def create_camera_tab(self):
+        """创建摄像头控制标签页"""
+        widget = QWidget()
+        layout = QVBoxLayout(widget)
+        
+        # 摄像头选择分组
+        camera_group = QGroupBox("摄像头选择")
+        camera_layout = QFormLayout(camera_group)
+        
+        # 摄像头列表
+        self.camera_combo = QComboBox()
+        self.camera_combo.addItem("默认摄像头", 0)
+        self.camera_combo.addItem("USB摄像头1", 1)
+        self.camera_combo.addItem("网络摄像头", "rtsp://admin:admin@192.168.1.100:554/stream")
+        camera_layout.addRow("摄像头:", self.camera_combo)
+        
+        # 分辨率选择
+        self.resolution_combo = QComboBox()
+        self.resolution_combo.addItem("320x240", (320, 240))
+        self.resolution_combo.addItem("640x480", (640, 480))
+        self.resolution_combo.addItem("1280x720", (1280, 720))
+        self.resolution_combo.addItem("1920x1080", (1920, 1080))
+        camera_layout.addRow("分辨率:", self.resolution_combo)
+        
+        # 帧率选择
+        self.fps_spin = QSpinBox()
+        self.fps_spin.setRange(1, 60)
+        self.fps_spin.setValue(30)
+        camera_layout.addRow("帧率:", self.fps_spin)
+        
+        layout.addWidget(camera_group)
+        
+        # 图像调整分组
+        adjust_group = QGroupBox("图像调整")
+        adjust_layout = QFormLayout(adjust_group)
+        
+        # 亮度调整
+        self.brightness_slider = QSlider(Qt.Horizontal)
+        self.brightness_slider.setRange(0, 100)
+        self.brightness_slider.setValue(50)
+        adjust_layout.addRow("亮度:", self.brightness_slider)
+        
+        # 对比度调整
+        self.contrast_slider = QSlider(Qt.Horizontal)
+        self.contrast_slider.setRange(0, 100)
+        self.contrast_slider.setValue(50)
+        adjust_layout.addRow("对比度:", self.contrast_slider)
+        
+        # 饱和度调整
+        self.saturation_slider = QSlider(Qt.Horizontal)
+        self.saturation_slider.setRange(0, 100)
+        self.saturation_slider.setValue(50)
+        adjust_layout.addRow("饱和度:", self.saturation_slider)
+        
+        layout.addWidget(adjust_group)
+        
+        # 添加应用按钮
+        self.apply_camera_btn = QPushButton("应用设置")
+        self.apply_camera_btn.clicked.connect(self.apply_camera_settings)
+        layout.addWidget(self.apply_camera_btn)
+        
+        return widget
+        
+    def create_alert_tab(self):
+        """创建告警控制标签页"""
+        widget = QWidget()
+        layout = QVBoxLayout(widget)
+        
+        # 告警阈值分组
+        threshold_group = QGroupBox("告警阈值设置")
+        threshold_layout = QFormLayout(threshold_group)
+        
+        # 火灾告警阈值
+        self.fire_threshold = QSlider(Qt.Horizontal)
+        self.fire_threshold.setRange(50, 95)
+        self.fire_threshold.setValue(75)
+        self.fire_threshold.setTracking(True)
+        self.fire_threshold.setTickPosition(QSlider.TicksBelow)
+        threshold_layout.addRow("火灾告警阈值:", self.fire_threshold)
+        
+        # 动物告警阈值
+        self.animal_threshold = QSlider(Qt.Horizontal)
+        self.animal_threshold.setRange(50, 95)
+        self.animal_threshold.setValue(70)
+        self.animal_threshold.setTracking(True)
+        self.animal_threshold.setTickPosition(QSlider.TicksBelow)
+        threshold_layout.addRow("动物告警阈值:", self.animal_threshold)
+        
+        # 滑坡告警阈值
+        self.landslide_threshold = QSlider(Qt.Horizontal)
+        self.landslide_threshold.setRange(50, 95)
+        self.landslide_threshold.setValue(80)
+        self.landslide_threshold.setTracking(True)
+        self.landslide_threshold.setTickPosition(QSlider.TicksBelow)
+        threshold_layout.addRow("滑坡告警阈值:", self.landslide_threshold)
+        
+        # 病虫害告警阈值
+        self.pest_threshold = QSlider(Qt.Horizontal)
+        self.pest_threshold.setRange(50, 95)
+        self.pest_threshold.setValue(70)
+        self.pest_threshold.setTracking(True)
+        self.pest_threshold.setTickPosition(QSlider.TicksBelow)
+        threshold_layout.addRow("病虫害告警阈值:", self.pest_threshold)
+        
+        layout.addWidget(threshold_group)
+        
+        # 告警方式分组
+        method_group = QGroupBox("告警方式")
+        method_layout = QVBoxLayout(method_group)
+        
+        # 添加复选框
+        self.ui_alert_check = QCheckBox("界面告警")
+        self.ui_alert_check.setChecked(True)
+        method_layout.addWidget(self.ui_alert_check)
+        
+        self.sound_alert_check = QCheckBox("声音告警")
+        self.sound_alert_check.setChecked(True)
+        method_layout.addWidget(self.sound_alert_check)
+        
+        self.sms_alert_check = QCheckBox("短信告警")
+        self.sms_alert_check.setChecked(False)
+        method_layout.addWidget(self.sms_alert_check)
+        
+        self.email_alert_check = QCheckBox("邮件告警")
+        self.email_alert_check.setChecked(False)
+        method_layout.addWidget(self.email_alert_check)
+        
+        layout.addWidget(method_group)
+        
+        # 添加应用按钮
+        self.apply_alert_btn = QPushButton("应用设置")
+        self.apply_alert_btn.clicked.connect(self.apply_alert_settings)
+        layout.addWidget(self.apply_alert_btn)
+        
+        return widget
+        
+    @pyqtSlot(int)
+    def update_conf_threshold(self, value):
+        """更新置信度阈值显示"""
+        self.conf_label.setText(f"{value / 100:.2f}")
+        
+    @pyqtSlot(int)
+    def update_iou_threshold(self, value):
+        """更新IOU阈值显示"""
+        self.iou_label.setText(f"{value / 100:.2f}")
+        
+    @pyqtSlot()
+    def save_detection_params(self):
+        """保存检测参数"""
+        # 获取参数
+        conf_threshold = self.conf_slider.value() / 100
+        iou_threshold = self.iou_slider.value() / 100
+        image_size = self.size_combo.currentData()
+        batch_size = self.batch_spin.value()
+        
+        # 更新配置
+        self.config['conf_threshold'] = conf_threshold
+        self.config['iou_threshold'] = iou_threshold
+        self.config['image_size'] = image_size
+        self.config['batch_size'] = batch_size
+        
+        # 保存检测任务
+        self.config['enable_fire_detection'] = self.fire_check.isChecked()
+        self.config['enable_animal_detection'] = self.animal_check.isChecked()
+        self.config['enable_landslide_detection'] = self.landslide_check.isChecked()
+        self.config['enable_pest_detection'] = self.pest_check.isChecked()
+        
+        # 保存到设置文件(实际项目中实现)
+        print("检测参数已保存")
+        
+    @pyqtSlot()
+    def apply_camera_settings(self):
+        """应用摄像头设置"""
+        # 获取参数
+        camera_source = self.camera_combo.currentData()
+        resolution = self.resolution_combo.currentData()
+        fps = self.fps_spin.value()
+        brightness = self.brightness_slider.value()
+        contrast = self.contrast_slider.value()
+        saturation = self.saturation_slider.value()
+        
+        # 更新配置
+        self.config['camera_source'] = camera_source
+        self.config['camera_resolution'] = resolution
+        self.config['camera_fps'] = fps
+        self.config['camera_brightness'] = brightness
+        self.config['camera_contrast'] = contrast
+        self.config['camera_saturation'] = saturation
+        
+        # 应用设置(实际项目中实现)
+        print("摄像头设置已应用")
+        
+    @pyqtSlot()
+    def apply_alert_settings(self):
+        """应用告警设置"""
+        # 获取参数
+        fire_threshold = self.fire_threshold.value()
+        animal_threshold = self.animal_threshold.value()
+        landslide_threshold = self.landslide_threshold.value()
+        pest_threshold = self.pest_threshold.value()
+        
+        # 获取告警方式
+        alert_methods = []
+        if self.ui_alert_check.isChecked():
+            alert_methods.append('ui')
+        if self.sound_alert_check.isChecked():
+            alert_methods.append('sound')
+        if self.sms_alert_check.isChecked():
+            alert_methods.append('sms')
+        if self.email_alert_check.isChecked():
+            alert_methods.append('email')
+            
+        # 更新配置
+        self.config['fire_alert_threshold'] = fire_threshold
+        self.config['animal_alert_threshold'] = animal_threshold
+        self.config['landslide_alert_threshold'] = landslide_threshold
+        self.config['pest_alert_threshold'] = pest_threshold
+        self.config['alert_methods'] = alert_methods
+        
+        # 应用设置(实际项目中实现)
+        print("告警设置已应用")
+        
+    @pyqtSlot()
+    def start_monitoring(self):
+        """启动监测"""
+        # 切换按钮状态
+        self.start_btn.setEnabled(False)
+        self.stop_btn.setEnabled(True)
+        
+        # 通知主窗口启动监测
+        parent = self.parent()
+        while parent:
+            if hasattr(parent, 'start_monitoring'):
+                parent.start_monitoring()
+                break
+            parent = parent.parent()
+        
+    @pyqtSlot()
+    def stop_monitoring(self):
+        """停止监测"""
+        # 切换按钮状态
+        self.start_btn.setEnabled(True)
+        self.stop_btn.setEnabled(False)
+        
+        # 通知主窗口停止监测
+        parent = self.parent()
+        while parent:
+            if hasattr(parent, 'stop_monitoring'):
+                parent.stop_monitoring()
+                break
+            parent = parent.parent() 

+ 598 - 0
ui/components/drone_manager.py

@@ -0,0 +1,598 @@
+import os
+import cv2
+import numpy as np
+import random
+from datetime import datetime
+from PyQt5.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QLabel, 
+                            QPushButton, QComboBox, QToolBar, QAction, 
+                            QGridLayout, QFrame, QSplitter, QFileDialog,
+                            QTableWidget, QTableWidgetItem, QHeaderView,
+                            QAbstractItemView, QGroupBox, QTabWidget)
+from PyQt5.QtCore import Qt, QTimer, pyqtSlot, pyqtSignal, QSize, QRect, QThread
+from PyQt5.QtGui import QImage, QPixmap, QIcon, QPainter, QPen, QColor, QFont, QBrush
+
+class DroneSimulator(QThread):
+    """无人机模拟器,用于模拟无人机状态和视频流"""
+    update_frame = pyqtSignal(int, np.ndarray)  # 发送无人机ID和视频帧
+    update_status = pyqtSignal(int, dict)  # 发送无人机ID和状态信息
+    update_detection = pyqtSignal(int, list)  # 发送无人机ID和检测结果
+    
+    def __init__(self, drone_id, drone_type="DJI Mavic Air 2"):
+        super().__init__()
+        self.drone_id = drone_id
+        self.drone_type = drone_type
+        self.running = False
+        self.battery = 100
+        self.altitude = 120  # 初始高度,米
+        self.speed = 0
+        self.gps = {"lat": 39.916527 + random.uniform(-0.01, 0.01), 
+                    "lng": 116.397128 + random.uniform(-0.01, 0.01)}
+        self.signal = 95
+        self.status = "待命"
+        
+        # 选择一个示例视频作为无人机视频源
+        self.video_files = [
+            "resources/videos/drone_forest_1.mp4",
+            "resources/videos/drone_forest_2.mp4",
+            "resources/videos/drone_forest_3.mp4"
+        ]
+        # 使用无人机ID作为随机种子选择视频源,确保每个无人机有不同的视频
+        random.seed(drone_id)
+        self.video_source = "resources/videos/forest_fire.mp4"  # 默认使用一个通用视频
+        
+        # 初始化帧计数
+        self.frame_count = 0
+        
+    def run(self):
+        """运行无人机模拟器"""
+        self.running = True
+        cap = cv2.VideoCapture(self.video_source)
+        
+        if not cap.isOpened():
+            # 无法打开视频,使用生成的图像
+            print(f"无法打开视频源,使用生成图像: {self.video_source}")
+            while self.running:
+                # 生成模拟画面
+                frame = np.zeros((480, 640, 3), dtype=np.uint8)
+                # 添加一些背景
+                frame[:] = (30, 50, 30)  # 深绿色背景
+                # 添加一些文字
+                text = f"无人机 #{self.drone_id} - 信号丢失"
+                cv2.putText(frame, text, (50, 240), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
+                
+                # 添加当前状态信息
+                info_text = f"电池: {self.battery}% | 高度: {self.altitude}m | 信号: {self.signal}%"
+                cv2.putText(frame, info_text, (50, 280), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (200, 200, 200), 1)
+                
+                # 更新无人机状态
+                self.update_drone_status()
+                
+                # 发送帧和状态
+                self.update_frame.emit(self.drone_id, frame)
+                self.update_status.emit(self.drone_id, self.get_status())
+                
+                # 控制帧率
+                self.msleep(100)
+        else:
+            while self.running:
+                ret, frame = cap.read()
+                if not ret:
+                    # 视频结束,从头开始
+                    cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
+                    continue
+                
+                # 添加无人机信息叠加层
+                frame = self.add_drone_info(frame)
+                
+                # 每隔一段时间生成模拟检测结果
+                self.frame_count += 1
+                if self.frame_count % 30 == 0:  # 每30帧生成一次检测结果
+                    detections = self.generate_mock_detection(frame)
+                    self.update_detection.emit(self.drone_id, detections)
+                    # 将检测结果绘制在画面上
+                    frame = self.draw_detections(frame, detections)
+                
+                # 更新无人机状态
+                self.update_drone_status()
+                
+                # 发送帧和状态
+                self.update_frame.emit(self.drone_id, frame)
+                self.update_status.emit(self.drone_id, self.get_status())
+                
+                # 控制帧率
+                self.msleep(50)
+            
+            cap.release()
+    
+    def stop(self):
+        """停止无人机模拟器"""
+        self.running = False
+        self.wait()
+    
+    def update_drone_status(self):
+        """更新无人机状态"""
+        # 模拟电池消耗
+        self.battery = max(0, self.battery - random.uniform(0.01, 0.05))
+        
+        # 模拟高度变化
+        if random.random() < 0.3:
+            self.altitude += random.uniform(-1, 1)
+            self.altitude = max(30, min(200, self.altitude))
+        
+        # 模拟速度变化
+        self.speed = random.uniform(0, 8)
+        
+        # 模拟GPS位置变化
+        self.gps["lat"] += random.uniform(-0.0001, 0.0001)
+        self.gps["lng"] += random.uniform(-0.0001, 0.0001)
+        
+        # 模拟信号强度变化
+        if random.random() < 0.2:
+            self.signal += random.uniform(-2, 1)
+            self.signal = max(60, min(100, self.signal))
+    
+    def get_status(self):
+        """获取无人机状态信息"""
+        return {
+            "id": self.drone_id,
+            "type": self.drone_type,
+            "battery": self.battery,
+            "altitude": self.altitude,
+            "speed": self.speed,
+            "gps": self.gps,
+            "signal": self.signal,
+            "status": self.status,
+            "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+        }
+    
+    def add_drone_info(self, frame):
+        """在视频帧上添加无人机信息"""
+        height, width = frame.shape[:2]
+        
+        # 添加半透明的顶部信息栏
+        overlay = frame.copy()
+        cv2.rectangle(overlay, (0, 0), (width, 40), (0, 0, 0), -1)
+        cv2.addWeighted(overlay, 0.7, frame, 0.3, 0, frame, 0)
+        
+        # 添加无人机ID和类型
+        cv2.putText(frame, f"无人机 #{self.drone_id} | {self.drone_type}", 
+                   (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
+        
+        # 添加电池和GPS信息
+        battery_text = f"电池: {int(self.battery)}%"
+        battery_color = (0, 255, 0) if self.battery > 30 else (0, 165, 255) if self.battery > 15 else (0, 0, 255)
+        cv2.putText(frame, battery_text, (width - 300, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.6, battery_color, 1)
+        
+        gps_text = f"GPS: {self.gps['lat']:.4f}, {self.gps['lng']:.4f}"
+        cv2.putText(frame, gps_text, (width - 180, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 200, 200), 1)
+        
+        # 添加时间戳
+        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+        cv2.putText(frame, timestamp, (width - 180, height - 10), 
+                   cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
+        
+        # 添加高度和速度信息
+        altitude_text = f"高度: {int(self.altitude)}m"
+        cv2.putText(frame, altitude_text, (10, height - 30), 
+                   cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
+        
+        speed_text = f"速度: {self.speed:.1f}m/s"
+        cv2.putText(frame, speed_text, (10, height - 10), 
+                   cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
+        
+        return frame
+    
+    def generate_mock_detection(self, frame):
+        """生成模拟检测结果"""
+        height, width = frame.shape[:2]
+        
+        # 随机决定是否生成检测结果
+        if random.random() < 0.3:  # 30%的概率生成检测
+            detection_type = random.choice(['fire', 'animal', 'landslide', 'pest'])
+            
+            # 根据检测类型设置标签
+            label_map = {
+                'fire': '火灾',
+                'animal': '野生动物',
+                'landslide': '滑坡',
+                'pest': '病虫害'
+            }
+            
+            # 对于病虫害类型,随机选择具体的病虫害种类
+            pest_types = ['松毛虫', '美国白蛾', '落叶松毛虫', '杨树食叶害虫', '松材线虫病']
+            pest_subtypes = ['轻度', '中度', '重度']
+            
+            # 随机位置
+            x1 = random.randint(50, width - 150)
+            y1 = random.randint(50, height - 150)
+            w = random.randint(50, 150)
+            h = random.randint(50, 150)
+            x2 = x1 + w
+            y2 = y1 + h
+            
+            # 随机置信度
+            confidence = random.uniform(0.65, 0.95)
+            
+            # 如果是病虫害类型,创建更详细的标签
+            if detection_type == 'pest':
+                pest_type = random.choice(pest_types)
+                pest_subtype = random.choice(pest_subtypes)
+                return [{
+                    'task': detection_type,
+                    'class': 0,
+                    'label': f"{label_map[detection_type]}-{pest_type}({pest_subtype})",
+                    'confidence': confidence,
+                    'bbox': [x1, y1, x2, y2],
+                    'subtype': pest_type,
+                    'severity': pest_subtype
+                }]
+            else:
+                return [{
+                    'task': detection_type,
+                    'class': 0,
+                    'label': label_map[detection_type],
+                    'confidence': confidence,
+                    'bbox': [x1, y1, x2, y2]
+                }]
+        else:
+            return []  # 没有检测结果
+    
+    def draw_detections(self, frame, detections):
+        """在帧上绘制检测结果"""
+        for det in detections:
+            # 获取边界框和标签
+            x1, y1, x2, y2 = [int(c) for c in det['bbox']]
+            label = f"{det['label']} {det['confidence']:.2f}"
+            
+            # 根据任务选择颜色
+            if det['task'] == 'fire':
+                color = (0, 0, 255)  # 红色
+            elif det['task'] == 'animal':
+                color = (0, 255, 0)  # 绿色
+            elif det['task'] == 'landslide':
+                color = (255, 0, 0)  # 蓝色
+            elif det['task'] == 'pest':
+                color = (128, 0, 128)  # 紫色
+            else:
+                color = (255, 255, 0)  # 青色
+                
+            # 绘制边界框
+            cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
+            
+            # 绘制标签背景
+            text_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
+            cv2.rectangle(frame, (x1, y1 - text_size[1] - 5), (x1 + text_size[0], y1), color, -1)
+            
+            # 绘制标签
+            cv2.putText(frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
+            
+            # 如果是病虫害,绘制额外信息
+            if det['task'] == 'pest' and 'severity' in det:
+                # 在框的上方显示严重程度
+                severity_text = f"严重程度: {det['severity']}"
+                cv2.putText(frame, severity_text, (x1, y1 - 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
+        
+        return frame
+
+class DroneManager(QWidget):
+    """无人机管理组件,用于管理多个无人机,显示视频流和状态"""
+    
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.drones = {}  # 存储无人机模拟器,key为无人机ID
+        self.drone_frames = {}  # 存储无人机视频帧
+        self.drone_status = {}  # 存储无人机状态
+        self.drone_detections = {}  # 存储无人机检测结果
+        
+        self.init_ui()
+        
+        # 创建自动更新定时器
+        self.timer = QTimer(self)
+        self.timer.timeout.connect(self.update_drone_display)
+        self.timer.start(200)  # 每200毫秒更新一次显示
+        
+    def init_ui(self):
+        """初始化UI"""
+        # 创建主布局
+        main_layout = QVBoxLayout(self)
+        main_layout.setContentsMargins(0, 0, 0, 0)
+        
+        # 创建工具栏
+        toolbar_layout = QHBoxLayout()
+        
+        # 添加标题
+        title_label = QLabel("无人机管理")
+        title_label.setFont(QFont("Microsoft YaHei", 12, QFont.Bold))
+        title_label.setStyleSheet("color: white; margin: 5px;")
+        toolbar_layout.addWidget(title_label)
+        
+        # 添加空白占位
+        toolbar_layout.addStretch(1)
+        
+        # 添加无人机类型选择
+        self.drone_type_combo = QComboBox()
+        self.drone_type_combo.addItem("DJI Mavic Air 2")
+        self.drone_type_combo.addItem("DJI Phantom 4")
+        self.drone_type_combo.addItem("DJI Inspire 2")
+        self.drone_type_combo.addItem("Autel EVO II")
+        toolbar_layout.addWidget(QLabel("无人机类型:"))
+        toolbar_layout.addWidget(self.drone_type_combo)
+        
+        # 添加添加无人机按钮
+        self.add_drone_btn = QPushButton("添加无人机")
+        self.add_drone_btn.setIcon(QIcon(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'add.png')))
+        self.add_drone_btn.clicked.connect(self.add_drone)
+        toolbar_layout.addWidget(self.add_drone_btn)
+        
+        # 添加删除无人机按钮
+        self.remove_drone_btn = QPushButton("删除无人机")
+        self.remove_drone_btn.setIcon(QIcon(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'remove.png')))
+        self.remove_drone_btn.clicked.connect(self.remove_drone)
+        toolbar_layout.addWidget(self.remove_drone_btn)
+        
+        main_layout.addLayout(toolbar_layout)
+        
+        # 创建分割器
+        splitter = QSplitter(Qt.Vertical)
+        main_layout.addWidget(splitter, 1)
+        
+        # 创建无人机视图区域 - 使用标签页
+        self.tab_widget = QTabWidget()
+        self.tab_widget.setTabPosition(QTabWidget.North)
+        self.tab_widget.setStyleSheet("QTabWidget::pane { border: 0; } QTabBar::tab { background-color: #102040; color: white; padding: 6px 12px; margin-right: 2px; } QTabBar::tab:selected { background-color: #1a3a5a; }")
+        splitter.addWidget(self.tab_widget)
+        
+        # 创建无人机控制面板
+        control_panel = QWidget()
+        control_layout = QVBoxLayout(control_panel)
+        
+        # 添加无人机状态表格
+        self.status_table = QTableWidget()
+        self.status_table.setColumnCount(9)
+        self.status_table.setHorizontalHeaderLabels(["ID", "类型", "电池", "高度", "速度", "经度", "纬度", "信号", "状态"])
+        self.status_table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
+        self.status_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
+        self.status_table.setSelectionBehavior(QAbstractItemView.SelectRows)
+        self.status_table.setAlternatingRowColors(True)
+        self.status_table.setStyleSheet("alternate-background-color: #0c1e32; background-color: #081a2e; color: white; "
+                                      "QHeaderView::section { background-color: #15253a; color: white; padding: 4px; "
+                                      "border: 1px solid #1e3a5a; font-weight: bold; }")
+        control_layout.addWidget(self.status_table)
+        
+        # 添加按钮栏
+        btn_layout = QHBoxLayout()
+        
+        # 起飞按钮
+        self.takeoff_btn = QPushButton("起飞")
+        self.takeoff_btn.clicked.connect(self.takeoff_drone)
+        btn_layout.addWidget(self.takeoff_btn)
+        
+        # 降落按钮
+        self.land_btn = QPushButton("降落")
+        self.land_btn.clicked.connect(self.land_drone)
+        btn_layout.addWidget(self.land_btn)
+        
+        # 返航按钮
+        self.return_btn = QPushButton("返航")
+        self.return_btn.clicked.connect(self.return_drone)
+        btn_layout.addWidget(self.return_btn)
+        
+        # 停止按钮
+        self.stop_btn = QPushButton("紧急停止")
+        self.stop_btn.setStyleSheet("background-color: #8b0000; color: white;")
+        self.stop_btn.clicked.connect(self.emergency_stop_drone)
+        btn_layout.addWidget(self.stop_btn)
+        
+        control_layout.addLayout(btn_layout)
+        
+        splitter.addWidget(control_panel)
+        
+        # 设置分割器比例
+        splitter.setSizes([int(self.height() * 0.7), int(self.height() * 0.3)])
+        
+        # 添加一些初始无人机
+        QTimer.singleShot(500, self.add_initial_drones)
+    
+    def add_initial_drones(self):
+        """添加初始无人机"""
+        for i in range(3):  # 添加3个初始无人机
+            self.add_drone()
+    
+    def add_drone(self):
+        """添加一个无人机"""
+        # 生成新的无人机ID
+        drone_id = len(self.drones) + 1
+        drone_type = self.drone_type_combo.currentText()
+        
+        # 创建无人机模拟器
+        drone = DroneSimulator(drone_id, drone_type)
+        drone.update_frame.connect(self.update_drone_frame)
+        drone.update_status.connect(self.update_drone_status)
+        drone.update_detection.connect(self.update_drone_detection)
+        drone.start()
+        
+        # 存储无人机
+        self.drones[drone_id] = drone
+        
+        # 创建视频显示标签页
+        drone_tab = QWidget()
+        tab_layout = QVBoxLayout(drone_tab)
+        tab_layout.setContentsMargins(0, 0, 0, 0)
+        
+        # 创建视频帧标签
+        frame_label = QLabel()
+        frame_label.setAlignment(Qt.AlignCenter)
+        frame_label.setMinimumSize(640, 480)
+        frame_label.setStyleSheet("background-color: black;")
+        tab_layout.addWidget(frame_label)
+        
+        # 添加标签页
+        self.tab_widget.addTab(drone_tab, f"无人机 #{drone_id}")
+        
+        # 切换到新标签页
+        self.tab_widget.setCurrentIndex(self.tab_widget.count() - 1)
+        
+        # 初始化视频帧
+        self.drone_frames[drone_id] = frame_label
+        
+        # 更新状态表格
+        self.update_status_table()
+    
+    def remove_drone(self):
+        """移除选中的无人机"""
+        # 获取当前选中的标签页
+        current_index = self.tab_widget.currentIndex()
+        if current_index >= 0:
+            # 获取无人机ID
+            drone_id = int(self.tab_widget.tabText(current_index).split("#")[1])
+            
+            # 停止无人机模拟器
+            if drone_id in self.drones:
+                self.drones[drone_id].stop()
+                del self.drones[drone_id]
+            
+            # 移除视频帧
+            if drone_id in self.drone_frames:
+                del self.drone_frames[drone_id]
+            
+            # 移除状态
+            if drone_id in self.drone_status:
+                del self.drone_status[drone_id]
+            
+            # 移除检测结果
+            if drone_id in self.drone_detections:
+                del self.drone_detections[drone_id]
+            
+            # 移除标签页
+            self.tab_widget.removeTab(current_index)
+            
+            # 更新状态表格
+            self.update_status_table()
+    
+    def update_drone_frame(self, drone_id, frame):
+        """更新无人机视频帧"""
+        if drone_id in self.drone_frames:
+            # 转换为QImage并显示
+            height, width, channels = frame.shape
+            bytes_per_line = channels * width
+            q_image = QImage(frame.data, width, height, bytes_per_line, QImage.Format_RGB888).rgbSwapped()
+            self.drone_frames[drone_id].setPixmap(QPixmap.fromImage(q_image).scaled(
+                self.drone_frames[drone_id].width(), 
+                self.drone_frames[drone_id].height(),
+                Qt.KeepAspectRatio,
+                Qt.SmoothTransformation
+            ))
+    
+    def update_drone_status(self, drone_id, status):
+        """更新无人机状态"""
+        self.drone_status[drone_id] = status
+    
+    def update_drone_detection(self, drone_id, detections):
+        """更新无人机检测结果"""
+        self.drone_detections[drone_id] = detections
+        
+        # 如果有检测结果,可以向主窗口发送告警
+        if detections and hasattr(self.parent(), 'alert_panel'):
+            for det in detections:
+                # 构造告警信息
+                alert = {
+                    'time': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
+                    'type': det['task'],
+                    'location': f"无人机 #{drone_id} 位置",
+                    'detail': f"检测到{det['label']},置信度: {det['confidence']:.2f}",
+                    'level': 'high' if det['confidence'] > 0.85 else 'medium'
+                }
+                # 向告警面板添加告警
+                self.parent().alert_panel.add_alert(alert)
+    
+    def update_drone_display(self):
+        """更新无人机显示和状态表格"""
+        self.update_status_table()
+    
+    def update_status_table(self):
+        """更新状态表格"""
+        self.status_table.setRowCount(0)
+        
+        for drone_id, status in self.drone_status.items():
+            row = self.status_table.rowCount()
+            self.status_table.insertRow(row)
+            
+            # 设置ID
+            self.status_table.setItem(row, 0, QTableWidgetItem(str(drone_id)))
+            
+            # 设置类型
+            self.status_table.setItem(row, 1, QTableWidgetItem(status['type']))
+            
+            # 设置电池
+            battery_item = QTableWidgetItem(f"{int(status['battery'])}%")
+            if status['battery'] > 30:
+                battery_item.setForeground(QBrush(QColor(0, 255, 0)))
+            elif status['battery'] > 15:
+                battery_item.setForeground(QBrush(QColor(255, 165, 0)))
+            else:
+                battery_item.setForeground(QBrush(QColor(255, 0, 0)))
+            self.status_table.setItem(row, 2, battery_item)
+            
+            # 设置高度
+            self.status_table.setItem(row, 3, QTableWidgetItem(f"{int(status['altitude'])}m"))
+            
+            # 设置速度
+            self.status_table.setItem(row, 4, QTableWidgetItem(f"{status['speed']:.1f}m/s"))
+            
+            # 设置经度
+            self.status_table.setItem(row, 5, QTableWidgetItem(f"{status['gps']['lng']:.6f}"))
+            
+            # 设置纬度
+            self.status_table.setItem(row, 6, QTableWidgetItem(f"{status['gps']['lat']:.6f}"))
+            
+            # 设置信号
+            signal_item = QTableWidgetItem(f"{int(status['signal'])}%")
+            if status['signal'] > 80:
+                signal_item.setForeground(QBrush(QColor(0, 255, 0)))
+            elif status['signal'] > 60:
+                signal_item.setForeground(QBrush(QColor(255, 165, 0)))
+            else:
+                signal_item.setForeground(QBrush(QColor(255, 0, 0)))
+            self.status_table.setItem(row, 7, signal_item)
+            
+            # 设置状态
+            self.status_table.setItem(row, 8, QTableWidgetItem(status['status']))
+    
+    def takeoff_drone(self):
+        """起飞选中的无人机"""
+        selected_rows = self.status_table.selectionModel().selectedRows()
+        for index in selected_rows:
+            drone_id = int(self.status_table.item(index.row(), 0).text())
+            if drone_id in self.drones:
+                self.drones[drone_id].status = "已起飞"
+    
+    def land_drone(self):
+        """降落选中的无人机"""
+        selected_rows = self.status_table.selectionModel().selectedRows()
+        for index in selected_rows:
+            drone_id = int(self.status_table.item(index.row(), 0).text())
+            if drone_id in self.drones:
+                self.drones[drone_id].status = "正在降落"
+    
+    def return_drone(self):
+        """返航选中的无人机"""
+        selected_rows = self.status_table.selectionModel().selectedRows()
+        for index in selected_rows:
+            drone_id = int(self.status_table.item(index.row(), 0).text())
+            if drone_id in self.drones:
+                self.drones[drone_id].status = "返航中"
+    
+    def emergency_stop_drone(self):
+        """紧急停止选中的无人机"""
+        selected_rows = self.status_table.selectionModel().selectedRows()
+        for index in selected_rows:
+            drone_id = int(self.status_table.item(index.row(), 0).text())
+            if drone_id in self.drones:
+                self.drones[drone_id].status = "紧急停止"
+    
+    def closeEvent(self, event):
+        """窗口关闭时停止所有无人机"""
+        for drone in self.drones.values():
+            drone.stop()
+        event.accept() 

+ 303 - 0
ui/components/grid_camera_view.py

@@ -0,0 +1,303 @@
+from PyQt5.QtWidgets import (QWidget, QGridLayout, QLabel, QPushButton, 
+                            QVBoxLayout, QHBoxLayout, QComboBox, QMenu)
+from PyQt5.QtCore import Qt, QTimer, pyqtSignal, QSize
+from PyQt5.QtGui import QImage, QPixmap, QIcon
+import cv2
+import numpy as np
+import base64
+from ..assets.icons import GRID_ICON, SINGLE_ICON, REFRESH_ICON
+
+def create_icon_from_base64(base64_str):
+    """从Base64字符串创建图标"""
+    pixmap = QPixmap()
+    pixmap.loadFromData(base64.b64decode(base64_str))
+    return QIcon(pixmap)
+
+class CameraGridCell(QLabel):
+    """单个摄像头格子组件"""
+    clicked = pyqtSignal(int)  # 点击信号,传递格子索引
+    fire_detected = pyqtSignal(str)  # 火灾检测信号
+    animal_detected = pyqtSignal(object, str, float)  # 动物检测信号
+    
+    def __init__(self, index, parent=None):
+        super().__init__(parent)
+        self.index = index
+        self.active = False
+        self.last_fire_check = 0  # 上次火情检查时间
+        self.fire_check_interval = 1.0  # 火情检查间隔(秒)
+        self.setAlignment(Qt.AlignCenter)
+        self.setMinimumSize(320, 240)
+        self.setStyleSheet("""
+            QLabel {
+                border: 2px solid #1e3a5a;
+                background-color: #0a1a2a;
+                color: white;
+            }
+            QLabel:hover {
+                border: 2px solid #3a6a9a;
+            }
+        """)
+        self.setText("摄像头未连接")
+        
+    def mousePressEvent(self, event):
+        if event.button() == Qt.LeftButton:
+            self.clicked.emit(self.index)
+        super().mousePressEvent(event)
+        
+    def setImage(self, image, camera_info=None):
+        """设置图像并进行灾害检测"""
+        if isinstance(image, np.ndarray):
+            # 保存原始图像用于显示
+            display_image = image.copy()
+            
+            # 进行火灾检测
+            current_time = cv2.getTickCount() / cv2.getTickFrequency()
+            if current_time - self.last_fire_check >= self.fire_check_interval:
+                self.last_fire_check = current_time
+                
+                # 火灾检测
+                if self.detect_fire(image):
+                    region = camera_info['name'] if camera_info else f"摄像头 {self.index + 1}"
+                    self.fire_detected.emit(region)
+                    # 在图像上标注火灾警告
+                    cv2.putText(display_image, "火灾警告!", (10, 30),
+                              cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
+                
+                # 动物检测
+                animal_result = self.detect_animal(image)
+                if animal_result:
+                    species, confidence = animal_result
+                    self.animal_detected.emit(image, species, confidence)
+                    # 在图像上标注动物信息
+                    cv2.putText(display_image, f"{species} ({confidence:.2f}%)", (10, 60),
+                              cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
+            
+            # 转换BGR到RGB并显示
+            display_image = cv2.cvtColor(display_image, cv2.COLOR_BGR2RGB)
+            height, width, channel = display_image.shape
+            bytesPerLine = 3 * width
+            qImg = QImage(display_image.data, width, height, bytesPerLine, QImage.Format_RGB888)
+            self.setPixmap(QPixmap.fromImage(qImg).scaled(
+                self.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation))
+            self.active = True
+        else:
+            self.setText("摄像头未连接")
+            self.active = False
+            
+    def detect_fire(self, image):
+        """检测火灾
+        使用简单的颜色阈值方法检测火焰
+        """
+        try:
+            # 转换到HSV颜色空间
+            hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
+            
+            # 定义火焰的颜色范围(红色和橙色)
+            lower_red1 = np.array([0, 120, 70])
+            upper_red1 = np.array([10, 255, 255])
+            lower_red2 = np.array([170, 120, 70])
+            upper_red2 = np.array([180, 255, 255])
+            
+            # 创建掩码
+            mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
+            mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
+            mask = cv2.bitwise_or(mask1, mask2)
+            
+            # 计算火焰像素占比
+            fire_ratio = np.sum(mask > 0) / (mask.shape[0] * mask.shape[1])
+            
+            # 如果火焰像素占比超过阈值,认为检测到火灾
+            return fire_ratio > 0.01
+            
+        except Exception as e:
+            print(f"火灾检测出错: {e}")
+            return False
+            
+    def detect_animal(self, image):
+        """检测动物
+        这里使用简单的运动检测作为示例
+        实际项目中应该使用更复杂的目标检测模型
+        """
+        try:
+            # 在实际项目中,这里应该使用预训练的目标检测模型
+            # 例如 YOLO 或 SSD
+            # 这里仅作为示例返回模拟结果
+            return None
+            
+        except Exception as e:
+            print(f"动物检测出错: {e}")
+            return None
+
+class GridCameraView(QWidget):
+    """九宫格摄像头视图"""
+    
+    # 添加灾害检测信号
+    fire_detected = pyqtSignal(str)  # 火灾检测信号
+    animal_detected = pyqtSignal(object, str, float)  # 动物检测信号
+    
+    def __init__(self, parent=None):
+        super().__init__(parent)
+        self.cameras = {}  # 存储摄像头对象
+        self.current_layout = "grid"  # grid或single
+        self.active_cell = None
+        
+        # 创建图标
+        self.grid_icon = create_icon_from_base64(GRID_ICON)
+        self.single_icon = create_icon_from_base64(SINGLE_ICON)
+        self.refresh_icon = create_icon_from_base64(REFRESH_ICON)
+        
+        self.init_ui()
+        
+    def init_ui(self):
+        """初始化UI"""
+        # 创建主布局
+        self.main_layout = QVBoxLayout(self)
+        self.main_layout.setContentsMargins(5, 5, 5, 5)
+        self.main_layout.setSpacing(5)
+        
+        # 创建工具栏
+        toolbar = QHBoxLayout()
+        
+        # 添加布局切换按钮
+        self.layout_btn = QPushButton("切换视图")
+        self.layout_btn.setIcon(self.grid_icon)
+        self.layout_btn.clicked.connect(self.toggle_layout)
+        toolbar.addWidget(self.layout_btn)
+        
+        # 添加摄像头选择下拉框
+        self.camera_combo = QComboBox()
+        self.camera_combo.addItem("所有摄像头")
+        self.camera_combo.addItem("地面摄像头")
+        self.camera_combo.addItem("无人机摄像头")
+        self.camera_combo.currentIndexChanged.connect(self.on_camera_type_changed)
+        toolbar.addWidget(QLabel("摄像头类型:"))
+        toolbar.addWidget(self.camera_combo)
+        
+        toolbar.addStretch()
+        
+        # 添加刷新按钮
+        refresh_btn = QPushButton("刷新")
+        refresh_btn.setIcon(self.refresh_icon)
+        refresh_btn.clicked.connect(self.refresh_cameras)
+        toolbar.addWidget(refresh_btn)
+        
+        self.main_layout.addLayout(toolbar)
+        
+        # 创建九宫格容器
+        self.grid_container = QWidget()
+        self.grid_layout = QGridLayout(self.grid_container)
+        self.grid_layout.setSpacing(5)
+        
+        # 创建9个摄像头格子
+        self.cells = []
+        for i in range(9):
+            cell = CameraGridCell(i)
+            cell.clicked.connect(self.on_cell_clicked)
+            # 连接灾害检测信号
+            cell.fire_detected.connect(self.fire_detected)
+            cell.animal_detected.connect(self.animal_detected)
+            self.cells.append(cell)
+            self.grid_layout.addWidget(cell, i // 3, i % 3)
+            
+        self.main_layout.addWidget(self.grid_container)
+        
+        # 创建定时器用于更新画面
+        self.update_timer = QTimer(self)
+        self.update_timer.timeout.connect(self.update_frames)
+        self.update_timer.start(33)  # 约30fps
+        
+        # 初始化摄像头
+        self.refresh_cameras()
+        
+    def toggle_layout(self):
+        """切换布局模式"""
+        if self.current_layout == "grid":
+            self.current_layout = "single"
+            self.layout_btn.setIcon(self.single_icon)
+            # 隐藏除了活动格子之外的所有格子
+            for i, cell in enumerate(self.cells):
+                cell.setVisible(i == self.active_cell if self.active_cell is not None else i == 0)
+        else:
+            self.current_layout = "grid"
+            self.layout_btn.setIcon(self.grid_icon)
+            # 显示所有格子
+            for cell in self.cells:
+                cell.setVisible(True)
+                
+    def on_cell_clicked(self, index):
+        """处理格子点击事件"""
+        self.active_cell = index
+        if self.current_layout == "grid":
+            self.toggle_layout()  # 切换到单视图模式
+            
+    def on_camera_type_changed(self, index):
+        """处理摄像头类型切换"""
+        self.refresh_cameras()
+        
+    def refresh_cameras(self):
+        """刷新摄像头列表"""
+        # 这里应该实现摄像头检测和连接逻辑
+        # 示例:模拟9个摄像头
+        self.cameras.clear()
+        camera_type = self.camera_combo.currentText()
+        
+        if camera_type in ["所有摄像头", "地面摄像头"]:
+            # 添加4个地面摄像头
+            for i in range(4):
+                try:
+                    cap = cv2.VideoCapture(i)
+                    if cap.isOpened():
+                        self.cameras[i] = {
+                            'capture': cap,
+                            'type': 'ground',
+                            'name': f'地面摄像头 {i+1}'
+                        }
+                except Exception as e:
+                    print(f"连接摄像头 {i} 失败: {e}")
+                    
+        if camera_type in ["所有摄像头", "无人机摄像头"]:
+            # 模拟5个无人机摄像头
+            for i in range(5):
+                self.cameras[i+4] = {
+                    'capture': None,  # 实际项目中应该连接到无人机视频流
+                    'type': 'drone',
+                    'name': f'无人机 {i+1}'
+                }
+                
+    def update_frames(self):
+        """更新所有摄像头画面"""
+        if self.current_layout == "single" and self.active_cell is not None:
+            # 单视图模式只更新活动格子
+            self.update_cell(self.active_cell)
+        else:
+            # 网格模式更新所有格子
+            for i in range(9):
+                self.update_cell(i)
+                
+    def update_cell(self, index):
+        """更新单个格子的画面"""
+        if index in self.cameras:
+            camera = self.cameras[index]
+            if camera['type'] == 'ground' and camera['capture'] is not None:
+                ret, frame = camera['capture'].read()
+                if ret:
+                    self.cells[index].setImage(frame, camera)
+                    return
+            elif camera['type'] == 'drone':
+                # 模拟无人机画面
+                frame = np.zeros((480, 640, 3), dtype=np.uint8)
+                cv2.putText(frame, f"无人机 {index-3} 画面", (50, 240),
+                           cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
+                self.cells[index].setImage(frame, camera)
+                return
+                
+        # 如果没有摄像头或获取画面失败
+        self.cells[index].setImage(None)
+        
+    def closeEvent(self, event):
+        """关闭时释放摄像头资源"""
+        self.update_timer.stop()
+        for camera in self.cameras.values():
+            if camera['type'] == 'ground' and camera['capture'] is not None:
+                camera['capture'].release()
+        super().closeEvent(event) 

+ 724 - 0
ui/components/map_view.py

@@ -0,0 +1,724 @@
+import os
+import requests
+from PyQt5.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QPushButton, 
+                            QLabel, QComboBox, QGroupBox, QToolBar, QAction, QGridLayout)
+from PyQt5.QtCore import Qt, QUrl, pyqtSlot, QSize
+from PyQt5.QtGui import QIcon
+from PyQt5.QtWebEngineWidgets import QWebEngineView, QWebEnginePage, QWebEngineProfile, QWebEngineSettings
+from PyQt5.QtWebEngineCore import QWebEngineUrlRequestInterceptor
+
+class RequestInterceptor(QWebEngineUrlRequestInterceptor):
+    """网络请求拦截器,用于调试地图加载问题"""
+    def interceptRequest(self, info):
+        print(f"请求URL: {info.requestUrl().toString()}")
+        print(f"请求方法: {info.requestMethod()}")
+        print(f"请求头: {info.requestHeaders()}")
+
+class CustomWebEnginePage(QWebEnginePage):
+    def javaScriptConsoleMessage(self, level, message, line, source):
+        level_str = {
+            0: "INFO",
+            1: "WARNING",
+            2: "ERROR"
+        }.get(level, "UNKNOWN")
+        print(f"JS控制台 [{level_str}] {message} (第{line}行, 来源: {source})")
+
+class MapView(QWidget):
+    """地图视图组件,用于在地理信息系统上显示监测区域和告警位置"""
+    
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.init_ui()
+        
+    def init_ui(self):
+        """初始化UI"""
+        # 创建主布局
+        layout = QVBoxLayout(self)
+        layout.setContentsMargins(0, 0, 0, 0)
+        
+        # 创建地图容器
+        map_container = QWidget()
+        map_layout = QVBoxLayout(map_container)
+        map_layout.setContentsMargins(0, 0, 0, 0)
+        
+        # 创建工具栏
+        toolbar = QToolBar()
+        toolbar.setIconSize(QSize(16, 16))
+        
+        # 添加地图类型选择下拉框
+        self.map_type_combo = QComboBox()
+        self.map_type_combo.addItems(["卫星图", "地形图", "道路图", "混合图"])
+        self.map_type_combo.setCurrentIndex(0)  # 设置默认选中卫星图
+        self.map_type_combo.currentIndexChanged.connect(self.change_map_type)
+        toolbar.addWidget(QLabel("地图类型: "))
+        toolbar.addWidget(self.map_type_combo)
+        
+        toolbar.addSeparator()
+        
+        # 添加区域选择下拉框
+        self.region_combo = QComboBox()
+        for region in self.config.get('monitor_regions', []):
+            self.region_combo.addItem(region['name'], region)
+        self.region_combo.currentIndexChanged.connect(self.change_region)
+        toolbar.addWidget(QLabel("监测区域: "))
+        toolbar.addWidget(self.region_combo)
+        
+        toolbar.addSeparator()
+        
+        # 添加缩放按钮
+        self.zoom_in_btn = QPushButton("放大")
+        self.zoom_in_btn.setIcon(QIcon(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'zoom_in.png')))
+        self.zoom_in_btn.clicked.connect(self.zoom_in)
+        toolbar.addWidget(self.zoom_in_btn)
+        
+        self.zoom_out_btn = QPushButton("缩小")
+        self.zoom_out_btn.setIcon(QIcon(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'zoom_out.png')))
+        self.zoom_out_btn.clicked.connect(self.zoom_out)
+        toolbar.addWidget(self.zoom_out_btn)
+        
+        toolbar.addSeparator()
+        
+        # 显示/隐藏告警点
+        self.show_alerts_btn = QPushButton("显示告警")
+        self.show_alerts_btn.setCheckable(True)
+        self.show_alerts_btn.setChecked(True)
+        self.show_alerts_btn.clicked.connect(self.toggle_alerts)
+        toolbar.addWidget(self.show_alerts_btn)
+        
+        toolbar.addSeparator()
+        
+        # 添加定位按钮
+        self.location_btn = QPushButton("定位")
+        self.location_btn.setIcon(QIcon(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'location.png')))
+        self.location_btn.clicked.connect(self.locate_current_position)
+        toolbar.addWidget(self.location_btn)
+        
+        # 添加工具栏到地图布局
+        map_layout.addWidget(toolbar)
+        
+        # 创建Web视图用于显示地图
+        self.web_view = QWebEngineView()
+        
+        # 配置Edge浏览器引擎
+        profile = QWebEngineProfile.defaultProfile()
+        profile.setHttpUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36 Edg/91.0.864.59")
+        
+        # 添加请求拦截器
+        interceptor = RequestInterceptor()
+        profile.setUrlRequestInterceptor(interceptor)
+        
+        # 启用必要的Web功能
+        settings = profile.settings()
+        settings.setAttribute(QWebEngineSettings.JavascriptEnabled, True)
+        settings.setAttribute(QWebEngineSettings.LocalStorageEnabled, True)
+        settings.setAttribute(QWebEngineSettings.WebGLEnabled, True)
+        settings.setAttribute(QWebEngineSettings.PluginsEnabled, True)
+        settings.setAttribute(QWebEngineSettings.FullScreenSupportEnabled, True)
+        
+        # 启用跨域访问
+        profile.setHttpCacheType(QWebEngineProfile.DiskHttpCache)
+        profile.setPersistentCookiesPolicy(QWebEngineProfile.AllowPersistentCookies)
+        
+        # 设置页面
+        page = CustomWebEnginePage(self.web_view)
+        self.web_view.setPage(page)
+        map_layout.addWidget(self.web_view)
+        
+        # 将地图容器添加到主布局
+        layout.addWidget(map_container)
+        
+        # 加载初始地图
+        self.load_map()
+        
+    def load_map(self):
+        """加载百度地图"""
+        try:
+            # 获取地图配置
+            center = self.config.get('map_center', [39.915, 116.404])
+            zoom = self.config.get('map_zoom', 15)
+            
+            # 使用临时文件加载地图
+            import tempfile
+            temp_path = os.path.join(tempfile.gettempdir(), "map_temp.html")
+            
+            # 构建HTML
+            html = f"""<!DOCTYPE html>
+<html>
+<head>
+    <meta charset="utf-8" />
+    <meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
+    <meta name="viewport" content="initial-scale=1.0, user-scalable=no" />
+    <title>森林监测地图</title>
+    <script type="text/javascript" src="https://api.map.baidu.com/api?v=3.0&ak={self.config.get('baidu_map_key', '')}&type=webgl"></script>
+    <style type="text/css">
+        html, body, #map {{
+            height: 100%;
+            width: 100%;
+            margin: 0;
+            padding: 0;
+        }}
+        .loading {{
+            position: absolute;
+            top: 50%;
+            left: 50%;
+            transform: translate(-50%, -50%);
+            text-align: center;
+            font-family: Arial, sans-serif;
+        }}
+        .offline-map {{
+            width: 100%;
+            height: 100%;
+            background-color: #eee;
+            display: flex;
+            flex-direction: column;
+            align-items: center;
+            justify-content: center;
+            font-family: Arial, sans-serif;
+        }}
+        .info-window {{
+            padding: 5px;
+            min-width: 150px;
+        }}
+        .info-window .title {{
+            font-weight: bold;
+            margin-bottom: 5px;
+            color: #3a8ee6;
+        }}
+    </style>
+</head>
+<body>
+    <div id="map">
+        <div class="loading">加载百度地图中...</div>
+    </div>
+    
+    <script type="text/javascript">
+        // 全局变量
+        var map;
+        var alertMarkers = [];
+        var isMapLoaded = false;
+        
+        // 初始化函数
+        function initMap() {{
+            try {{
+                console.log("开始初始化地图...");
+                
+                // 检查BMap对象是否存在
+                if (typeof BMapGL === "undefined") {{
+                    throw new Error("BMapGL未定义,百度地图API未正确加载");
+                }}
+                
+                console.log("创建地图实例...");
+                // 创建地图实例
+                map = new BMapGL.Map("map");
+                
+                console.log("设置地图中心点...");
+                // 创建点坐标(注意经纬度顺序:经度在前,纬度在后)
+                var point = new BMapGL.Point({center[1]}, {center[0]});
+                
+                console.log("初始化地图视图...");
+                // 初始化地图
+                map.centerAndZoom(point, {zoom});
+                
+                // 设置默认地图类型为卫星图
+                map.setMapType(BMAP_SATELLITE_MAP);
+                
+                // 设置默认显示选项,隐藏POI图标
+                map.setDisplayOptions({{
+                    poi: false,  // 隐藏POI图标
+                    poiText: false,  // 隐藏POI文字
+                    building: false  // 隐藏3D建筑物
+                }});
+                
+                // 开启鼠标滚轮缩放
+                map.enableScrollWheelZoom(true);
+                
+                console.log("添加地图控件...");
+                // 添加地图控件
+                map.addControl(new BMapGL.NavigationControl());    // 导航控件
+                map.addControl(new BMapGL.ScaleControl());         // 比例尺控件
+                map.addControl(new BMapGL.ZoomControl());          // 缩放控件
+                map.addControl(new BMapGL.LocationControl());      // 定位控件
+                
+                // 地图初始化成功标记
+                isMapLoaded = true;
+                console.log("地图初始化完成");
+
+                // 自动定位到当前位置
+                setTimeout(function() {{
+                    // 定义错误状态处理函数
+                    function handleLocationError(status) {{
+                        var errorMsg = "";
+                        switch(status) {{
+                            case 6:
+                                errorMsg = "定位权限被拒绝,请在浏览器设置中允许获取位置信息";
+                                break;
+                            case 2:
+                            case 8:
+                                errorMsg = "定位不可用或超时,尝试使用IP定位";
+                                break;
+                            default:
+                                errorMsg = "定位失败(错误码:" + status + "),尝试使用IP定位";
+                        }}
+                        console.error(errorMsg);
+                        return errorMsg;
+                    }}
+
+                    // 显示定位结果的函数
+                    function showLocationResult(point, accuracy, address, locationType, errorMsg) {{
+                        var marker = new BMapGL.Marker(point);
+                        map.addOverlay(marker);
+                        map.centerAndZoom(point, locationType === 'ip' ? 12 : 18);
+
+                        // 如果是精确定位,显示精度圈
+                        if (accuracy && locationType !== 'ip') {{
+                            var circle = new BMapGL.Circle(point, accuracy, {{
+                                strokeColor: "#1E90FF",
+                                strokeWeight: 1,
+                                strokeOpacity: 0.5,
+                                fillColor: "#1E90FF",
+                                fillOpacity: 0.1
+                            }});
+                            map.addOverlay(circle);
+                        }}
+
+                        var locationTypeText = {{
+                            'sdk': '手机GPS',
+                            'h5': '浏览器定位',
+                            'ip': 'IP定位'
+                        }}[locationType] || '未知方式';
+
+                        var infoWindow = new BMapGL.InfoWindow(
+                            '<div class="info-window">' +
+                            '<div class="title">当前位置</div>' +
+                            '<div>定位方式: ' + locationTypeText + '</div>' +
+                            '<div>经度: ' + point.lng.toFixed(6) + '</div>' +
+                            '<div>纬度: ' + point.lat.toFixed(6) + '</div>' +
+                            '<div>地址: ' + address + '</div>' +
+                            (accuracy && locationType !== 'ip' ? '<div>定位精度: ' + accuracy.toFixed(1) + '米</div>' : '') +
+                            (errorMsg ? '<div style="color: #ff9900;">' + errorMsg + '</div>' : '') +
+                            '</div>'
+                        );
+                        marker.openInfoWindow(infoWindow);
+                    }}
+
+                    // 先尝试SDK定位
+                    var geolocation = new BMapGL.Geolocation();
+                    geolocation.enableSDKLocation();
+                    geolocation.getCurrentPosition(function(r) {{
+                        if(this.getStatus() == BMAP_STATUS_SUCCESS) {{
+                            // SDK定位成功
+                            var geoc = new BMapGL.Geocoder();
+                            geoc.getLocation(r.point, function(rs) {{
+                                var addComp = rs.addressComponents;
+                                var address = addComp.province + addComp.city + 
+                                            addComp.district + addComp.street + 
+                                            addComp.streetNumber;
+                                showLocationResult(r.point, r.accuracy, address, 'sdk');
+                            }});
+                        }} else {{
+                            // SDK定位失败,尝试浏览器H5定位
+                            var h5geolocation = new BMapGL.Geolocation();
+                            h5geolocation.getCurrentPosition(function(r) {{
+                                if(this.getStatus() == BMAP_STATUS_SUCCESS) {{
+                                    // H5定位成功
+                                    var geoc = new BMapGL.Geocoder();
+                                    geoc.getLocation(r.point, function(rs) {{
+                                        var addComp = rs.addressComponents;
+                                        var address = addComp.province + addComp.city + 
+                                                    addComp.district + addComp.street + 
+                                                    addComp.streetNumber;
+                                        showLocationResult(r.point, r.accuracy, address, 'h5');
+                                    }});
+                                }} else {{
+                                    // H5定位也失败,使用IP定位
+                                    var errorMsg = handleLocationError(this.getStatus());
+                                    var myCity = new BMapGL.LocalCity();
+                                    myCity.get(function(result) {{
+                                        var geoc = new BMapGL.Geocoder();
+                                        geoc.getLocation(result.center, function(rs) {{
+                                            var addComp = rs.addressComponents;
+                                            var address = addComp.province + addComp.city + 
+                                                        addComp.district + addComp.street + 
+                                                        addComp.streetNumber;
+                                            showLocationResult(result.center, null, address, 'ip', 
+                                                '注意:由于无法获取精确位置,已切换到IP定位(精度较低)');
+                                        }});
+                                    }});
+                                }}
+                            }}, {{
+                                enableHighAccuracy: true,
+                                timeout: 5000,
+                                maximumAge: 0
+                            }});
+                        }}
+                    }}, {{
+                        enableHighAccuracy: true,
+                        timeout: 5000,
+                        maximumAge: 0,
+                        SDKLocation: true,
+                        coordType: 'bd09ll',
+                        poiDistance: true,
+                        poiNumber: 1
+                    }});
+                }}, 1000);
+                
+                // 定义对外接口
+                window.mapFunctions = {{
+                    setMapType: function(type) {{
+                        if (!isMapLoaded) return false;
+                        try {{
+                            switch(type) {{
+                                case 0: // 卫星图
+                                    map.setMapType(BMAP_SATELLITE_MAP);
+                                    map.setDisplayOptions({{
+                                        poi: false,
+                                        poiText: false,
+                                        building: false
+                                    }});
+                                    map.setTilt(0);
+                                    break;
+                                case 1: // 地形图
+                                    map.setMapType(BMAP_NORMAL_MAP);
+                                    map.setDisplayOptions({{
+                                        poi: false,
+                                        poiText: false,
+                                        building: false
+                                    }});
+                                    break;
+                                case 2: // 道路图
+                                    map.setMapType(BMAP_NORMAL_MAP);
+                                    map.setDisplayOptions({{
+                                        poi: false,
+                                        poiText: false,
+                                        building: false
+                                    }});
+                                    break;
+                                case 3: // 混合图
+                                    map.setMapType(BMAP_SATELLITE_MAP);
+                                    map.setDisplayOptions({{
+                                        poi: false,
+                                        poiText: false,
+                                        building: true
+                                    }});
+                                    break;
+                            }}
+                            return true;
+                        }} catch(e) {{
+                            console.error("切换地图类型出错:", e);
+                            return false;
+                        }}
+                    }},
+                    
+                    zoomIn: function() {{
+                        if (!isMapLoaded) return false;
+                        try {{
+                            map.zoomIn();
+                            return true;
+                        }} catch(e) {{
+                            console.error("地图放大出错:", e);
+                            return false;
+                        }}
+                    }},
+                    
+                    zoomOut: function() {{
+                        if (!isMapLoaded) return false;
+                        try {{
+                            map.zoomOut();
+                            return true;
+                        }} catch(e) {{
+                            console.error("地图缩小出错:", e);
+                            return false;
+                        }}
+                    }},
+                    
+                    panTo: function(lat, lng) {{
+                        if (!isMapLoaded) return false;
+                        try {{
+                            var point = new BMapGL.Point(lng, lat);
+                            map.panTo(point);
+                            return true;
+                        }} catch(e) {{
+                            console.error("地图平移出错:", e);
+                            return false;
+                        }}
+                    }},
+                    
+                    toggleAlerts: function(show) {{
+                        if (!isMapLoaded) return false;
+                        try {{
+                            if (alertMarkers.length > 0) {{
+                                alertMarkers.forEach(function(marker) {{
+                                    if (show) marker.show();
+                                    else marker.hide();
+                                }});
+                            }}
+                            return true;
+                        }} catch(e) {{
+                            console.error("切换告警显示出错:", e);
+                            return false;
+                        }}
+                    }},
+                    
+                    locateCurrentPosition: function() {{
+                        if (!isMapLoaded) {{
+                            console.error('地图未加载完成');
+                            return false;
+                        }}
+                        
+                        // 创建定位控件
+                        var locationControl = new BMapGL.LocationControl();
+                        locationControl.addEventListener("locationSuccess", function(e){{
+                            var address = '';
+                            address += e.addressComponent.province;
+                            address += e.addressComponent.city;
+                            address += e.addressComponent.district;
+                            address += e.addressComponent.street;
+                            address += e.addressComponent.streetNumber;
+                            
+                            // 在marker上显示信息窗口
+                            var infoWindow = new BMapGL.InfoWindow(
+                                '<div class="info-window">' +
+                                '<div class="title">当前位置</div>' +
+                                '<div>经度: ' + e.point.lng + '</div>' +
+                                '<div>纬度: ' + e.point.lat + '</div>' +
+                                '<div>地址: ' + address + '</div>' +
+                                '</div>'
+                            );
+                            var marker = new BMapGL.Marker(e.point);
+                            map.addOverlay(marker);
+                            marker.openInfoWindow(infoWindow);
+                            
+                            console.log('定位成功');
+                        }});
+                        locationControl.addEventListener("locationError", function(e){{
+                            console.error('定位失败:' + e.message);
+                        }});
+                        locationControl.location();
+                        return true;
+                    }}
+                }};
+                
+            }} catch (e) {{
+                console.error("地图初始化失败:", e);
+                document.getElementById("map").innerHTML = 
+                    '<div class="offline-map"><h2>百度地图初始化失败</h2><p>错误信息: ' + e.message + '</p></div>';
+            }}
+        }}
+        
+        // 添加示例告警点
+        function addExampleAlerts() {{
+            // 清空告警点数组
+            alertMarkers = [];
+            // 暂时不添加示例告警点
+        }}
+
+        // 页面加载完成后初始化地图
+        window.onload = initMap;
+    </script>
+</body>
+</html>"""
+            
+            # 将HTML保存到临时文件
+            with open(temp_path, "w", encoding="utf-8") as f:
+                f.write(html)
+            
+            # 从本地文件加载
+            self.web_view.load(QUrl.fromLocalFile(temp_path))
+            print("正在从本地文件加载百度地图...")
+            
+        except Exception as e:
+            print(f"地图加载失败: {e}")
+            # 显示错误信息
+            error_html = f"""
+            <html>
+            <body style="background-color: #f0f0f0; color: #333; font-family: Arial, sans-serif; text-align: center; padding: 50px;">
+                <h2>地图加载失败</h2>
+                <p>错误信息: {str(e)}</p>
+                <p>请检查以下内容:</p>
+                <ol style="text-align: left; max-width: 500px; margin: 0 auto;">
+                    <li>确认网络连接正常</li>
+                    <li>确认百度地图API密钥有效</li>
+                    <li>检查PyQt WebEngine设置</li>
+                    <li>查看控制台输出的详细错误信息</li>
+                </ol>
+            </body>
+            </html>
+            """
+            self.web_view.setHtml(error_html)
+      
+    @pyqtSlot(int)
+    def change_map_type(self, index):
+        """改变地图类型"""
+        # 通过JS调用地图函数
+        script = f"""
+            try {{
+                console.log('调用地图类型切换,类型索引:', {index});
+                if (window.mapFunctions && typeof window.mapFunctions.setMapType === 'function') {{
+                    window.mapFunctions.setMapType({index});
+                    return true;
+                }} else {{
+                    console.error('地图类型切换函数不可用');
+                    return false;
+                }}
+            }} catch(e) {{
+                console.error('执行地图类型切换时出错:', e);
+                return false;
+            }}
+        """
+        self.web_view.page().runJavaScript(script)
+        
+    @pyqtSlot(int)
+    def change_region(self, index):
+        """改变监测区域"""
+        if index >= 0 and index < len(self.config.get('monitor_regions', [])):
+            region = self.config['monitor_regions'][index]
+            # 通过JS调用地图函数
+            script = """
+                try {
+                    if (window.mapFunctions && typeof window.mapFunctions.panTo === 'function') {
+                        window.mapFunctions.panTo(0, 0);
+                        return true;
+                    } else {
+                        console.error('地图平移函数不可用');
+                        return false;
+                    }
+                } catch(e) {
+                    console.error('平移到区域时出错:', e);
+                    return false;
+                }
+            """
+            self.web_view.page().runJavaScript(script)
+        
+    @pyqtSlot()
+    def zoom_in(self):
+        """地图放大"""
+        script = """
+            try {
+                if (window.mapFunctions && typeof window.mapFunctions.zoomIn === 'function') {
+                    window.mapFunctions.zoomIn();
+                    return true;
+                } else {
+                    console.error('地图放大函数不可用');
+                    return false;
+                }
+            } catch(e) {
+                console.error('地图放大时出错:', e);
+                return false;
+            }
+        """
+        self.web_view.page().runJavaScript(script)
+        
+    @pyqtSlot()
+    def zoom_out(self):
+        """地图缩小"""
+        script = """
+            try {
+                if (window.mapFunctions && typeof window.mapFunctions.zoomOut === 'function') {
+                    window.mapFunctions.zoomOut();
+                    return true;
+                } else {
+                    console.error('地图缩小函数不可用');
+                    return false;
+                }
+            } catch(e) {
+                console.error('地图缩小时出错:', e);
+                return false;
+            }
+        """
+        self.web_view.page().runJavaScript(script)
+        
+    @pyqtSlot(bool)
+    def toggle_alerts(self, checked):
+        """切换告警点显示状态"""
+        script = f"""
+            try {{
+                if (window.mapFunctions && typeof window.mapFunctions.toggleAlerts === 'function') {{
+                    window.mapFunctions.toggleAlerts({str(checked).lower()});
+                    return true;
+                }} else {{
+                    console.error('告警点切换函数不可用');
+                    return false;
+                }}
+            }} catch(e) {{
+                console.error('切换告警点时出错:', e);
+                return false;
+            }}
+        """
+        self.web_view.page().runJavaScript(script)
+        
+    def locate_current_position(self):
+        """定位当前位置"""
+        script = """
+            try {
+                if (!isMapLoaded) {
+                    console.error('地图未加载完成');
+                    return false;
+                }
+                
+                // 创建定位控件
+                var locationControl = new BMapGL.LocationControl();
+                locationControl.addEventListener("locationSuccess", function(e){{
+                    var address = '';
+                    address += e.addressComponent.province;
+                    address += e.addressComponent.city;
+                    address += e.addressComponent.district;
+                    address += e.addressComponent.street;
+                    address += e.addressComponent.streetNumber;
+                    
+                    // 在marker上显示信息窗口
+                    var infoWindow = new BMapGL.InfoWindow(
+                        '<div class="info-window">' +
+                        '<div class="title">当前位置</div>' +
+                        '<div>经度: ' + e.point.lng + '</div>' +
+                        '<div>纬度: ' + e.point.lat + '</div>' +
+                        '<div>地址: ' + address + '</div>' +
+                        '</div>'
+                    );
+                    var marker = new BMapGL.Marker(e.point);
+                    map.addOverlay(marker);
+                    marker.openInfoWindow(infoWindow);
+                    
+                    console.log('定位成功');
+                }});
+                locationControl.addEventListener("locationError", function(e){{
+                    console.error('定位失败:' + e.message);
+                }});
+                locationControl.location();
+                return true;
+            } catch(e) {
+                console.error('执行定位时出错:', e);
+                return false;
+            }
+        """
+        self.web_view.page().runJavaScript(script)
+        
+    def update_view(self):
+        """更新地图视图"""
+        # 实际项目中可以在这里添加刷新告警点等逻辑
+        pass
+
+    def get_accurate_ip_location(self):
+        """获取更精确的IP定位信息"""
+        try:
+            # 使用腾讯位置服务
+            response = requests.get(
+                'https://apis.map.qq.com/ws/location/v1/ip',
+                params={
+                    'key': self.config.get('qq_map_key', ''),  # 需要在配置中添加腾讯地图密钥
+                    'output': 'json'
+                },
+                timeout=5
+            )
+            if response.status_code == 200:
+                data = response.json()
+                if data['status'] == 0:
+                    location = data['result']['location']
+                    return {
+                        'lng': location['lng'],
+                        'lat': location['lat'],
+                        'accuracy': data['result'].get('accuracy', 0)
+                    }
+        except Exception as e:
+            print(f"腾讯地图IP定位失败: {e}")
+        return None

+ 961 - 0
ui/components/statistics_panel.py

@@ -0,0 +1,961 @@
+import os
+import random
+from datetime import datetime, timedelta
+from PyQt5.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QLabel, QSlider,
+                              QTabWidget, QGroupBox, QComboBox, QPushButton, QTableWidget, QTableWidgetItem, QHeaderView)
+from PyQt5.QtCore import Qt, pyqtSlot, QTimer, QMargins
+from PyQt5.QtGui import QFont, QPainter, QPen, QColor
+from PyQt5.QtChart import QChart, QChartView, QLineSeries, QBarSet, QBarSeries, QPieSeries, QValueAxis, QDateTimeAxis, QBarCategoryAxis, QLegend
+
+class StatisticsPanel(QWidget):
+    """统计面板组件,显示各类灾害检测统计信息和趋势图表"""
+    
+    def __init__(self, config, parent=None):
+        super().__init__(parent)
+        self.config = config
+        
+        # 初始化统计数据
+        self.alert_stats = {
+            'fire': 0,
+            'animal': 0,
+            'landslide': 0,
+            'pest': 0
+        }
+        
+        # 初始化区域统计数据为空字典
+        self.region_stats = {}
+        
+        # 初始化区域名称映射字典
+        self.region_mapping = {
+            '北部山区': '东北区',  # 假设北部山区属于东北区
+            '南部林区': '东南区',  # 假设南部林区属于东南区
+            '东部山脊': '东北区',  # 假设东部山脊属于东北区
+            '西部谷地': '西北区',  # 假设西部谷地属于西北区
+            '西部林区': '西北区',  # 假设西部林区属于西北区
+            '中央林场': '中部区'   # 假设中央林场属于中部区
+        }
+        
+        # 模拟数据
+        self.init_mock_data()
+        
+        # 初始化UI
+        self.init_ui()
+        
+        # 创建并保存趋势图视图的引用
+        self.trend_chart_view = self.create_trend_chart_view()
+        
+        # 设置自动更新定时器
+        self.update_timer = QTimer(self)
+        self.update_timer.timeout.connect(self.auto_update_statistics)
+        self.update_timer.start(5000)  # 每5秒更新一次统计数据
+        
+    def init_mock_data(self):
+        """初始化模拟数据"""
+        # 病虫害类型统计数据
+        self.pest_type_stats = {
+            '松材线虫': 0,
+            '杨树食叶害虫': 0,
+            '松毛虫': 0,
+            '蚜虫': 0,
+            '松墨天牛': 0,
+            '落叶松针叶锈病': 0
+        }
+        
+        # 初始化区域统计字典
+        if not self.region_stats:
+            self.region_stats = {
+                '东北区': 0,
+                '西北区': 0,
+                '中部区': 0,
+                '东南区': 0,
+                '西南区': 0
+            }
+        
+        # 模拟24小时数据
+        self.hour_data = []
+        now = datetime.now()
+        for i in range(24):
+            time_point = now - timedelta(hours=23-i)
+            self.hour_data.append({
+                'time': time_point,
+                'fire': 0,
+                'animal': 0,
+                'landslide': 0,
+                'forest_degradation': 0,
+                'pest': 0
+            })
+        
+        # 模拟7天数据
+        self.day_data = []
+        for i in range(7):
+            time_point = now - timedelta(days=6-i)
+            self.day_data.append({
+                'time': time_point,
+                'fire': 0,
+                'animal': 0,
+                'landslide': 0,
+                'forest_degradation': 0,
+                'pest': 0
+            })
+        
+        # 模拟30天数据
+        self.month_data = []
+        for i in range(30):
+            time_point = now - timedelta(days=29-i)
+            self.month_data.append({
+                'time': time_point,
+                'fire': 0,
+                'animal': 0,
+                'landslide': 0,
+                'forest_degradation': 0,
+                'pest': 0
+            })
+        
+    def init_ui(self):
+        """初始化UI"""
+        # 创建主布局
+        layout = QVBoxLayout(self)
+        layout.setContentsMargins(5, 5, 5, 5)
+        
+        # 创建标签页
+        self.tab_widget = QTabWidget()
+        
+        # 概览标签页
+        overview_tab = self.create_overview_tab()
+        self.tab_widget.addTab(overview_tab, "概览")
+        
+        # 趋势图标签页
+        trend_tab = self.create_trend_tab()
+        self.tab_widget.addTab(trend_tab, "趋势图")
+        
+        # 区域统计标签页
+        region_tab = self.create_region_tab()
+        self.tab_widget.addTab(region_tab, "区域统计")
+        
+        # 添加标签页到布局
+        layout.addWidget(self.tab_widget)
+        
+        # 底部工具栏
+        toolbar = QHBoxLayout()
+        
+        # 时间范围下拉框
+        self.time_range_combo = QComboBox()
+        self.time_range_combo.addItems(["最近24小时", "最近7天", "最近30天", "本月", "本年"])
+        self.time_range_combo.currentIndexChanged.connect(self.change_time_range)
+        toolbar.addWidget(QLabel("时间范围:"))
+        toolbar.addWidget(self.time_range_combo)
+        
+        # 添加刷新按钮
+        refresh_btn = QPushButton("刷新")
+        refresh_btn.clicked.connect(self.refresh_statistics)
+        toolbar.addStretch(1)
+        toolbar.addWidget(refresh_btn)
+        
+        # 添加工具栏到布局
+        layout.addLayout(toolbar)
+        
+        # 设置最小高度
+        self.setMinimumHeight(300)
+        
+    def create_overview_tab(self):
+        """创建概览标签页"""
+        tab = QWidget()
+        tab.setObjectName("overview_tab")  # 添加对象名称
+        layout = QVBoxLayout(tab)
+        
+        # 告警统计
+        stats_group = QGroupBox("告警统计")
+        stats_group.setObjectName("stats_group")  # 添加对象名称
+        stats_layout = QHBoxLayout(stats_group)
+        
+        # 火灾告警
+        fire_label = QLabel(f"火灾告警: {self.alert_stats['fire']}")
+        fire_label.setObjectName("fire_label")  # 添加对象名称
+        fire_label.setStyleSheet("color: red;")
+        stats_layout.addWidget(fire_label)
+        
+        # 动物告警
+        animal_label = QLabel(f"动物告警: {self.alert_stats['animal']}")
+        animal_label.setObjectName("animal_label")  # 添加对象名称
+        animal_label.setStyleSheet("color: green;")
+        stats_layout.addWidget(animal_label)
+        
+        # 滑坡告警
+        landslide_label = QLabel(f"滑坡告警: {self.alert_stats['landslide']}")
+        landslide_label.setObjectName("landslide_label")  # 添加对象名称
+        landslide_label.setStyleSheet("color: blue;")
+        stats_layout.addWidget(landslide_label)
+        
+        # 病虫害告警
+        pest_label = QLabel(f"病虫害: {self.alert_stats['pest']}")
+        pest_label.setObjectName("pest_label")  # 添加对象名称
+        pest_label.setStyleSheet("color: purple;")
+        stats_layout.addWidget(pest_label)
+        
+        layout.addWidget(stats_group)
+        
+        # 饼图 - 使用create_pie_chart方法
+        chart_view = self.create_pie_chart()
+        chart_view.setObjectName("overview_pie_chart")  # 添加对象名称
+        
+        layout.addWidget(chart_view)
+        
+        return tab
+        
+    def create_trend_tab(self):
+        """创建趋势图标签页"""
+        tab = QWidget()
+        tab.setObjectName("trend_tab")  # 添加对象名称
+        layout = QVBoxLayout(tab)
+        
+        # 24小时趋势图
+        trend_chart = QChart()
+        trend_chart.setTitle("24小时告警趋势")
+        trend_chart.setAnimationOptions(QChart.SeriesAnimations)
+        trend_chart.setBackgroundBrush(QColor("#0a1a2a"))
+        trend_chart.setTitleBrush(QColor("white"))
+        trend_chart.setTitleFont(QFont("Microsoft YaHei", 10, QFont.Bold))
+        
+        # 创建折线系列 - 火灾
+        fire_series = QLineSeries()
+        fire_series.setName("火灾告警")
+        fire_series.setColor(QColor(255, 100, 100))
+        
+        # 创建折线系列 - 动物
+        animal_series = QLineSeries()
+        animal_series.setName("动物告警")
+        animal_series.setColor(QColor(100, 255, 100))
+        
+        # 创建折线系列 - 滑坡
+        landslide_series = QLineSeries()
+        landslide_series.setName("滑坡告警")
+        landslide_series.setColor(QColor(100, 100, 255))
+        
+        # 创建折线系列 - 病虫害
+        pest_series = QLineSeries()
+        pest_series.setName("病虫害告警")
+        pest_series.setColor(QColor(180, 100, 200))
+        
+        # 添加数据点
+        for i, data in enumerate(self.hour_data):
+            timestamp = data['time'].timestamp() * 1000  # 转换为毫秒
+            fire_series.append(timestamp, data['fire'])
+            animal_series.append(timestamp, data['animal'])
+            landslide_series.append(timestamp, data['landslide'])
+            pest_series.append(timestamp, data['pest'])
+        
+        # 添加系列到图表
+        trend_chart.addSeries(fire_series)
+        trend_chart.addSeries(animal_series)
+        trend_chart.addSeries(landslide_series)
+        trend_chart.addSeries(pest_series)
+        
+        # 创建X轴(时间轴)
+        axis_x = QDateTimeAxis()
+        axis_x.setFormat("HH:mm")
+        axis_x.setTitleText("时间")
+        axis_x.setTickCount(8)  # 显示8个刻度
+        axis_x.setRange(
+            self.hour_data[0]['time'],
+            self.hour_data[-1]['time']
+        )
+        axis_x.setTitleBrush(QColor("white"))
+        axis_x.setLabelsColor(QColor("white"))
+        
+        # 创建Y轴
+        axis_y = QValueAxis()
+        axis_y.setLabelFormat("%d")
+        axis_y.setTitleText("告警数量")
+        axis_y.setRange(0, 15)  # 增大Y轴范围,避免文字被裁剪
+        axis_y.setTickCount(6)  # 增加刻度数量以更好显示
+        axis_y.setTitleBrush(QColor("white"))
+        axis_y.setLabelsColor(QColor("white"))
+        
+        # 添加坐标轴到图表
+        trend_chart.addAxis(axis_x, Qt.AlignBottom)
+        trend_chart.addAxis(axis_y, Qt.AlignLeft)
+        
+        # 将所有系列依附到坐标轴
+        fire_series.attachAxis(axis_x)
+        fire_series.attachAxis(axis_y)
+        animal_series.attachAxis(axis_x)
+        animal_series.attachAxis(axis_y)
+        landslide_series.attachAxis(axis_x)
+        landslide_series.attachAxis(axis_y)
+        pest_series.attachAxis(axis_x)
+        pest_series.attachAxis(axis_y)
+        
+        # 设置图例位置和样式
+        trend_chart.legend().setVisible(True)
+        trend_chart.legend().setAlignment(Qt.AlignBottom)
+        trend_chart.legend().setLabelColor(QColor("white"))
+        trend_chart.legend().setMarkerShape(QLegend.MarkerShapeCircle)  
+        
+        # 创建图表视图
+        chart_view = QChartView(trend_chart)
+        chart_view.setRenderHint(QPainter.Antialiasing)
+        chart_view.setMinimumHeight(250)  # 增加最小高度
+        chart_view.setBackgroundBrush(QColor("#0a1a2a"))
+        
+        layout.addWidget(chart_view)
+        
+        return tab
+        
+    def create_region_tab(self):
+        """创建区域统计标签页"""
+        tab = QWidget()
+        layout = QVBoxLayout(tab)
+        
+        # 添加区域告警统计图表
+        chart_view = self.create_region_chart_view()
+        layout.addWidget(chart_view)
+        
+        # 添加病虫害类型统计表格
+        pest_table = self.create_pest_table()
+        layout.addWidget(pest_table)
+        
+        return tab
+        
+    def create_pest_table(self):
+        """创建病虫害类型统计表格"""
+        # 创建表格
+        table = QTableWidget()
+        table.setObjectName("pest_table")  # 添加对象名称便于检索
+        table.setColumnCount(2)
+        table.setRowCount(len(self.pest_type_stats))
+        
+        # 设置表头
+        table.setHorizontalHeaderLabels(["病虫害类型", "检测数量"])
+        
+        # 添加数据
+        row = 0
+        for pest_type, count in self.pest_type_stats.items():
+            # 添加病虫害类型
+            type_item = QTableWidgetItem(pest_type)
+            type_item.setTextAlignment(Qt.AlignCenter)
+            table.setItem(row, 0, type_item)
+            
+            # 添加数量
+            count_item = QTableWidgetItem(str(count))
+            count_item.setTextAlignment(Qt.AlignCenter)
+            table.setItem(row, 1, count_item)
+            
+            row += 1
+        
+        # 设置表格样式
+        table.setStyleSheet("""
+            QTableWidget {
+                background-color: rgba(0, 20, 40, 0.8);
+                color: white;
+                gridline-color: rgba(80, 160, 220, 0.5);
+                border: 1px solid rgba(80, 160, 220, 0.5);
+                border-radius: 4px;
+            }
+            QHeaderView::section {
+                background-color: rgba(0, 60, 120, 0.9);
+                color: white;
+                padding: 4px;
+                border: 1px solid rgba(80, 160, 220, 0.5);
+            }
+            QTableWidget::item {
+                border-bottom: 1px solid rgba(80, 160, 220, 0.3);
+            }
+        """)
+        
+        # 调整表格大小
+        table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
+        table.verticalHeader().setVisible(False)
+        table.setFixedHeight(200)
+        
+        # 打印当前病虫害统计数据用于调试
+        print("病虫害统计数据:")
+        for pest_type, count in self.pest_type_stats.items():
+            print(f"  {pest_type}: {count}")
+        
+        return table
+        
+    @pyqtSlot(int)
+    def change_time_range(self, index):
+        """切换时间范围"""
+        # 获取当前选择的时间范围
+        time_range = self.time_range_combo.currentText()
+        
+        # 更新趋势图
+        if time_range == "最近24小时":
+            self.update_trend_chart(self.hour_data, "HH:mm", 8)
+        elif time_range == "最近7天":
+            self.update_trend_chart(self.day_data[:7], "MM-dd", 7)
+        elif time_range == "最近30天":
+            self.update_trend_chart(self.day_data, "MM-dd", 10)
+        elif time_range == "本月":
+            # 筛选本月数据
+            now = datetime.now()
+            month_data = [d for d in self.day_data if d['time'].month == now.month]
+            self.update_trend_chart(month_data, "MM-dd", 10)
+        elif time_range == "本年":
+            # 筛选本年数据 (这里使用全部数据模拟)
+            self.update_trend_chart(self.day_data, "MM-dd", 10)
+            
+    def update_trend_chart(self, data, date_format, tick_count):
+        """更新趋势图"""
+        if not data:
+            return
+            
+        # 获取当前标签页中的图表
+        chart = self.trend_chart_view.chart()
+        
+        # 清除所有系列
+        chart.removeAllSeries()
+        
+        # 创建新系列
+        fire_series = QLineSeries()
+        fire_series.setName("火灾告警")
+        fire_series.setColor(QColor(255, 100, 100))
+        
+        animal_series = QLineSeries()
+        animal_series.setName("动物告警")
+        animal_series.setColor(QColor(100, 255, 100))
+        
+        landslide_series = QLineSeries()
+        landslide_series.setName("滑坡告警")
+        landslide_series.setColor(QColor(100, 100, 255))
+        
+        pest_series = QLineSeries()
+        pest_series.setName("病虫害告警")
+        pest_series.setColor(QColor(180, 100, 200))
+        
+        # 添加数据点
+        for i, d in enumerate(data):
+            timestamp = d['time'].timestamp() * 1000  # 转换为毫秒
+            fire_series.append(timestamp, d['fire'])
+            animal_series.append(timestamp, d['animal'])
+            landslide_series.append(timestamp, d['landslide'])
+            pest_series.append(timestamp, d['pest'])
+        
+        # 添加系列到图表
+        chart.addSeries(fire_series)
+        chart.addSeries(animal_series)
+        chart.addSeries(landslide_series)
+        chart.addSeries(pest_series)
+        
+        # 更新图表标题
+        chart.setTitle(f"{self.time_range_combo.currentText()}告警趋势")
+        
+        # 创建/更新坐标轴
+        chart.createDefaultAxes()
+        
+        # 更新X轴
+        x_axis = chart.axes(Qt.Horizontal)[0]
+        if isinstance(x_axis, QDateTimeAxis):
+            x_axis.setFormat(date_format)
+            x_axis.setTickCount(tick_count)
+            x_axis.setRange(data[0]['time'], data[-1]['time'])
+        
+        # 更新Y轴
+        y_axis = chart.axes(Qt.Vertical)[0]
+        if isinstance(y_axis, QValueAxis):
+            # 寻找最大值
+            max_value = 0
+            for d in data:
+                for t in ['fire', 'animal', 'landslide', 'forest_degradation', 'pest']:
+                    max_value = max(max_value, d[t])
+                    
+            # 确保Y轴至少有一些高度,即使没有数据
+            if max_value < 1:
+                max_value = 1
+                
+            # 设置Y轴范围,上浮20%以便更好地查看
+            y_axis.setRange(0, max_value * 1.2)
+            
+            # 根据数据范围确定合适的刻度数量
+            if max_value <= 5:
+                y_axis.setTickCount(max_value + 1)  # 每个值一个刻度
+            else:
+                y_axis.setTickCount(6)  # 较大范围使用5-6个刻度
+        
+    @pyqtSlot()
+    def refresh_statistics(self):
+        """刷新统计数据"""
+        # 不重置数据,只刷新UI
+        # 更新当前标签页
+        current_tab = self.tab_widget.currentIndex()
+        if current_tab == 0:
+            # 更新概览标签页
+            overview_tab = self.create_overview_tab()
+            self.tab_widget.removeTab(0)
+            self.tab_widget.insertTab(0, overview_tab, "概览")
+        elif current_tab == 1:
+            # 更新趋势图标签页
+            self.change_time_range(self.time_range_combo.currentIndex())
+        elif current_tab == 2:
+            # 更新区域统计标签页
+            region_tab = self.create_region_tab()
+            self.tab_widget.removeTab(2)
+            self.tab_widget.insertTab(2, region_tab, "区域统计")
+            
+        # 切换回当前标签页
+        self.tab_widget.setCurrentIndex(current_tab)
+        
+        print("已刷新统计面板显示")
+    
+    @pyqtSlot()
+    def auto_update_statistics(self):
+        """自动更新统计数据(由定时器调用)"""
+        # 更新数据
+        self.update_statistics()
+        
+        # 动态更新UI
+        self.update_current_tab()
+        
+    def update_current_tab(self):
+        """更新所有标签页的UI,确保数据及时刷新"""
+        # 更新概览标签页上的告警统计
+        overview_tab = self.tab_widget.widget(0)
+        if overview_tab:
+            stats_group = overview_tab.findChild(QGroupBox, "stats_group")
+            if stats_group:
+                # 更新标签
+                for i, (key, color) in enumerate([
+                    ('fire', 'red'), 
+                    ('animal', 'green'), 
+                    ('landslide', 'blue'), 
+                    ('pest', 'purple')
+                ]):
+                    label_name = f"{key}_label"
+                    label = stats_group.findChild(QLabel, label_name)
+                    if label:
+                        label.setText(f"{key.capitalize()}告警: {self.alert_stats[key]}")
+            
+            # 更新饼图
+            chart_view = overview_tab.findChild(QChartView)
+            if chart_view:
+                # 创建新饼图
+                new_chart_view = self.create_pie_chart()
+                # 替换旧饼图
+                layout = overview_tab.layout()
+                layout.replaceWidget(chart_view, new_chart_view)
+                chart_view.deleteLater()
+                
+        # 更新趋势图
+        self.change_time_range(self.time_range_combo.currentIndex())
+        
+        # 保存当前区域图表视图的引用,便于后续更新
+        if not hasattr(self, 'region_chart_view'):
+            self.region_chart_view = self.create_region_chart_view()
+            
+        # 更新区域统计图
+        region_tab = self.tab_widget.widget(2)
+        if region_tab:
+            chart_view = region_tab.findChild(QChartView)
+            if chart_view:
+                # 创建新区域图
+                new_chart_view = self.create_region_chart_view()
+                # 替换旧区域图
+                layout = region_tab.layout()
+                layout.replaceWidget(chart_view, new_chart_view)
+                chart_view.deleteLater()
+                
+                # 更新病虫害类型表格
+                table = region_tab.findChild(QTableWidget)
+                if table:
+                    for row, (pest_type, count) in enumerate(self.pest_type_stats.items()):
+                        if row < table.rowCount() and count > 0:
+                            count_item = QTableWidgetItem(str(count))
+                            count_item.setTextAlignment(Qt.AlignCenter)
+                            table.setItem(row, 1, count_item)
+                
+        # 触发刷新
+        self.update()  # 强制刷新UI
+        print("已刷新所有统计图表")
+        
+    def update_statistics(self):
+        """更新统计数据(由外部定时器调用)"""
+        # 获取当前时间
+        now = datetime.now()
+        
+        # 对所有数据点进行处理,实现自然衰减的效果
+        for data in self.hour_data:
+            for key in ['fire', 'animal', 'landslide', 'forest_degradation', 'pest']:
+                # 如果是当前小时的数据,保持不变
+                if data['time'].hour == now.hour and data['time'].day == now.day:
+                    continue
+                
+                # 对于旧数据,每次更新减少20%,但不低于0
+                # 这样可以实现数值的自然衰减,使图表能体现趋势变化
+                if data[key] > 0:
+                    data[key] = max(0, data[key] - 0.2)
+        
+        # 对日数据也进行类似处理
+        for data in self.day_data:
+            # 如果不是当前日期的数据,进行衰减
+            if data['time'].day != now.day or data['time'].month != now.month:
+                for key in ['fire', 'animal', 'landslide', 'forest_degradation', 'pest']:
+                    if data[key] > 0:
+                        data[key] = max(0, data[key] - 0.1)  # 日数据衰减更慢一些
+        
+        # 更新UI显示
+        self.update_current_tab()
+
+    def create_pie_chart(self):
+        """创建新的饼图供外部使用"""
+        pie_chart = QChart()
+        pie_chart.setTitle("告警类型分布")
+        pie_chart.setAnimationOptions(QChart.SeriesAnimations)
+        pie_chart.setBackgroundBrush(QColor("#0a1a2a"))
+        pie_chart.setTitleBrush(QColor("white"))
+        pie_chart.setTitleFont(QFont("Microsoft YaHei", 10, QFont.Bold))
+        
+        # 创建饼图系列
+        series = QPieSeries()
+        
+        # 检查是否有数据
+        total_alerts = sum(self.alert_stats.values())
+        if total_alerts > 0:
+            series.append("火灾告警", self.alert_stats['fire'])
+            series.append("动物告警", self.alert_stats['animal'])
+            series.append("滑坡告警", self.alert_stats['landslide'])
+            series.append("病虫害告警", self.alert_stats['pest'])
+            
+            # 设置切片颜色
+            if len(series.slices()) >= 5:
+                series.slices()[0].setBrush(QColor(255, 100, 100))  # 红色
+                series.slices()[1].setBrush(QColor(100, 255, 100))  # 绿色
+                series.slices()[2].setBrush(QColor(100, 100, 255))  # 蓝色
+                series.slices()[3].setBrush(QColor(255, 200, 100))  # 橙色
+                series.slices()[4].setBrush(QColor(180, 100, 200))  # 紫色
+                
+                # 设置标签颜色
+                for slice in series.slices():
+                    slice.setLabelColor(QColor("white"))
+                    slice.setLabelFont(QFont("Microsoft YaHei", 9))
+                
+                # 突出显示第一个切片
+                series.slices()[0].setExploded(True)
+                series.slices()[0].setLabelVisible(True)
+        else:
+            # 如果没有数据,添加一个空的占位切片
+            placeholder = series.append("无告警数据", 1)
+            placeholder.setBrush(QColor(100, 100, 100))  # 灰色
+            placeholder.setLabelColor(QColor("white"))
+            placeholder.setLabelFont(QFont("Microsoft YaHei", 9))
+            placeholder.setLabelVisible(True)
+        
+        pie_chart.addSeries(series)
+        pie_chart.legend().setLabelColor(QColor("white"))
+        
+        # 创建图表视图
+        chart_view = QChartView(pie_chart)
+        chart_view.setRenderHint(QPainter.Antialiasing)
+        chart_view.setBackgroundBrush(QColor("#0a1a2a"))
+        
+        return chart_view
+
+    def create_trend_chart_view(self):
+        """创建新的趋势图视图供外部使用"""
+        trend_chart = QChart()
+        trend_chart.setTitle("24小时告警趋势")
+        trend_chart.setAnimationOptions(QChart.SeriesAnimations)
+        trend_chart.setBackgroundBrush(QColor("#0a1a2a"))
+        trend_chart.setTitleBrush(QColor("white"))
+        trend_chart.setTitleFont(QFont("Microsoft YaHei", 10, QFont.Bold))
+        
+        # 设置图表边距,增加底部和左侧空间显示坐标文字
+        trend_chart.setMargins(QMargins(10, 10, 10, 20))
+        
+        # 创建折线系列 - 火灾
+        fire_series = QLineSeries()
+        fire_series.setName("火灾告警")
+        fire_series.setColor(QColor(255, 100, 100))
+        
+        # 创建折线系列 - 动物
+        animal_series = QLineSeries()
+        animal_series.setName("动物告警")
+        animal_series.setColor(QColor(100, 255, 100))
+        
+        # 创建折线系列 - 滑坡
+        landslide_series = QLineSeries()
+        landslide_series.setName("滑坡告警")
+        landslide_series.setColor(QColor(100, 100, 255))
+        
+        # 创建折线系列 - 病虫害
+        pest_series = QLineSeries()
+        pest_series.setName("病虫害告警")
+        pest_series.setColor(QColor(180, 100, 200))
+        
+        # 添加数据点
+        for i, data in enumerate(self.hour_data):
+            timestamp = data['time'].timestamp() * 1000  # 转换为毫秒
+            fire_series.append(timestamp, data['fire'])
+            animal_series.append(timestamp, data['animal'])
+            landslide_series.append(timestamp, data['landslide'])
+            pest_series.append(timestamp, data['pest'])
+        
+        # 添加系列到图表
+        trend_chart.addSeries(fire_series)
+        trend_chart.addSeries(animal_series)
+        trend_chart.addSeries(landslide_series)
+        trend_chart.addSeries(pest_series)
+        
+        # 创建X轴(时间轴)
+        axis_x = QDateTimeAxis()
+        axis_x.setFormat("HH:mm")
+        axis_x.setTitleText("时间")
+        axis_x.setTickCount(8)  # 显示8个刻度
+        axis_x.setRange(
+            self.hour_data[0]['time'],
+            self.hour_data[-1]['time']
+        )
+        axis_x.setTitleBrush(QColor("white"))
+        axis_x.setLabelsColor(QColor("white"))
+        axis_x.setLabelsFont(QFont("Microsoft YaHei", 8))
+        axis_x.setTitleFont(QFont("Microsoft YaHei", 9, QFont.Bold))
+        
+        # 创建Y轴
+        axis_y = QValueAxis()
+        axis_y.setLabelFormat("%d")
+        axis_y.setTitleText("告警数量")
+        axis_y.setRange(0, 15)  # 增大Y轴范围,避免文字被裁剪
+        axis_y.setTickCount(6)  # 增加刻度数量以更好显示
+        axis_y.setTitleBrush(QColor("white"))
+        axis_y.setLabelsColor(QColor("white")) 
+        axis_y.setLabelsFont(QFont("Microsoft YaHei", 8))
+        axis_y.setTitleFont(QFont("Microsoft YaHei", 9, QFont.Bold))
+        
+        # 添加坐标轴到图表
+        trend_chart.addAxis(axis_x, Qt.AlignBottom)
+        trend_chart.addAxis(axis_y, Qt.AlignLeft)
+        
+        # 将所有系列依附到坐标轴
+        fire_series.attachAxis(axis_x)
+        fire_series.attachAxis(axis_y)
+        animal_series.attachAxis(axis_x)
+        animal_series.attachAxis(axis_y)
+        landslide_series.attachAxis(axis_x)
+        landslide_series.attachAxis(axis_y)
+        pest_series.attachAxis(axis_x)
+        pest_series.attachAxis(axis_y)
+        
+        # 设置图例位置和样式
+        trend_chart.legend().setVisible(True)
+        trend_chart.legend().setAlignment(Qt.AlignBottom)
+        trend_chart.legend().setLabelColor(QColor("white"))
+        trend_chart.legend().setMarkerShape(QLegend.MarkerShapeCircle)
+        trend_chart.legend().setFont(QFont("Microsoft YaHei", 8))
+        
+        # 创建图表视图
+        chart_view = QChartView(trend_chart)
+        chart_view.setRenderHint(QPainter.Antialiasing)
+        chart_view.setMinimumHeight(280)  # 增加最小高度
+        chart_view.setBackgroundBrush(QColor("#0a1a2a"))
+        
+        return chart_view
+
+    def create_region_chart_view(self):
+        """创建区域统计图视图"""
+        chart = QChart()
+        chart.setTitle("区域统计")
+        chart.setAnimationOptions(QChart.SeriesAnimations)
+        chart.setBackgroundBrush(QColor("#0a1a2a"))
+        chart.setTitleBrush(QColor("white"))
+        chart.setTitleFont(QFont("Microsoft YaHei", 10, QFont.Bold))
+        
+        # 设置图表边距,增加底部和左侧空间显示坐标文字
+        chart.setMargins(QMargins(10, 10, 10, 20))
+        
+        # 创建数据集
+        barset = QBarSet("告警数量")
+        barset.setColor(QColor(100, 200, 255))  # 设置柱状图颜色
+        
+        # 获取区域名称和数据
+        regions = list(self.region_stats.keys())
+        values = list(self.region_stats.values())
+        
+        # 如果区域统计为空,使用默认区域
+        if not regions:
+            regions = ['东北区', '西北区', '中部区', '东南区', '西南区']
+            values = [0, 0, 0, 0, 0]
+            print("警告: 区域统计数据为空,使用默认空值")
+        
+        # 打印当前区域统计数据用于调试
+        print("区域统计数据:")
+        for region, value in zip(regions, values):
+            print(f"  {region}: {value}")
+        
+        # 添加数据到集合
+        for value in values:
+            barset.append(value)
+        
+        # 创建条形系列
+        series = QBarSeries()
+        series.append(barset)
+        series.setLabelsVisible(True)
+        series.setLabelsPosition(QBarSeries.LabelsInsideEnd)  # 标签位置在柱内端
+        
+        # 添加系列到图表
+        chart.addSeries(series)
+        
+        # 创建X轴(区域)
+        axis_x = QBarCategoryAxis()
+        axis_x.append(regions)
+        axis_x.setTitleText("区域")
+        axis_x.setTitleBrush(QColor("white"))
+        axis_x.setLabelsColor(QColor("white"))
+        axis_x.setLabelsFont(QFont("Microsoft YaHei", 8))
+        axis_x.setTitleFont(QFont("Microsoft YaHei", 9, QFont.Bold))
+        
+        # 创建Y轴
+        axis_y = QValueAxis()
+        axis_y.setLabelFormat("%d")
+        axis_y.setTitleText("告警数量")
+        
+        # 设置Y轴范围,确保即使是小数值也能看到变化
+        max_value = max(values) if values and max(values) > 0 else 1
+        axis_y.setRange(0, max_value * 1.2 + 1)  # 最大值上浮20%,并确保至少有高度
+        
+        axis_y.setTickCount(6)
+        axis_y.setTitleBrush(QColor("white"))
+        axis_y.setLabelsColor(QColor("white"))
+        axis_y.setLabelsFont(QFont("Microsoft YaHei", 8))
+        axis_y.setTitleFont(QFont("Microsoft YaHei", 9, QFont.Bold))
+        
+        # 添加坐标轴到图表
+        chart.addAxis(axis_x, Qt.AlignBottom)
+        chart.addAxis(axis_y, Qt.AlignLeft)
+        
+        # 将系列附加到坐标轴
+        series.attachAxis(axis_x)
+        series.attachAxis(axis_y)
+        
+        # 图例设置
+        chart.legend().setVisible(False)  # 隐藏图例
+        
+        # 创建图表视图
+        chart_view = QChartView(chart)
+        chart_view.setRenderHint(QPainter.Antialiasing)
+        chart_view.setMinimumHeight(280)  # 增加最小高度
+        chart_view.setBackgroundBrush(QColor("#0a1a2a"))
+        
+        return chart_view
+
+    def handle_new_alert(self, alert_type, region):
+        """处理新的告警信息
+        
+        Args:
+            alert_type (str): 告警类型 ('fire', 'animal', 'landslide', 'forest_degradation', 'pest')
+            region (str): 告警区域
+        """
+        print(f"统计面板收到新告警: 类型={alert_type}, 区域={region}")
+        
+        # 更新告警统计总数
+        if alert_type in self.alert_stats:
+            self.alert_stats[alert_type] += 1
+            print(f"更新告警统计: {alert_type} = {self.alert_stats[alert_type]}")
+            
+        # 更新区域统计
+        mapped_region = self.region_mapping.get(region, '中部区')  # 如果没有映射则默认为中部区
+        if mapped_region not in self.region_stats:
+            self.region_stats[mapped_region] = 0
+        self.region_stats[mapped_region] += 1
+        print(f"更新区域统计: {mapped_region} = {self.region_stats[mapped_region]}")
+        
+        # 更新趋势数据
+        now = datetime.now()
+        
+        # 更新小时数据
+        for data in self.hour_data:
+            if data['time'].hour == now.hour and data['time'].day == now.day:
+                if alert_type in data:
+                    data[alert_type] += 1
+                    print(f"更新小时趋势: {alert_type} = {data[alert_type]}")
+                break
+        
+        # 更新日数据
+        for data in self.day_data:
+            if data['time'].day == now.day and data['time'].month == now.month:
+                if alert_type in data:
+                    data[alert_type] += 1
+                    print(f"更新日趋势: {alert_type} = {data[alert_type]}")
+                break
+        
+        # 更新月数据
+        for data in self.month_data:
+            if data['time'].day == now.day and data['time'].month == now.month:
+                if alert_type in data:
+                    data[alert_type] += 1
+                    print(f"更新月趋势: {alert_type} = {data[alert_type]}")
+                break
+        
+        # 更新UI显示
+        self.update_current_tab()
+        
+        # 强制刷新UI
+        from PyQt5.QtWidgets import QApplication
+        QApplication.processEvents()
+
+    def handle_alert_processed(self, alert_type, region):
+        """处理告警已被处理的信号
+        
+        Args:
+            alert_type (str): 告警类型 ('fire', 'animal', 'landslide', 'forest_degradation', 'pest')
+            region (str): 告警区域
+        """
+        print(f"统计面板收到告警处理通知: 类型={alert_type}, 区域={region}")
+        
+        # 减少告警统计总数,但确保不会小于0
+        if alert_type in self.alert_stats and self.alert_stats[alert_type] > 0:
+            self.alert_stats[alert_type] -= 1
+            print(f"减少告警统计: {alert_type} = {self.alert_stats[alert_type]}")
+            
+        # 区域名称映射(将告警区域名称映射到区域统计中的键)
+        region_mapping = {
+            '北部山区': '东北区',
+            '南部林区': '东南区',
+            '东部山脊': '东北区',
+            '西部谷地': '西北区',
+            '西部林区': '西北区',
+            '中央林场': '中部区'
+        }
+        
+        # 减少区域统计总数
+        mapped_region = region_mapping.get(region, region)
+        
+        if mapped_region in self.region_stats and self.region_stats[mapped_region] > 0:
+            self.region_stats[mapped_region] -= 1
+            print(f"减少区域统计: {mapped_region} = {self.region_stats[mapped_region]}")
+            
+        # 如果是病虫害类型,减少一种随机病虫害
+        if alert_type == 'pest':
+            import random
+            pest_types = [k for k, v in self.pest_type_stats.items() if v > 0]
+            if pest_types:
+                selected_pest = random.choice(pest_types)
+                self.pest_type_stats[selected_pest] -= 1
+                print(f"减少病虫害类型: {selected_pest} = {self.pest_type_stats[selected_pest]}")
+            
+        # 更新当前小时的趋势数据 - 减少告警量
+        now = datetime.now()
+        for data in self.hour_data:
+            if data['time'].hour == now.hour and data['time'].day == now.day:
+                if alert_type in data and data[alert_type] > 0:
+                    data[alert_type] -= 1
+                    print(f"减少小时趋势: {alert_type} = {data[alert_type]}")
+                break
+                
+        # 更新当前日期的趋势数据 - 减少告警量
+        for data in self.day_data:
+            if data['time'].day == now.day and data['time'].month == now.month:
+                if alert_type in data and data[alert_type] > 0:
+                    data[alert_type] -= 1
+                    print(f"减少日趋势: {alert_type} = {data[alert_type]}")
+                break
+                
+        # 基于当前选择的时间范围更新趋势图
+        current_index = self.time_range_combo.currentIndex() if hasattr(self, 'time_range_combo') else 0
+        self.change_time_range(current_index)
+            
+        # 更新UI - 确保立即刷新所有标签页
+        self.update_current_tab()
+        
+        # 强制刷新UI
+        from PyQt5.QtWidgets import QApplication
+        QApplication.processEvents() 

BIN
ui/pages/__pycache__/main_window.cpython-38.pyc


BIN
ui/pages/__pycache__/main_window.cpython-39.pyc


+ 750 - 0
ui/pages/main_window.py

@@ -0,0 +1,750 @@
+import os
+import sys
+from datetime import datetime
+from PyQt5.QtWidgets import (QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, 
+                            QTabWidget, QLabel, QPushButton, QComboBox, 
+                            QStatusBar, QSplitter, QProgressBar, QAction, 
+                            QFileDialog, QMessageBox, QMenu, QGroupBox)
+from PyQt5.QtCore import Qt, QTimer, pyqtSlot, QUrl, QSize
+from PyQt5.QtGui import QIcon, QPixmap, QFont, QColor, QPainter
+from PyQt5.QtWebEngineWidgets import QWebEngineView
+import psutil
+from PyQt5.QtWidgets import QApplication
+import time
+
+from ui.components.map_view import MapView
+from ui.components.camera_view import CameraView
+from ui.components.alert_panel import AlertPanel
+from ui.components.control_panel import ControlPanel
+from ui.components.statistics_panel import StatisticsPanel
+from ui.components.drone_manager import DroneManager
+from ui.components.grid_camera_view import GridCameraView
+from PyQt5.QtChart import QChartView, QChart, QPieSeries
+
+class MainWindow(QMainWindow):
+    """
+    森林多模态灾害监测系统主窗口
+    """
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.last_fire_update_time = 0  # 添加上次火灾更新时间记录
+        self.fire_update_interval = 5  # 设置最小更新间隔为5秒
+        self.last_animal_update = 0  # 添加动物检测更新时间记录
+        self.min_update_interval = 5  # 设置最小更新间隔为5秒
+        self.init_ui()
+        
+        # 连接摄像头视图的火灾检测信号
+        self.camera_view.fire_detected.connect(self.on_fire_detected)
+        
+    def init_ui(self):
+        """初始化UI界面"""
+        # 设置窗口属性
+        self.setWindowTitle("森林多模态灾害监测系统")
+        self.setGeometry(100, 100, 1280, 800)
+        
+        # 设置图标
+        icon_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'icon.png')
+        if os.path.exists(icon_path):
+            self.setWindowIcon(QIcon(icon_path))
+        
+        # 创建主布局
+        main_widget = QWidget()
+        self.setCentralWidget(main_widget)
+        main_layout = QVBoxLayout(main_widget)
+        main_layout.setContentsMargins(5, 5, 5, 5)
+        
+        # 创建顶部工具栏
+        self.create_toolbar()
+        
+        # 创建三栏主分割器
+        main_splitter = QSplitter(Qt.Horizontal)
+        main_layout.addWidget(main_splitter)
+        
+        # === 左侧栏:控制面板和地图 ===
+        left_panel = QWidget()
+        left_layout = QVBoxLayout(left_panel)
+        left_layout.setContentsMargins(0, 0, 0, 0)
+        
+        # 控制面板
+        self.control_panel = ControlPanel(self.config)
+        left_layout.addWidget(self.control_panel)
+        
+        # 地图视图
+        self.map_view = MapView(self.config)
+        left_layout.addWidget(self.map_view)
+        
+        # === 中间栏:摄像头视图和告警面板 ===
+        middle_panel = QWidget()
+        middle_layout = QVBoxLayout(middle_panel)
+        middle_layout.setContentsMargins(0, 0, 0, 0)
+        
+        # 创建标签页容器
+        camera_tabs = QTabWidget()
+        camera_tabs.setTabPosition(QTabWidget.South)
+        camera_tabs.setStyleSheet("QTabBar::tab { background-color: #102040; color: white; padding: 6px 12px; margin-right: 2px; border-top-left-radius: 4px; border-top-right-radius: 4px; } QTabBar::tab:selected { background-color: #1a3a5a; }")
+        
+        # 添加九宫格视图标签页
+        self.grid_camera_view = GridCameraView()
+        camera_tabs.addTab(self.grid_camera_view, "多路监控")
+        
+        # 连接九宫格视图的灾害检测信号
+        self.grid_camera_view.fire_detected.connect(self.on_fire_detected)
+        self.grid_camera_view.animal_detected.connect(self.on_animal_detected)
+        
+        # 添加单路摄像头视图标签页
+        self.camera_view = CameraView(self.config)
+        camera_tabs.addTab(self.camera_view, "单路监控")
+        
+        # 添加无人机管理标签页
+        self.drone_manager = DroneManager(self.config)
+        camera_tabs.addTab(self.drone_manager, "无人机集群")
+        
+        # 将标签页容器添加到中间布局
+        middle_layout.addWidget(camera_tabs, 7)  # 占70%高度
+        
+        # 告警面板
+        self.alert_panel = AlertPanel(self.config)
+        middle_layout.addWidget(self.alert_panel, 3)  # 占30%高度
+        
+        # 连接告警面板的信号
+        self.alert_panel.alert_added.connect(self.on_alert_added)
+        # 连接告警处理信号
+        self.alert_panel.alert_processed.connect(self.on_alert_processed)
+        
+        # === 右侧栏:统计信息面板 ===
+        right_panel = QWidget()
+        right_layout = QVBoxLayout(right_panel)
+        right_layout.setContentsMargins(0, 0, 0, 0)
+        right_layout.setSpacing(5)
+        right_panel.setStyleSheet("background-color: #0a1a2a; "
+                               "QGroupBox { background-color: #102040; border: 1px solid #1e3a5a; border-radius: 5px; margin-top: 8px; } "
+                               "QGroupBox::title { subcontrol-origin: margin; left: 10px; padding: 0 3px; color: white; font-weight: bold; }"
+                               "QLabel { color: white; }"
+                               "QComboBox { color: white; background-color: #1a3a5a; border: 1px solid #2a4a6a; }"
+                               "QPushButton { color: white; background-color: #1a3a5a; border: 1px solid #2a4a6a; }")
+        
+        # 标题标签
+        stats_title = QLabel("统计信息")
+        stats_title.setFont(QFont("Microsoft YaHei", 14, QFont.Bold))
+        stats_title.setStyleSheet("color: white; margin: 5px;")
+        stats_title.setAlignment(Qt.AlignCenter)
+        right_layout.addWidget(stats_title)
+        
+        # 创建统计面板
+        self.statistics_panel = StatisticsPanel(self.config)
+        
+        # ==== 概览部分 ====
+        overview_group = QGroupBox("告警概览")
+        overview_group.setObjectName("overview_group")
+        overview_layout = QVBoxLayout(overview_group)
+        overview_layout.setContentsMargins(5, 10, 5, 10)  # 设置合适的边距
+        
+        # 告警统计数字
+        stats_layout = QHBoxLayout()
+        fire_label = QLabel(f"火灾告警: {self.statistics_panel.alert_stats['fire']}")
+        fire_label.setObjectName("fire_overview")
+        fire_label.setStyleSheet("color: #ff6666; font-weight: bold;")
+        
+        animal_label = QLabel(f"动物告警: {self.statistics_panel.alert_stats['animal']}")
+        animal_label.setObjectName("animal_overview")
+        animal_label.setStyleSheet("color: #66ff66; font-weight: bold;")
+        
+        landslide_label = QLabel(f"滑坡告警: {self.statistics_panel.alert_stats['landslide']}")
+        landslide_label.setObjectName("landslide_overview")
+        landslide_label.setStyleSheet("color: #6666ff; font-weight: bold;")
+        
+        pest_label = QLabel(f"病虫害: {self.statistics_panel.alert_stats['pest']}")
+        pest_label.setObjectName("pest_overview")
+        pest_label.setStyleSheet("color: #cc99ff; font-weight: bold;")
+        
+        stats_layout.addWidget(fire_label)
+        stats_layout.addWidget(animal_label)
+        stats_layout.addWidget(landslide_label)
+        stats_layout.addWidget(pest_label)
+        overview_layout.addLayout(stats_layout)
+        
+        # 饼图视图
+        pie_chart = self.statistics_panel.create_pie_chart()
+        pie_chart.setObjectName("overview_pie_chart")
+        pie_chart.setMinimumHeight(250)  # 设置最小高度
+        pie_chart.setFixedHeight(250)    # 设置固定高度,防止尺寸变化
+        overview_layout.addWidget(pie_chart)
+        
+        right_layout.addWidget(overview_group)
+        
+        # ==== 趋势图部分 ====
+        trend_group = QGroupBox("告警趋势")
+        trend_group.setObjectName("trend_group")
+        trend_layout = QVBoxLayout(trend_group)
+        trend_layout.setContentsMargins(5, 10, 5, 10)  # 设置合适的边距
+        
+        # 时间范围选择器
+        time_range_layout = QHBoxLayout()
+        time_range_layout.addWidget(QLabel("时间范围:"))
+        time_range_combo = QComboBox()
+        time_range_combo.setObjectName("time_range_combo")
+        time_range_combo.addItems(["最近24小时", "最近7天", "最近30天", "本月", "本年"])
+        time_range_combo.currentIndexChanged.connect(self.statistics_panel.change_time_range)
+        time_range_layout.addWidget(time_range_combo)
+        time_range_layout.addStretch(1)
+        trend_layout.addLayout(time_range_layout)
+        
+        # 趋势图视图
+        trend_chart = self.statistics_panel.create_trend_chart_view()
+        trend_chart.setObjectName("trend_chart_view")
+        trend_chart.setMinimumHeight(280)  # 设置最小高度
+        trend_chart.setFixedHeight(280)    # 设置固定高度,防止尺寸变化
+        trend_layout.addWidget(trend_chart)
+        
+        right_layout.addWidget(trend_group)
+        
+        # ==== 区域统计部分 ====
+        region_group = QGroupBox("区域统计")
+        region_group.setObjectName("region_group")
+        region_layout = QVBoxLayout(region_group)
+        region_layout.setContentsMargins(5, 10, 5, 10)  # 设置合适的边距
+        
+        # 区域统计图表
+        region_chart = self.statistics_panel.create_region_chart_view()
+        region_chart.setObjectName("region_chart_view")
+        region_chart.setMinimumHeight(280)  # 设置最小高度
+        region_chart.setFixedHeight(280)    # 设置固定高度,防止尺寸变化
+        region_layout.addWidget(region_chart)
+        
+        right_layout.addWidget(region_group)
+        
+        # 刷新按钮
+        refresh_btn = QPushButton("刷新统计信息")
+        refresh_btn.clicked.connect(self.statistics_panel.refresh_statistics)
+        right_layout.addWidget(refresh_btn)
+        
+        # 将三个面板添加到主分割器
+        main_splitter.addWidget(left_panel)
+        main_splitter.addWidget(middle_panel)
+        main_splitter.addWidget(right_panel)
+        
+        # 设置三栏分割比例
+        main_splitter.setSizes([int(self.width() * 0.2), int(self.width() * 0.55), int(self.width() * 0.25)])
+        
+        # 创建状态栏
+        self.create_statusbar()
+        
+        # 创建定时器,定期更新UI
+        self.update_timer = QTimer(self)
+        self.update_timer.timeout.connect(self.update_ui)
+        self.update_timer.start(5000)  # 每5秒更新一次
+        
+        # 应用暗色/亮色主题
+        self.apply_theme()
+        
+    def create_toolbar(self):
+        """创建工具栏"""
+        self.toolbar = self.addToolBar("主工具栏")
+        self.toolbar.setMovable(False)
+        
+        # 添加启动监控按钮
+        start_action = QAction(QIcon(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'start.png')), "启动监控", self)
+        start_action.triggered.connect(self.start_monitoring)
+        self.toolbar.addAction(start_action)
+        
+        # 添加停止监控按钮
+        stop_action = QAction(QIcon(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'stop.png')), "停止监控", self)
+        stop_action.triggered.connect(self.stop_monitoring)
+        self.toolbar.addAction(stop_action)
+        
+        self.toolbar.addSeparator()
+        
+        # 添加无人机控制菜单
+        drone_menu = QMenu("无人机", self)
+        
+        # 添加起飞所有无人机动作
+        takeoff_all_action = QAction(QIcon(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'takeoff.png')), "起飞所有无人机", self)
+        takeoff_all_action.triggered.connect(self.takeoff_all_drones)
+        drone_menu.addAction(takeoff_all_action)
+        
+        # 添加返航所有无人机动作
+        return_all_action = QAction(QIcon(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'return.png')), "返航所有无人机", self)
+        return_all_action.triggered.connect(self.return_all_drones)
+        drone_menu.addAction(return_all_action)
+        
+        # 添加紧急停止所有无人机动作
+        emergency_action = QAction(QIcon(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'emergency.png')), "紧急停止所有无人机", self)
+        emergency_action.triggered.connect(self.emergency_stop_all_drones)
+        drone_menu.addAction(emergency_action)
+        
+        # 添加无人机菜单按钮
+        drone_action = QAction(QIcon(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'drone.png')), "无人机控制", self)
+        drone_action.setMenu(drone_menu)
+        self.toolbar.addAction(drone_action)
+        
+        self.toolbar.addSeparator()
+        
+        # 添加导入数据按钮
+        import_action = QAction(QIcon(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'import.png')), "导入数据", self)
+        import_action.triggered.connect(self.import_data)
+        self.toolbar.addAction(import_action)
+        
+        # 添加导出报告按钮
+        export_action = QAction(QIcon(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'export.png')), "导出报告", self)
+        export_action.triggered.connect(self.export_report)
+        self.toolbar.addAction(export_action)
+        
+        self.toolbar.addSeparator()
+        
+        # 添加设置按钮
+        settings_action = QAction(QIcon(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'settings.png')), "设置", self)
+        settings_action.triggered.connect(self.show_settings)
+        self.toolbar.addAction(settings_action)
+        
+        # 添加帮助按钮
+        help_action = QAction(QIcon(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'help.png')), "帮助", self)
+        help_action.triggered.connect(self.show_help)
+        self.toolbar.addAction(help_action)
+        
+    def create_statusbar(self):
+        """创建状态栏"""
+        self.statusbar = QStatusBar()
+        self.setStatusBar(self.statusbar)
+        
+        # 添加系统状态标签
+        self.status_label = QLabel("系统状态: 待机")
+        self.statusbar.addWidget(self.status_label)
+        
+        # 添加CPU使用率进度条
+        self.cpu_label = QLabel("CPU: ")
+        self.statusbar.addPermanentWidget(self.cpu_label)
+        
+        self.cpu_progress = QProgressBar()
+        self.cpu_progress.setMaximumWidth(100)
+        self.cpu_progress.setMaximumHeight(16)
+        self.cpu_progress.setRange(0, 100)
+        self.cpu_progress.setValue(0)
+        self.statusbar.addPermanentWidget(self.cpu_progress)
+        
+        # 添加内存使用率进度条
+        self.memory_label = QLabel("内存: ")
+        self.statusbar.addPermanentWidget(self.memory_label)
+        
+        self.memory_progress = QProgressBar()
+        self.memory_progress.setMaximumWidth(100)
+        self.memory_progress.setMaximumHeight(16)
+        self.memory_progress.setRange(0, 100)
+        self.memory_progress.setValue(0)
+        self.statusbar.addPermanentWidget(self.memory_progress)
+        
+        # 添加时间标签
+        self.time_label = QLabel()
+        self.update_time()  # 初始化时间
+        self.statusbar.addPermanentWidget(self.time_label)
+        
+        # 创建时间更新定时器
+        self.time_timer = QTimer(self)
+        self.time_timer.timeout.connect(self.update_time)
+        self.time_timer.start(1000)  # 每秒更新一次
+        
+    def update_time(self):
+        """更新状态栏时间"""
+        current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+        self.time_label.setText(current_time)
+        
+    def apply_theme(self):
+        """应用主题"""
+        if self.config.get('dark_mode', True):
+            # 应用暗色主题
+            self.setStyleSheet("""
+                QMainWindow, QWidget {
+                    background-color: #2D2D30;
+                    color: #FFFFFF;
+                }
+                QTabWidget::pane {
+                    border: 1px solid #3F3F46;
+                    background-color: #252526;
+                }
+                QTabBar::tab {
+                    background-color: #2D2D30;
+                    color: #FFFFFF;
+                    padding: 5px 15px;
+                    border: 1px solid #3F3F46;
+                }
+                QTabBar::tab:selected {
+                    background-color: #007ACC;
+                }
+                QPushButton {
+                    background-color: #0E639C;
+                    color: white;
+                    border: none;
+                    padding: 5px 15px;
+                    border-radius: 2px;
+                }
+                QPushButton:hover {
+                    background-color: #1177BB;
+                }
+                QPushButton:pressed {
+                    background-color: #13669C;
+                }
+            """)
+        else:
+            # 应用亮色主题
+            self.setStyleSheet("")
+        
+    def update_ui(self):
+        """定期更新UI界面"""
+        # 更新系统状态
+        self.status_label.setText("系统状态: 正常运行中")
+        
+        # 获取真实CPU和内存使用率
+        cpu_usage = psutil.cpu_percent()
+        memory_usage = psutil.virtual_memory().percent
+        
+        self.cpu_progress.setValue(cpu_usage)
+        self.memory_progress.setValue(memory_usage)
+        
+        # 更新子组件
+        self.camera_view.update_view()
+        self.map_view.update_view()
+        self.alert_panel.update_alerts()
+        self.statistics_panel.update_statistics()
+        
+        # 添加无人机组件更新
+        if hasattr(self, 'drone_manager'):
+            self.drone_manager.update_drone_display()
+        
+        # 更新右侧概览面板的统计数据
+        self.update_overview_stats()
+        
+    @pyqtSlot()
+    def start_monitoring(self):
+        """启动监控"""
+        self.status_label.setText("系统状态: 监控中")
+        self.camera_view.start_monitoring()
+        # 更新其他组件状态
+        
+    @pyqtSlot()
+    def stop_monitoring(self):
+        """停止监控"""
+        self.status_label.setText("系统状态: 已停止")
+        self.camera_view.stop_monitoring()
+        # 更新其他组件状态
+        
+    @pyqtSlot()
+    def import_data(self):
+        """导入数据"""
+        file_dialog = QFileDialog()
+        file_dialog.setFileMode(QFileDialog.ExistingFile)
+        file_dialog.setNameFilter("数据文件 (*.csv *.json *.zip)")
+        
+        if file_dialog.exec_():
+            file_paths = file_dialog.selectedFiles()
+            if file_paths:
+                # 处理导入文件
+                QMessageBox.information(self, "导入数据", f"已选择导入文件: {file_paths[0]}")
+                
+    @pyqtSlot()
+    def export_report(self):
+        """导出报告"""
+        file_dialog = QFileDialog()
+        file_dialog.setAcceptMode(QFileDialog.AcceptSave)
+        file_dialog.setNameFilter("报告文件 (*.pdf *.docx *.html)")
+        
+        if file_dialog.exec_():
+            file_paths = file_dialog.selectedFiles()
+            if file_paths:
+                # 处理导出报告
+                QMessageBox.information(self, "导出报告", f"报告将保存到: {file_paths[0]}")
+                
+    @pyqtSlot()
+    def show_settings(self):
+        """显示设置对话框"""
+        self.tab_widget.setCurrentIndex(3)  # 切换到设置页面
+        
+    @pyqtSlot()
+    def show_help(self):
+        """显示帮助信息"""
+        QMessageBox.information(self, "帮助", "森林多模态灾害监测系统\n版本: 1.0.0\n\n基于YOLOv5的智能监测系统,用于森林火灾、滑坡、动物盗猎等多灾害监测。")
+        
+    @pyqtSlot()
+    def takeoff_all_drones(self):
+        """起飞所有无人机"""
+        if hasattr(self, 'drone_manager'):
+            # 切换到无人机管理标签页
+            for i in range(4):  # 假设标签页在索引1-3
+                if "无人机" in self.findChildren(QTabWidget)[i].tabText(1):
+                    self.findChildren(QTabWidget)[i].setCurrentIndex(1)
+                    break
+            
+            # 修改所有无人机状态为已起飞
+            for drone_id, drone in self.drone_manager.drones.items():
+                drone.status = "已起飞"
+            
+            # 更新状态表格
+            self.drone_manager.update_status_table()
+            
+            QMessageBox.information(self, "无人机控制", "已发送起飞命令至所有无人机")
+    
+    @pyqtSlot()
+    def return_all_drones(self):
+        """返航所有无人机"""
+        if hasattr(self, 'drone_manager'):
+            # 切换到无人机管理标签页
+            for i in range(4):  # 假设标签页在索引1-3
+                if "无人机" in self.findChildren(QTabWidget)[i].tabText(1):
+                    self.findChildren(QTabWidget)[i].setCurrentIndex(1)
+                    break
+            
+            # 修改所有无人机状态为返航中
+            for drone_id, drone in self.drone_manager.drones.items():
+                drone.status = "返航中"
+            
+            # 更新状态表格
+            self.drone_manager.update_status_table()
+            
+            QMessageBox.information(self, "无人机控制", "已发送返航命令至所有无人机")
+    
+    @pyqtSlot()
+    def emergency_stop_all_drones(self):
+        """紧急停止所有无人机"""
+        if hasattr(self, 'drone_manager'):
+            reply = QMessageBox.warning(self, "紧急停止确认", 
+                                       "确定要紧急停止所有无人机吗?这可能导致无人机坠落!", 
+                                       QMessageBox.Yes | QMessageBox.No,
+                                       QMessageBox.No)
+            
+            if reply == QMessageBox.Yes:
+                # 切换到无人机管理标签页
+                for i in range(4):  # 假设标签页在索引1-3
+                    if "无人机" in self.findChildren(QTabWidget)[i].tabText(1):
+                        self.findChildren(QTabWidget)[i].setCurrentIndex(1)
+                        break
+                
+                # 修改所有无人机状态为紧急停止
+                for drone_id, drone in self.drone_manager.drones.items():
+                    drone.status = "紧急停止"
+                
+                # 更新状态表格
+                self.drone_manager.update_status_table()
+                
+                QMessageBox.critical(self, "无人机控制", "已发送紧急停止命令至所有无人机")
+    
+    def closeEvent(self, event):
+        """窗口关闭事件"""
+        reply = QMessageBox.question(self, '退出确认', 
+                                    "确定要退出系统吗?", 
+                                    QMessageBox.Yes | QMessageBox.No,
+                                    QMessageBox.No)
+        
+        if reply == QMessageBox.Yes:
+            # 停止所有正在运行的线程和定时器
+            self.update_timer.stop()
+            self.time_timer.stop()
+            self.camera_view.stop_monitoring()
+            
+            # 停止所有无人机
+            if hasattr(self, 'drone_manager'):
+                for drone in self.drone_manager.drones.values():
+                    drone.stop()
+            
+            event.accept()
+        else:
+            event.ignore()
+
+    def update_overview_stats(self):
+        """更新右侧统计面板中的数据显示"""
+        # 更新告警概览组的统计数字
+        fire_label = self.findChild(QLabel, "fire_overview")
+        if fire_label:
+            fire_label.setText(f"火灾告警: {self.statistics_panel.alert_stats['fire']}")
+            
+        animal_label = self.findChild(QLabel, "animal_overview")
+        if animal_label:
+            animal_label.setText(f"动物告警: {self.statistics_panel.alert_stats['animal']}")
+            
+        landslide_label = self.findChild(QLabel, "landslide_overview")
+        if landslide_label:
+            landslide_label.setText(f"滑坡告警: {self.statistics_panel.alert_stats['landslide']}")
+            
+        pest_label = self.findChild(QLabel, "pest_overview")
+        if pest_label:
+            pest_label.setText(f"病虫害: {self.statistics_panel.alert_stats['pest']}")
+            
+        # 更新各个图表
+        self.update_overview_chart()
+        self.update_trend_chart()
+        self.update_region_chart()
+
+    def update_overview_chart(self):
+        """更新右侧统计面板中的饼图"""
+        # 获取当前概览组中的饼图
+        overview_group = self.findChild(QGroupBox, "overview_group")
+        if overview_group:
+            # 获取旧的饼图视图
+            old_chart_view = None
+            for i in range(overview_group.layout().count()):
+                item = overview_group.layout().itemAt(i)
+                if item and item.widget() and isinstance(item.widget(), QChartView):
+                    old_chart_view = item.widget()
+                    break
+            
+            if old_chart_view:
+                # 创建新的饼图
+                new_chart_view = self.statistics_panel.create_pie_chart()
+                new_chart_view.setObjectName("overview_pie_chart")
+                new_chart_view.setMinimumHeight(250)  # 设置最小高度
+                new_chart_view.setFixedHeight(250)    # 设置固定高度,防止尺寸变化
+                
+                # 替换旧的饼图
+                layout = overview_group.layout()
+                layout.replaceWidget(old_chart_view, new_chart_view)
+                old_chart_view.deleteLater() 
+
+    def update_trend_chart(self):
+        """更新趋势图"""
+        trend_group = self.findChild(QGroupBox, "trend_group")
+        if trend_group:
+            # 获取旧的趋势图
+            old_chart_view = trend_group.findChild(QChartView, "trend_chart_view")
+            if old_chart_view:
+                # 创建新的趋势图
+                new_chart_view = self.statistics_panel.create_trend_chart_view()
+                new_chart_view.setObjectName("trend_chart_view")
+                
+                # 保持尺寸一致
+                new_chart_view.setMinimumHeight(280)
+                new_chart_view.setFixedHeight(280)
+                
+                # 替换旧的趋势图
+                layout = trend_group.layout()
+                for i in range(layout.count()):
+                    item = layout.itemAt(i)
+                    if item.widget() == old_chart_view:
+                        layout.replaceWidget(old_chart_view, new_chart_view)
+                        old_chart_view.deleteLater()
+                        break
+    
+    def update_region_chart(self):
+        """更新区域统计图"""
+        region_group = self.findChild(QGroupBox, "region_group")
+        if region_group:
+            # 获取旧的区域图
+            old_chart_view = region_group.findChild(QChartView, "region_chart_view")
+            if old_chart_view:
+                # 创建新的区域图
+                new_chart_view = self.statistics_panel.create_region_chart_view()
+                new_chart_view.setObjectName("region_chart_view")
+                
+                # 保持尺寸一致
+                new_chart_view.setMinimumHeight(280)
+                new_chart_view.setFixedHeight(280)
+                
+                # 替换旧的区域图
+                layout = region_group.layout()
+                for i in range(layout.count()):
+                    item = layout.itemAt(i)
+                    if item.widget() == old_chart_view:
+                        layout.replaceWidget(old_chart_view, new_chart_view)
+                        old_chart_view.deleteLater()
+                        print("已更新区域统计图")
+                        break
+            else:
+                print("警告: 未找到区域图表视图")
+        else:
+            print("警告: 未找到区域统计组")
+
+    def on_alert_added(self, alert_type, region):
+        """处理新告警信号,更新统计面板"""
+        print(f"主窗口收到新告警: 类型={alert_type}, 区域={region}")
+        
+        # 调用统计面板的方法处理新告警
+        self.statistics_panel.handle_new_alert(alert_type, region)
+        
+        # 立即更新右侧统计信息显示
+        self.update_overview_stats()
+        
+        # 强制更新区域图表
+        self.update_region_chart()
+        
+        # 重置随机更新定时器,避免冲突
+        self.statistics_panel.update_timer.stop()
+        self.statistics_panel.update_timer.start(10000)  # 增加到10秒,减少干扰
+        
+        # 强制立即刷新所有图表
+        QApplication.processEvents()  # 确保UI更新立即可见 
+
+    def on_alert_processed(self, alert_type, region):
+        """处理告警处理信号,更新统计面板"""
+        print(f"主窗口收到告警处理: 类型={alert_type}, 区域={region}")
+        
+        # 调用统计面板的方法处理告警处理
+        self.statistics_panel.handle_alert_processed(alert_type, region)
+        
+        # 立即更新右侧统计信息显示
+        self.update_overview_stats()
+        
+        # 强制更新区域图表
+        self.update_region_chart()
+        
+        # 重置随机更新定时器,避免冲突
+        self.statistics_panel.update_timer.stop()
+        self.statistics_panel.update_timer.start(10000)  # 增加到10秒,减少干扰
+        
+        # 强制立即刷新所有图表
+        QApplication.processEvents()  # 确保UI更新立即可见 
+
+    def on_fire_detected(self, region):
+        """处理火灾检测信号"""
+        current_time = datetime.now().timestamp()
+        # 检查是否达到最小更新间隔
+        if current_time - self.last_fire_update_time < self.fire_update_interval:
+            return  # 如果间隔太短,直接返回不处理
+            
+        print(f"收到火灾检测信号,区域:{region}")
+        # 更新上次处理时间
+        self.last_fire_update_time = current_time
+        
+        # 创建并添加告警
+        alert = {
+            'type': 'fire',
+            'location': region,
+            'time': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
+            'detail': '检测到火灾隐患',
+            'level': 'high',
+            'status': '未处理'
+        }
+        self.alert_panel.add_alert(alert)  # 这会触发alert_added信号,统计更新将在on_alert_added中处理 
+
+    def on_animal_detected(self, image, species, confidence):
+        """处理动物检测信号"""
+        current_time = time.time()
+        if current_time - self.last_animal_update < self.min_update_interval:
+            return
+            
+        self.last_animal_update = current_time
+        
+        # 创建动物检测告警
+        alert = {
+            'type': 'animal',
+            'location': '未知位置',  # 可以根据摄像头位置更新
+            'time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
+            'status': 'unprocessed',
+            'details': f'检测到{species},置信度: {confidence:.1f}%',
+            'image': image  # 保存检测到动物的图像
+        }
+        
+        # 添加到告警面板
+        self.alert_panel.add_alert(alert)
+        
+        # 调用统计面板的方法处理新告警
+        self.statistics_panel.handle_new_alert('animal', '未知位置')
+        
+        # 立即更新右侧统计信息显示
+        self.update_overview_stats()
+        
+        # 强制更新区域图表
+        self.update_region_chart()
+        
+        # 重置随机更新定时器,避免冲突
+        self.statistics_panel.update_timer.stop()
+        self.statistics_panel.update_timer.start(10000)  # 增加到10秒,减少干扰
+        
+        # 强制立即刷新所有图表
+        QApplication.processEvents()  # 确保UI更新立即可见 

+ 559 - 0
ui/splash_screen.py

@@ -0,0 +1,559 @@
+import sys
+import os
+from PyQt5.QtWidgets import (QApplication, QSplashScreen, QProgressBar, 
+                             QLabel, QVBoxLayout, QWidget, QFrame, QDesktopWidget)
+from PyQt5.QtCore import Qt, QTimer, QSize, QPropertyAnimation, QEasingCurve, QPoint
+from PyQt5.QtGui import QPixmap, QFont, QColor, QPainter, QBrush, QPen, QRadialGradient, QLinearGradient, QMovie
+import math  # 添加数学库
+
+class ParticleEffect(QWidget):
+    """粒子效果动画组件"""
+    def __init__(self, parent=None):
+        super().__init__(parent)
+        self.setAttribute(Qt.WA_TranslucentBackground)
+        self.setAttribute(Qt.WA_TransparentForMouseEvents)
+        self.particles = []
+        self.timer = QTimer(self)
+        self.timer.timeout.connect(self.update_particles)
+        self.timer.start(50)  # 每50毫秒更新一次
+        
+        # 初始化粒子
+        for _ in range(100):  # 增加粒子数量
+            self.particles.append({
+                'x': self.width() * 0.5,
+                'y': self.height() * 0.5,
+                'vx': (0.5 - float(os.urandom(1)[0]) / 255.0) * 8,  # 增加速度范围
+                'vy': (0.5 - float(os.urandom(1)[0]) / 255.0) * 8,
+                'size': 1 + float(os.urandom(1)[0]) / 32.0,  # 增加粒子大小
+                'alpha': 0.5 + float(os.urandom(1)[0]) / 255.0,
+                'color': QColor(100 + int(os.urandom(1)[0]) % 155, 
+                                200 + int(os.urandom(1)[0]) % 55, 
+                                220 + int(os.urandom(1)[0]) % 35,  # 更亮的蓝色调
+                                200)
+            })
+    
+    def update_particles(self):
+        width = self.width() or 1920  # 防止宽度为0
+        height = self.height() or 1080
+        
+        for p in self.particles:
+            # 更新位置
+            p['x'] += p['vx']
+            p['y'] += p['vy']
+            
+            # 如果超出边界,重新生成粒子
+            if (p['x'] < 0 or p['x'] > width or 
+                p['y'] < 0 or p['y'] > height):
+                p['x'] = width * 0.5 + (width * 0.3 * (float(os.urandom(1)[0]) / 255.0 - 0.5))  # 随机中心位置
+                p['y'] = height * 0.5 + (height * 0.3 * (float(os.urandom(1)[0]) / 255.0 - 0.5))
+                p['vx'] = (0.5 - float(os.urandom(1)[0]) / 255.0) * 8
+                p['vy'] = (0.5 - float(os.urandom(1)[0]) / 255.0) * 8
+        
+        self.update()  # 触发重绘
+        
+    def paintEvent(self, event):
+        painter = QPainter(self)
+        painter.setRenderHint(QPainter.Antialiasing)
+        
+        for p in self.particles:
+            painter.setPen(Qt.NoPen)
+            painter.setBrush(p['color'])
+            size = p['size']
+            painter.drawEllipse(p['x'] - size/2, p['y'] - size/2, size, size)
+            
+            # 为部分粒子添加轻微的发光效果
+            if size > 2:
+                glow = QColor(p['color'])
+                glow.setAlpha(50)
+                painter.setBrush(glow)
+                painter.drawEllipse(p['x'] - size, p['y'] - size, size * 2, size * 2)
+        
+class HexEffect(QWidget):
+    """六边形网格效果"""
+    def __init__(self, parent=None):
+        super().__init__(parent)
+        self.setAttribute(Qt.WA_TranslucentBackground)
+        self.setAttribute(Qt.WA_TransparentForMouseEvents)
+        self.hexagons = []
+        self.timer = QTimer(self)
+        self.timer.timeout.connect(self.update_hexagons)
+        self.timer.start(100)  # 每100毫秒更新一次
+        self.alpha_direction = 1  # 透明度变化方向
+        self.current_alpha = 40  # 初始透明度
+        
+    def update_hexagons(self):
+        # 更新六边形透明度
+        self.current_alpha += self.alpha_direction
+        if self.current_alpha >= 60:  # 最大透明度
+            self.alpha_direction = -1
+        elif self.current_alpha <= 20:  # 最小透明度
+            self.alpha_direction = 1
+            
+        self.update()  # 触发重绘
+        
+    def paintEvent(self, event):
+        painter = QPainter(self)
+        painter.setRenderHint(QPainter.Antialiasing)
+        
+        hex_size = 60  # 六边形大小
+        width = self.width()
+        height = self.height()
+        
+        # 计算六边形网格
+        horizontal_spacing = hex_size * 1.5
+        vertical_spacing = hex_size * 0.866 * 2  # sqrt(3)/2 * 2 * hex_size
+        
+        rows = int(height / vertical_spacing) + 2
+        cols = int(width / horizontal_spacing) + 2
+        
+        color = QColor(0, 180, 220, self.current_alpha)
+        painter.setPen(QPen(QColor(0, 220, 255, 80), 1))
+        
+        for row in range(rows):
+            for col in range(cols):
+                x = col * horizontal_spacing
+                y = row * vertical_spacing
+                
+                # 偶数行需要偏移
+                if row % 2 == 1:
+                    x += hex_size * 0.75
+                
+                # 绘制六边形
+                painter.setBrush(QBrush(color))
+                self.draw_hexagon(painter, x, y, hex_size)
+                
+    def draw_hexagon(self, painter, x, y, size):
+        """绘制六边形"""
+        points = []
+        for i in range(6):
+            angle_deg = 60 * i - 30
+            angle_rad = 3.14159 * angle_deg / 180
+            point_x = x + size * 0.5 * math.cos(angle_rad)
+            point_y = y + size * 0.5 * math.sin(angle_rad)
+            points.append(QPoint(point_x, point_y))
+            
+        painter.drawPolygon(points)
+
+class CircuitEffect(QWidget):
+    """电路板效果"""
+    def __init__(self, parent=None):
+        super().__init__(parent)
+        self.setAttribute(Qt.WA_TranslucentBackground)
+        self.setAttribute(Qt.WA_TransparentForMouseEvents)
+        self.circuit_points = []
+        self.circuit_lines = []
+        self.pulse_positions = {}  # 线路上的脉冲位置
+        
+        self.timer = QTimer(self)
+        self.timer.timeout.connect(self.update_pulses)
+        self.timer.start(50)  # 每50毫秒更新一次
+        
+    def generate_circuits(self, width, height):
+        """生成电路图案"""
+        self.circuit_points = []
+        self.circuit_lines = []
+        self.pulse_positions = {}
+        
+        # 网格大小
+        cell_size = 100
+        cols = width // cell_size + 1
+        rows = height // cell_size + 1
+        
+        # 生成网格点
+        for row in range(rows):
+            for col in range(cols):
+                # 添加一些随机性
+                jitter_x = (float(os.urandom(1)[0]) / 255.0 - 0.5) * cell_size * 0.5
+                jitter_y = (float(os.urandom(1)[0]) / 255.0 - 0.5) * cell_size * 0.5
+                
+                x = col * cell_size + jitter_x
+                y = row * cell_size + jitter_y
+                
+                self.circuit_points.append((x, y))
+        
+        # 生成线条连接
+        for i, point1 in enumerate(self.circuit_points):
+            # 找到最近的几个点
+            distances = []
+            for j, point2 in enumerate(self.circuit_points):
+                if i != j:
+                    dx = point1[0] - point2[0]
+                    dy = point1[1] - point2[1]
+                    distance = (dx * dx + dy * dy) ** 0.5
+                    if distance < cell_size * 1.8:  # 只连接较近的点
+                        distances.append((distance, j))
+            
+            # 最多连接3条线
+            distances.sort()
+            for k in range(min(3, len(distances))):
+                j = distances[k][1]
+                if i < j:  # 避免重复添加
+                    self.circuit_lines.append((i, j))
+                    # 初始化脉冲位置,20%的线有脉冲
+                    if float(os.urandom(1)[0]) / 255.0 < 0.2:
+                        self.pulse_positions[(i, j)] = 0.0
+    
+    def update_pulses(self):
+        """更新脉冲位置"""
+        # 更新现有脉冲
+        keys_to_remove = []
+        for line, pos in self.pulse_positions.items():
+            self.pulse_positions[line] = pos + 0.02  # 脉冲前进速度
+            if self.pulse_positions[line] > 1.0:
+                keys_to_remove.append(line)
+        
+        # 移除完成的脉冲
+        for key in keys_to_remove:
+            del self.pulse_positions[key]
+        
+        # 随机添加新脉冲
+        if len(self.circuit_lines) > 0 and len(self.pulse_positions) < len(self.circuit_lines) * 0.2:
+            if float(os.urandom(1)[0]) / 255.0 < 0.1:  # 10%几率添加新脉冲
+                line_idx = int(float(os.urandom(1)[0]) / 255.0 * len(self.circuit_lines))
+                # 确保索引不超出范围
+                line_idx = min(line_idx, len(self.circuit_lines) - 1)
+                line = self.circuit_lines[line_idx]
+                if line not in self.pulse_positions:
+                    self.pulse_positions[line] = 0.0
+        
+        self.update()  # 触发重绘
+        
+    def paintEvent(self, event):
+        # 如果还没有生成电路,则生成
+        if not self.circuit_points:
+            self.generate_circuits(self.width(), self.height())
+            
+        painter = QPainter(self)
+        painter.setRenderHint(QPainter.Antialiasing)
+        
+        # 绘制线条
+        painter.setPen(QPen(QColor(0, 180, 220, 40), 1))
+        for i, j in self.circuit_lines:
+            p1 = self.circuit_points[i]
+            p2 = self.circuit_points[j]
+            painter.drawLine(p1[0], p1[1], p2[0], p2[1])
+        
+        # 绘制脉冲
+        for line, pos in self.pulse_positions.items():
+            i, j = line
+            p1 = self.circuit_points[i]
+            p2 = self.circuit_points[j]
+            
+            # 计算脉冲位置
+            pulse_x = p1[0] + (p2[0] - p1[0]) * pos
+            pulse_y = p1[1] + (p2[1] - p1[1]) * pos
+            
+            # 绘制脉冲(亮点)
+            gradient = QRadialGradient(pulse_x, pulse_y, 10)
+            gradient.setColorAt(0, QColor(0, 230, 255, 200))
+            gradient.setColorAt(1, QColor(0, 230, 255, 0))
+            
+            painter.setBrush(QBrush(gradient))
+            painter.setPen(Qt.NoPen)
+            painter.drawEllipse(pulse_x - 10, pulse_y - 10, 20, 20)
+            
+        # 绘制交叉点
+        painter.setPen(Qt.NoPen)
+        painter.setBrush(QBrush(QColor(0, 200, 220, 100)))
+        for point in self.circuit_points:
+            painter.drawEllipse(point[0] - 3, point[1] - 3, 6, 6)
+        
+class CustomSplashScreen(QSplashScreen):
+    """自定义启动屏幕"""
+    def __init__(self):
+        super().__init__()
+        self.setWindowFlag(Qt.FramelessWindowHint)
+        self.setWindowFlag(Qt.WindowStaysOnTopHint)
+        
+        # 获取屏幕尺寸并全屏显示
+        desktop = QDesktopWidget().availableGeometry()
+        self.screen_width = desktop.width()
+        self.screen_height = desktop.height()
+        self.setGeometry(0, 0, self.screen_width, self.screen_height)
+        
+        # 创建基础窗口
+        self.setStyleSheet("""
+            QSplashScreen {
+                background-color: #0a1a2a;
+                border: 0px;
+            }
+        """)
+        
+        # 创建中央组件来布局内容
+        self.central_widget = QWidget(self)
+        self.central_widget.setGeometry(0, 0, self.screen_width, self.screen_height)
+        
+        # 主布局
+        self.layout = QVBoxLayout(self.central_widget)
+        self.layout.setContentsMargins(40, 40, 40, 40)
+        self.layout.setSpacing(15)
+        
+        # 添加空白占位
+        self.layout.addStretch(1)
+        
+        # 添加标题
+        self.title_label = QLabel("森林多模态灾害监测系统", self)
+        self.title_label.setAlignment(Qt.AlignCenter)
+        font = QFont("Microsoft YaHei", 42, QFont.Bold)
+        self.title_label.setFont(font)
+        self.title_label.setStyleSheet("color: #00e6e6; margin-top: 20px;")
+        self.layout.addWidget(self.title_label)
+        
+        # 添加英文副标题
+        self.subtitle_label = QLabel("Forest Multi-modal Disaster Monitoring System", self)
+        self.subtitle_label.setAlignment(Qt.AlignCenter)
+        subtitle_font = QFont("Arial", 20)
+        subtitle_font.setItalic(True)
+        self.subtitle_label.setFont(subtitle_font)
+        self.subtitle_label.setStyleSheet("color: #99f2ff; margin-bottom: 20px;")
+        self.layout.addWidget(self.subtitle_label)
+        
+        # 添加分隔线
+        self.separator = QFrame(self)
+        self.separator.setFrameShape(QFrame.HLine)
+        self.separator.setStyleSheet("background-color: #00a0c0; margin: 20px 200px;")
+        self.separator.setFixedHeight(3)
+        self.layout.addWidget(self.separator)
+        
+        # 添加加载动画
+        self.animation_label = QLabel(self)
+        self.animation_label.setAlignment(Qt.AlignCenter)
+        # 检查GIF文件是否存在,否则使用替代动画
+        animation_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'ui', 'assets', 'loading.gif')
+        if os.path.exists(animation_path):
+            try:
+                self.animation = QMovie(animation_path)
+                self.animation.setScaledSize(QSize(200, 200))
+                self.animation_label.setMovie(self.animation)
+                self.animation.start()
+            except Exception as e:
+                print(f"加载动画文件失败: {e}")
+                self.animation_label.setText("系统初始化中...")
+                self.animation_label.setStyleSheet("color: #99f2ff; font-size: 24px;")
+        else:
+            # 使用文本代替动画
+            self.animation_label.setText("系统初始化中...")
+            self.animation_label.setStyleSheet("color: #99f2ff; font-size: 24px;")
+            print(f"动画文件不存在: {animation_path}")
+        
+        self.layout.addWidget(self.animation_label)
+        
+        # 添加状态标签
+        self.status_label = QLabel("正在加载组件...", self)
+        self.status_label.setAlignment(Qt.AlignCenter)
+        self.status_label.setStyleSheet("color: #99f2ff; font-size: 20px; margin-top: 20px;")
+        self.layout.addWidget(self.status_label)
+        
+        # 添加进度条
+        self.progress_bar = QProgressBar(self)
+        self.progress_bar.setRange(0, 100)
+        self.progress_bar.setValue(0)
+        self.progress_bar.setTextVisible(False)
+        self.progress_bar.setFixedHeight(10)
+        self.progress_bar.setStyleSheet("""
+            QProgressBar {
+                background-color: #001824;
+                border: 1px solid #004080;
+                border-radius: 5px;
+                margin: 10px 100px;
+            }
+            QProgressBar::chunk {
+                background-color: qlineargradient(x1:0, y1:0.5, x2:1, y2:0.5, stop:0 #00ccff, stop:1 #00ffcc);
+                border-radius: 5px;
+            }
+        """)
+        self.layout.addWidget(self.progress_bar)
+        
+        # 添加空白占位
+        self.layout.addStretch(2)
+        
+        # 添加版本信息
+        self.version_label = QLabel("版本 1.0.0  |  © 2025 火眼金睛灾害监测技术实验室", self)
+        self.version_label.setAlignment(Qt.AlignCenter)
+        self.version_label.setStyleSheet("color: #4db8ff; font-size: 14px; margin-bottom: 20px;")
+        self.layout.addWidget(self.version_label)
+        
+        # 添加底部说明
+        self.bottom_label = QLabel("正在启动...", self)
+        self.bottom_label.setAlignment(Qt.AlignCenter)
+        self.bottom_label.setStyleSheet("color: #6698ff; font-size: 12px; margin-bottom: 10px;")
+        self.layout.addWidget(self.bottom_label)
+        
+        # 添加背景效果
+        self.circuit_effect = CircuitEffect(self)
+        self.circuit_effect.setGeometry(0, 0, self.screen_width, self.screen_height)
+        
+        # 添加粒子效果
+        self.particle_effect = ParticleEffect(self)
+        self.particle_effect.setGeometry(0, 0, self.screen_width, self.screen_height)
+        
+        # 初始化进度计时器
+        self.progress_timer = QTimer(self)
+        self.progress_timer.timeout.connect(self.update_progress)
+        self.current_progress = 0
+        
+        # 状态消息列表
+        self.status_messages = [
+            "加载核心模块...",
+            "初始化灾害检测模型...",
+            "配置多模态数据源...",
+            "加载地理信息系统...",
+            "初始化摄像头监控模块...",
+            "配置告警系统...",
+            "连接云端数据中心...",
+            "准备AI分析引擎...",
+            "系统启动完成,正在进入主界面..."
+        ]
+        self.current_message_index = 0
+        
+        # 效果定时器
+        self.effect_timer = QTimer(self)
+        self.effect_timer.timeout.connect(self.update_effects)
+        self.effect_timer.start(100)
+        
+        # 效果变量
+        self.title_phase = 0
+        self.glow_value = 0
+        self.glow_direction = 1
+        
+    def update_effects(self):
+        """更新视觉效果"""
+        # 标题发光效果
+        self.glow_value += self.glow_direction * 2
+        if self.glow_value > 100:
+            self.glow_value = 100
+            self.glow_direction = -1
+        elif self.glow_value < 0:
+            self.glow_value = 0
+            self.glow_direction = 1
+            
+        glow_color = f"rgba(0, {180 + self.glow_value//2}, {220 + self.glow_value//3}, 0.8)"
+        self.title_label.setStyleSheet(f"color: #00e6e6; margin-top: 20px; text-shadow: 0 0 15px {glow_color};")
+        
+        # 底部提示变化
+        self.title_phase += 1
+        if self.title_phase % 50 == 0:
+            dots = "." * ((self.title_phase // 10) % 4)
+            self.bottom_label.setText(f"正在启动{dots}")
+        
+    def start_loading(self, duration=5000):
+        """开始加载动画,持续duration毫秒"""
+        self.progress_timer.start(duration / 100)  # 将持续时间分成100步
+        
+    def update_progress(self):
+        """更新进度和消息"""
+        self.current_progress += 1
+        self.progress_bar.setValue(self.current_progress)
+        
+        # 更新状态消息
+        if self.current_progress % (100 // len(self.status_messages)) == 0:
+            if self.current_message_index < len(self.status_messages):
+                self.status_label.setText(self.status_messages[self.current_message_index])
+                self.current_message_index += 1
+        
+        # 完成加载
+        if self.current_progress >= 100:
+            self.progress_timer.stop()
+            QTimer.singleShot(500, self.finish_loading)  # 延迟半秒后完成
+    
+    def finish_loading(self):
+        """完成加载后的处理"""
+        # 在实际应用中,这里可以发出加载完成的信号
+        pass
+            
+    def paintEvent(self, event):
+        """自定义绘制事件"""
+        super().paintEvent(event)
+        painter = QPainter(self)
+        painter.setRenderHint(QPainter.Antialiasing)
+        
+        # 绘制背景渐变
+        gradient = QLinearGradient(0, 0, 0, self.height())
+        gradient.setColorAt(0, QColor(10, 30, 60))
+        gradient.setColorAt(0.5, QColor(5, 20, 40))
+        gradient.setColorAt(1, QColor(10, 30, 60))
+        painter.fillRect(0, 0, self.width(), self.height(), gradient)
+        
+        # 绘制边框
+        margin = 10
+        pen = QPen(QColor(0, 180, 230, 100), 3)
+        painter.setPen(pen)
+        painter.drawRect(margin, margin, self.width()-margin*2, self.height()-margin*2)
+        
+        # 绘制顶部和底部装饰线
+        painter.setPen(QPen(QColor(0, 180, 230, 150), 2))
+        
+        # 顶部装饰
+        top_margin = 30
+        line_width = self.width() * 0.4
+        painter.drawLine(self.width()/2 - line_width/2, top_margin, 
+                         self.width()/2 + line_width/2, top_margin)
+        
+        # 左上角装饰
+        painter.drawLine(margin*2, top_margin*2, margin*6, top_margin*2)
+        painter.drawLine(margin*2, top_margin*2, margin*2, top_margin*5)
+        
+        # 右上角装饰
+        painter.drawLine(self.width()-margin*6, top_margin*2, self.width()-margin*2, top_margin*2)
+        painter.drawLine(self.width()-margin*2, top_margin*2, self.width()-margin*2, top_margin*5)
+        
+        # 底部装饰
+        bottom_margin = 30
+        painter.drawLine(self.width()/2 - line_width/2, self.height()-bottom_margin, 
+                         self.width()/2 + line_width/2, self.height()-bottom_margin)
+        
+        # 左下角装饰
+        painter.drawLine(margin*2, self.height()-top_margin*2, margin*6, self.height()-top_margin*2)
+        painter.drawLine(margin*2, self.height()-top_margin*2, margin*2, self.height()-top_margin*5)
+        
+        # 右下角装饰
+        painter.drawLine(self.width()-margin*6, self.height()-top_margin*2, self.width()-margin*2, self.height()-top_margin*2)
+        painter.drawLine(self.width()-margin*2, self.height()-top_margin*2, self.width()-margin*2, self.height()-top_margin*5)
+
+def show_splash_screen(app, main_window, duration=5000):
+    """显示启动屏幕
+    
+    Args:
+        app: QApplication实例
+        main_window: 主窗口实例
+        duration: 显示启动屏幕的时间(毫秒)
+    """
+    splash = CustomSplashScreen()
+    splash.show()
+    
+    # 开始加载动画
+    splash.start_loading(duration)
+    
+    # 设置计时器,在指定时间后隐藏启动屏幕并显示主窗口
+    QTimer.singleShot(duration, lambda: finish_splash(splash, main_window))
+    
+    # 处理事件,确保启动屏幕显示
+    app.processEvents()
+    
+def finish_splash(splash, main_window):
+    """完成启动屏幕并显示主窗口"""
+    # 淡出效果
+    fade_out = QPropertyAnimation(splash, b"windowOpacity")
+    fade_out.setDuration(1000)  # 延长淡出时间
+    fade_out.setStartValue(1.0)
+    fade_out.setEndValue(0.0)
+    fade_out.setEasingCurve(QEasingCurve.OutQuad)
+    fade_out.finished.connect(splash.close)
+    fade_out.start()
+    
+    # 显示主窗口
+    main_window.show()
+
+# 测试代码
+if __name__ == "__main__":
+    app = QApplication(sys.argv)
+    
+    # 创建一个简单的主窗口用于测试
+    main = QWidget()
+    main.setWindowTitle("测试主窗口")
+    main.resize(800, 600)
+    
+    # 显示启动屏幕
+    show_splash_screen(app, main, 5000)
+    
+    sys.exit(app.exec_()) 

+ 37 - 0
utils/__init__.py

@@ -0,0 +1,37 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+"""
+utils/initialization
+"""
+
+
+def notebook_init(verbose=True):
+    # Check system software and hardware
+    print('Checking setup...')
+
+    import os
+    import shutil
+
+    from utils.general import check_requirements, emojis, is_colab
+    from utils.torch_utils import select_device  # imports
+
+    check_requirements(('psutil', 'IPython'))
+    import psutil
+    from IPython import display  # to display images and clear console output
+
+    if is_colab():
+        shutil.rmtree('/content/sample_data', ignore_errors=True)  # remove colab /sample_data directory
+
+    if verbose:
+        # System info
+        # gb = 1 / 1000 ** 3  # bytes to GB
+        gib = 1 / 1024 ** 3  # bytes to GiB
+        ram = psutil.virtual_memory().total
+        total, used, free = shutil.disk_usage("/")
+        display.clear_output()
+        s = f'({os.cpu_count()} CPUs, {ram * gib:.1f} GB RAM, {(total - free) * gib:.1f}/{total * gib:.1f} GB disk)'
+    else:
+        s = ''
+
+    select_device(newline=False)
+    print(emojis(f'Setup complete ✅ {s}'))
+    return display

BIN
utils/__pycache__/__init__.cpython-311.pyc


BIN
utils/__pycache__/__init__.cpython-38.pyc


BIN
utils/__pycache__/__init__.cpython-39.pyc


BIN
utils/__pycache__/augmentations.cpython-311.pyc


BIN
utils/__pycache__/augmentations.cpython-38.pyc


BIN
utils/__pycache__/augmentations.cpython-39.pyc


BIN
utils/__pycache__/autoanchor.cpython-38.pyc


BIN
utils/__pycache__/autoanchor.cpython-39.pyc


BIN
utils/__pycache__/autobatch.cpython-38.pyc


BIN
utils/__pycache__/callbacks.cpython-38.pyc


BIN
utils/__pycache__/config_loader.cpython-38.pyc


BIN
utils/__pycache__/config_loader.cpython-39.pyc


BIN
utils/__pycache__/datasets.cpython-311.pyc


BIN
utils/__pycache__/datasets.cpython-38.pyc


BIN
utils/__pycache__/datasets.cpython-39.pyc


BIN
utils/__pycache__/downloads.cpython-311.pyc


BIN
utils/__pycache__/downloads.cpython-38.pyc


BIN
utils/__pycache__/downloads.cpython-39.pyc


BIN
utils/__pycache__/flame_detector.cpython-39.pyc


BIN
utils/__pycache__/general.cpython-311.pyc


BIN
utils/__pycache__/general.cpython-38.pyc


BIN
utils/__pycache__/general.cpython-39.pyc


BIN
utils/__pycache__/loss.cpython-38.pyc


BIN
utils/__pycache__/metrics.cpython-311.pyc


BIN
utils/__pycache__/metrics.cpython-38.pyc


BIN
utils/__pycache__/metrics.cpython-39.pyc


Some files were not shown because too many files changed in this diff