start_training.sh 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. #!/bin/bash
  2. # -*- coding: utf-8 -*-
  3. #
  4. # 神机网络安全模型自动化训练启动脚本
  5. #
  6. # 使用方法:
  7. # 1. 完整训练: ./start_training.sh
  8. # 2. 仅数据处理: ./start_training.sh --mode data
  9. # 3. 仅模型训练: ./start_training.sh --mode train
  10. # 4. 仅模型测试: ./start_training.sh --mode test
  11. # 5. 交互模式: ./start_training.sh --mode interactive
  12. #
  13. set -e # 遇到错误立即退出
  14. # 颜色定义
  15. RED='\033[0;31m'
  16. GREEN='\033[0;32m'
  17. YELLOW='\033[1;33m'
  18. BLUE='\033[0;34m'
  19. NC='\033[0m' # No Color
  20. # 日志函数
  21. log_info() {
  22. echo -e "${GREEN}[INFO]${NC} $1"
  23. }
  24. log_warn() {
  25. echo -e "${YELLOW}[WARN]${NC} $1"
  26. }
  27. log_error() {
  28. echo -e "${RED}[ERROR]${NC} $1"
  29. }
  30. log_step() {
  31. echo -e "${BLUE}[STEP]${NC} $1"
  32. }
  33. # 检查系统环境
  34. check_system() {
  35. log_step "检查系统环境..."
  36. # 检查Python
  37. if ! command -v python3 &> /dev/null; then
  38. log_error "Python3 未安装"
  39. exit 1
  40. fi
  41. python_version=$(python3 -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")
  42. log_info "Python版本: $python_version"
  43. # 检查pip
  44. if ! command -v pip3 &> /dev/null; then
  45. log_error "pip3 未安装"
  46. exit 1
  47. fi
  48. # screen检查已移除,统一在前台运行
  49. # 检查CUDA
  50. if command -v nvidia-smi &> /dev/null; then
  51. log_info "检测到NVIDIA GPU"
  52. nvidia-smi --query-gpu=name,memory.total --format=csv,noheader,nounits | head -1
  53. else
  54. log_warn "未检测到NVIDIA GPU,将使用CPU训练"
  55. fi
  56. log_info "系统环境检查完成"
  57. }
  58. # 配置pip镜像源
  59. configure_pip_mirror() {
  60. log_step "配置pip镜像源..."
  61. # 测试网络连接并选择最佳镜像源
  62. if ping -c 1 -W 3 pypi.tuna.tsinghua.edu.cn &> /dev/null; then
  63. log_info "配置清华大学镜像源"
  64. pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
  65. pip config set global.trusted-host pypi.tuna.tsinghua.edu.cn
  66. elif ping -c 1 -W 3 mirrors.aliyun.com &> /dev/null; then
  67. log_info "配置阿里云镜像源"
  68. pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/
  69. pip config set global.trusted-host mirrors.aliyun.com
  70. elif ping -c 1 -W 3 pypi.douban.com &> /dev/null; then
  71. log_info "配置豆瓣镜像源"
  72. pip config set global.index-url https://pypi.douban.com/simple/
  73. pip config set global.trusted-host pypi.douban.com
  74. else
  75. log_warn "网络连接检查失败,使用默认源(可能较慢)"
  76. fi
  77. # 设置其他pip优化参数
  78. pip config set global.timeout 300
  79. pip config set global.retries 3
  80. log_info "当前pip配置:"
  81. pip config list || log_warn "无法显示pip配置"
  82. }
  83. # 设置虚拟环境
  84. setup_venv() {
  85. log_step "设置Python虚拟环境..."
  86. VENV_DIR="./venv"
  87. if [ ! -d "$VENV_DIR" ]; then
  88. log_info "创建虚拟环境..."
  89. python3 -m venv "$VENV_DIR"
  90. fi
  91. log_info "激活虚拟环境..."
  92. source "$VENV_DIR/bin/activate"
  93. # 配置pip镜像源
  94. configure_pip_mirror
  95. # 升级pip和构建工具
  96. log_info "升级pip、setuptools和wheel..."
  97. pip install --upgrade pip setuptools wheel || {
  98. log_warn "构建工具升级失败,继续使用当前版本"
  99. }
  100. log_info "虚拟环境设置完成"
  101. }
  102. # 安装依赖
  103. install_dependencies() {
  104. log_step "安装Python依赖..."
  105. # 优先使用基础依赖文件,避免安装问题
  106. if [ -f "requirements-basic.txt" ]; then
  107. log_info "使用requirements-basic.txt安装核心依赖..."
  108. pip install -r requirements-basic.txt || {
  109. log_error "基础依赖安装失败"
  110. exit 1
  111. }
  112. # 尝试安装可选依赖
  113. log_info "安装可选依赖..."
  114. # 安装bitsandbytes(量化支持)
  115. pip install bitsandbytes>=0.39.0 || {
  116. log_warn "bitsandbytes安装失败,量化功能可能不可用"
  117. }
  118. # 安装nvidia-ml-py(GPU监控)
  119. if command -v nvidia-smi &> /dev/null; then
  120. pip install nvidia-ml-py>=12.535.108 || {
  121. log_warn "nvidia-ml-py安装失败,GPU监控功能可能不可用"
  122. }
  123. # 尝试安装flash-attn
  124. log_info "检测到NVIDIA GPU,尝试安装flash-attn(可选)..."
  125. log_info "正在预安装torch以支持flash-attn编译..."
  126. pip install torch>=2.0.0 || log_warn "torch预安装失败,flash-attn可能无法安装"
  127. pip install flash-attn>=2.0.0 --no-build-isolation || {
  128. log_warn "flash-attn安装失败,将跳过此依赖(不影响基本功能)"
  129. log_warn "如需flash-attn,请手动安装:pip install flash-attn --no-build-isolation"
  130. }
  131. else
  132. log_info "未检测到GPU,跳过GPU相关可选依赖"
  133. fi
  134. elif [ -f "requirements.txt" ]; then
  135. log_info "使用requirements.txt安装依赖..."
  136. pip install -r requirements.txt || {
  137. log_error "依赖安装失败,请检查requirements.txt"
  138. exit 1
  139. }
  140. else
  141. log_error "未找到依赖文件 (requirements-basic.txt 或 requirements.txt)"
  142. exit 1
  143. fi
  144. log_info "依赖安装完成"
  145. }
  146. # 设置环境变量
  147. setup_environment() {
  148. log_step "设置环境变量..."
  149. # 设置HuggingFace缓存目录
  150. export HF_HOME="./cache/huggingface"
  151. export TRANSFORMERS_CACHE="./cache/transformers"
  152. # 设置ModelScope缓存目录
  153. export MODELSCOPE_CACHE="./cache/modelscope"
  154. # 设置CUDA相关环境变量
  155. export CUDA_VISIBLE_DEVICES=0
  156. # 设置Python路径
  157. export PYTHONPATH="$PWD/src:$PYTHONPATH"
  158. log_info "环境变量设置完成"
  159. }
  160. # 创建日志目录
  161. setup_logging() {
  162. log_step "设置日志目录..."
  163. LOG_DIR="./logs"
  164. mkdir -p "$LOG_DIR"
  165. # 生成日志文件名
  166. TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
  167. LOG_FILE="$LOG_DIR/training_$TIMESTAMP.log"
  168. log_info "日志文件: $LOG_FILE"
  169. }
  170. # 运行训练
  171. run_training() {
  172. local mode="$1"
  173. local additional_args="$2"
  174. log_step "开始模型训练..."
  175. # 构建命令
  176. CMD="python3 main.py --mode $mode $additional_args"
  177. log_info "执行命令: $CMD"
  178. # 前台运行训练
  179. log_info "在前台运行训练..."
  180. log_info "提示: 如需后台运行,可使用 screen 或 nohup 命令"
  181. $CMD 2>&1 | tee "$LOG_FILE"
  182. }
  183. # 显示帮助信息
  184. show_help() {
  185. echo "神机网络安全模型自动化训练系统"
  186. echo ""
  187. echo "使用方法:"
  188. echo " $0 [选项]"
  189. echo ""
  190. echo "选项:"
  191. echo " --mode MODE 运行模式 (full|data|train|test|interactive|check)"
  192. echo " --force-download 强制重新下载数据"
  193. echo " --model-path PATH 模型路径 (用于test和interactive模式)"
  194. echo " --resume 从最新checkpoint继续训练"
  195. echo " --resume-from PATH 从指定checkpoint路径继续训练"
  196. echo " --model MODEL 选择基础模型 (qwen|chatglm|baichuan|llama等)"
  197. echo " --list-models 列出支持的模型"
  198. echo " --help 显示此帮助信息"
  199. echo ""
  200. echo "运行模式:"
  201. echo " full 完整训练流程 (默认)"
  202. echo " data 仅数据下载和处理"
  203. echo " train 仅模型训练"
  204. echo " test 仅模型测试"
  205. echo " interactive 交互式对话"
  206. echo " check 检查系统环境"
  207. echo ""
  208. echo "示例:"
  209. echo " $0 # 完整训练"
  210. echo " $0 --mode data --force-download # 重新下载数据"
  211. echo " $0 --mode train # 仅训练模型"
  212. echo " $0 --mode train --resume # 从最新checkpoint继续训练"
  213. echo " $0 --mode train --model chatglm # 使用ChatGLM模型训练"
  214. echo " $0 --list-models # 列出支持的模型"
  215. echo " $0 --mode test # 测试模型"
  216. echo " $0 --mode interactive # 交互模式"
  217. }
  218. # 主函数
  219. main() {
  220. # 默认参数
  221. MODE="full"
  222. ADDITIONAL_ARGS=""
  223. # 如果没有提供任何参数,显示帮助信息
  224. if [[ $# -eq 0 ]]; then
  225. show_help
  226. echo ""
  227. echo "提示: 如果要运行完整训练流程,请使用: $0 --mode full"
  228. exit 0
  229. fi
  230. # 解析命令行参数
  231. while [[ $# -gt 0 ]]; do
  232. case $1 in
  233. --mode)
  234. MODE="$2"
  235. shift 2
  236. ;;
  237. --force-download)
  238. ADDITIONAL_ARGS="$ADDITIONAL_ARGS --force-download"
  239. shift
  240. ;;
  241. --model-path)
  242. ADDITIONAL_ARGS="$ADDITIONAL_ARGS --model-path '$2'"
  243. shift 2
  244. ;;
  245. --resume)
  246. ADDITIONAL_ARGS="$ADDITIONAL_ARGS --resume"
  247. shift
  248. ;;
  249. --resume-from)
  250. ADDITIONAL_ARGS="$ADDITIONAL_ARGS --resume-from '$2'"
  251. shift 2
  252. ;;
  253. --model)
  254. ADDITIONAL_ARGS="$ADDITIONAL_ARGS --model '$2'"
  255. shift 2
  256. ;;
  257. --list-models)
  258. ADDITIONAL_ARGS="$ADDITIONAL_ARGS --list-models"
  259. shift
  260. ;;
  261. --help)
  262. show_help
  263. exit 0
  264. ;;
  265. *)
  266. log_error "未知参数: $1"
  267. show_help
  268. exit 1
  269. ;;
  270. esac
  271. done
  272. # 验证模式
  273. case $MODE in
  274. full|data|train|test|interactive|check)
  275. ;;
  276. *)
  277. log_error "无效的运行模式: $MODE"
  278. show_help
  279. exit 1
  280. ;;
  281. esac
  282. echo "==========================================="
  283. echo "神机网络安全模型自动化训练系统"
  284. echo "==========================================="
  285. echo "运行模式: $MODE"
  286. echo "开始时间: $(date)"
  287. echo "==========================================="
  288. # 执行步骤
  289. check_system
  290. setup_venv
  291. install_dependencies
  292. setup_environment
  293. setup_logging
  294. # 运行训练
  295. run_training "$MODE" "$ADDITIONAL_ARGS"
  296. log_info "训练完成"
  297. }
  298. # 错误处理
  299. trap 'log_error "脚本执行失败,退出码: $?"' ERR
  300. # 运行主函数
  301. main "$@"