sft.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669
  1. from __future__ import annotations
  2. import os
  3. import re
  4. import torch
  5. import torch.nn.functional as F
  6. import numpy as np
  7. import pandas as pd
  8. from typing import Any, Callable, Literal, Dict, List
  9. from torch import Tensor
  10. from torch.utils.data import Dataset
  11. from vllm.model_executor import set_random_seed as vllm_set_random_seed
  12. from vllm import LLM, SamplingParams
  13. from unittest.mock import patch
  14. from transformers import PreTrainedModel, AutoTokenizer, PreTrainedTokenizerBase
  15. def init_vllm(model_id: str, device: str, seed: int, gpu_memory_utilization: float = 0.85):
  16. """
  17. 初始化 vLLM LLM(大语言模型),可选择设备及显存利用率,在推理时与训练策略分离。
  18. 参考自 HuggingFace TRL 实现。
  19. Args:
  20. model_id (str): 模型标识
  21. device (str): 目标设备,如 'cuda:0'
  22. seed (int): 随机种子
  23. gpu_memory_utilization (float): 显存占用比例
  24. Returns:
  25. LLM: vLLM初始化好的对象
  26. """
  27. vllm_set_random_seed(seed)
  28. # Patch 1:让vllm假装“集群”只有1卡
  29. world_size_patch = patch("torch.distributed.get_world_size", return_value=1)
  30. # Patch 2:跳过vllm内部显存剖析的某个断言
  31. profiling_patch = patch(
  32. "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
  33. return_value=None
  34. )
  35. with world_size_patch, profiling_patch:
  36. llm = LLM(
  37. model=model_id,
  38. device=device,
  39. dtype=torch.bfloat16,
  40. enable_prefix_caching=True,
  41. gpu_memory_utilization=gpu_memory_utilization,
  42. )
  43. return llm
  44. def load_policy_into_vllm_instance(policy: PreTrainedModel, llm: LLM):
  45. """
  46. 把训练好的PyTorch模型参数装载进vLLM实例。
  47. 参考自 HuggingFace TRL。
  48. Args:
  49. policy (PreTrainedModel): 已训练好的transformers模型
  50. llm (LLM): vLLM实例
  51. Returns:
  52. None
  53. """
  54. state_dict = policy.state_dict()
  55. # 下面这一行依赖vllm当前内部实现,如有变动请根据vllm源代码调整!
  56. llm_model = llm.llm_engine.model_executor.driver_worker.model_runner.model
  57. llm_model.load_weights(state_dict.items())
  58. def evaluate_gsm8k(
  59. model_id: str,
  60. policy: PreTrainedModel,
  61. tokenizer: AutoTokenizer,
  62. prompt_strs: List[str],
  63. output_strs: List[str],
  64. save_path,
  65. device="cuda:0",
  66. seed=0,
  67. gpu_memory_utilization=0.5
  68. ):
  69. llm = init_vllm(model_id, device, seed, gpu_memory_utilization)
  70. sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=1024, stop=tokenizer.eos_token)
  71. outputs = llm.generate(prompt_strs, sampling_params)
  72. # import pdb; pdb.set_trace()
  73. result_df = pd.DataFrame(columns=['Prompt', 'Generated_Text', 'Correct_Answer', 'Parsed_Answer', 'Parsed_Correct_Answer', 'Evaluation_Score', 'ParseFail'])
  74. correct, parse_fail_cnt = 0, 0
  75. for i, output in enumerate(outputs):
  76. prompt = output.prompt
  77. generated_text = output.outputs[0].text
  78. correct_answer = output_strs[i]
  79. parsed_answer = parse_gsm8k_qwen_response(generated_text)
  80. parsed_correct_answer = run_parse_gsm8k_response(correct_answer)
  81. parse_fail = parsed_answer == None
  82. evaluation_score = 1 if parsed_correct_answer == parsed_answer else 0
  83. if evaluation_score == 1:
  84. correct += 1
  85. if parse_fail:
  86. parse_fail_cnt += 1
  87. temp_df = pd.DataFrame({
  88. 'Prompt': [prompt],
  89. 'Generated_Text': [generated_text],
  90. 'Correct_Answer': [correct_answer],
  91. 'Parsed_Answer': [parsed_answer],
  92. 'Parsed_Correct_Answer': [parsed_correct_answer],
  93. 'Evaluation_Score': [evaluation_score],
  94. 'ParseFail': [parse_fail]
  95. })
  96. result_df = pd.concat([result_df, temp_df], ignore_index=True)
  97. print(f"Parse fail {parse_fail_cnt}/{len(outputs)}")
  98. print(f"Correct {correct}/{len(outputs)} Accuracy is {round(correct/len(outputs)*100, 2)}%")
  99. result_df.to_csv(save_path)
  100. return {
  101. 'parse_fail_rate': parse_fail_cnt / len(outputs),
  102. 'correct_rate': correct / len(outputs)
  103. }
  104. def parse_gsm8k_qwen_response(
  105. model_output: str,
  106. ) -> str | None:
  107. matches = re.findall(r'```output(.*?)```', model_output, re.DOTALL)
  108. if matches:
  109. res = matches[0].strip()
  110. return res
  111. return None
  112. def run_tokenize_prompt_and_output(
  113. prompt_strs: list[str],
  114. output_strs: list[str],
  115. tokenizer: PreTrainedTokenizerBase,
  116. ) -> dict[str, Tensor]:
  117. """Tokenize the prompt and output strings, and construct a mask that is 1
  118. for the response tokens and 0 for other tokens (prompt or padding).
  119. Args:
  120. prompt_strs: list[str], the prompt strings.
  121. output_strs: list[str], the output strings.
  122. tokenizer: PreTrainedTokenizer, the tokenizer to use.
  123. Returns:
  124. dict[str, torch.Tensor]:
  125. "input_ids": torch.Tensor of shape (batch_size, max(prompt_and_output_lens) - 1):
  126. the tokenized prompt and output strings, with the final token sliced off.
  127. "labels": torch.Tensor of shape (batch_size, max(prompt_and_output_lens) - 1):
  128. shifted input_ids (i.e., the input_ids without the first token).
  129. "response_mask": torch.Tensor of shape (batch_size, max(prompt_and_output_lens) - 1):
  130. a mask on the response tokens in `labels`.
  131. """
  132. input_ids_list = []
  133. response_mask_list = []
  134. for prompt, output in zip(prompt_strs, output_strs):
  135. prompt_enc = tokenizer(prompt, add_special_tokens=False)
  136. output_enc = tokenizer(output, add_special_tokens=False)
  137. full_input = prompt_enc['input_ids'] + output_enc['input_ids']
  138. response_mask = [0] * len(prompt_enc['input_ids']) + [1] * len(output_enc['input_ids'])
  139. input_ids_list.append(torch.tensor(full_input, dtype=torch.long))
  140. response_mask_list.append(torch.tensor(response_mask, dtype=torch.long))
  141. batch_size = len(input_ids_list)
  142. max_len = max(len(ids) for ids in input_ids_list)
  143. input_ids_batch = torch.full((batch_size, max_len), tokenizer.pad_token_id, dtype=torch.long)
  144. response_mask_batch = torch.zeros((batch_size, max_len), dtype=torch.long)
  145. for i, (ids, mask) in enumerate(zip(input_ids_list, response_mask_list)):
  146. seq_len = len(ids)
  147. input_ids_batch[i, :seq_len] = ids
  148. response_mask_batch[i, :seq_len] = mask
  149. return {
  150. "input_ids": input_ids_batch[:, :-1], # (batch, max_len-1)
  151. "labels": input_ids_batch[:, 1:], # (batch, max_len-1)
  152. "response_mask": response_mask_batch[:, 1:] # (batch, max_len-1)
  153. }
  154. def run_compute_group_normalized_rewards(
  155. reward_fn: Callable,
  156. rollout_responses: list[str],
  157. repeated_ground_truths: list[str],
  158. group_size: int,
  159. advantage_eps: float,
  160. normalize_by_std: bool,
  161. ) -> tuple[torch.Tensor, dict[str, float]]:
  162. """
  163. Compute rewards for each group of rollout responses,
  164. normalized by the group size.
  165. For more on GRPO, see:
  166. DeepSeekMath: https://arxiv.org/abs/2402.03300
  167. DeepSeek-R1: https://arxiv.org/abs/2501.12948
  168. Args:
  169. reward_fn: Callable[[str, str], dict[str, float]],
  170. scores the rollout responses against the ground truths,
  171. producing a dict with keys
  172. "reward", "format_reward", and "answer_reward".
  173. rollout_responses: list[str], rollouts from the policy.
  174. The length of this list is
  175. `rollout_batch_size = n_prompts_per_rollout_batch * group_size`.
  176. repeated_ground_truths: list[str], the ground truths for the examples.
  177. The length of this list is `rollout_batch_size`,
  178. because the ground truth for each example is repeated `group_size` times.
  179. group_size: int, number of rollouts per group.
  180. advantage_eps: float, epsilon to avoid division by zero
  181. during group normalization.
  182. normalize_by_std: bool, whether to normalize the rewards by
  183. std(rewards).
  184. Returns:
  185. tuple[torch.Tensor, torch.Tensor, dict[str, float]]:
  186. torch.Tensor of shape (rollout_batch_size,):
  187. group-normalized rewards for each rollout response.
  188. torch.Tensor of shape (rollout_batch_size,):
  189. raw rewards for each rollout response.
  190. dict[str, float]: metadata for the rewards of the rollout batch.
  191. You may choose what you wish to log here
  192. (some statistics of the rewards, etc.).
  193. """
  194. # 1. Compute raw rewards for all responses
  195. raw_rewards = []
  196. format_rewards = []
  197. answer_rewards = []
  198. for response, gt in zip(rollout_responses, repeated_ground_truths):
  199. reward_dict = reward_fn(response, gt)
  200. raw_rewards.append(reward_dict["reward"])
  201. format_rewards.append(reward_dict["format_reward"])
  202. answer_rewards.append(reward_dict["answer_reward"])
  203. raw_rewards = torch.tensor(raw_rewards, dtype=torch.float32)
  204. # 2. Group normalization
  205. N = len(raw_rewards)
  206. assert N % group_size == 0, "Rollout batch size must be divisible by group_size"
  207. n_groups = N // group_size
  208. # 分组: [n_groups, group_size]
  209. group_rewards = raw_rewards.view(n_groups, group_size)
  210. group_means = group_rewards.mean(dim=1, keepdim=True)
  211. if normalize_by_std:
  212. group_stds = group_rewards.std(dim=1, keepdim=True)
  213. denom = group_stds + advantage_eps
  214. else:
  215. denom = 1.0
  216. # 归一化
  217. normalized_groups = (group_rewards - group_means) / denom # [n_groups, group_size]
  218. normalized_rewards = normalized_groups.view(N) # 还原回(N,)
  219. # 3. Optional: Collect some statistics
  220. metadata = {
  221. "reward_mean": float(raw_rewards.mean()),
  222. "reward_std": float(raw_rewards.std()),
  223. "reward_max": float(raw_rewards.max()),
  224. "reward_min": float(raw_rewards.min()),
  225. "format_reward_mean": float(np.mean(format_rewards)),
  226. "answer_reward_mean": float(np.mean(answer_rewards)),
  227. }
  228. return normalized_rewards, raw_rewards, metadata
  229. def run_compute_entropy(logits: torch.Tensor) -> torch.Tensor:
  230. """Get the entropy of the logits (i.e., entropy of the final dimension)."""
  231. # Numerically stable computation of entropy
  232. lse = torch.logsumexp(logits, dim=-1)
  233. probs = torch.softmax(logits, dim=-1)
  234. expected_logit = (probs * logits).sum(dim=-1)
  235. entropy = lse - expected_logit
  236. return entropy
  237. def run_get_response_log_probs(
  238. model: torch.nn.Module,
  239. input_ids: torch.Tensor,
  240. labels: torch.Tensor,
  241. return_token_entropy: bool,
  242. ) -> torch.Tensor:
  243. """Get the conditional log-probs of the response given the prompt,
  244. and optionally the entropy of the next token predictions.
  245. Args:
  246. model: PreTrainedModel, the model to score.
  247. input_ids: torch.Tensor of shape (batch_size, sequence_length):
  248. the tokenized prompt and output.
  249. labels: torch.Tensor of shape (batch_size, sequence_length):
  250. shifted input_ids.
  251. return_token_entropy: bool, whether to return the entropy of the
  252. next token predictions.
  253. Returns:
  254. dict[str, torch.Tensor]:
  255. "log_probs": torch.Tensor of shape (batch_size, sequence_length):
  256. the conditional log-probs of the response given the prompt.
  257. Note that we have not masked out the token indices corresponding
  258. to the prompt or padding; that is done in the train loop.
  259. "token_entropy": Optional[torch.Tensor] of shape (batch_size, sequence_length):
  260. the entropy of the next token predictions. As with the log-probs,
  261. we have not masked out the token indices corresponding to the prompt
  262. or padding; that is done in the train loop.
  263. """
  264. # Get logits from model
  265. outputs = model(input_ids)
  266. logits = outputs.logits # (batch_size, seq_len, vocab_size)
  267. # Compute log-probabilities for each token in the labels
  268. log_probs = F.log_softmax(logits, dim=-1) # (batch_size, seq_len, vocab_size)
  269. # Gather log-probabilities at the correct label indices
  270. # Unsqueeze `labels` to match log_probs' last dim for gather
  271. # Result shape: (batch_size, seq_len)
  272. log_probs_for_labels = torch.gather(
  273. log_probs, dim=2, index=labels.unsqueeze(-1)
  274. ).squeeze(-1)
  275. result = {
  276. "log_probs": log_probs_for_labels
  277. }
  278. if return_token_entropy:
  279. token_entropy = run_compute_entropy(logits) # (batch_size, seq_len)
  280. result["token_entropy"] = token_entropy
  281. return result
  282. def run_compute_naive_policy_gradient_loss(
  283. raw_rewards_or_advantages: torch.Tensor,
  284. policy_log_probs: torch.Tensor,
  285. ) -> torch.Tensor:
  286. """Compute policy gradient loss using either raw rewards or advantages.
  287. Args:
  288. raw_rewards_or_advantages: torch.Tensor of shape (batch_size, 1):
  289. the raw rewards or advantages for each rollout response.
  290. policy_log_probs: torch.Tensor of shape (batch_size, sequence_length):
  291. the log-probs of the policy.
  292. Returns:
  293. torch.Tensor of shape (batch_size, sequence_length):
  294. the policy gradient per-token loss.
  295. """
  296. seq_length = policy_log_probs.shape[1]
  297. rewards_or_advantages = raw_rewards_or_advantages.expand(-1, seq_length)
  298. pg_loss = -policy_log_probs * rewards_or_advantages
  299. return pg_loss
  300. def run_compute_grpo_clip_loss(
  301. advantages: torch.Tensor,
  302. policy_log_probs: torch.Tensor,
  303. old_log_probs: torch.Tensor,
  304. cliprange: float,
  305. ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
  306. """Compute the GRPO-Clip loss.
  307. Args:
  308. advantages: torch.Tensor of shape (batch_size, 1):
  309. the advantages for each rollout response.
  310. policy_log_probs: torch.Tensor of shape (batch_size, sequence_length):
  311. the log-probs of the policy.
  312. old_log_probs: torch.Tensor of shape (batch_size, sequence_length):
  313. the log-probs of the old policy.
  314. cliprange: float, the clip range for the ratio.
  315. Returns:
  316. tuple[torch.Tensor, dict[str, torch.Tensor]]:
  317. torch.Tensor of shape (batch_size, sequence_length):
  318. the GRPO-Clip per-token loss.
  319. dict[str, torch.Tensor]: metadata for the GRPO-Clip loss
  320. (used to compute clip fraction).
  321. """
  322. seq_length = policy_log_probs.shape[1]
  323. advantages = advantages.expand(-1, seq_length)
  324. ratio = torch.exp(policy_log_probs - old_log_probs) # (batch_size, seq_length)
  325. clipped_ratio = torch.clamp(ratio, 1 - cliprange, 1 + cliprange)
  326. lhs, rhs = ratio * advantages, clipped_ratio * advantages
  327. loss = -torch.min(lhs, rhs)
  328. metadata = {
  329. "clipped": (rhs < lhs).float(),
  330. }
  331. return loss, metadata
  332. def run_compute_policy_gradient_loss(
  333. policy_log_probs: torch.Tensor,
  334. loss_type: str,
  335. raw_rewards: torch.Tensor,
  336. advantages: torch.Tensor,
  337. old_log_probs: torch.Tensor,
  338. cliprange: float,
  339. ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
  340. """
  341. Wrapper that delegates to the appropriate policy gradient loss function above.
  342. """
  343. batch_size, seq_length = policy_log_probs.shape
  344. if loss_type == "no_baseline":
  345. assert raw_rewards is not None, "raw_rewards required for no_baseline"
  346. assert raw_rewards.shape == (batch_size, 1), "raw_rewards must have shape (batch_size, 1)"
  347. loss = run_compute_naive_policy_gradient_loss(raw_rewards, policy_log_probs)
  348. meta = {}
  349. elif loss_type == "reinforce_with_baseline":
  350. assert advantages is not None, "advantages required for reinforce_with_baseline"
  351. assert advantages.shape == (batch_size, 1), "advantages must have shape (batch_size, 1)"
  352. loss = run_compute_naive_policy_gradient_loss(advantages, policy_log_probs)
  353. meta = {}
  354. elif loss_type == "grpo_clip":
  355. assert advantages is not None, "advantages required for grpo_clip"
  356. assert old_log_probs is not None, "old_log_probs required for grpo_clip"
  357. assert cliprange is not None, "cliprange required for grpo_clip"
  358. assert old_log_probs.shape == (batch_size, seq_length), "old_log_probs must have shape (batch_size, seq_length)"
  359. loss, meta = run_compute_grpo_clip_loss(advantages, policy_log_probs, old_log_probs, cliprange)
  360. else:
  361. raise ValueError(f"Unknown loss_type: {loss_type}")
  362. return loss, meta
  363. def run_masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int | None = None) -> torch.Tensor:
  364. """Compute the mean of the tensor along a dimension,
  365. considering only the elements with mask value 1.
  366. Args:
  367. tensor: torch.Tensor, the tensor to compute the mean of.
  368. mask: torch.Tensor, the mask. We only take the mean over
  369. the elements with mask value 1.
  370. dim: int | None, the dimension to compute the mean along.
  371. If None, sum over all non-masked elements and average
  372. by their total count.
  373. Returns:
  374. torch.Tensor, the mean of the tensor along the specified
  375. dimension, considering only the elements with mask value 1.
  376. """
  377. n_tokens = mask.sum(dim=dim)
  378. masked_tensor = tensor * mask
  379. return masked_tensor.sum(dim=dim) / n_tokens
  380. def run_sft_microbatch_train_step(
  381. policy_log_probs: torch.Tensor,
  382. response_mask: torch.Tensor,
  383. gradient_accumulation_steps: int,
  384. normalize_constant: int | None = 1.0,
  385. ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
  386. """Compute the policy gradient loss and backprop its gradients for a microbatch.
  387. """
  388. batch_size, seq_length = policy_log_probs.shape
  389. # Cross-entropy loss (neg log likelihood), per token and masked
  390. ce_loss = -policy_log_probs # (batch_size, seq_length)
  391. # Sum over tokens and batch, only response tokens count
  392. loss_sum = run_masked_normalize(ce_loss, response_mask, normalize_constant=normalize_constant)
  393. loss = loss_sum / batch_size / gradient_accumulation_steps
  394. loss.backward()
  395. # For logging
  396. n_tokens = response_mask.sum()
  397. avg_token_ce = loss_sum / (n_tokens + 1e-8)
  398. metadata = {
  399. "loss_sum": loss_sum.detach(),
  400. "n_tokens": n_tokens.detach(),
  401. "avg_ce_per_token": avg_token_ce.detach(),
  402. "mean_log_prob": (policy_log_probs * response_mask).sum() / (n_tokens + 1e-8)
  403. }
  404. return loss.detach(), metadata
  405. def run_grpo_microbatch_train_step(
  406. policy_log_probs: torch.Tensor,
  407. response_mask: torch.Tensor,
  408. gradient_accumulation_steps: int,
  409. loss_type: Literal["no_baseline", "reinforce_with_baseline", "grpo_clip"],
  410. raw_rewards: torch.Tensor | None = None,
  411. advantages: torch.Tensor | None = None,
  412. old_log_probs: torch.Tensor | None = None,
  413. cliprange: float | None = None,
  414. ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
  415. """Compute the policy gradient loss and backprop its gradients for a microbatch.
  416. Args:
  417. policy_log_probs: torch.Tensor of shape (batch_size, sequence_length):
  418. the log-probs of the policy.
  419. response_mask: torch.Tensor of shape (batch_size, sequence_length):
  420. the mask for the response.
  421. gradient_accumulation_steps: int, the number of gradient accumulation steps.
  422. loss_type: Literal["no_baseline", "reinforce_with_baseline", "grpo_clip"],
  423. the type of loss function to use.
  424. raw_rewards: torch.Tensor | None, the raw rewards for each rollout response.
  425. Needed for loss_type="no_baseline".
  426. advantages: torch.Tensor | None, the advantages for each rollout response.
  427. Needed for loss_type in {"reinforce_with_baseline", "grpo_clip"}.
  428. old_log_probs: torch.Tensor | None, the log-probs of the old policy.
  429. Needed for loss_type="grpo_clip".
  430. cliprange: float | None, the clip range for the ratio.
  431. Needed for loss_type="grpo_clip".
  432. constant_normalize_factor: int | None, provided if we want to sum over
  433. the sequence dimension and normalize by this constant factor
  434. (as in Dr. GRPO).
  435. Returns:
  436. tuple[torch.Tensor, dict[str, torch.Tensor]]:
  437. the policy gradient loss and its metadata.
  438. """
  439. # 1. 调用run_compute_policy_gradient_loss 得到逐token损失 (batch, seq)
  440. loss_per_token, meta = run_compute_policy_gradient_loss(
  441. policy_log_probs=policy_log_probs,
  442. loss_type=loss_type,
  443. raw_rewards=raw_rewards,
  444. advantages=advantages,
  445. old_log_probs=old_log_probs,
  446. cliprange=cliprange,
  447. )
  448. # 2. 用response_mask对loss按token聚合: (batch, )
  449. loss_per_example = run_masked_mean(loss_per_token, response_mask, dim=1) # (batch, )
  450. # 3. 平均到batch: 标量
  451. loss = loss_per_example.mean()
  452. # 4. adjust for grad accumulation
  453. loss = loss / gradient_accumulation_steps
  454. # 5. backward
  455. loss.backward()
  456. # 6. meta 里可以增加点logging的信息
  457. meta = meta.copy()
  458. meta["microbatch_loss"] = loss.detach()
  459. meta["loss_per_example"] = loss_per_example.detach()
  460. return loss, meta
  461. def run_masked_normalize(
  462. tensor: torch.Tensor,
  463. mask: torch.Tensor,
  464. dim: int | None = None,
  465. normalize_constant: float = 1.0,
  466. ) -> torch.Tensor:
  467. """Sum over a dimension and normalize by a constant,
  468. considering only the elements with mask value 1.
  469. Args:
  470. tensor: torch.Tensor, the tensor to sum and normalize.
  471. mask: torch.Tensor, the mask. We only consider elements
  472. with mask value 1.
  473. dim: int | None, the dimension to sum along before
  474. normalization. If None, sum over all dimensions.
  475. normalize_constant: float, the constant to divide by
  476. for normalization.
  477. Returns:
  478. torch.Tensor, the normalized sum, where masked elements
  479. (mask=0) don't contribute to the sum.
  480. """
  481. masked_tensor = tensor * mask
  482. sum_vals = masked_tensor.sum(dim=dim)
  483. normalized = sum_vals / normalize_constant
  484. return normalized
  485. """
  486. The below adapters are used in the optional
  487. RLHF / safety part of the Alignment assignment.
  488. """
  489. def get_packed_sft_dataset(
  490. tokenizer: PreTrainedTokenizerBase,
  491. dataset_path: str | os.PathLike,
  492. seq_length: int,
  493. shuffle: bool,
  494. ) -> Dataset:
  495. """
  496. Given a tokenizer and a path to a dataset with instruction-tuning examples,
  497. construct a PyTorch Dataset for language modeling. The examples should be
  498. packed, i.e., all sequences in the dataset are of a constant length (`seq_length`).
  499. Args:
  500. tokenizer: transformers.PreTrainedTokenizerBase
  501. Transformers tokenizer to use in tokenizing and encoding text.
  502. dataset_path: str
  503. Path to file with instruction-tuning examples.
  504. seq_length: int
  505. Number of tokens to include in each example.
  506. shuffle: bool
  507. If true, shuffle the documents before packing them into examples.
  508. Returns:
  509. PyTorch Dataset for language modeling. Each example in this dataset is a dictionary of
  510. with keys "input_ids" and "labels" (both tensors of shape (seq_length, )).
  511. "input_ids" contains the token IDs for the language modeling inputs, and "labels" contains
  512. the token IDs for the language modeling labels.
  513. """
  514. raise NotImplementedError
  515. def run_iterate_batches(
  516. dataset: Dataset,
  517. batch_size: int,
  518. shuffle: bool,
  519. ):
  520. """
  521. Given a PyTorch Dataset, return an iterable over batches of size `batch_size`.
  522. Iterating through the returned iterable should constitute one epoch over the Dataset.
  523. Args:
  524. dataset: Dataset
  525. Dataset to emit batches from.
  526. batch_size: int
  527. Number of examples to include per batch.
  528. shuffle: bool
  529. If true, shuffle examples before batching them.
  530. Returns:
  531. Iterable over batches, where each batch has size `batch_size`.
  532. """
  533. raise NotImplementedError
  534. def run_parse_gsm8k_response(
  535. model_output: str,
  536. ) -> str | None:
  537. """
  538. Given a GSM8K model output, parse the model output into a predicted numeric answer by
  539. taking the last number that occurs in the output.
  540. model_output: str
  541. str with the model's output to a GSM8K example.
  542. Returns:
  543. str with the predicted numeric answer if the model output can be parsed into a prediction,
  544. else None.
  545. """
  546. # raise NotImplementedError
  547. import re
  548. numbers = re.findall(r'-?\d+\.?\d*', model_output)
  549. # 如果没有找到任何数字,返回 None
  550. if not numbers:
  551. return None
  552. # 返回最后一个数字
  553. return numbers[-1]