| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586 |
- from __future__ import annotations
- import os
- import json
- import einx
- import math
- import torch
- import logging
- import torch.nn as nn
- from torch import Tensor
- from einops import rearrange, einsum
- from jaxtyping import Float, Bool, Int
- logger = logging.getLogger(__name__)
- def softmax(x, dim=-1):
- rescaled_input = x - torch.max(x, dim=dim, keepdim=True)[0]
- exponentiated_rescaled_input = torch.exp(rescaled_input)
- return exponentiated_rescaled_input / torch.sum(exponentiated_rescaled_input, dim=dim, keepdim=True)
- class Linear(nn.Module):
- def __init__(self, d_in: int, d_out: int):
- """A linear layer initialized with truncated normal fan-in fan-out.
- Args:
- d_in: int
- The number of input features.
- d_out: int
- The number of output features.
- """
-
- super().__init__()
- std = math.sqrt(2 / (d_in + d_out))
- self.weight: Float[Tensor, " d_out d_in"] = nn.Parameter(
- nn.init.trunc_normal_(torch.empty(d_out, d_in), std=std, a=-3*std, b=3*std),
- requires_grad=True
- )
- def forward(self, x: Float[Tensor, " ... d_in"]) -> Float[Tensor, " ... d_out"]:
- return einsum(x, self.weight, "... d_in, d_out d_in -> ... d_out")
-
- def extra_repr(self):
- return f"d_out={self.weight.shape[0]}, d_in={self.weight.shape[1]}"
- class Embedding(nn.Module):
- def __init__(self, vocab_size: int, d_model: int):
- super().__init__()
- std = 1.0
- self.weight = nn.Parameter(
- nn.init.trunc_normal_(torch.empty(vocab_size, d_model), std=std, a=-3 * std, b=3 * std),
- requires_grad=True
- )
-
- def forward(self, token_ids: Int[Tensor, " ..."]) -> Float[Tensor, " ... d_model"]:
- return self.weight[token_ids, :]
-
- def extra_repr(self):
- return f"vocab_size={self.weight.shape[0]}, d={self.weight.shape[1]}"
- class RMSNorm(nn.Module):
- """
- This module implements root mean square layer normalization, as
- described in Eq. 4 of https://arxiv.org/abs/1910.07467
- Args:
- hidden_size: int
- Dimensionality of the input to normalize.
- eps: float, default is 1e-5
- A value added to the denominator for numerical stability.
- Returns:
- FloatTensor of same shape as input.
- """
- def __init__(
- self,
- hidden_size: int,
- eps: float = 1e-5,
- device=None,
- ):
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size, device=device))
- self.eps = eps
- def forward(self, x):
- """
- Args:
- x: FloatTensor of shape `(batch_size, *)`.
- The input to apply root mean square layer normalization on.
- Returns:
- FloatTensor of same shape as input
- """
- # NOTE: in practice, many implementations will
- # manually upcast the input to fp32 here to prevent overflow when you
- # square the input.
- # https://github.com/pytorch/pytorch/issues/66707
- in_dtype = x.dtype
- x = x.to(torch.float32)
- rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
- x = x * rms
- return (self.weight * x).to(in_dtype)
-
- def extra_repr(self):
- return f"hidden_size={self.weight.shape[0]}, eps={self.eps}"
- class RotaryEmbedding(nn.Module):
- def __init__(self, context_length: int, dim: int, theta: float = 10000.0):
- super().__init__()
- self.register_buffer(
- "_freq_cis_cache",
- RotaryEmbedding._init_cache(context_length, dim, theta), persistent=False
- )
-
- @staticmethod
- def _init_cache(context_length: int, dim: int, theta: float) -> Float[Tensor, " 2 context_length half_dim"]:
- assert dim % 2 == 0
- d = torch.arange(0, dim, 2) / dim
- freqs = theta ** -d
- t = torch.arange(context_length)
- freqs = einsum(t, freqs, "t, f -> t f")
- cos, sin = torch.cos(freqs), torch.sin(freqs)
- return torch.stack((cos, sin))
- def forward(self, x: Float[Tensor, " ... seq d"], pos_ids: Int[Tensor, " ... seq"]) -> Float[Tensor, " ... seq d"]:
- x1, x2 = rearrange(x, '... (half_d xy) -> xy ... half_d', xy=2)
- # Standard
- # cos, sin = self._freq_cis_cache[:, pos_ids, :]
- # einx
- cos, sin = einx.get_at('cos_sin [pos] half_dim, ... -> cos_sin ... half_dim', self._freq_cis_cache, pos_ids)
- # 2D rotation matrix applied to pairs in x
- x1_rot = cos * x1 - sin * x2
- x2_rot = sin * x1 + cos * x2
- result = einx.rearrange('... x_half, ... x_half -> ... (x_half (1 + 1))', x1_rot, x2_rot).contiguous()
- return result
-
- def extra_repr(self):
- return f"context_length={self._freq_cis_cache.shape[0]}, dim/2={self._freq_cis_cache.shape[1]}"
- class BasicsTransformerLM(nn.Module):
- """A Transformer language model.
- Args:
- vocab_size: int
- The number of unique items in the output vocabulary to be predicted.
- context_length: int,
- The maximum number of tokens to process at once.
- d_model: int
- The dimensionality of the model embeddings and sublayer outputs.
- num_layers: int
- The number of Transformer layers to use.
- num_heads: int
- Number of heads to use in multi-headed attention. `d_model` must be
- evenly divisible by `num_heads`.
- d_ff: int
- Dimensionality of the feed-forward inner layer (section 3.3).
- rope_theta: float
- The theta value for the RoPE positional encoding.
- Returns:
- FloatTensor of shape (batch size, sequence_length, vocab_size) with the
- predicted unnormalized next-word distribution for each token.
- """
- def __init__(
- self,
- vocab_size: int,
- context_length: int,
- d_model: int,
- num_layers: int,
- num_heads: int,
- d_ff: int,
- rope_theta: float,
- ):
- # Store the model configuration for serialization / deserialization
- self.config = {
- k: v for k, v in locals().items() if k != "self" and not (k.startswith("__") and k.endswith("__"))
- }
- super().__init__()
- self.vocab_size = vocab_size
- self.context_length = context_length
- self.d_model = d_model
- self.token_embeddings = Embedding(vocab_size, d_model)
- d_head = d_model // num_heads
- self.positional_encoder = RotaryEmbedding(
- context_length=context_length,
- dim=d_head,
- theta=rope_theta
- )
- self.layers = nn.ModuleList(
- [
- TransformerBlock(
- d_model=d_model,
- num_heads=num_heads,
- d_ff=d_ff,
- positional_encoder=self.positional_encoder,
- )
- for _ in range(num_layers)
- ]
- )
- self.ln_final = RMSNorm(d_model)
- self.lm_head = Linear(d_model, vocab_size)
- # report number of parameters
- logger.info(f"number of non-embedding parameters: {self.get_num_params() / 1e6:.2f}M")
- def get_num_params(self, non_embedding=True):
- """
- Return the number of parameters in the model.
- For non-embedding count (default), the lm_head parameters get subtracted.
- """
- n_params = sum(p.numel() for p in self.parameters())
- if non_embedding:
- n_params -= self.lm_head.weight.numel()
- return n_params
- def forward(self, x: Int[Tensor, " ... sequence_length"]) -> Float[Tensor, " ... sequence_length vocab_size"]:
- """
- Args:
- x: Input IDs for language modeling.
- Returns: A FloatTensor of shape
- (batch size, sequence_length, vocab_size) with the predicted unnormalized next-word
- distribution for each token.
- """
- _, sequence_length = x.size()
- # (batch size, sequence_length, d_model)
- x = self.token_embeddings(x)
- for layer in self.layers:
- # (batch size, sequence_length, d_model)
- x = layer(x)
- # (batch size, sequence_length, d_model)
- x = self.ln_final(x)
- # (batch size, sequence_length, vocab_size)
- return self.lm_head(x), None
- @torch.no_grad()
- def generate(
- self,
- x: torch.Tensor,
- max_new_tokens: int,
- temperature: float = 1.0,
- top_k: int | None = None,
- eos_token_id: int | None = None,
- ):
- """
- Args:
- x: LongTensor of shape `(1, sequence_length,)` or `(sequence_length, )`.
- Input IDs to condition on when generating.
- max_new_tokens: int
- Maximum number of tokens to generate.
- temperature: float
- Temperature to use during generation.
- top_k: int
- If provided, only sample from the `top_k` vocab items (by probability).
- eos_token_id: int
- If provided, stop generation when we generate this ID.
- Returns: A LongTensor of shape (max_new_tokens,) with the generated model output.
- """
- if x.dim() == 1:
- x = x.unsqueeze(0)
-
- original_sequence_length = x.size(-1)
- for _ in range(max_new_tokens):
- # Take the last `context_length` tokens if the input is
- # beyond the model's context length
- x = x[:, -self.context_length :] if x.size(1) > self.context_length else x
- # Get the logits from the model
- logits, _ = self.forward(x)
- # Take the logits for the next token
- next_token_logits = logits[:, -1]
- # apply temperature scaling
- temperature_scaled_next_token_logits = next_token_logits / temperature
- # If top-k is provided, take the tokens with the highest score
- if top_k:
- topk_values, _ = torch.topk(
- temperature_scaled_next_token_logits,
- min(top_k, temperature_scaled_next_token_logits.size(-1)),
- )
- # Get the score of the kth item that we kept---items with lower scores should be masked.
- threshold = topk_values[:, -1]
- topk_mask = temperature_scaled_next_token_logits < threshold
- temperature_scaled_next_token_logits.masked_fill(topk_mask, float("-inf"))
- next_token_probabilities = softmax(temperature_scaled_next_token_logits, dim=-1)
- next_token_id = torch.multinomial(next_token_probabilities, 1)
- # End generation if we see the EOS token ID
- if eos_token_id is not None and next_token_id.item() == eos_token_id:
- break
- x = torch.cat((x, next_token_id), dim=-1)
- new_token_ids = x[:, original_sequence_length:]
- return new_token_ids
- @classmethod
- def from_pretrained(cls, pretrained_model_path: str):
- config_path = os.path.join(pretrained_model_path, "model_config.json")
- with open(config_path) as f:
- config = json.load(f)
- model = cls(**config)
- weights_path = os.path.join(pretrained_model_path, "model.pt")
- state_dict = torch.load(weights_path)
- # Remove _orig_mod. prefix that comes from serializing a compiled model
- unwanted_prefix = "_orig_mod."
- for k, _ in list(state_dict.items()):
- if k.startswith(unwanted_prefix):
- state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
- model.load_state_dict(state_dict)
- return model
- class TransformerBlock(nn.Module):
- """A single Transformer layer.
- This implements a single layer of the Transformer, as described in section 3.1
- of the paper.
- Args:
- d_model: int
- The dimensionality of the model embeddings and sublayer outputs.
- num_heads: int
- Number of heads to use in multi-headed attention. `d_model` must be
- evenly divisible by `num_heads`.
- d_ff: int
- Dimensionality of the feed-forward inner layer (section 3.3).
- positional_encoder: RotaryEmbedding
- The RoPE module to use.
- Returns:
- FloatTensor of shape `(batch_size, sequence_length, d_model)`.
- """
- def __init__(
- self,
- d_model: int,
- num_heads: int,
- d_ff: int,
- positional_encoder: RotaryEmbedding,
- ):
- super().__init__()
- self.attn = CausalMultiHeadSelfAttention(
- d_model=d_model,
- num_heads=num_heads,
- positional_encoder=positional_encoder,
- )
- self.ffn = SwiGLU(d_model=d_model, d_ff=d_ff)
- # self.ffn = SiLU(d_model)
- self.ln1 = RMSNorm(d_model)
- self.ln2 = RMSNorm(d_model)
- def forward(self, x: torch.Tensor):
- """
- Args:
- x: FloatTensor of shape `(batch_size, sequence_length, d_model)`.
- The input to process with the Transformer block.
- Returns:
- FloatTensor of shape `(batch_size, sequence_length, d_model)`.
- """
- # NOTE: this is a pre-norm Transformer, and differs from the original
- # description in the paper.
- # Apply the multi-head self-attention sublayer
- x_attn = self.attn(self.ln1(x))
- attn_sublayer_output = x + x_attn
- # Apply the feed-forward sublayer
- x_ffn = self.ffn(self.ln2(attn_sublayer_output))
- ffn_sublayer_output = attn_sublayer_output + x_ffn
- return ffn_sublayer_output
- # post norm
- # def forward(self, x: torch.Tensor):
- # """
- # Args:
- # x: FloatTensor of shape `(batch_size, sequence_length, d_model)`.
- # The input to process with the Transformer block.
- # Returns:
- # FloatTensor of shape `(batch_size, sequence_length, d_model)`.
- # """
- # # NOTE: this is a pre-norm Transformer, and differs from the original
- # # description in the paper.
- # # Apply the multi-head self-attention sublayer
- # # x_attn = self.attn(self.ln1(x))
- # x_attn = self.attn(x)
- # attn_sublayer_output = x + x_attn
- # attn_sublayer_output = self.ln1(attn_sublayer_output)
- # # Apply the feed-forward sublayer
- # # x_ffn = self.ffn(self.ln2(attn_sublayer_output))
- # x_ffn = self.ffn(attn_sublayer_output)
- # ffn_sublayer_output = attn_sublayer_output + x_ffn
- # ffn_sublayer_output = self.ln2(ffn_sublayer_output)
- # return ffn_sublayer_output
- # No RMSNorm
- # def forward(self, x: torch.Tensor):
- # """
- # Args:
- # x: FloatTensor of shape `(batch_size, sequence_length, d_model)`.
- # The input to process with the Transformer block.
- # Returns:
- # FloatTensor of shape `(batch_size, sequence_length, d_model)`.
- # """
- # # NOTE: this is a pre-norm Transformer, and differs from the original
- # # description in the paper.
- # # Apply the multi-head self-attention sublayer
- # x_attn = self.attn(x)
- # attn_sublayer_output = x + x_attn
- # # Apply the feed-forward sublayer
- # x_ffn = self.ffn(x)
- # ffn_sublayer_output = attn_sublayer_output + x_ffn
- # return ffn_sublayer_output
- class SwiGLU(nn.Module):
- def __init__(self, d_model: int, d_ff: int):
- super().__init__()
- self.w1 = Linear(d_model, d_ff)
- self.w2 = Linear(d_ff, d_model)
- self.w3 = Linear(d_model, d_ff)
- def forward(self, x):
- return self.w2(silu(self.w1(x)) * self.w3(x))
- class SiLU(nn.Module):
- def __init__(self, d_model: int):
- super().__init__()
- d_ff = 4 * d_model
- self.w1 = Linear(d_model, d_ff)
- self.w2 = Linear(d_ff, d_model)
- def forward(self, x):
- return self.w2(silu(self.w1(x)))
- def scaled_dot_product_attention(
- Q: Float[Tensor, " ... queries d_k"],
- K: Float[Tensor, " ... keys d_k"],
- V: Float[Tensor, " ... keys d_v"],
- mask: Bool[Tensor, " ... queries keys"] | None = None,
- ) -> Float[Tensor, " ... queries d_v"]:
- """Scaled dot-product attention.
- This function implements Eq. 1 of the Transformer paper.
- Args:
- Q: Tensor of queries, may have any number of leading dimensions.
- K: Tensor of keys, sharing leading dimensions with Q.
- V: Tensor of values, sharding leading dimensions with Q and K.
- mask: An (optional) mask of shape (..., seq_len, seq_len).
- Attention scores for positions with a mask value of `False` should
- be masked out, i.e., not affect the softmaxed attention probabilities.
- Returns:
- torch.FloatTensor of shape (..., seq_len, value_dimension)
- with the output of running your scaled dot product attention
- implementation with the provided key, query, and value tensors.
- """
- d_k = K.shape[-1]
- attention_scores = einsum(Q, K, "... query d_k, ... key d_k -> ... query key") / math.sqrt(d_k)
- if mask is not None:
- attention_scores = torch.where(mask, attention_scores, float("-inf"))
- attention_weights = softmax(attention_scores, dim=-1) # Softmax over the key dimension
- return einsum(attention_weights, V, "... query key, ... key d_v -> ... query d_v")
- class CausalMultiHeadSelfAttention(nn.Module):
- """Multi-Head Self-Attention
- This function implements section 3.2.2 of the Transformer paper. In particular,
- given an input tensor of shape `(batch_size, sequence_length, d_model)`, we project
- it to create queries, keys, and values, and then perform causal multi-headed attention with
- those queries, keys, and values.
- Args:
- d_model: int
- The dimensionality of the model embeddings and sublayer outputs.
- num_heads: int
- Number of heads to use in multi-headed attention. `d_model` must be
- evenly divisible by `num_heads`.
- positional_encoder: RotaryEmbedding
- The RoPE module to use.
- Returns:
- Tensor of shape `(batch_size, sequence_length, d_model)`.
- """
- def __init__(
- self,
- d_model: int,
- num_heads: int,
- positional_encoder: RotaryEmbedding,
- ):
- super().__init__()
- assert d_model % num_heads == 0
- self.d_model = d_model
- self.num_heads = num_heads
- self.d_k = d_model // num_heads
- self.d_v = self.d_k
- self.q_proj = Linear(self.d_model, self.num_heads * self.d_k)
- self.k_proj = Linear(self.d_model, self.num_heads * self.d_k)
- self.v_proj = Linear(self.d_model, self.num_heads * self.d_v)
- self.output_proj = Linear(self.num_heads * self.d_v, self.d_model)
- self.positional_encoder = positional_encoder # RoPE
- def forward(self, x: Float[Tensor, " ... seq d_k"], token_positions: Int[Tensor, " ... seq"] | None = None) -> Float[Tensor, " ... seq d_v"]:
- """
- Args:
- x: The input to perform multi-headed self-attention on.
- positional_ids: The positional indices along the sequence dimension of the input embeddings.
- Returns:
- Self-attention outputs.
- """
- *b, sequence_length, d_model = x.size()
- assert d_model == self.d_model
- Q = self.q_proj(x)
- K = self.k_proj(x)
- V = self.v_proj(x)
- # Take apart each head from the embedding dimension of Q, K, V to shape (..., num_heads, seq_len, d_k).
- Q, K, V = (
- rearrange(X, "... seq (heads d) -> ... heads seq d", heads=self.num_heads)
- for X in (Q, K, V)
- ) # fmt: skip
- if token_positions is None:
- token_positions = einx.rearrange("seq -> b... seq", torch.arange(sequence_length, device=x.device), b=[1] * len(b))
- # Duplicate token positions for each head
- token_positions = rearrange(token_positions, "... seq -> ... 1 seq")
- # Q = self.positional_encoder(Q, token_positions)
- # K = self.positional_encoder(K, token_positions)
- # Construct causal mask
- seq = torch.arange(sequence_length, device=x.device)
- qi = einx.rearrange('query -> b... 1 query 1', seq, b=[1] * len(b))
- kj = einx.rearrange('key -> b... 1 1 key', seq, b=[1] * len(b))
- causal_mask = qi >= kj # (query, key)
- # Shape: (..., num_heads, sequence_length, d_k)
- attn_output = scaled_dot_product_attention(K=K, Q=Q, V=V, mask=causal_mask)
- # Concatenate the attention output from all heads.
- # (..., sequence_length, num_heads * d_v).
- attn_output = rearrange(attn_output, "batch heads seq d_v -> batch seq (heads d_v)").contiguous()
- # Apply the output projection
- output = self.output_proj(attn_output)
- return output
- def silu(x: torch.Tensor):
- return x * torch.sigmoid(x)
|