cs336_lm.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586
  1. from __future__ import annotations
  2. import os
  3. import json
  4. import einx
  5. import math
  6. import torch
  7. import logging
  8. import torch.nn as nn
  9. from torch import Tensor
  10. from einops import rearrange, einsum
  11. from jaxtyping import Float, Bool, Int
  12. logger = logging.getLogger(__name__)
  13. def softmax(x, dim=-1):
  14. rescaled_input = x - torch.max(x, dim=dim, keepdim=True)[0]
  15. exponentiated_rescaled_input = torch.exp(rescaled_input)
  16. return exponentiated_rescaled_input / torch.sum(exponentiated_rescaled_input, dim=dim, keepdim=True)
  17. class Linear(nn.Module):
  18. def __init__(self, d_in: int, d_out: int):
  19. """A linear layer initialized with truncated normal fan-in fan-out.
  20. Args:
  21. d_in: int
  22. The number of input features.
  23. d_out: int
  24. The number of output features.
  25. """
  26. super().__init__()
  27. std = math.sqrt(2 / (d_in + d_out))
  28. self.weight: Float[Tensor, " d_out d_in"] = nn.Parameter(
  29. nn.init.trunc_normal_(torch.empty(d_out, d_in), std=std, a=-3*std, b=3*std),
  30. requires_grad=True
  31. )
  32. def forward(self, x: Float[Tensor, " ... d_in"]) -> Float[Tensor, " ... d_out"]:
  33. return einsum(x, self.weight, "... d_in, d_out d_in -> ... d_out")
  34. def extra_repr(self):
  35. return f"d_out={self.weight.shape[0]}, d_in={self.weight.shape[1]}"
  36. class Embedding(nn.Module):
  37. def __init__(self, vocab_size: int, d_model: int):
  38. super().__init__()
  39. std = 1.0
  40. self.weight = nn.Parameter(
  41. nn.init.trunc_normal_(torch.empty(vocab_size, d_model), std=std, a=-3 * std, b=3 * std),
  42. requires_grad=True
  43. )
  44. def forward(self, token_ids: Int[Tensor, " ..."]) -> Float[Tensor, " ... d_model"]:
  45. return self.weight[token_ids, :]
  46. def extra_repr(self):
  47. return f"vocab_size={self.weight.shape[0]}, d={self.weight.shape[1]}"
  48. class RMSNorm(nn.Module):
  49. """
  50. This module implements root mean square layer normalization, as
  51. described in Eq. 4 of https://arxiv.org/abs/1910.07467
  52. Args:
  53. hidden_size: int
  54. Dimensionality of the input to normalize.
  55. eps: float, default is 1e-5
  56. A value added to the denominator for numerical stability.
  57. Returns:
  58. FloatTensor of same shape as input.
  59. """
  60. def __init__(
  61. self,
  62. hidden_size: int,
  63. eps: float = 1e-5,
  64. device=None,
  65. ):
  66. super().__init__()
  67. self.weight = nn.Parameter(torch.ones(hidden_size, device=device))
  68. self.eps = eps
  69. def forward(self, x):
  70. """
  71. Args:
  72. x: FloatTensor of shape `(batch_size, *)`.
  73. The input to apply root mean square layer normalization on.
  74. Returns:
  75. FloatTensor of same shape as input
  76. """
  77. # NOTE: in practice, many implementations will
  78. # manually upcast the input to fp32 here to prevent overflow when you
  79. # square the input.
  80. # https://github.com/pytorch/pytorch/issues/66707
  81. in_dtype = x.dtype
  82. x = x.to(torch.float32)
  83. rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  84. x = x * rms
  85. return (self.weight * x).to(in_dtype)
  86. def extra_repr(self):
  87. return f"hidden_size={self.weight.shape[0]}, eps={self.eps}"
  88. class RotaryEmbedding(nn.Module):
  89. def __init__(self, context_length: int, dim: int, theta: float = 10000.0):
  90. super().__init__()
  91. self.register_buffer(
  92. "_freq_cis_cache",
  93. RotaryEmbedding._init_cache(context_length, dim, theta), persistent=False
  94. )
  95. @staticmethod
  96. def _init_cache(context_length: int, dim: int, theta: float) -> Float[Tensor, " 2 context_length half_dim"]:
  97. assert dim % 2 == 0
  98. d = torch.arange(0, dim, 2) / dim
  99. freqs = theta ** -d
  100. t = torch.arange(context_length)
  101. freqs = einsum(t, freqs, "t, f -> t f")
  102. cos, sin = torch.cos(freqs), torch.sin(freqs)
  103. return torch.stack((cos, sin))
  104. def forward(self, x: Float[Tensor, " ... seq d"], pos_ids: Int[Tensor, " ... seq"]) -> Float[Tensor, " ... seq d"]:
  105. x1, x2 = rearrange(x, '... (half_d xy) -> xy ... half_d', xy=2)
  106. # Standard
  107. # cos, sin = self._freq_cis_cache[:, pos_ids, :]
  108. # einx
  109. cos, sin = einx.get_at('cos_sin [pos] half_dim, ... -> cos_sin ... half_dim', self._freq_cis_cache, pos_ids)
  110. # 2D rotation matrix applied to pairs in x
  111. x1_rot = cos * x1 - sin * x2
  112. x2_rot = sin * x1 + cos * x2
  113. result = einx.rearrange('... x_half, ... x_half -> ... (x_half (1 + 1))', x1_rot, x2_rot).contiguous()
  114. return result
  115. def extra_repr(self):
  116. return f"context_length={self._freq_cis_cache.shape[0]}, dim/2={self._freq_cis_cache.shape[1]}"
  117. class BasicsTransformerLM(nn.Module):
  118. """A Transformer language model.
  119. Args:
  120. vocab_size: int
  121. The number of unique items in the output vocabulary to be predicted.
  122. context_length: int,
  123. The maximum number of tokens to process at once.
  124. d_model: int
  125. The dimensionality of the model embeddings and sublayer outputs.
  126. num_layers: int
  127. The number of Transformer layers to use.
  128. num_heads: int
  129. Number of heads to use in multi-headed attention. `d_model` must be
  130. evenly divisible by `num_heads`.
  131. d_ff: int
  132. Dimensionality of the feed-forward inner layer (section 3.3).
  133. rope_theta: float
  134. The theta value for the RoPE positional encoding.
  135. Returns:
  136. FloatTensor of shape (batch size, sequence_length, vocab_size) with the
  137. predicted unnormalized next-word distribution for each token.
  138. """
  139. def __init__(
  140. self,
  141. vocab_size: int,
  142. context_length: int,
  143. d_model: int,
  144. num_layers: int,
  145. num_heads: int,
  146. d_ff: int,
  147. rope_theta: float,
  148. ):
  149. # Store the model configuration for serialization / deserialization
  150. self.config = {
  151. k: v for k, v in locals().items() if k != "self" and not (k.startswith("__") and k.endswith("__"))
  152. }
  153. super().__init__()
  154. self.vocab_size = vocab_size
  155. self.context_length = context_length
  156. self.d_model = d_model
  157. self.token_embeddings = Embedding(vocab_size, d_model)
  158. d_head = d_model // num_heads
  159. self.positional_encoder = RotaryEmbedding(
  160. context_length=context_length,
  161. dim=d_head,
  162. theta=rope_theta
  163. )
  164. self.layers = nn.ModuleList(
  165. [
  166. TransformerBlock(
  167. d_model=d_model,
  168. num_heads=num_heads,
  169. d_ff=d_ff,
  170. positional_encoder=self.positional_encoder,
  171. )
  172. for _ in range(num_layers)
  173. ]
  174. )
  175. self.ln_final = RMSNorm(d_model)
  176. self.lm_head = Linear(d_model, vocab_size)
  177. # report number of parameters
  178. logger.info(f"number of non-embedding parameters: {self.get_num_params() / 1e6:.2f}M")
  179. def get_num_params(self, non_embedding=True):
  180. """
  181. Return the number of parameters in the model.
  182. For non-embedding count (default), the lm_head parameters get subtracted.
  183. """
  184. n_params = sum(p.numel() for p in self.parameters())
  185. if non_embedding:
  186. n_params -= self.lm_head.weight.numel()
  187. return n_params
  188. def forward(self, x: Int[Tensor, " ... sequence_length"]) -> Float[Tensor, " ... sequence_length vocab_size"]:
  189. """
  190. Args:
  191. x: Input IDs for language modeling.
  192. Returns: A FloatTensor of shape
  193. (batch size, sequence_length, vocab_size) with the predicted unnormalized next-word
  194. distribution for each token.
  195. """
  196. _, sequence_length = x.size()
  197. # (batch size, sequence_length, d_model)
  198. x = self.token_embeddings(x)
  199. for layer in self.layers:
  200. # (batch size, sequence_length, d_model)
  201. x = layer(x)
  202. # (batch size, sequence_length, d_model)
  203. x = self.ln_final(x)
  204. # (batch size, sequence_length, vocab_size)
  205. return self.lm_head(x), None
  206. @torch.no_grad()
  207. def generate(
  208. self,
  209. x: torch.Tensor,
  210. max_new_tokens: int,
  211. temperature: float = 1.0,
  212. top_k: int | None = None,
  213. eos_token_id: int | None = None,
  214. ):
  215. """
  216. Args:
  217. x: LongTensor of shape `(1, sequence_length,)` or `(sequence_length, )`.
  218. Input IDs to condition on when generating.
  219. max_new_tokens: int
  220. Maximum number of tokens to generate.
  221. temperature: float
  222. Temperature to use during generation.
  223. top_k: int
  224. If provided, only sample from the `top_k` vocab items (by probability).
  225. eos_token_id: int
  226. If provided, stop generation when we generate this ID.
  227. Returns: A LongTensor of shape (max_new_tokens,) with the generated model output.
  228. """
  229. if x.dim() == 1:
  230. x = x.unsqueeze(0)
  231. original_sequence_length = x.size(-1)
  232. for _ in range(max_new_tokens):
  233. # Take the last `context_length` tokens if the input is
  234. # beyond the model's context length
  235. x = x[:, -self.context_length :] if x.size(1) > self.context_length else x
  236. # Get the logits from the model
  237. logits, _ = self.forward(x)
  238. # Take the logits for the next token
  239. next_token_logits = logits[:, -1]
  240. # apply temperature scaling
  241. temperature_scaled_next_token_logits = next_token_logits / temperature
  242. # If top-k is provided, take the tokens with the highest score
  243. if top_k:
  244. topk_values, _ = torch.topk(
  245. temperature_scaled_next_token_logits,
  246. min(top_k, temperature_scaled_next_token_logits.size(-1)),
  247. )
  248. # Get the score of the kth item that we kept---items with lower scores should be masked.
  249. threshold = topk_values[:, -1]
  250. topk_mask = temperature_scaled_next_token_logits < threshold
  251. temperature_scaled_next_token_logits.masked_fill(topk_mask, float("-inf"))
  252. next_token_probabilities = softmax(temperature_scaled_next_token_logits, dim=-1)
  253. next_token_id = torch.multinomial(next_token_probabilities, 1)
  254. # End generation if we see the EOS token ID
  255. if eos_token_id is not None and next_token_id.item() == eos_token_id:
  256. break
  257. x = torch.cat((x, next_token_id), dim=-1)
  258. new_token_ids = x[:, original_sequence_length:]
  259. return new_token_ids
  260. @classmethod
  261. def from_pretrained(cls, pretrained_model_path: str):
  262. config_path = os.path.join(pretrained_model_path, "model_config.json")
  263. with open(config_path) as f:
  264. config = json.load(f)
  265. model = cls(**config)
  266. weights_path = os.path.join(pretrained_model_path, "model.pt")
  267. state_dict = torch.load(weights_path)
  268. # Remove _orig_mod. prefix that comes from serializing a compiled model
  269. unwanted_prefix = "_orig_mod."
  270. for k, _ in list(state_dict.items()):
  271. if k.startswith(unwanted_prefix):
  272. state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
  273. model.load_state_dict(state_dict)
  274. return model
  275. class TransformerBlock(nn.Module):
  276. """A single Transformer layer.
  277. This implements a single layer of the Transformer, as described in section 3.1
  278. of the paper.
  279. Args:
  280. d_model: int
  281. The dimensionality of the model embeddings and sublayer outputs.
  282. num_heads: int
  283. Number of heads to use in multi-headed attention. `d_model` must be
  284. evenly divisible by `num_heads`.
  285. d_ff: int
  286. Dimensionality of the feed-forward inner layer (section 3.3).
  287. positional_encoder: RotaryEmbedding
  288. The RoPE module to use.
  289. Returns:
  290. FloatTensor of shape `(batch_size, sequence_length, d_model)`.
  291. """
  292. def __init__(
  293. self,
  294. d_model: int,
  295. num_heads: int,
  296. d_ff: int,
  297. positional_encoder: RotaryEmbedding,
  298. ):
  299. super().__init__()
  300. self.attn = CausalMultiHeadSelfAttention(
  301. d_model=d_model,
  302. num_heads=num_heads,
  303. positional_encoder=positional_encoder,
  304. )
  305. self.ffn = SwiGLU(d_model=d_model, d_ff=d_ff)
  306. # self.ffn = SiLU(d_model)
  307. self.ln1 = RMSNorm(d_model)
  308. self.ln2 = RMSNorm(d_model)
  309. def forward(self, x: torch.Tensor):
  310. """
  311. Args:
  312. x: FloatTensor of shape `(batch_size, sequence_length, d_model)`.
  313. The input to process with the Transformer block.
  314. Returns:
  315. FloatTensor of shape `(batch_size, sequence_length, d_model)`.
  316. """
  317. # NOTE: this is a pre-norm Transformer, and differs from the original
  318. # description in the paper.
  319. # Apply the multi-head self-attention sublayer
  320. x_attn = self.attn(self.ln1(x))
  321. attn_sublayer_output = x + x_attn
  322. # Apply the feed-forward sublayer
  323. x_ffn = self.ffn(self.ln2(attn_sublayer_output))
  324. ffn_sublayer_output = attn_sublayer_output + x_ffn
  325. return ffn_sublayer_output
  326. # post norm
  327. # def forward(self, x: torch.Tensor):
  328. # """
  329. # Args:
  330. # x: FloatTensor of shape `(batch_size, sequence_length, d_model)`.
  331. # The input to process with the Transformer block.
  332. # Returns:
  333. # FloatTensor of shape `(batch_size, sequence_length, d_model)`.
  334. # """
  335. # # NOTE: this is a pre-norm Transformer, and differs from the original
  336. # # description in the paper.
  337. # # Apply the multi-head self-attention sublayer
  338. # # x_attn = self.attn(self.ln1(x))
  339. # x_attn = self.attn(x)
  340. # attn_sublayer_output = x + x_attn
  341. # attn_sublayer_output = self.ln1(attn_sublayer_output)
  342. # # Apply the feed-forward sublayer
  343. # # x_ffn = self.ffn(self.ln2(attn_sublayer_output))
  344. # x_ffn = self.ffn(attn_sublayer_output)
  345. # ffn_sublayer_output = attn_sublayer_output + x_ffn
  346. # ffn_sublayer_output = self.ln2(ffn_sublayer_output)
  347. # return ffn_sublayer_output
  348. # No RMSNorm
  349. # def forward(self, x: torch.Tensor):
  350. # """
  351. # Args:
  352. # x: FloatTensor of shape `(batch_size, sequence_length, d_model)`.
  353. # The input to process with the Transformer block.
  354. # Returns:
  355. # FloatTensor of shape `(batch_size, sequence_length, d_model)`.
  356. # """
  357. # # NOTE: this is a pre-norm Transformer, and differs from the original
  358. # # description in the paper.
  359. # # Apply the multi-head self-attention sublayer
  360. # x_attn = self.attn(x)
  361. # attn_sublayer_output = x + x_attn
  362. # # Apply the feed-forward sublayer
  363. # x_ffn = self.ffn(x)
  364. # ffn_sublayer_output = attn_sublayer_output + x_ffn
  365. # return ffn_sublayer_output
  366. class SwiGLU(nn.Module):
  367. def __init__(self, d_model: int, d_ff: int):
  368. super().__init__()
  369. self.w1 = Linear(d_model, d_ff)
  370. self.w2 = Linear(d_ff, d_model)
  371. self.w3 = Linear(d_model, d_ff)
  372. def forward(self, x):
  373. return self.w2(silu(self.w1(x)) * self.w3(x))
  374. class SiLU(nn.Module):
  375. def __init__(self, d_model: int):
  376. super().__init__()
  377. d_ff = 4 * d_model
  378. self.w1 = Linear(d_model, d_ff)
  379. self.w2 = Linear(d_ff, d_model)
  380. def forward(self, x):
  381. return self.w2(silu(self.w1(x)))
  382. def scaled_dot_product_attention(
  383. Q: Float[Tensor, " ... queries d_k"],
  384. K: Float[Tensor, " ... keys d_k"],
  385. V: Float[Tensor, " ... keys d_v"],
  386. mask: Bool[Tensor, " ... queries keys"] | None = None,
  387. ) -> Float[Tensor, " ... queries d_v"]:
  388. """Scaled dot-product attention.
  389. This function implements Eq. 1 of the Transformer paper.
  390. Args:
  391. Q: Tensor of queries, may have any number of leading dimensions.
  392. K: Tensor of keys, sharing leading dimensions with Q.
  393. V: Tensor of values, sharding leading dimensions with Q and K.
  394. mask: An (optional) mask of shape (..., seq_len, seq_len).
  395. Attention scores for positions with a mask value of `False` should
  396. be masked out, i.e., not affect the softmaxed attention probabilities.
  397. Returns:
  398. torch.FloatTensor of shape (..., seq_len, value_dimension)
  399. with the output of running your scaled dot product attention
  400. implementation with the provided key, query, and value tensors.
  401. """
  402. d_k = K.shape[-1]
  403. attention_scores = einsum(Q, K, "... query d_k, ... key d_k -> ... query key") / math.sqrt(d_k)
  404. if mask is not None:
  405. attention_scores = torch.where(mask, attention_scores, float("-inf"))
  406. attention_weights = softmax(attention_scores, dim=-1) # Softmax over the key dimension
  407. return einsum(attention_weights, V, "... query key, ... key d_v -> ... query d_v")
  408. class CausalMultiHeadSelfAttention(nn.Module):
  409. """Multi-Head Self-Attention
  410. This function implements section 3.2.2 of the Transformer paper. In particular,
  411. given an input tensor of shape `(batch_size, sequence_length, d_model)`, we project
  412. it to create queries, keys, and values, and then perform causal multi-headed attention with
  413. those queries, keys, and values.
  414. Args:
  415. d_model: int
  416. The dimensionality of the model embeddings and sublayer outputs.
  417. num_heads: int
  418. Number of heads to use in multi-headed attention. `d_model` must be
  419. evenly divisible by `num_heads`.
  420. positional_encoder: RotaryEmbedding
  421. The RoPE module to use.
  422. Returns:
  423. Tensor of shape `(batch_size, sequence_length, d_model)`.
  424. """
  425. def __init__(
  426. self,
  427. d_model: int,
  428. num_heads: int,
  429. positional_encoder: RotaryEmbedding,
  430. ):
  431. super().__init__()
  432. assert d_model % num_heads == 0
  433. self.d_model = d_model
  434. self.num_heads = num_heads
  435. self.d_k = d_model // num_heads
  436. self.d_v = self.d_k
  437. self.q_proj = Linear(self.d_model, self.num_heads * self.d_k)
  438. self.k_proj = Linear(self.d_model, self.num_heads * self.d_k)
  439. self.v_proj = Linear(self.d_model, self.num_heads * self.d_v)
  440. self.output_proj = Linear(self.num_heads * self.d_v, self.d_model)
  441. self.positional_encoder = positional_encoder # RoPE
  442. def forward(self, x: Float[Tensor, " ... seq d_k"], token_positions: Int[Tensor, " ... seq"] | None = None) -> Float[Tensor, " ... seq d_v"]:
  443. """
  444. Args:
  445. x: The input to perform multi-headed self-attention on.
  446. positional_ids: The positional indices along the sequence dimension of the input embeddings.
  447. Returns:
  448. Self-attention outputs.
  449. """
  450. *b, sequence_length, d_model = x.size()
  451. assert d_model == self.d_model
  452. Q = self.q_proj(x)
  453. K = self.k_proj(x)
  454. V = self.v_proj(x)
  455. # Take apart each head from the embedding dimension of Q, K, V to shape (..., num_heads, seq_len, d_k).
  456. Q, K, V = (
  457. rearrange(X, "... seq (heads d) -> ... heads seq d", heads=self.num_heads)
  458. for X in (Q, K, V)
  459. ) # fmt: skip
  460. if token_positions is None:
  461. token_positions = einx.rearrange("seq -> b... seq", torch.arange(sequence_length, device=x.device), b=[1] * len(b))
  462. # Duplicate token positions for each head
  463. token_positions = rearrange(token_positions, "... seq -> ... 1 seq")
  464. # Q = self.positional_encoder(Q, token_positions)
  465. # K = self.positional_encoder(K, token_positions)
  466. # Construct causal mask
  467. seq = torch.arange(sequence_length, device=x.device)
  468. qi = einx.rearrange('query -> b... 1 query 1', seq, b=[1] * len(b))
  469. kj = einx.rearrange('key -> b... 1 1 key', seq, b=[1] * len(b))
  470. causal_mask = qi >= kj # (query, key)
  471. # Shape: (..., num_heads, sequence_length, d_k)
  472. attn_output = scaled_dot_product_attention(K=K, Q=Q, V=V, mask=causal_mask)
  473. # Concatenate the attention output from all heads.
  474. # (..., sequence_length, num_heads * d_v).
  475. attn_output = rearrange(attn_output, "batch heads seq d_v -> batch seq (heads d_v)").contiguous()
  476. # Apply the output projection
  477. output = self.output_proj(attn_output)
  478. return output
  479. def silu(x: torch.Tensor):
  480. return x * torch.sigmoid(x)