train_sft.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import torch
  2. import hydra
  3. from omegaconf import DictConfig
  4. from datasets import load_dataset
  5. from transformers import AutoModelForCausalLM, AutoTokenizer
  6. from clean_llm.train.sft import run_tokenize_prompt_and_output
  7. from clean_llm.train.sft import run_get_response_log_probs
  8. from clean_llm.train.sft import run_sft_microbatch_train_step
  9. @hydra.main(config_path="configs", config_name="sft_gsm8k", version_base=None)
  10. def main(cfg: DictConfig):
  11. model = AutoModelForCausalLM.from_pretrained(
  12. cfg.model_path,
  13. torch_dtype=torch.bfloat16
  14. )
  15. tokenizer = AutoTokenizer.from_pretrained(cfg.model_path)
  16. dataset = load_dataset(cfg.dataset_path, 'main')
  17. train_prompt_strs = [example['question'] for example in dataset['train']]
  18. train_output_strs = [example['answer'] for example in dataset['train']]
  19. test_prompt_strs = [example['question'] for example in dataset['test']]
  20. test_output_strs = [example['answer'] for example in dataset['test']]
  21. print(f"Num of train examples = {len(train_prompt_strs)}")
  22. train_prompt_strs = train_prompt_strs[:3]
  23. train_output_strs = train_output_strs[:3]
  24. train_inputs = run_tokenize_prompt_and_output(train_prompt_strs, train_output_strs, tokenizer)
  25. test_inputs = run_tokenize_prompt_and_output(test_prompt_strs, test_output_strs, tokenizer)
  26. # import pdb; pdb.set_trace()
  27. res = run_get_response_log_probs(model,
  28. train_inputs['input_ids'],
  29. train_inputs['labels'],
  30. return_token_entropy=True)
  31. policy_log_probs = res['log_probs']
  32. token_entropy = res['token_entropy']
  33. loss, metadata = run_sft_microbatch_train_step(policy_log_probs,
  34. train_inputs['response_mask'],
  35. gradient_accumulation_steps=1,
  36. normalize_constant=1)
  37. print('loss', loss)
  38. if __name__ == '__main__':
  39. main()