model_manager.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 模型管理CLI工具
  5. 提供命令行界面来管理多个模型,包括列出、下载、切换和测试模型。
  6. 使用方法:
  7. python scripts/model_manager.py list # 列出所有可用模型
  8. python scripts/model_manager.py download <model_key> # 下载指定模型
  9. python scripts/model_manager.py switch <model_key> # 切换当前模型
  10. python scripts/model_manager.py current # 显示当前模型
  11. python scripts/model_manager.py test <model_key> # 测试模型推理
  12. python scripts/model_manager.py chat <model_key> # 与模型对话
  13. python scripts/model_manager.py info <model_key> # 显示模型详细信息
  14. """
  15. import argparse
  16. import sys
  17. from pathlib import Path
  18. from typing import Optional
  19. # 添加项目根目录到Python路径
  20. project_root = Path(__file__).parent.parent
  21. sys.path.insert(0, str(project_root))
  22. from src.config.settings import Config
  23. from src.config.model_configs import ModelRegistry
  24. from src.model.inference import SecurityModelInference
  25. from src.model.downloader import ModelDownloader
  26. from src.utils.logger import get_logger
  27. logger = get_logger(__name__)
  28. class ModelManager:
  29. """模型管理器"""
  30. def __init__(self):
  31. self.config = Config()
  32. self.registry = ModelRegistry()
  33. def list_models(self):
  34. """列出所有可用模型"""
  35. models = self.registry.list_models()
  36. print("\n=== 可用模型列表 ===")
  37. print(f"{'模型键':<20} {'名称':<30} {'架构':<15} {'状态':<10}")
  38. print("-" * 80)
  39. current_model = self.config.get_current_model_key()
  40. for model_key, model_config in models.items():
  41. # 检查模型是否已下载
  42. self.config.set_current_model(model_key)
  43. model_path = Path(self.config.get_model_path())
  44. status = "已下载" if model_path.exists() and any(model_path.iterdir()) else "未下载"
  45. # 标记当前模型
  46. marker = "*" if model_key == current_model else " "
  47. print(f"{marker}{model_key:<19} {model_config.name:<30} {model_config.architecture:<15} {status:<10}")
  48. print(f"\n当前模型: {current_model}")
  49. print("注: * 表示当前选中的模型")
  50. def download_model(self, model_key: str):
  51. """下载指定模型"""
  52. if model_key not in self.registry.list_models():
  53. print(f"错误: 未知的模型键 '{model_key}'")
  54. self.list_available_keys()
  55. return False
  56. try:
  57. print(f"\n开始下载模型: {model_key}")
  58. downloader = ModelDownloader(self.config, model_key=model_key)
  59. model_path = downloader.ensure_model_downloaded()
  60. print(f"模型下载成功: {model_path}")
  61. return True
  62. except Exception as e:
  63. print(f"模型下载失败: {e}")
  64. logger.error(f"下载模型 {model_key} 失败", exc_info=True)
  65. return False
  66. def switch_model(self, model_key: str):
  67. """切换当前模型"""
  68. if model_key not in self.registry.list_models():
  69. print(f"错误: 未知的模型键 '{model_key}'")
  70. self.list_available_keys()
  71. return False
  72. try:
  73. old_model = self.config.get_current_model_key()
  74. self.config.set_current_model(model_key)
  75. print(f"模型切换成功: {old_model} -> {model_key}")
  76. return True
  77. except Exception as e:
  78. print(f"模型切换失败: {e}")
  79. logger.error(f"切换到模型 {model_key} 失败", exc_info=True)
  80. return False
  81. def show_current(self):
  82. """显示当前模型信息"""
  83. current_key = self.config.get_current_model_key()
  84. current_config = self.config.get_current_model_config()
  85. print(f"\n=== 当前模型信息 ===")
  86. print(f"模型键: {current_key}")
  87. print(f"名称: {current_config.name}")
  88. print(f"架构: {current_config.architecture}")
  89. print(f"最大长度: {current_config.max_length}")
  90. print(f"本地路径: {self.config.get_model_path()}")
  91. # 检查模型状态
  92. model_path = Path(self.config.get_model_path())
  93. if model_path.exists() and any(model_path.iterdir()):
  94. print(f"状态: 已下载")
  95. else:
  96. print(f"状态: 未下载")
  97. def show_model_info(self, model_key: str):
  98. """显示指定模型的详细信息"""
  99. if model_key not in self.registry.list_models():
  100. print(f"错误: 未知的模型键 '{model_key}'")
  101. self.list_available_keys()
  102. return False
  103. model_config = self.registry.get_model_config(model_key)
  104. print(f"\n=== 模型信息: {model_key} ===")
  105. print(f"名称: {model_config.name}")
  106. print(f"架构: {model_config.architecture}")
  107. print(f"最大长度: {model_config.max_length}")
  108. print(f"支持Chat模板: {model_config.supports_chat_template}")
  109. print(f"支持量化: {model_config.supports_quantization}")
  110. print(f"支持LoRA: {model_config.supports_lora}")
  111. # 下载信息
  112. self.config.set_current_model(model_key)
  113. modelscope_id = self.config.get_model_id_for_download('modelscope')
  114. huggingface_id = self.config.get_model_id_for_download('huggingface')
  115. print(f"\n下载信息:")
  116. print(f" ModelScope: {modelscope_id or '不支持'}")
  117. print(f" HuggingFace: {huggingface_id or '不支持'}")
  118. print(f" 本地路径: {self.config.get_model_path()}")
  119. # 检查状态
  120. model_path = Path(self.config.get_model_path())
  121. if model_path.exists() and any(model_path.iterdir()):
  122. print(f" 状态: 已下载")
  123. else:
  124. print(f" 状态: 未下载")
  125. # 特殊配置
  126. if model_config.special_tokens:
  127. print(f"\n特殊Token: {model_config.special_tokens}")
  128. if model_config.lora_target_modules:
  129. print(f"LoRA目标模块: {model_config.lora_target_modules}")
  130. if model_config.generation_config:
  131. print(f"生成配置: {model_config.generation_config}")
  132. return True
  133. def test_model(self, model_key: str):
  134. """测试模型推理"""
  135. if model_key not in self.registry.list_models():
  136. print(f"错误: 未知的模型键 '{model_key}'")
  137. self.list_available_keys()
  138. return False
  139. try:
  140. print(f"\n测试模型: {model_key}")
  141. # 检查模型是否已下载
  142. self.config.set_current_model(model_key)
  143. model_path = Path(self.config.get_model_path())
  144. if not model_path.exists() or not any(model_path.iterdir()):
  145. print(f"模型未下载,请先下载: python {sys.argv[0]} download {model_key}")
  146. return False
  147. # 创建推理实例
  148. inference = SecurityModelInference(self.config, model_key=model_key)
  149. inference.load_model(model_key=model_key)
  150. # 测试对话
  151. test_message = "请简单介绍一下你自己"
  152. print(f"\n测试问题: {test_message}")
  153. print("生成回复中...")
  154. response, _ = inference.chat(test_message)
  155. print(f"\n模型回复: {response}")
  156. print(f"\n测试完成!模型 {model_key} 工作正常。")
  157. return True
  158. except Exception as e:
  159. print(f"模型测试失败: {e}")
  160. logger.error(f"测试模型 {model_key} 失败", exc_info=True)
  161. return False
  162. def chat_with_model(self, model_key: str):
  163. """与模型进行交互式对话"""
  164. if model_key not in self.registry.list_models():
  165. print(f"错误: 未知的模型键 '{model_key}'")
  166. self.list_available_keys()
  167. return False
  168. try:
  169. print(f"\n启动与模型 {model_key} 的对话")
  170. # 检查模型是否已下载
  171. self.config.set_current_model(model_key)
  172. model_path = Path(self.config.get_model_path())
  173. if not model_path.exists() or not any(model_path.iterdir()):
  174. print(f"模型未下载,请先下载: python {sys.argv[0]} download {model_key}")
  175. return False
  176. # 创建推理实例
  177. print("加载模型中...")
  178. inference = SecurityModelInference(self.config, model_key=model_key)
  179. inference.load_model(model_key=model_key)
  180. print("模型加载完成!")
  181. print("输入 'quit' 或 'exit' 退出对话")
  182. print("-" * 50)
  183. history = []
  184. while True:
  185. try:
  186. user_input = input("\n用户: ").strip()
  187. if user_input.lower() in ['quit', 'exit', '退出']:
  188. print("对话结束")
  189. break
  190. if not user_input:
  191. continue
  192. print("助手: ", end="", flush=True)
  193. response, history = inference.chat(user_input, history)
  194. print(response)
  195. except KeyboardInterrupt:
  196. print("\n\n对话被中断")
  197. break
  198. except Exception as e:
  199. print(f"\n对话出错: {e}")
  200. continue
  201. return True
  202. except Exception as e:
  203. print(f"启动对话失败: {e}")
  204. logger.error(f"与模型 {model_key} 对话失败", exc_info=True)
  205. return False
  206. def list_available_keys(self):
  207. """列出可用的模型键"""
  208. models = self.registry.list_models()
  209. print(f"\n可用的模型键: {', '.join(models.keys())}")
  210. def main():
  211. """主函数"""
  212. parser = argparse.ArgumentParser(
  213. description="神机模型管理工具",
  214. formatter_class=argparse.RawDescriptionHelpFormatter,
  215. epilog="""
  216. 示例:
  217. %(prog)s list # 列出所有模型
  218. %(prog)s download qwen2.5-1.5b # 下载Qwen模型
  219. %(prog)s switch chatglm3-6b # 切换到ChatGLM模型
  220. %(prog)s current # 显示当前模型
  221. %(prog)s test qwen2.5-1.5b # 测试模型
  222. %(prog)s chat qwen2.5-1.5b # 与模型对话
  223. %(prog)s info baichuan2-7b # 显示模型信息
  224. """
  225. )
  226. subparsers = parser.add_subparsers(dest='command', help='可用命令')
  227. # list命令
  228. subparsers.add_parser('list', help='列出所有可用模型')
  229. # download命令
  230. download_parser = subparsers.add_parser('download', help='下载指定模型')
  231. download_parser.add_argument('model_key', help='模型键')
  232. # switch命令
  233. switch_parser = subparsers.add_parser('switch', help='切换当前模型')
  234. switch_parser.add_argument('model_key', help='模型键')
  235. # current命令
  236. subparsers.add_parser('current', help='显示当前模型信息')
  237. # test命令
  238. test_parser = subparsers.add_parser('test', help='测试模型推理')
  239. test_parser.add_argument('model_key', help='模型键')
  240. # chat命令
  241. chat_parser = subparsers.add_parser('chat', help='与模型进行交互式对话')
  242. chat_parser.add_argument('model_key', help='模型键')
  243. # info命令
  244. info_parser = subparsers.add_parser('info', help='显示模型详细信息')
  245. info_parser.add_argument('model_key', help='模型键')
  246. args = parser.parse_args()
  247. if not args.command:
  248. parser.print_help()
  249. return
  250. manager = ModelManager()
  251. try:
  252. if args.command == 'list':
  253. manager.list_models()
  254. elif args.command == 'download':
  255. success = manager.download_model(args.model_key)
  256. sys.exit(0 if success else 1)
  257. elif args.command == 'switch':
  258. success = manager.switch_model(args.model_key)
  259. sys.exit(0 if success else 1)
  260. elif args.command == 'current':
  261. manager.show_current()
  262. elif args.command == 'test':
  263. success = manager.test_model(args.model_key)
  264. sys.exit(0 if success else 1)
  265. elif args.command == 'chat':
  266. success = manager.chat_with_model(args.model_key)
  267. sys.exit(0 if success else 1)
  268. elif args.command == 'info':
  269. success = manager.show_model_info(args.model_key)
  270. sys.exit(0 if success else 1)
  271. else:
  272. print(f"未知命令: {args.command}")
  273. parser.print_help()
  274. sys.exit(1)
  275. except KeyboardInterrupt:
  276. print("\n操作被中断")
  277. sys.exit(1)
  278. except Exception as e:
  279. print(f"执行失败: {e}")
  280. logger.error("命令执行失败", exc_info=True)
  281. sys.exit(1)
  282. if __name__ == "__main__":
  283. main()