| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- 多模型使用示例
- 本示例展示如何在项目中使用不同的模型进行推理和训练。
- """
- import sys
- from pathlib import Path
- # 添加项目根目录到Python路径
- project_root = Path(__file__).parent.parent
- sys.path.insert(0, str(project_root))
- from src.config.settings import Config
- from src.config.model_configs import ModelRegistry
- from src.model.inference import SecurityModelInference
- from src.model.trainer import SecurityModelTrainer
- from src.model.downloader import ModelDownloader
- def list_available_models():
- """列出所有可用的模型"""
- registry = ModelRegistry()
- models = registry.list_models()
-
- print("=== 可用模型列表 ===")
- for model_key, model_config in models.items():
- print(f"模型键: {model_key}")
- print(f" 名称: {model_config.name}")
- print(f" 架构: {model_config.architecture}")
- print(f" 最大长度: {model_config.max_length}")
- print(f" 支持Chat模板: {model_config.supports_chat_template}")
- print(f" 支持量化: {model_config.supports_quantization}")
- print(f" 支持LoRA: {model_config.supports_lora}")
- print()
- def test_model_inference(model_key: str):
- """测试指定模型的推理功能"""
- print(f"\n=== 测试模型推理: {model_key} ===")
-
- try:
- # 创建配置和推理实例
- config = Config()
- inference = SecurityModelInference(config, model_key=model_key)
-
- # 检查模型是否已下载
- downloader = ModelDownloader(config, model_key=model_key)
- model_path = downloader.ensure_model_downloaded()
- print(f"模型路径: {model_path}")
-
- # 加载模型
- inference.load_model(model_key=model_key)
- print("模型加载成功")
-
- # 测试对话
- test_message = "请介绍一下网络安全的基本概念"
- response, history = inference.chat(test_message)
-
- print(f"用户: {test_message}")
- print(f"助手: {response}")
-
- return True
-
- except Exception as e:
- print(f"模型 {model_key} 测试失败: {e}")
- return False
- def compare_models():
- """比较不同模型的回复"""
- print("\n=== 模型回复比较 ===")
-
- test_question = "什么是SQL注入攻击?如何防护?"
- print(f"测试问题: {test_question}\n")
-
- registry = ModelRegistry()
- models = registry.list_models()
-
- for model_key in models.keys():
- print(f"--- {model_key} ---")
- try:
- config = Config()
- inference = SecurityModelInference(config, model_key=model_key)
-
- # 检查模型是否存在
- model_path = Path(config.get_model_path())
- if not model_path.exists():
- print(f"模型未下载,跳过: {model_path}")
- continue
-
- inference.load_model(model_key=model_key)
- response, _ = inference.chat(test_question)
-
- print(f"回复: {response[:200]}..." if len(response) > 200 else f"回复: {response}")
-
- except Exception as e:
- print(f"错误: {e}")
-
- print()
- def switch_model_demo():
- """演示模型切换功能"""
- print("\n=== 模型切换演示 ===")
-
- config = Config()
-
- # 显示当前模型
- current_model = config.get_current_model_key()
- print(f"当前模型: {current_model}")
-
- # 切换到不同的模型
- registry = ModelRegistry()
- models = list(registry.list_models().keys())
-
- for model_key in models[:3]: # 只测试前3个模型
- print(f"\n切换到模型: {model_key}")
- config.set_current_model(model_key)
-
- current_config = config.get_current_model_config()
- print(f"模型名称: {current_config.name}")
- print(f"模型架构: {current_config.architecture}")
- print(f"本地路径: {config.get_model_path()}")
- def download_model_demo(model_key: str):
- """演示模型下载功能"""
- print(f"\n=== 下载模型演示: {model_key} ===")
-
- try:
- config = Config()
- downloader = ModelDownloader(config, model_key=model_key)
-
- # 获取下载信息
- modelscope_id = config.get_model_id_for_download('modelscope')
- huggingface_id = config.get_model_id_for_download('huggingface')
-
- print(f"ModelScope ID: {modelscope_id}")
- print(f"HuggingFace ID: {huggingface_id}")
- print(f"本地路径: {config.get_model_path()}")
-
- # 检查是否已下载
- model_path = Path(config.get_model_path())
- if model_path.exists() and any(model_path.iterdir()):
- print("模型已存在")
- else:
- print("模型需要下载")
- # 注意:实际下载可能需要很长时间,这里只是演示
- # model_path = downloader.ensure_model_downloaded()
- # print(f"下载完成: {model_path}")
-
- except Exception as e:
- print(f"下载演示失败: {e}")
- def main():
- """主函数"""
- print("神机多模型支持演示")
- print("=" * 50)
-
- # 1. 列出可用模型
- list_available_models()
-
- # 2. 演示模型切换
- switch_model_demo()
-
- # 3. 演示下载功能
- download_model_demo('qwen2.5-1.5b')
- download_model_demo('chatglm3-6b')
-
- # 4. 测试模型推理(如果模型已下载)
- registry = ModelRegistry()
- for model_key in list(registry.list_models().keys())[:2]: # 只测试前2个
- config = Config()
- config.set_current_model(model_key)
- model_path = Path(config.get_model_path())
-
- if model_path.exists() and any(model_path.iterdir()):
- test_model_inference(model_key)
- else:
- print(f"\n模型 {model_key} 未下载,跳过推理测试")
-
- # 5. 比较模型回复(如果有多个模型已下载)
- # compare_models()
-
- print("\n演示完成!")
- if __name__ == "__main__":
- main()
|