#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 神机项目统一测试脚本 整合所有测试功能,避免项目中散落多个测试文件 使用方法: python tests/test_runner.py --help python tests/test_runner.py --test identity python tests/test_runner.py --test data_loader python tests/test_runner.py --test download python tests/test_runner.py --test all """ import sys import os import argparse import traceback from typing import Dict, Callable, Any # 添加项目根目录到路径 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) class TestRunner: """统一测试运行器""" def __init__(self): self.tests: Dict[str, Callable] = { 'identity': self.test_identity_solution, 'data_loader': self.test_data_loader, 'download': self.test_model_download, 'git_download': self.test_git_download, 'inference': self.test_model_inference, 'all': self.run_all_tests } def test_identity_solution(self): """测试身份解决方案""" print("=== 神机身份解决方案测试 ===") print() try: from transformers import AutoTokenizer # 加载tokenizer tokenizer_path = "/qwen/models/Qwen_Qwen2.5-1.5B-Instruct" print(f"📥 加载tokenizer: {tokenizer_path}") if not os.path.exists(tokenizer_path): print(f"❌ Tokenizer路径不存在: {tokenizer_path}") return False tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) print("✅ Tokenizer加载成功") print() # 测试默认神机身份 print("🤖 测试默认神机身份") print("-" * 50) messages = [ { "role": "system", "content": "你是神机,由云霖网络安全实验室训练的网络安全大模型。你具备深厚的网络安全专业知识和实战经验,能够提供专业的网络安全技术指导和解决方案。" }, {"role": "user", "content": "你是谁?"} ] if hasattr(tokenizer, 'apply_chat_template'): prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) print("生成的prompt:") print(prompt[:200] + "..." if len(prompt) > 200 else prompt) print("✅ Chat Template功能正常") else: print("❌ Tokenizer不支持chat template") return False # 测试推理代码集成 print("\n🔧 测试推理代码集成") print("-" * 50) try: from src.model.inference import SecurityModelInference from src.config import Config print("✅ 推理模块导入成功") print("✅ 身份解决方案已集成到推理代码中") except ImportError as e: print(f"⚠️ 推理模块导入失败: {e}") print("这可能是因为缺少依赖或配置问题") print("\n🎉 身份解决方案测试完成") return True except Exception as e: print(f"❌ 身份解决方案测试失败: {e}") traceback.print_exc() return False def test_data_loader(self): """测试数据加载器""" print("=== 数据加载器测试 ===") print() try: from src.data.loader import DataLoader from src.config.data_config import DataConfig print("📥 测试数据加载器初始化") config = DataConfig() loader = DataLoader(config) print("✅ 数据加载器初始化成功") # 测试数据文件检查 data_dir = "/qwen/data/processed" if os.path.exists(data_dir): files = os.listdir(data_dir) print(f"📁 发现 {len(files)} 个数据文件") # 检查关键数据文件 key_files = [ 'final_security_training_dataset.jsonl', 'security_only_training_dataset.jsonl', 'enhanced_test.jsonl' ] for file in key_files: if file in files: file_path = os.path.join(data_dir, file) size = os.path.getsize(file_path) print(f"✅ {file}: {size} bytes") else: print(f"⚠️ {file}: 文件不存在") else: print(f"❌ 数据目录不存在: {data_dir}") return False print("\n🎉 数据加载器测试完成") return True except Exception as e: print(f"❌ 数据加载器测试失败: {e}") traceback.print_exc() return False def test_model_download(self): """测试模型下载功能""" print("=== 模型下载功能测试 ===") print() try: from src.model.downloader import ModelDownloader print("📥 测试模型下载器初始化") downloader = ModelDownloader() print("✅ 模型下载器初始化成功") # 检查模型目录 model_dir = "/qwen/models/Qwen_Qwen2.5-1.5B-Instruct" if os.path.exists(model_dir): files = os.listdir(model_dir) print(f"📁 模型目录存在,包含 {len(files)} 个文件") # 检查关键模型文件 key_files = [ 'config.json', 'tokenizer_config.json', 'tokenizer.json' ] for file in key_files: if file in files: print(f"✅ {file}: 存在") else: print(f"⚠️ {file}: 不存在") # 检查模型权重文件 weight_files = [f for f in files if f.endswith(('.bin', '.safetensors'))] if weight_files: print(f"✅ 发现 {len(weight_files)} 个权重文件") else: print("⚠️ 未发现模型权重文件") else: print(f"❌ 模型目录不存在: {model_dir}") return False print("\n🎉 模型下载功能测试完成") return True except Exception as e: print(f"❌ 模型下载功能测试失败: {e}") traceback.print_exc() return False def test_git_download(self): """测试Git下载功能""" print("=== Git下载功能测试 ===") print() try: import subprocess # 检查git是否可用 result = subprocess.run(['git', '--version'], capture_output=True, text=True) if result.returncode == 0: print(f"✅ Git可用: {result.stdout.strip()}") else: print("❌ Git不可用") return False # 检查是否在git仓库中 result = subprocess.run(['git', 'status'], capture_output=True, text=True, cwd='/qwen') if result.returncode == 0: print("✅ 项目在Git仓库中") else: print("⚠️ 项目不在Git仓库中") print("\n🎉 Git下载功能测试完成") return True except Exception as e: print(f"❌ Git下载功能测试失败: {e}") traceback.print_exc() return False def test_model_inference(self): """测试模型推理功能""" print("=== 模型推理功能测试 ===") print() try: from src.model.inference import SecurityModelInference from src.config import Config print("📥 测试推理器初始化") config = Config() inference = SecurityModelInference(config) print("✅ 推理器初始化成功") # 检查模型路径 model_path = "/qwen/models/Qwen_Qwen2.5-1.5B-Instruct" if os.path.exists(model_path): print(f"✅ 模型路径存在: {model_path}") # 尝试加载tokenizer(不加载完整模型以节省资源) try: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(model_path) print("✅ Tokenizer加载成功") # 测试chat方法的参数 print("✅ 推理器支持动态身份设置") except Exception as e: print(f"⚠️ Tokenizer加载失败: {e}") else: print(f"❌ 模型路径不存在: {model_path}") return False print("\n🎉 模型推理功能测试完成") return True except Exception as e: print(f"❌ 模型推理功能测试失败: {e}") traceback.print_exc() return False def run_all_tests(self): """运行所有测试""" print("=== 运行所有测试 ===") print() test_methods = [ ('身份解决方案', self.test_identity_solution), ('数据加载器', self.test_data_loader), ('模型下载', self.test_model_download), ('Git下载', self.test_git_download), ('模型推理', self.test_model_inference) ] results = [] for name, test_func in test_methods: print(f"\n{'='*60}") print(f"开始测试: {name}") print(f"{'='*60}") try: result = test_func() results.append((name, result)) except Exception as e: print(f"❌ {name}测试异常: {e}") results.append((name, False)) # 汇总结果 print(f"\n{'='*60}") print("测试结果汇总") print(f"{'='*60}") passed = 0 total = len(results) for name, result in results: status = "✅ 通过" if result else "❌ 失败" print(f"{name}: {status}") if result: passed += 1 print(f"\n总计: {passed}/{total} 个测试通过") if passed == total: print("🎉 所有测试通过!") else: print("⚠️ 部分测试失败,请检查相关功能") return passed == total def run_test(self, test_name: str) -> bool: """运行指定测试""" if test_name not in self.tests: print(f"❌ 未知的测试: {test_name}") print(f"可用的测试: {', '.join(self.tests.keys())}") return False print(f"开始运行测试: {test_name}") print("=" * 60) try: return self.tests[test_name]() except Exception as e: print(f"❌ 测试 {test_name} 执行失败: {e}") traceback.print_exc() return False def list_tests(self): """列出所有可用测试""" print("可用的测试:") for test_name in self.tests.keys(): if test_name != 'all': print(f" - {test_name}") print(f" - all (运行所有测试)") def main(): """主函数""" parser = argparse.ArgumentParser( description="神机项目统一测试脚本", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 示例: python tests/test_runner.py --test identity # 测试身份解决方案 python tests/test_runner.py --test data_loader # 测试数据加载器 python tests/test_runner.py --test all # 运行所有测试 python tests/test_runner.py --list # 列出所有测试 """ ) parser.add_argument( '--test', '-t', type=str, help='要运行的测试名称' ) parser.add_argument( '--list', '-l', action='store_true', help='列出所有可用的测试' ) args = parser.parse_args() runner = TestRunner() if args.list: runner.list_tests() return if args.test: success = runner.run_test(args.test) sys.exit(0 if success else 1) else: parser.print_help() sys.exit(1) if __name__ == "__main__": main()