multi_model_example.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 多模型使用示例
  5. 本示例展示如何在项目中使用不同的模型进行推理和训练。
  6. """
  7. import sys
  8. from pathlib import Path
  9. # 添加项目根目录到Python路径
  10. project_root = Path(__file__).parent.parent
  11. sys.path.insert(0, str(project_root))
  12. from src.config.settings import Config
  13. from src.config.model_configs import ModelRegistry
  14. from src.model.inference import SecurityModelInference
  15. from src.model.trainer import SecurityModelTrainer
  16. from src.model.downloader import ModelDownloader
  17. def list_available_models():
  18. """列出所有可用的模型"""
  19. registry = ModelRegistry()
  20. models = registry.list_models()
  21. print("=== 可用模型列表 ===")
  22. for model_key, model_config in models.items():
  23. print(f"模型键: {model_key}")
  24. print(f" 名称: {model_config.name}")
  25. print(f" 架构: {model_config.architecture}")
  26. print(f" 最大长度: {model_config.max_length}")
  27. print(f" 支持Chat模板: {model_config.supports_chat_template}")
  28. print(f" 支持量化: {model_config.supports_quantization}")
  29. print(f" 支持LoRA: {model_config.supports_lora}")
  30. print()
  31. def test_model_inference(model_key: str):
  32. """测试指定模型的推理功能"""
  33. print(f"\n=== 测试模型推理: {model_key} ===")
  34. try:
  35. # 创建配置和推理实例
  36. config = Config()
  37. inference = SecurityModelInference(config, model_key=model_key)
  38. # 检查模型是否已下载
  39. downloader = ModelDownloader(config, model_key=model_key)
  40. model_path = downloader.ensure_model_downloaded()
  41. print(f"模型路径: {model_path}")
  42. # 加载模型
  43. inference.load_model(model_key=model_key)
  44. print("模型加载成功")
  45. # 测试对话
  46. test_message = "请介绍一下网络安全的基本概念"
  47. response, history = inference.chat(test_message)
  48. print(f"用户: {test_message}")
  49. print(f"助手: {response}")
  50. return True
  51. except Exception as e:
  52. print(f"模型 {model_key} 测试失败: {e}")
  53. return False
  54. def compare_models():
  55. """比较不同模型的回复"""
  56. print("\n=== 模型回复比较 ===")
  57. test_question = "什么是SQL注入攻击?如何防护?"
  58. print(f"测试问题: {test_question}\n")
  59. registry = ModelRegistry()
  60. models = registry.list_models()
  61. for model_key in models.keys():
  62. print(f"--- {model_key} ---")
  63. try:
  64. config = Config()
  65. inference = SecurityModelInference(config, model_key=model_key)
  66. # 检查模型是否存在
  67. model_path = Path(config.get_model_path())
  68. if not model_path.exists():
  69. print(f"模型未下载,跳过: {model_path}")
  70. continue
  71. inference.load_model(model_key=model_key)
  72. response, _ = inference.chat(test_question)
  73. print(f"回复: {response[:200]}..." if len(response) > 200 else f"回复: {response}")
  74. except Exception as e:
  75. print(f"错误: {e}")
  76. print()
  77. def switch_model_demo():
  78. """演示模型切换功能"""
  79. print("\n=== 模型切换演示 ===")
  80. config = Config()
  81. # 显示当前模型
  82. current_model = config.get_current_model_key()
  83. print(f"当前模型: {current_model}")
  84. # 切换到不同的模型
  85. registry = ModelRegistry()
  86. models = list(registry.list_models().keys())
  87. for model_key in models[:3]: # 只测试前3个模型
  88. print(f"\n切换到模型: {model_key}")
  89. config.set_current_model(model_key)
  90. current_config = config.get_current_model_config()
  91. print(f"模型名称: {current_config.name}")
  92. print(f"模型架构: {current_config.architecture}")
  93. print(f"本地路径: {config.get_model_path()}")
  94. def download_model_demo(model_key: str):
  95. """演示模型下载功能"""
  96. print(f"\n=== 下载模型演示: {model_key} ===")
  97. try:
  98. config = Config()
  99. downloader = ModelDownloader(config, model_key=model_key)
  100. # 获取下载信息
  101. modelscope_id = config.get_model_id_for_download('modelscope')
  102. huggingface_id = config.get_model_id_for_download('huggingface')
  103. print(f"ModelScope ID: {modelscope_id}")
  104. print(f"HuggingFace ID: {huggingface_id}")
  105. print(f"本地路径: {config.get_model_path()}")
  106. # 检查是否已下载
  107. model_path = Path(config.get_model_path())
  108. if model_path.exists() and any(model_path.iterdir()):
  109. print("模型已存在")
  110. else:
  111. print("模型需要下载")
  112. # 注意:实际下载可能需要很长时间,这里只是演示
  113. # model_path = downloader.ensure_model_downloaded()
  114. # print(f"下载完成: {model_path}")
  115. except Exception as e:
  116. print(f"下载演示失败: {e}")
  117. def main():
  118. """主函数"""
  119. print("神机多模型支持演示")
  120. print("=" * 50)
  121. # 1. 列出可用模型
  122. list_available_models()
  123. # 2. 演示模型切换
  124. switch_model_demo()
  125. # 3. 演示下载功能
  126. download_model_demo('qwen2.5-1.5b')
  127. download_model_demo('chatglm3-6b')
  128. # 4. 测试模型推理(如果模型已下载)
  129. registry = ModelRegistry()
  130. for model_key in list(registry.list_models().keys())[:2]: # 只测试前2个
  131. config = Config()
  132. config.set_current_model(model_key)
  133. model_path = Path(config.get_model_path())
  134. if model_path.exists() and any(model_path.iterdir()):
  135. test_model_inference(model_key)
  136. else:
  137. print(f"\n模型 {model_key} 未下载,跳过推理测试")
  138. # 5. 比较模型回复(如果有多个模型已下载)
  139. # compare_models()
  140. print("\n演示完成!")
  141. if __name__ == "__main__":
  142. main()