| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- import os
- import torch
- import hydra
- from omegaconf import DictConfig
- from clean_llm.eval.eval_pretrain import evaluate
- from clean_llm.models.qwen2_5 import Qwen2_5
- from clean_llm.models.cs336_lm import BasicsTransformerLM
- from clean_llm.tokenizer.tokenizer import get_custom_tokenizer
- from clean_llm.utils import _to_device_and_compile
- @hydra.main(config_path="configs", config_name="evaluate_cs336_lm", version_base=None)
- def main(cfg: DictConfig):
- model_config, eval_config, tokenizer_config = cfg.model, cfg.eval, cfg.tokenizer
- tokenizer = get_custom_tokenizer(**tokenizer_config)
- print(tokenizer.vocab_size)
- if cfg.model_type == "qwen2_5":
- model_config.vocab_size = tokenizer.vocab_size
- model_config.eos_token_id = tokenizer.eos_token_id
- model = Qwen2_5.from_config(model_config)
- elif cfg.model_type == "cs336_lm":
- model = BasicsTransformerLM(**model_config)
- model, device = _to_device_and_compile(model)
- tokenizer = get_custom_tokenizer(**tokenizer_config)
- with open(os.path.join(eval_config.save_path, f"ckpt_iter{eval_config.iteration}.pt"), 'rb') as f:
- checkpoint = torch.load(f, weights_only=False)
- model.load_state_dict(checkpoint['model_state_dict'])
- # 生成与输出
- result_text = evaluate(
- model=model,
- tokenizer=tokenizer,
- device=device,
- prompt=eval_config.prompt,
- max_new_tokens=eval_config.max_new_tokens,
- temperature=eval_config.temperature,
- top_k=eval_config.top_k,
- eos_token_id=tokenizer.eos_token_id # 视你的tokenizer设置而定
- )
- print("输入:", eval_config.prompt)
- print("生成结果:", result_text)
- if __name__ == "__main__":
- main()
|