check_environment.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 环境检查脚本
  5. 自动检查系统要求和依赖安装情况
  6. """
  7. import sys
  8. import subprocess
  9. import importlib
  10. from pathlib import Path
  11. def check_python_version():
  12. """检查Python版本"""
  13. print("🔍 检查Python版本...")
  14. version = sys.version_info
  15. if version.major == 3 and version.minor >= 8:
  16. print(f"✅ Python版本: {version.major}.{version.minor}.{version.micro} (符合要求)")
  17. return True
  18. else:
  19. print(f"❌ Python版本: {version.major}.{version.minor}.{version.micro} (需要Python 3.8+)")
  20. return False
  21. def check_memory():
  22. """检查系统内存"""
  23. print("\n🔍 检查系统内存...")
  24. try:
  25. import psutil
  26. memory = psutil.virtual_memory()
  27. memory_gb = memory.total / (1024**3)
  28. if memory_gb >= 8:
  29. print(f"✅ 系统内存: {memory_gb:.1f}GB (符合要求)")
  30. return True
  31. else:
  32. print(f"⚠️ 系统内存: {memory_gb:.1f}GB (推荐8GB+)")
  33. return True # 不强制要求
  34. except ImportError:
  35. print("⚠️ 无法检查内存 (psutil未安装)")
  36. return True
  37. def check_gpu():
  38. """检查GPU"""
  39. print("\n🔍 检查GPU...")
  40. try:
  41. import torch
  42. if torch.cuda.is_available():
  43. gpu_count = torch.cuda.device_count()
  44. gpu_name = torch.cuda.get_device_name(0)
  45. print(f"✅ 检测到GPU: {gpu_name} (数量: {gpu_count})")
  46. return True
  47. else:
  48. print("⚠️ 未检测到可用GPU (可以使用CPU训练,但速度较慢)")
  49. return True
  50. except ImportError:
  51. print("⚠️ 无法检查GPU (PyTorch未安装)")
  52. return True
  53. def check_dependencies():
  54. """检查核心依赖"""
  55. print("\n🔍 检查核心依赖...")
  56. required_packages = {
  57. 'torch': 'PyTorch',
  58. 'transformers': 'Transformers',
  59. 'datasets': 'Datasets',
  60. 'peft': 'PEFT',
  61. 'accelerate': 'Accelerate',
  62. 'huggingface_hub': 'Hugging Face Hub',
  63. 'safetensors': 'SafeTensors',
  64. 'sentencepiece': 'SentencePiece',
  65. 'tqdm': 'TQDM',
  66. 'requests': 'Requests',
  67. 'numpy': 'NumPy',
  68. 'pandas': 'Pandas',
  69. 'psutil': 'PSUtil'
  70. }
  71. missing_packages = []
  72. for package, name in required_packages.items():
  73. try:
  74. module = importlib.import_module(package)
  75. version = getattr(module, '__version__', 'unknown')
  76. print(f"✅ {name}: {version}")
  77. except ImportError:
  78. print(f"❌ {name}: 未安装")
  79. missing_packages.append(package)
  80. return len(missing_packages) == 0, missing_packages
  81. def check_project_structure():
  82. """检查项目结构"""
  83. print("\n🔍 检查项目结构...")
  84. required_files = [
  85. 'requirements.txt',
  86. 'main.py',
  87. 'start_training.sh',
  88. 'src/app.py',
  89. 'src/config/settings.py',
  90. 'src/model/trainer.py',
  91. 'src/data/loader.py'
  92. ]
  93. missing_files = []
  94. for file_path in required_files:
  95. if Path(file_path).exists():
  96. print(f"✅ {file_path}")
  97. else:
  98. print(f"❌ {file_path}: 文件不存在")
  99. missing_files.append(file_path)
  100. return len(missing_files) == 0, missing_files
  101. def provide_solutions(missing_packages, missing_files):
  102. """提供解决方案"""
  103. if missing_packages or missing_files:
  104. print("\n🔧 解决方案:")
  105. if missing_packages:
  106. print("\n📦 安装缺失的依赖:")
  107. print("pip install -r requirements.txt")
  108. print("\n或者逐个安装:")
  109. for package in missing_packages:
  110. print(f"pip install {package}")
  111. if missing_files:
  112. print("\n📁 缺失的文件:")
  113. for file_path in missing_files:
  114. print(f"- {file_path}")
  115. print("请确保您在正确的项目目录中运行此脚本")
  116. def main():
  117. """主函数"""
  118. print("🚀 Qwen安全模型训练项目 - 环境检查")
  119. print("=" * 50)
  120. checks = [
  121. check_python_version(),
  122. check_memory(),
  123. check_gpu()
  124. ]
  125. deps_ok, missing_packages = check_dependencies()
  126. checks.append(deps_ok)
  127. structure_ok, missing_files = check_project_structure()
  128. checks.append(structure_ok)
  129. print("\n" + "=" * 50)
  130. if all(checks):
  131. print("🎉 环境检查通过!您可以开始使用项目了。")
  132. print("\n📚 下一步:")
  133. print("1. 运行: python main.py --help")
  134. print("2. 或者: ./start_training.sh --help")
  135. print("3. 查看: README.md 了解详细使用方法")
  136. else:
  137. print("⚠️ 环境检查发现问题,请按照以下建议解决:")
  138. provide_solutions(missing_packages, missing_files)
  139. print("\n解决问题后,请重新运行此脚本进行检查。")
  140. return all(checks)
  141. if __name__ == "__main__":
  142. success = main()
  143. sys.exit(0 if success else 1)