| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669 |
- from __future__ import annotations
- import os
- import re
- import torch
- import torch.nn.functional as F
- import numpy as np
- import pandas as pd
- from typing import Any, Callable, Literal, Dict, List
- from torch import Tensor
- from torch.utils.data import Dataset
- from vllm.model_executor import set_random_seed as vllm_set_random_seed
- from vllm import LLM, SamplingParams
- from unittest.mock import patch
- from transformers import PreTrainedModel, AutoTokenizer, PreTrainedTokenizerBase
- def init_vllm(model_id: str, device: str, seed: int, gpu_memory_utilization: float = 0.85):
- """
- 初始化 vLLM LLM(大语言模型),可选择设备及显存利用率,在推理时与训练策略分离。
- 参考自 HuggingFace TRL 实现。
-
- Args:
- model_id (str): 模型标识
- device (str): 目标设备,如 'cuda:0'
- seed (int): 随机种子
- gpu_memory_utilization (float): 显存占用比例
-
- Returns:
- LLM: vLLM初始化好的对象
- """
- vllm_set_random_seed(seed)
-
- # Patch 1:让vllm假装“集群”只有1卡
- world_size_patch = patch("torch.distributed.get_world_size", return_value=1)
- # Patch 2:跳过vllm内部显存剖析的某个断言
- profiling_patch = patch(
- "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
- return_value=None
- )
- with world_size_patch, profiling_patch:
- llm = LLM(
- model=model_id,
- device=device,
- dtype=torch.bfloat16,
- enable_prefix_caching=True,
- gpu_memory_utilization=gpu_memory_utilization,
- )
- return llm
- def load_policy_into_vllm_instance(policy: PreTrainedModel, llm: LLM):
- """
- 把训练好的PyTorch模型参数装载进vLLM实例。
- 参考自 HuggingFace TRL。
-
- Args:
- policy (PreTrainedModel): 已训练好的transformers模型
- llm (LLM): vLLM实例
-
- Returns:
- None
- """
- state_dict = policy.state_dict()
- # 下面这一行依赖vllm当前内部实现,如有变动请根据vllm源代码调整!
- llm_model = llm.llm_engine.model_executor.driver_worker.model_runner.model
- llm_model.load_weights(state_dict.items())
- def evaluate_gsm8k(
- model_id: str,
- policy: PreTrainedModel,
- tokenizer: AutoTokenizer,
- prompt_strs: List[str],
- output_strs: List[str],
- save_path,
- device="cuda:0",
- seed=0,
- gpu_memory_utilization=0.5
- ):
- llm = init_vllm(model_id, device, seed, gpu_memory_utilization)
- sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=1024, stop=tokenizer.eos_token)
- outputs = llm.generate(prompt_strs, sampling_params)
- # import pdb; pdb.set_trace()
- result_df = pd.DataFrame(columns=['Prompt', 'Generated_Text', 'Correct_Answer', 'Parsed_Answer', 'Parsed_Correct_Answer', 'Evaluation_Score', 'ParseFail'])
- correct, parse_fail_cnt = 0, 0
- for i, output in enumerate(outputs):
- prompt = output.prompt
- generated_text = output.outputs[0].text
- correct_answer = output_strs[i]
-
- parsed_answer = parse_gsm8k_qwen_response(generated_text)
- parsed_correct_answer = run_parse_gsm8k_response(correct_answer)
-
- parse_fail = parsed_answer == None
- evaluation_score = 1 if parsed_correct_answer == parsed_answer else 0
- if evaluation_score == 1:
- correct += 1
- if parse_fail:
- parse_fail_cnt += 1
-
-
- temp_df = pd.DataFrame({
- 'Prompt': [prompt],
- 'Generated_Text': [generated_text],
- 'Correct_Answer': [correct_answer],
- 'Parsed_Answer': [parsed_answer],
- 'Parsed_Correct_Answer': [parsed_correct_answer],
- 'Evaluation_Score': [evaluation_score],
- 'ParseFail': [parse_fail]
- })
-
- result_df = pd.concat([result_df, temp_df], ignore_index=True)
- print(f"Parse fail {parse_fail_cnt}/{len(outputs)}")
- print(f"Correct {correct}/{len(outputs)} Accuracy is {round(correct/len(outputs)*100, 2)}%")
- result_df.to_csv(save_path)
-
- return {
- 'parse_fail_rate': parse_fail_cnt / len(outputs),
- 'correct_rate': correct / len(outputs)
- }
- def parse_gsm8k_qwen_response(
- model_output: str,
- ) -> str | None:
- matches = re.findall(r'```output(.*?)```', model_output, re.DOTALL)
- if matches:
- res = matches[0].strip()
-
- return res
-
- return None
- def run_tokenize_prompt_and_output(
- prompt_strs: list[str],
- output_strs: list[str],
- tokenizer: PreTrainedTokenizerBase,
- ) -> dict[str, Tensor]:
- """Tokenize the prompt and output strings, and construct a mask that is 1
- for the response tokens and 0 for other tokens (prompt or padding).
- Args:
- prompt_strs: list[str], the prompt strings.
- output_strs: list[str], the output strings.
- tokenizer: PreTrainedTokenizer, the tokenizer to use.
- Returns:
- dict[str, torch.Tensor]:
- "input_ids": torch.Tensor of shape (batch_size, max(prompt_and_output_lens) - 1):
- the tokenized prompt and output strings, with the final token sliced off.
- "labels": torch.Tensor of shape (batch_size, max(prompt_and_output_lens) - 1):
- shifted input_ids (i.e., the input_ids without the first token).
- "response_mask": torch.Tensor of shape (batch_size, max(prompt_and_output_lens) - 1):
- a mask on the response tokens in `labels`.
- """
- input_ids_list = []
- response_mask_list = []
- for prompt, output in zip(prompt_strs, output_strs):
- prompt_enc = tokenizer(prompt, add_special_tokens=False)
- output_enc = tokenizer(output, add_special_tokens=False)
- full_input = prompt_enc['input_ids'] + output_enc['input_ids']
- response_mask = [0] * len(prompt_enc['input_ids']) + [1] * len(output_enc['input_ids'])
- input_ids_list.append(torch.tensor(full_input, dtype=torch.long))
- response_mask_list.append(torch.tensor(response_mask, dtype=torch.long))
- batch_size = len(input_ids_list)
- max_len = max(len(ids) for ids in input_ids_list)
- input_ids_batch = torch.full((batch_size, max_len), tokenizer.pad_token_id, dtype=torch.long)
- response_mask_batch = torch.zeros((batch_size, max_len), dtype=torch.long)
- for i, (ids, mask) in enumerate(zip(input_ids_list, response_mask_list)):
- seq_len = len(ids)
- input_ids_batch[i, :seq_len] = ids
- response_mask_batch[i, :seq_len] = mask
- return {
- "input_ids": input_ids_batch[:, :-1], # (batch, max_len-1)
- "labels": input_ids_batch[:, 1:], # (batch, max_len-1)
- "response_mask": response_mask_batch[:, 1:] # (batch, max_len-1)
- }
- def run_compute_group_normalized_rewards(
- reward_fn: Callable,
- rollout_responses: list[str],
- repeated_ground_truths: list[str],
- group_size: int,
- advantage_eps: float,
- normalize_by_std: bool,
- ) -> tuple[torch.Tensor, dict[str, float]]:
- """
- Compute rewards for each group of rollout responses,
- normalized by the group size.
- For more on GRPO, see:
- DeepSeekMath: https://arxiv.org/abs/2402.03300
- DeepSeek-R1: https://arxiv.org/abs/2501.12948
- Args:
- reward_fn: Callable[[str, str], dict[str, float]],
- scores the rollout responses against the ground truths,
- producing a dict with keys
- "reward", "format_reward", and "answer_reward".
- rollout_responses: list[str], rollouts from the policy.
- The length of this list is
- `rollout_batch_size = n_prompts_per_rollout_batch * group_size`.
- repeated_ground_truths: list[str], the ground truths for the examples.
- The length of this list is `rollout_batch_size`,
- because the ground truth for each example is repeated `group_size` times.
- group_size: int, number of rollouts per group.
- advantage_eps: float, epsilon to avoid division by zero
- during group normalization.
- normalize_by_std: bool, whether to normalize the rewards by
- std(rewards).
- Returns:
- tuple[torch.Tensor, torch.Tensor, dict[str, float]]:
- torch.Tensor of shape (rollout_batch_size,):
- group-normalized rewards for each rollout response.
- torch.Tensor of shape (rollout_batch_size,):
- raw rewards for each rollout response.
- dict[str, float]: metadata for the rewards of the rollout batch.
- You may choose what you wish to log here
- (some statistics of the rewards, etc.).
- """
- # 1. Compute raw rewards for all responses
- raw_rewards = []
- format_rewards = []
- answer_rewards = []
- for response, gt in zip(rollout_responses, repeated_ground_truths):
- reward_dict = reward_fn(response, gt)
- raw_rewards.append(reward_dict["reward"])
- format_rewards.append(reward_dict["format_reward"])
- answer_rewards.append(reward_dict["answer_reward"])
- raw_rewards = torch.tensor(raw_rewards, dtype=torch.float32)
- # 2. Group normalization
- N = len(raw_rewards)
- assert N % group_size == 0, "Rollout batch size must be divisible by group_size"
- n_groups = N // group_size
- # 分组: [n_groups, group_size]
- group_rewards = raw_rewards.view(n_groups, group_size)
- group_means = group_rewards.mean(dim=1, keepdim=True)
- if normalize_by_std:
- group_stds = group_rewards.std(dim=1, keepdim=True)
- denom = group_stds + advantage_eps
- else:
- denom = 1.0
- # 归一化
- normalized_groups = (group_rewards - group_means) / denom # [n_groups, group_size]
- normalized_rewards = normalized_groups.view(N) # 还原回(N,)
- # 3. Optional: Collect some statistics
- metadata = {
- "reward_mean": float(raw_rewards.mean()),
- "reward_std": float(raw_rewards.std()),
- "reward_max": float(raw_rewards.max()),
- "reward_min": float(raw_rewards.min()),
- "format_reward_mean": float(np.mean(format_rewards)),
- "answer_reward_mean": float(np.mean(answer_rewards)),
- }
- return normalized_rewards, raw_rewards, metadata
- def run_compute_entropy(logits: torch.Tensor) -> torch.Tensor:
- """Get the entropy of the logits (i.e., entropy of the final dimension)."""
- # Numerically stable computation of entropy
- lse = torch.logsumexp(logits, dim=-1)
- probs = torch.softmax(logits, dim=-1)
- expected_logit = (probs * logits).sum(dim=-1)
- entropy = lse - expected_logit
- return entropy
- def run_get_response_log_probs(
- model: torch.nn.Module,
- input_ids: torch.Tensor,
- labels: torch.Tensor,
- return_token_entropy: bool,
- ) -> torch.Tensor:
- """Get the conditional log-probs of the response given the prompt,
- and optionally the entropy of the next token predictions.
- Args:
- model: PreTrainedModel, the model to score.
- input_ids: torch.Tensor of shape (batch_size, sequence_length):
- the tokenized prompt and output.
- labels: torch.Tensor of shape (batch_size, sequence_length):
- shifted input_ids.
- return_token_entropy: bool, whether to return the entropy of the
- next token predictions.
- Returns:
- dict[str, torch.Tensor]:
- "log_probs": torch.Tensor of shape (batch_size, sequence_length):
- the conditional log-probs of the response given the prompt.
- Note that we have not masked out the token indices corresponding
- to the prompt or padding; that is done in the train loop.
- "token_entropy": Optional[torch.Tensor] of shape (batch_size, sequence_length):
- the entropy of the next token predictions. As with the log-probs,
- we have not masked out the token indices corresponding to the prompt
- or padding; that is done in the train loop.
- """
- # Get logits from model
- outputs = model(input_ids)
- logits = outputs.logits # (batch_size, seq_len, vocab_size)
- # Compute log-probabilities for each token in the labels
- log_probs = F.log_softmax(logits, dim=-1) # (batch_size, seq_len, vocab_size)
- # Gather log-probabilities at the correct label indices
- # Unsqueeze `labels` to match log_probs' last dim for gather
- # Result shape: (batch_size, seq_len)
- log_probs_for_labels = torch.gather(
- log_probs, dim=2, index=labels.unsqueeze(-1)
- ).squeeze(-1)
- result = {
- "log_probs": log_probs_for_labels
- }
- if return_token_entropy:
- token_entropy = run_compute_entropy(logits) # (batch_size, seq_len)
- result["token_entropy"] = token_entropy
- return result
- def run_compute_naive_policy_gradient_loss(
- raw_rewards_or_advantages: torch.Tensor,
- policy_log_probs: torch.Tensor,
- ) -> torch.Tensor:
- """Compute policy gradient loss using either raw rewards or advantages.
- Args:
- raw_rewards_or_advantages: torch.Tensor of shape (batch_size, 1):
- the raw rewards or advantages for each rollout response.
- policy_log_probs: torch.Tensor of shape (batch_size, sequence_length):
- the log-probs of the policy.
- Returns:
- torch.Tensor of shape (batch_size, sequence_length):
- the policy gradient per-token loss.
- """
- seq_length = policy_log_probs.shape[1]
- rewards_or_advantages = raw_rewards_or_advantages.expand(-1, seq_length)
- pg_loss = -policy_log_probs * rewards_or_advantages
- return pg_loss
- def run_compute_grpo_clip_loss(
- advantages: torch.Tensor,
- policy_log_probs: torch.Tensor,
- old_log_probs: torch.Tensor,
- cliprange: float,
- ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
- """Compute the GRPO-Clip loss.
- Args:
- advantages: torch.Tensor of shape (batch_size, 1):
- the advantages for each rollout response.
- policy_log_probs: torch.Tensor of shape (batch_size, sequence_length):
- the log-probs of the policy.
- old_log_probs: torch.Tensor of shape (batch_size, sequence_length):
- the log-probs of the old policy.
- cliprange: float, the clip range for the ratio.
- Returns:
- tuple[torch.Tensor, dict[str, torch.Tensor]]:
- torch.Tensor of shape (batch_size, sequence_length):
- the GRPO-Clip per-token loss.
- dict[str, torch.Tensor]: metadata for the GRPO-Clip loss
- (used to compute clip fraction).
- """
- seq_length = policy_log_probs.shape[1]
- advantages = advantages.expand(-1, seq_length)
- ratio = torch.exp(policy_log_probs - old_log_probs) # (batch_size, seq_length)
- clipped_ratio = torch.clamp(ratio, 1 - cliprange, 1 + cliprange)
- lhs, rhs = ratio * advantages, clipped_ratio * advantages
- loss = -torch.min(lhs, rhs)
- metadata = {
- "clipped": (rhs < lhs).float(),
- }
- return loss, metadata
- def run_compute_policy_gradient_loss(
- policy_log_probs: torch.Tensor,
- loss_type: str,
- raw_rewards: torch.Tensor,
- advantages: torch.Tensor,
- old_log_probs: torch.Tensor,
- cliprange: float,
- ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
- """
- Wrapper that delegates to the appropriate policy gradient loss function above.
- """
- batch_size, seq_length = policy_log_probs.shape
- if loss_type == "no_baseline":
- assert raw_rewards is not None, "raw_rewards required for no_baseline"
- assert raw_rewards.shape == (batch_size, 1), "raw_rewards must have shape (batch_size, 1)"
- loss = run_compute_naive_policy_gradient_loss(raw_rewards, policy_log_probs)
- meta = {}
- elif loss_type == "reinforce_with_baseline":
- assert advantages is not None, "advantages required for reinforce_with_baseline"
- assert advantages.shape == (batch_size, 1), "advantages must have shape (batch_size, 1)"
- loss = run_compute_naive_policy_gradient_loss(advantages, policy_log_probs)
- meta = {}
- elif loss_type == "grpo_clip":
- assert advantages is not None, "advantages required for grpo_clip"
- assert old_log_probs is not None, "old_log_probs required for grpo_clip"
- assert cliprange is not None, "cliprange required for grpo_clip"
- assert old_log_probs.shape == (batch_size, seq_length), "old_log_probs must have shape (batch_size, seq_length)"
- loss, meta = run_compute_grpo_clip_loss(advantages, policy_log_probs, old_log_probs, cliprange)
- else:
- raise ValueError(f"Unknown loss_type: {loss_type}")
- return loss, meta
- def run_masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int | None = None) -> torch.Tensor:
- """Compute the mean of the tensor along a dimension,
- considering only the elements with mask value 1.
- Args:
- tensor: torch.Tensor, the tensor to compute the mean of.
- mask: torch.Tensor, the mask. We only take the mean over
- the elements with mask value 1.
- dim: int | None, the dimension to compute the mean along.
- If None, sum over all non-masked elements and average
- by their total count.
- Returns:
- torch.Tensor, the mean of the tensor along the specified
- dimension, considering only the elements with mask value 1.
- """
- n_tokens = mask.sum(dim=dim)
- masked_tensor = tensor * mask
- return masked_tensor.sum(dim=dim) / n_tokens
- def run_sft_microbatch_train_step(
- policy_log_probs: torch.Tensor,
- response_mask: torch.Tensor,
- gradient_accumulation_steps: int,
- normalize_constant: int | None = 1.0,
- ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
- """Compute the policy gradient loss and backprop its gradients for a microbatch.
- """
- batch_size, seq_length = policy_log_probs.shape
- # Cross-entropy loss (neg log likelihood), per token and masked
- ce_loss = -policy_log_probs # (batch_size, seq_length)
- # Sum over tokens and batch, only response tokens count
- loss_sum = run_masked_normalize(ce_loss, response_mask, normalize_constant=normalize_constant)
- loss = loss_sum / batch_size / gradient_accumulation_steps
- loss.backward()
- # For logging
- n_tokens = response_mask.sum()
- avg_token_ce = loss_sum / (n_tokens + 1e-8)
- metadata = {
- "loss_sum": loss_sum.detach(),
- "n_tokens": n_tokens.detach(),
- "avg_ce_per_token": avg_token_ce.detach(),
- "mean_log_prob": (policy_log_probs * response_mask).sum() / (n_tokens + 1e-8)
- }
- return loss.detach(), metadata
- def run_grpo_microbatch_train_step(
- policy_log_probs: torch.Tensor,
- response_mask: torch.Tensor,
- gradient_accumulation_steps: int,
- loss_type: Literal["no_baseline", "reinforce_with_baseline", "grpo_clip"],
- raw_rewards: torch.Tensor | None = None,
- advantages: torch.Tensor | None = None,
- old_log_probs: torch.Tensor | None = None,
- cliprange: float | None = None,
- ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
- """Compute the policy gradient loss and backprop its gradients for a microbatch.
- Args:
- policy_log_probs: torch.Tensor of shape (batch_size, sequence_length):
- the log-probs of the policy.
- response_mask: torch.Tensor of shape (batch_size, sequence_length):
- the mask for the response.
- gradient_accumulation_steps: int, the number of gradient accumulation steps.
- loss_type: Literal["no_baseline", "reinforce_with_baseline", "grpo_clip"],
- the type of loss function to use.
- raw_rewards: torch.Tensor | None, the raw rewards for each rollout response.
- Needed for loss_type="no_baseline".
- advantages: torch.Tensor | None, the advantages for each rollout response.
- Needed for loss_type in {"reinforce_with_baseline", "grpo_clip"}.
- old_log_probs: torch.Tensor | None, the log-probs of the old policy.
- Needed for loss_type="grpo_clip".
- cliprange: float | None, the clip range for the ratio.
- Needed for loss_type="grpo_clip".
- constant_normalize_factor: int | None, provided if we want to sum over
- the sequence dimension and normalize by this constant factor
- (as in Dr. GRPO).
- Returns:
- tuple[torch.Tensor, dict[str, torch.Tensor]]:
- the policy gradient loss and its metadata.
- """
- # 1. 调用run_compute_policy_gradient_loss 得到逐token损失 (batch, seq)
- loss_per_token, meta = run_compute_policy_gradient_loss(
- policy_log_probs=policy_log_probs,
- loss_type=loss_type,
- raw_rewards=raw_rewards,
- advantages=advantages,
- old_log_probs=old_log_probs,
- cliprange=cliprange,
- )
- # 2. 用response_mask对loss按token聚合: (batch, )
- loss_per_example = run_masked_mean(loss_per_token, response_mask, dim=1) # (batch, )
- # 3. 平均到batch: 标量
- loss = loss_per_example.mean()
- # 4. adjust for grad accumulation
- loss = loss / gradient_accumulation_steps
- # 5. backward
- loss.backward()
- # 6. meta 里可以增加点logging的信息
- meta = meta.copy()
- meta["microbatch_loss"] = loss.detach()
- meta["loss_per_example"] = loss_per_example.detach()
- return loss, meta
- def run_masked_normalize(
- tensor: torch.Tensor,
- mask: torch.Tensor,
- dim: int | None = None,
- normalize_constant: float = 1.0,
- ) -> torch.Tensor:
- """Sum over a dimension and normalize by a constant,
- considering only the elements with mask value 1.
- Args:
- tensor: torch.Tensor, the tensor to sum and normalize.
- mask: torch.Tensor, the mask. We only consider elements
- with mask value 1.
- dim: int | None, the dimension to sum along before
- normalization. If None, sum over all dimensions.
- normalize_constant: float, the constant to divide by
- for normalization.
- Returns:
- torch.Tensor, the normalized sum, where masked elements
- (mask=0) don't contribute to the sum.
- """
- masked_tensor = tensor * mask
- sum_vals = masked_tensor.sum(dim=dim)
- normalized = sum_vals / normalize_constant
- return normalized
- """
- The below adapters are used in the optional
- RLHF / safety part of the Alignment assignment.
- """
- def get_packed_sft_dataset(
- tokenizer: PreTrainedTokenizerBase,
- dataset_path: str | os.PathLike,
- seq_length: int,
- shuffle: bool,
- ) -> Dataset:
- """
- Given a tokenizer and a path to a dataset with instruction-tuning examples,
- construct a PyTorch Dataset for language modeling. The examples should be
- packed, i.e., all sequences in the dataset are of a constant length (`seq_length`).
- Args:
- tokenizer: transformers.PreTrainedTokenizerBase
- Transformers tokenizer to use in tokenizing and encoding text.
- dataset_path: str
- Path to file with instruction-tuning examples.
- seq_length: int
- Number of tokens to include in each example.
- shuffle: bool
- If true, shuffle the documents before packing them into examples.
- Returns:
- PyTorch Dataset for language modeling. Each example in this dataset is a dictionary of
- with keys "input_ids" and "labels" (both tensors of shape (seq_length, )).
- "input_ids" contains the token IDs for the language modeling inputs, and "labels" contains
- the token IDs for the language modeling labels.
- """
- raise NotImplementedError
- def run_iterate_batches(
- dataset: Dataset,
- batch_size: int,
- shuffle: bool,
- ):
- """
- Given a PyTorch Dataset, return an iterable over batches of size `batch_size`.
- Iterating through the returned iterable should constitute one epoch over the Dataset.
- Args:
- dataset: Dataset
- Dataset to emit batches from.
- batch_size: int
- Number of examples to include per batch.
- shuffle: bool
- If true, shuffle examples before batching them.
- Returns:
- Iterable over batches, where each batch has size `batch_size`.
- """
- raise NotImplementedError
- def run_parse_gsm8k_response(
- model_output: str,
- ) -> str | None:
- """
- Given a GSM8K model output, parse the model output into a predicted numeric answer by
- taking the last number that occurs in the output.
- model_output: str
- str with the model's output to a GSM8K example.
- Returns:
- str with the predicted numeric answer if the model output can be parsed into a prediction,
- else None.
- """
- # raise NotImplementedError
- import re
- numbers = re.findall(r'-?\d+\.?\d*', model_output)
-
- # 如果没有找到任何数字,返回 None
- if not numbers:
- return None
-
- # 返回最后一个数字
- return numbers[-1]
|