download_model.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 模型下载脚本
  5. 提供多种下载方式和错误处理机制
  6. """
  7. import os
  8. import sys
  9. import ssl
  10. import subprocess
  11. from pathlib import Path
  12. from typing import Optional
  13. def setup_ssl_bypass():
  14. """设置SSL绕过配置"""
  15. # 禁用SSL验证
  16. ssl._create_default_https_context = ssl._create_unverified_context
  17. # 设置环境变量
  18. os.environ['PYTHONHTTPSVERIFY'] = '0'
  19. os.environ['CURL_CA_BUNDLE'] = ''
  20. os.environ['REQUESTS_CA_BUNDLE'] = ''
  21. # 禁用urllib3警告
  22. import urllib3
  23. urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
  24. def download_with_modelscope_cli(model_name: str, output_dir: str) -> bool:
  25. """使用ModelScope命令行工具下载"""
  26. try:
  27. print(f"尝试使用ModelScope CLI下载: {model_name}")
  28. # 设置环境变量
  29. env = os.environ.copy()
  30. env.update({
  31. 'PYTHONHTTPSVERIFY': '0',
  32. 'SSL_VERIFY': 'false',
  33. 'CURL_CA_BUNDLE': '',
  34. 'REQUESTS_CA_BUNDLE': ''
  35. })
  36. cmd = ['modelscope', 'download', '--model', model_name, '--local_dir', output_dir]
  37. result = subprocess.run(cmd, env=env, capture_output=True, text=True)
  38. if result.returncode == 0:
  39. print(f"✓ ModelScope CLI下载成功: {output_dir}")
  40. return True
  41. else:
  42. print(f"✗ ModelScope CLI下载失败: {result.stderr}")
  43. return False
  44. except Exception as e:
  45. print(f"✗ ModelScope CLI下载异常: {e}")
  46. return False
  47. def download_with_python_api(model_name: str, output_dir: str) -> bool:
  48. """使用Python API下载"""
  49. try:
  50. print(f"尝试使用Python API下载: {model_name}")
  51. from modelscope import snapshot_download
  52. # 模型名称映射
  53. modelscope_names = {
  54. "Qwen/Qwen2.5-1.5B-Instruct": "qwen/Qwen2.5-1.5B-Instruct",
  55. "Qwen/Qwen2-1.5B": "qwen/Qwen2-1.5B",
  56. "Qwen/Qwen2-1.5B-Instruct": "qwen/Qwen2-1.5B-Instruct"
  57. }
  58. ms_model_name = modelscope_names.get(model_name, model_name)
  59. downloaded_path = snapshot_download(
  60. model_id=ms_model_name,
  61. local_dir=output_dir
  62. )
  63. print(f"✓ Python API下载成功: {downloaded_path}")
  64. return True
  65. except Exception as e:
  66. print(f"✗ Python API下载失败: {e}")
  67. return False
  68. def verify_model_files(model_dir: str) -> bool:
  69. """验证模型文件完整性"""
  70. model_path = Path(model_dir)
  71. if not model_path.exists():
  72. print(f"✗ 模型目录不存在: {model_dir}")
  73. return False
  74. # 检查必要文件
  75. required_files = [
  76. 'config.json',
  77. 'tokenizer.json',
  78. 'tokenizer_config.json'
  79. ]
  80. missing_files = []
  81. for file_name in required_files:
  82. if not (model_path / file_name).exists():
  83. missing_files.append(file_name)
  84. if missing_files:
  85. print(f"✗ 缺少必要文件: {missing_files}")
  86. return False
  87. # 检查是否只有临时文件夹
  88. contents = list(model_path.iterdir())
  89. if len(contents) == 1 and contents[0].name.startswith('._____temp'):
  90. print(f"✗ 模型目录只包含临时文件夹")
  91. return False
  92. print(f"✓ 模型文件验证通过")
  93. return True
  94. def get_user_model_choice():
  95. """获取用户的模型选择"""
  96. from src.config.model_configs import ModelRegistry
  97. print("\n=== 模型选择 ===")
  98. print("可用的模型:")
  99. # 显示可用模型列表
  100. registry = ModelRegistry()
  101. models = registry.list_models() # 返回 {key: name} 格式
  102. model_list = list(models.keys())
  103. for i, (model_key, model_name) in enumerate(models.items(), 1):
  104. # 获取完整配置以显示架构信息
  105. config = registry.get_model_config(model_key)
  106. print(f" {i}. {model_key}: {model_name} ({config.architecture})")
  107. print(f"\n默认模型: Qwen2.5-1.5B-Instruct (qwen2.5-1.5b-instruct)")
  108. print("请选择模型 (输入数字编号,或直接回车使用默认模型):")
  109. try:
  110. user_input = input("> ").strip()
  111. if not user_input: # 用户直接回车,使用默认模型
  112. return "qwen2.5-1.5b-instruct"
  113. # 尝试解析为数字
  114. try:
  115. choice_num = int(user_input)
  116. if 1 <= choice_num <= len(model_list):
  117. selected_key = model_list[choice_num - 1]
  118. print(f"已选择: {selected_key}")
  119. return selected_key
  120. else:
  121. print(f"无效的选择编号,使用默认模型")
  122. return "qwen2.5-1.5b-instruct"
  123. except ValueError:
  124. # 尝试直接匹配模型键
  125. if user_input in models:
  126. print(f"已选择: {user_input}")
  127. return user_input
  128. else:
  129. print(f"未找到模型 '{user_input}',使用默认模型")
  130. return "qwen2.5-1.5b-instruct"
  131. except KeyboardInterrupt:
  132. print("\n用户取消,使用默认模型")
  133. return "qwen2.5-1.5b-instruct"
  134. except Exception as e:
  135. print(f"输入错误: {e},使用默认模型")
  136. return "qwen2.5-1.5b-instruct"
  137. def main():
  138. """主函数"""
  139. # 获取用户选择的模型
  140. model_key = get_user_model_choice()
  141. # 获取模型配置
  142. from src.config.model_configs import ModelRegistry
  143. registry = ModelRegistry()
  144. model_config = registry.get_model_config(model_key)
  145. model_name = model_config.model_id # 用于下载的实际模型ID
  146. output_dir = f"/qwen/models/{model_config.name.replace('-', '_')}"
  147. print(f"开始下载模型: {model_name}")
  148. print(f"输出目录: {output_dir}")
  149. # 清理不完整的下载
  150. if Path(output_dir).exists():
  151. if not verify_model_files(output_dir):
  152. print(f"清理不完整的下载目录: {output_dir}")
  153. import shutil
  154. shutil.rmtree(output_dir)
  155. # 设置SSL绕过
  156. setup_ssl_bypass()
  157. # 创建输出目录
  158. Path(output_dir).mkdir(parents=True, exist_ok=True)
  159. # 尝试多种下载方式
  160. success = False
  161. # 方式1: ModelScope CLI
  162. if not success:
  163. success = download_with_modelscope_cli(model_name, output_dir)
  164. # 方式2: Python API
  165. if not success:
  166. success = download_with_python_api(model_name, output_dir)
  167. # 验证下载结果
  168. if success:
  169. if verify_model_files(output_dir):
  170. print(f"\n🎉 模型下载成功!")
  171. print(f"模型路径: {output_dir}")
  172. return 0
  173. else:
  174. print(f"\n❌ 模型下载不完整")
  175. return 1
  176. else:
  177. print(f"\n❌ 所有下载方式都失败了")
  178. print(f"\n建议解决方案:")
  179. print(f"1. 检查网络连接")
  180. print(f"2. 配置代理服务器")
  181. print(f"3. 手动下载模型文件")
  182. return 1
  183. if __name__ == "__main__":
  184. sys.exit(main())