| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- import torch
- import hydra
- from omegaconf import DictConfig
- from datasets import load_dataset
- from transformers import AutoModelForCausalLM, AutoTokenizer
- from clean_llm.train.sft import run_tokenize_prompt_and_output
- from clean_llm.train.sft import run_get_response_log_probs
- from clean_llm.train.sft import run_sft_microbatch_train_step
- @hydra.main(config_path="configs", config_name="sft_gsm8k", version_base=None)
- def main(cfg: DictConfig):
- model = AutoModelForCausalLM.from_pretrained(
- cfg.model_path,
- torch_dtype=torch.bfloat16
- )
- tokenizer = AutoTokenizer.from_pretrained(cfg.model_path)
- dataset = load_dataset(cfg.dataset_path, 'main')
-
- train_prompt_strs = [example['question'] for example in dataset['train']]
- train_output_strs = [example['answer'] for example in dataset['train']]
- test_prompt_strs = [example['question'] for example in dataset['test']]
- test_output_strs = [example['answer'] for example in dataset['test']]
-
- print(f"Num of train examples = {len(train_prompt_strs)}")
-
-
- train_prompt_strs = train_prompt_strs[:3]
- train_output_strs = train_output_strs[:3]
-
- train_inputs = run_tokenize_prompt_and_output(train_prompt_strs, train_output_strs, tokenizer)
- test_inputs = run_tokenize_prompt_and_output(test_prompt_strs, test_output_strs, tokenizer)
-
- # import pdb; pdb.set_trace()
-
- res = run_get_response_log_probs(model,
- train_inputs['input_ids'],
- train_inputs['labels'],
- return_token_entropy=True)
- policy_log_probs = res['log_probs']
- token_entropy = res['token_entropy']
-
- loss, metadata = run_sft_microbatch_train_step(policy_log_probs,
- train_inputs['response_mask'],
- gradient_accumulation_steps=1,
- normalize_constant=1)
-
- print('loss', loss)
-
-
-
- if __name__ == '__main__':
- main()
|