eval_pretrain.py 680 B

1234567891011121314151617181920212223242526272829
  1. import torch
  2. def evaluate(
  3. model,
  4. tokenizer,
  5. device,
  6. prompt,
  7. max_new_tokens,
  8. temperature,
  9. top_k,
  10. eos_token_id,
  11. ):
  12. model.eval()
  13. input_ids = tokenizer.encode(prompt)
  14. input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)
  15. with torch.no_grad():
  16. output_tokens = model.generate(
  17. input_tensor,
  18. max_new_tokens=max_new_tokens,
  19. temperature=temperature,
  20. top_k=top_k,
  21. eos_token_id=eos_token_id,
  22. )
  23. output_ids = output_tokens[0].cpu().numpy().tolist()
  24. full_ids = input_ids + output_ids
  25. text = tokenizer.decode(full_ids)
  26. return text