eval_pretrain.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import os
  2. import torch
  3. import hydra
  4. from omegaconf import DictConfig
  5. from clean_llm.eval.eval_pretrain import evaluate
  6. from clean_llm.models.qwen2_5 import Qwen2_5
  7. from clean_llm.models.cs336_lm import BasicsTransformerLM
  8. from clean_llm.tokenizer.tokenizer import get_custom_tokenizer
  9. from clean_llm.utils import _to_device_and_compile
  10. @hydra.main(config_path="configs", config_name="evaluate_cs336_lm", version_base=None)
  11. def main(cfg: DictConfig):
  12. model_config, eval_config, tokenizer_config = cfg.model, cfg.eval, cfg.tokenizer
  13. tokenizer = get_custom_tokenizer(**tokenizer_config)
  14. print(tokenizer.vocab_size)
  15. if cfg.model_type == "qwen2_5":
  16. model_config.vocab_size = tokenizer.vocab_size
  17. model_config.eos_token_id = tokenizer.eos_token_id
  18. model = Qwen2_5.from_config(model_config)
  19. elif cfg.model_type == "cs336_lm":
  20. model = BasicsTransformerLM(**model_config)
  21. model, device = _to_device_and_compile(model)
  22. tokenizer = get_custom_tokenizer(**tokenizer_config)
  23. with open(os.path.join(eval_config.save_path, f"ckpt_iter{eval_config.iteration}.pt"), 'rb') as f:
  24. checkpoint = torch.load(f, weights_only=False)
  25. model.load_state_dict(checkpoint['model_state_dict'])
  26. # 生成与输出
  27. result_text = evaluate(
  28. model=model,
  29. tokenizer=tokenizer,
  30. device=device,
  31. prompt=eval_config.prompt,
  32. max_new_tokens=eval_config.max_new_tokens,
  33. temperature=eval_config.temperature,
  34. top_k=eval_config.top_k,
  35. eos_token_id=tokenizer.eos_token_id # 视你的tokenizer设置而定
  36. )
  37. print("输入:", eval_config.prompt)
  38. print("生成结果:", result_text)
  39. if __name__ == "__main__":
  40. main()