test_qwen2_5.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import torch
  2. import warnings
  3. warnings.filterwarnings("ignore")
  4. from clean_llm.models.qwen2_5 import Qwen2_5
  5. from transformers import AutoTokenizer
  6. if torch.cuda.is_available():
  7. device = "cuda"
  8. elif torch.backends.mps.is_available():
  9. device = "mps"
  10. else:
  11. device = "cpu"
  12. model_path = "huggingface_models/Qwen/Qwen2.5-0.5B-Instruct"
  13. model = Qwen2_5.from_pretrained(model_path).to(device)
  14. tokenizer = AutoTokenizer.from_pretrained(model_path)
  15. print(f"[INFO] Load {model_path.split('/')[-1]} model on device {device}")
  16. prompt = "Give me a short introduction to large language model."
  17. messages = [
  18. {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
  19. {"role": "user", "content": prompt}
  20. ]
  21. text = tokenizer.apply_chat_template(
  22. messages,
  23. tokenize=False,
  24. add_generation_prompt=True
  25. )
  26. model_inputs = tokenizer([text], return_tensors="pt").to(device)
  27. input_ids = model_inputs["input_ids"]
  28. generated_idx = model.generate(
  29. input_ids,
  30. max_new_tokens=50,
  31. eos_token_id=tokenizer.eos_token_id
  32. )
  33. response_ids = generated_idx[0][len(input_ids[0]):]
  34. response = tokenizer.decode(response_ids, skip_special_tokens=True)
  35. print("Prompt:")
  36. print(prompt)
  37. print("Response:")
  38. print(response)