downloader.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. # -*- coding: utf-8 -*-
  2. """
  3. 模型下载器模块
  4. 负责从ModelScope或HuggingFace下载模型。
  5. """
  6. import os
  7. from pathlib import Path
  8. from typing import Optional
  9. from ..config import Config
  10. from ..utils.logger import get_logger
  11. class ModelDownloader:
  12. """模型下载器"""
  13. def __init__(self, config: Config = None, model_key: str = None):
  14. self.config = config or Config()
  15. # 设置当前模型
  16. if model_key:
  17. self.config.set_current_model(model_key)
  18. # 获取模型配置
  19. self.model_config = self.config.get_current_model_config()
  20. self.logger = get_logger(self.__class__.__name__)
  21. # 设置环境变量
  22. self.config.setup_environment()
  23. def download_model(self, model_name: Optional[str] = None,
  24. force_download: bool = False) -> Path:
  25. """下载模型"""
  26. if model_name is None:
  27. model_name = self.config.get_current_model_key()
  28. model_path = self.get_model_path(model_name)
  29. # 检查模型是否已存在
  30. if self.check_model_exists(model_name) and not force_download:
  31. self.logger.info(f"模型已存在: {model_path}")
  32. return model_path
  33. self.logger.info(f"开始下载模型: {model_name}")
  34. try:
  35. if self.config.USE_MODELSCOPE:
  36. return self._download_from_modelscope(model_name, model_path)
  37. else:
  38. return self._download_from_huggingface(model_name, model_path)
  39. except Exception as e:
  40. self.logger.error(f"模型下载失败: {e}")
  41. # 如果ModelScope失败,尝试HuggingFace
  42. if self.config.USE_MODELSCOPE:
  43. self.logger.info("尝试从HuggingFace下载...")
  44. return self._download_from_huggingface(model_name, model_path)
  45. raise
  46. def _download_from_modelscope(self, model_name: str, model_path: Path) -> Path:
  47. """从ModelScope下载模型"""
  48. # 首先尝试git clone方式下载
  49. try:
  50. return self._download_from_modelscope_git(model_name, model_path)
  51. except Exception as git_error:
  52. self.logger.warning(f"Git下载失败: {git_error},尝试SDK方式")
  53. # 如果git方式失败,回退到SDK方式
  54. try:
  55. return self._download_from_modelscope_sdk(model_name, model_path)
  56. except Exception as sdk_error:
  57. self.logger.error(f"SDK下载也失败: {sdk_error}")
  58. raise
  59. def _download_from_modelscope_git(self, model_name: str, model_path: Path) -> Path:
  60. """使用git clone从ModelScope下载模型"""
  61. import subprocess
  62. import shutil
  63. # 获取ModelScope下载ID
  64. download_id = self.config.get_model_id_for_download('modelscope', model_name)
  65. if not download_id:
  66. # 回退到原有映射逻辑
  67. modelscope_names = {
  68. "Qwen/Qwen2-1.5B": "qwen/Qwen2-1.5B",
  69. "Qwen/Qwen2-1.5B-Instruct": "qwen/Qwen2-1.5B-Instruct",
  70. "Qwen/Qwen2.5-1.5B-Instruct": "qwen/Qwen2.5-1.5B-Instruct"
  71. }
  72. download_id = modelscope_names.get(model_name, model_name)
  73. ms_model_name = download_id
  74. git_url = f"https://www.modelscope.cn/{ms_model_name}.git"
  75. self.logger.info(f"使用git clone从ModelScope下载: {git_url}")
  76. # 检查git是否可用
  77. try:
  78. subprocess.run(["git", "--version"], check=True, capture_output=True)
  79. except (subprocess.CalledProcessError, FileNotFoundError):
  80. raise Exception("Git未安装或不可用")
  81. # 确保目标目录存在
  82. model_path.parent.mkdir(parents=True, exist_ok=True)
  83. # 如果目标目录已存在,先删除
  84. if model_path.exists():
  85. shutil.rmtree(model_path)
  86. # 执行git clone
  87. try:
  88. cmd = ["git", "clone", git_url, str(model_path)]
  89. result = subprocess.run(
  90. cmd,
  91. check=True,
  92. capture_output=True,
  93. text=True,
  94. timeout=1800 # 30分钟超时
  95. )
  96. self.logger.info(f"Git clone成功: {model_path}")
  97. return model_path
  98. except subprocess.TimeoutExpired:
  99. raise Exception("Git clone超时")
  100. except subprocess.CalledProcessError as e:
  101. raise Exception(f"Git clone失败: {e.stderr}")
  102. def _download_from_modelscope_sdk(self, model_name: str, model_path: Path) -> Path:
  103. """使用SDK从ModelScope下载模型"""
  104. try:
  105. # 设置SSL配置以解决证书验证问题
  106. import ssl
  107. import urllib3
  108. import requests
  109. urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
  110. # 彻底禁用SSL验证
  111. original_context = ssl._create_default_https_context
  112. ssl._create_default_https_context = ssl._create_unverified_context
  113. # Monkey patch urllib3 to disable SSL verification
  114. import urllib3.util.ssl_
  115. original_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket
  116. def patched_ssl_wrap_socket(*args, **kwargs):
  117. kwargs['cert_reqs'] = ssl.CERT_NONE
  118. kwargs['check_hostname'] = False
  119. return original_ssl_wrap_socket(*args, **kwargs)
  120. urllib3.util.ssl_.ssl_wrap_socket = patched_ssl_wrap_socket
  121. try:
  122. from modelscope import snapshot_download
  123. # 获取ModelScope下载ID
  124. download_id = self.config.get_model_id_for_download('modelscope', model_name)
  125. if not download_id:
  126. # 回退到原有映射逻辑
  127. modelscope_names = {
  128. "Qwen/Qwen2-1.5B": "qwen/Qwen2-1.5B",
  129. "Qwen/Qwen2-1.5B-Instruct": "qwen/Qwen2-1.5B-Instruct",
  130. "Qwen/Qwen2.5-1.5B-Instruct": "qwen/Qwen2.5-1.5B-Instruct"
  131. }
  132. download_id = modelscope_names.get(model_name, model_name)
  133. ms_model_name = download_id
  134. self.logger.info(f"使用SDK从ModelScope下载: {ms_model_name}")
  135. downloaded_path = snapshot_download(
  136. model_id=ms_model_name,
  137. cache_dir=str(self.config.MODELS_DIR),
  138. local_dir=str(model_path)
  139. )
  140. self.logger.info(f"SDK下载完成: {downloaded_path}")
  141. return Path(downloaded_path)
  142. finally:
  143. # 恢复原始函数
  144. ssl._create_default_https_context = original_context
  145. urllib3.util.ssl_.ssl_wrap_socket = original_ssl_wrap_socket
  146. except ImportError:
  147. self.logger.error("ModelScope未安装,请安装: pip install modelscope")
  148. raise
  149. def _download_from_huggingface(self, model_name: str, model_path: Path) -> Path:
  150. """从HuggingFace下载模型"""
  151. try:
  152. from transformers import AutoTokenizer, AutoModelForCausalLM
  153. self.logger.info(f"从HuggingFace下载: {model_name}")
  154. # 下载分词器
  155. tokenizer = AutoTokenizer.from_pretrained(
  156. model_name,
  157. cache_dir=str(self.config.CACHE_DIR),
  158. trust_remote_code=True
  159. )
  160. tokenizer.save_pretrained(str(model_path))
  161. # 下载模型
  162. model = AutoModelForCausalLM.from_pretrained(
  163. model_name,
  164. cache_dir=str(self.config.CACHE_DIR),
  165. trust_remote_code=True
  166. )
  167. model.save_pretrained(str(model_path))
  168. self.logger.info(f"模型下载完成: {model_path}")
  169. return model_path
  170. except Exception as e:
  171. self.logger.error(f"HuggingFace下载失败: {e}")
  172. raise
  173. def check_model_exists(self, model_name: Optional[str] = None) -> bool:
  174. """检查模型是否存在"""
  175. if model_name is None:
  176. model_name = self.config.get_current_model_key()
  177. model_path = self.get_model_path(model_name)
  178. # 检查关键文件是否存在
  179. required_files = ["config.json", "tokenizer.json"]
  180. if not model_path.exists():
  181. return False
  182. for file_name in required_files:
  183. if not (model_path / file_name).exists():
  184. return False
  185. return True
  186. def get_model_path(self, model_name: Optional[str] = None) -> Path:
  187. """获取模型路径"""
  188. if model_name is None:
  189. model_name = self.model_config.name
  190. # 尝试使用配置中的路径方法
  191. if hasattr(self.config, 'get_model_path'):
  192. return Path(self.config.get_model_path(model_name))
  193. # 优先检查原始名称(去掉组织前缀)
  194. simple_name = model_name.split("/")[-1] if "/" in model_name else model_name
  195. simple_path = self.config.MODELS_DIR / simple_name
  196. # 如果简单名称的目录存在,使用它
  197. if simple_path.exists():
  198. return simple_path
  199. # 否则使用下划线替换的名称
  200. return self.config.MODELS_DIR / model_name.replace("/", "_")
  201. def get_model_info(self, model_name: Optional[str] = None) -> dict:
  202. """获取模型信息"""
  203. model_path = self.get_model_path(model_name)
  204. if not self.check_model_exists(model_name):
  205. return {"exists": False, "path": str(model_path)}
  206. # 计算模型大小
  207. total_size = 0
  208. file_count = 0
  209. for file_path in model_path.rglob("*"):
  210. if file_path.is_file():
  211. total_size += file_path.stat().st_size
  212. file_count += 1
  213. return {
  214. "exists": True,
  215. "path": str(model_path),
  216. "size_mb": total_size / (1024 * 1024),
  217. "file_count": file_count
  218. }