| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- 神机项目统一测试脚本
- 整合所有测试功能,避免项目中散落多个测试文件
- 使用方法:
- python tests/test_runner.py --help
- python tests/test_runner.py --test identity
- python tests/test_runner.py --test data_loader
- python tests/test_runner.py --test download
- python tests/test_runner.py --test all
- """
- import sys
- import os
- import argparse
- import traceback
- from typing import Dict, Callable, Any
- # 添加项目根目录到路径
- sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
- class TestRunner:
- """统一测试运行器"""
-
- def __init__(self):
- self.tests: Dict[str, Callable] = {
- 'identity': self.test_identity_solution,
- 'data_loader': self.test_data_loader,
- 'download': self.test_model_download,
- 'git_download': self.test_git_download,
- 'inference': self.test_model_inference,
- 'all': self.run_all_tests
- }
-
- def test_identity_solution(self):
- """测试身份解决方案"""
- print("=== 神机身份解决方案测试 ===")
- print()
-
- try:
- from transformers import AutoTokenizer
-
- # 加载tokenizer
- tokenizer_path = "/qwen/models/Qwen_Qwen2.5-1.5B-Instruct"
- print(f"📥 加载tokenizer: {tokenizer_path}")
-
- if not os.path.exists(tokenizer_path):
- print(f"❌ Tokenizer路径不存在: {tokenizer_path}")
- return False
-
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
- print("✅ Tokenizer加载成功")
- print()
-
- # 测试默认神机身份
- print("🤖 测试默认神机身份")
- print("-" * 50)
-
- messages = [
- {
- "role": "system",
- "content": "你是神机,由云霖网络安全实验室训练的网络安全大模型。你具备深厚的网络安全专业知识和实战经验,能够提供专业的网络安全技术指导和解决方案。"
- },
- {"role": "user", "content": "你是谁?"}
- ]
-
- if hasattr(tokenizer, 'apply_chat_template'):
- prompt = tokenizer.apply_chat_template(
- messages,
- tokenize=False,
- add_generation_prompt=True
- )
- print("生成的prompt:")
- print(prompt[:200] + "..." if len(prompt) > 200 else prompt)
- print("✅ Chat Template功能正常")
- else:
- print("❌ Tokenizer不支持chat template")
- return False
-
- # 测试推理代码集成
- print("\n🔧 测试推理代码集成")
- print("-" * 50)
-
- try:
- from src.model.inference import SecurityModelInference
- from src.config import Config
-
- print("✅ 推理模块导入成功")
- print("✅ 身份解决方案已集成到推理代码中")
-
- except ImportError as e:
- print(f"⚠️ 推理模块导入失败: {e}")
- print("这可能是因为缺少依赖或配置问题")
-
- print("\n🎉 身份解决方案测试完成")
- return True
-
- except Exception as e:
- print(f"❌ 身份解决方案测试失败: {e}")
- traceback.print_exc()
- return False
-
- def test_data_loader(self):
- """测试数据加载器"""
- print("=== 数据加载器测试 ===")
- print()
-
- try:
- from src.data.loader import DataLoader
- from src.config.data_config import DataConfig
-
- print("📥 测试数据加载器初始化")
- config = DataConfig()
- loader = DataLoader(config)
- print("✅ 数据加载器初始化成功")
-
- # 测试数据文件检查
- data_dir = "/qwen/data/processed"
- if os.path.exists(data_dir):
- files = os.listdir(data_dir)
- print(f"📁 发现 {len(files)} 个数据文件")
-
- # 检查关键数据文件
- key_files = [
- 'final_security_training_dataset.jsonl',
- 'security_only_training_dataset.jsonl',
- 'enhanced_test.jsonl'
- ]
-
- for file in key_files:
- if file in files:
- file_path = os.path.join(data_dir, file)
- size = os.path.getsize(file_path)
- print(f"✅ {file}: {size} bytes")
- else:
- print(f"⚠️ {file}: 文件不存在")
- else:
- print(f"❌ 数据目录不存在: {data_dir}")
- return False
-
- print("\n🎉 数据加载器测试完成")
- return True
-
- except Exception as e:
- print(f"❌ 数据加载器测试失败: {e}")
- traceback.print_exc()
- return False
-
- def test_model_download(self):
- """测试模型下载功能"""
- print("=== 模型下载功能测试 ===")
- print()
-
- try:
- from src.model.downloader import ModelDownloader
-
- print("📥 测试模型下载器初始化")
- downloader = ModelDownloader()
- print("✅ 模型下载器初始化成功")
-
- # 检查模型目录
- model_dir = "/qwen/models/Qwen_Qwen2.5-1.5B-Instruct"
- if os.path.exists(model_dir):
- files = os.listdir(model_dir)
- print(f"📁 模型目录存在,包含 {len(files)} 个文件")
-
- # 检查关键模型文件
- key_files = [
- 'config.json',
- 'tokenizer_config.json',
- 'tokenizer.json'
- ]
-
- for file in key_files:
- if file in files:
- print(f"✅ {file}: 存在")
- else:
- print(f"⚠️ {file}: 不存在")
-
- # 检查模型权重文件
- weight_files = [f for f in files if f.endswith(('.bin', '.safetensors'))]
- if weight_files:
- print(f"✅ 发现 {len(weight_files)} 个权重文件")
- else:
- print("⚠️ 未发现模型权重文件")
- else:
- print(f"❌ 模型目录不存在: {model_dir}")
- return False
-
- print("\n🎉 模型下载功能测试完成")
- return True
-
- except Exception as e:
- print(f"❌ 模型下载功能测试失败: {e}")
- traceback.print_exc()
- return False
-
- def test_git_download(self):
- """测试Git下载功能"""
- print("=== Git下载功能测试 ===")
- print()
-
- try:
- import subprocess
-
- # 检查git是否可用
- result = subprocess.run(['git', '--version'],
- capture_output=True, text=True)
- if result.returncode == 0:
- print(f"✅ Git可用: {result.stdout.strip()}")
- else:
- print("❌ Git不可用")
- return False
-
- # 检查是否在git仓库中
- result = subprocess.run(['git', 'status'],
- capture_output=True, text=True,
- cwd='/qwen')
- if result.returncode == 0:
- print("✅ 项目在Git仓库中")
- else:
- print("⚠️ 项目不在Git仓库中")
-
- print("\n🎉 Git下载功能测试完成")
- return True
-
- except Exception as e:
- print(f"❌ Git下载功能测试失败: {e}")
- traceback.print_exc()
- return False
-
- def test_model_inference(self):
- """测试模型推理功能"""
- print("=== 模型推理功能测试 ===")
- print()
-
- try:
- from src.model.inference import SecurityModelInference
- from src.config import Config
-
- print("📥 测试推理器初始化")
- config = Config()
- inference = SecurityModelInference(config)
- print("✅ 推理器初始化成功")
-
- # 检查模型路径
- model_path = "/qwen/models/Qwen_Qwen2.5-1.5B-Instruct"
- if os.path.exists(model_path):
- print(f"✅ 模型路径存在: {model_path}")
-
- # 尝试加载tokenizer(不加载完整模型以节省资源)
- try:
- from transformers import AutoTokenizer
- tokenizer = AutoTokenizer.from_pretrained(model_path)
- print("✅ Tokenizer加载成功")
-
- # 测试chat方法的参数
- print("✅ 推理器支持动态身份设置")
-
- except Exception as e:
- print(f"⚠️ Tokenizer加载失败: {e}")
- else:
- print(f"❌ 模型路径不存在: {model_path}")
- return False
-
- print("\n🎉 模型推理功能测试完成")
- return True
-
- except Exception as e:
- print(f"❌ 模型推理功能测试失败: {e}")
- traceback.print_exc()
- return False
-
- def run_all_tests(self):
- """运行所有测试"""
- print("=== 运行所有测试 ===")
- print()
-
- test_methods = [
- ('身份解决方案', self.test_identity_solution),
- ('数据加载器', self.test_data_loader),
- ('模型下载', self.test_model_download),
- ('Git下载', self.test_git_download),
- ('模型推理', self.test_model_inference)
- ]
-
- results = []
- for name, test_func in test_methods:
- print(f"\n{'='*60}")
- print(f"开始测试: {name}")
- print(f"{'='*60}")
-
- try:
- result = test_func()
- results.append((name, result))
- except Exception as e:
- print(f"❌ {name}测试异常: {e}")
- results.append((name, False))
-
- # 汇总结果
- print(f"\n{'='*60}")
- print("测试结果汇总")
- print(f"{'='*60}")
-
- passed = 0
- total = len(results)
-
- for name, result in results:
- status = "✅ 通过" if result else "❌ 失败"
- print(f"{name}: {status}")
- if result:
- passed += 1
-
- print(f"\n总计: {passed}/{total} 个测试通过")
-
- if passed == total:
- print("🎉 所有测试通过!")
- else:
- print("⚠️ 部分测试失败,请检查相关功能")
-
- return passed == total
-
- def run_test(self, test_name: str) -> bool:
- """运行指定测试"""
- if test_name not in self.tests:
- print(f"❌ 未知的测试: {test_name}")
- print(f"可用的测试: {', '.join(self.tests.keys())}")
- return False
-
- print(f"开始运行测试: {test_name}")
- print("=" * 60)
-
- try:
- return self.tests[test_name]()
- except Exception as e:
- print(f"❌ 测试 {test_name} 执行失败: {e}")
- traceback.print_exc()
- return False
-
- def list_tests(self):
- """列出所有可用测试"""
- print("可用的测试:")
- for test_name in self.tests.keys():
- if test_name != 'all':
- print(f" - {test_name}")
- print(f" - all (运行所有测试)")
- def main():
- """主函数"""
- parser = argparse.ArgumentParser(
- description="神机项目统一测试脚本",
- formatter_class=argparse.RawDescriptionHelpFormatter,
- epilog="""
- 示例:
- python tests/test_runner.py --test identity # 测试身份解决方案
- python tests/test_runner.py --test data_loader # 测试数据加载器
- python tests/test_runner.py --test all # 运行所有测试
- python tests/test_runner.py --list # 列出所有测试
- """
- )
-
- parser.add_argument(
- '--test', '-t',
- type=str,
- help='要运行的测试名称'
- )
-
- parser.add_argument(
- '--list', '-l',
- action='store_true',
- help='列出所有可用的测试'
- )
-
- args = parser.parse_args()
-
- runner = TestRunner()
-
- if args.list:
- runner.list_tests()
- return
-
- if args.test:
- success = runner.run_test(args.test)
- sys.exit(0 if success else 1)
- else:
- parser.print_help()
- sys.exit(1)
- if __name__ == "__main__":
- main()
|