basics.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. def repeat(x, rep):
  5. """
  6. GQA 共享 KV 需要repeat.
  7. (B, N_G, T, H) -> (B, N, T, H)
  8. rep: N // N_G
  9. """
  10. B, N_G, T, H = x.shape
  11. x = x.unsqueeze(2)
  12. x = x.expand(-1, -1, rep, -1, -1).contiguous()
  13. x = x.view(B, N_G * rep, T, H)
  14. return x
  15. class RMSNorm(nn.Module):
  16. def __init__(self, embed_dim, eps=1e-6):
  17. super(RMSNorm, self).__init__()
  18. self.weight = nn.Parameter(torch.ones(embed_dim))
  19. self.eps = eps
  20. def forward(self, x):
  21. input_dtype = x.dtype
  22. x = x.to(torch.float32)
  23. var = x.pow(2).mean(-1, keepdims=True)
  24. x = x * torch.rsqrt(var + self.eps)
  25. return self.weight * x.to(input_dtype)
  26. class SwiGLU(nn.Module):
  27. def __init__(self, embed_dim, immediate_dim, bias=False):
  28. super(SwiGLU, self).__init__()
  29. self.up_proj = nn.Linear(embed_dim, immediate_dim, bias=bias)
  30. self.down_proj = nn.Linear(immediate_dim, embed_dim, bias=bias)
  31. self.gate_proj = nn.Linear(embed_dim, immediate_dim, bias=bias)
  32. def forward(self, x):
  33. x, gate = self.up_proj(x), self.gate_proj(x)
  34. x = F.silu(gate) * x
  35. x = self.down_proj(x)
  36. return x
  37. def _compute_rope_params(config):
  38. base = config.rope_theta
  39. if 'Qwen' in config.architectures[0]:
  40. rope_dim = config.hidden_size // config.num_attention_heads
  41. elif 'Deepseek' in config.architectures[0]:
  42. rope_dim = config.qk_rope_head_dim
  43. inv_freq = 1.0 / (base ** (torch.arange(0, rope_dim, 2) / rope_dim)) # (dim // 2)
  44. T = config.max_position_embeddings
  45. position_ids_expanded = torch.arange(0, T).reshape(1, T) # (1, T)
  46. inv_freq_expanded = inv_freq.reshape(-1, 1) # (dim // 2, 1)
  47. # (dim // 2, T) -- transpose --> (T, dim // 2)
  48. freqs = (inv_freq_expanded @ position_ids_expanded.float()).transpose(0, 1)
  49. emb = torch.cat((freqs, freqs), dim=-1) # (T, dim)
  50. cos = emb.cos()
  51. sin = emb.sin()
  52. return cos, sin
  53. class RotaryEmbedding(nn.Module):
  54. def __init__(self, config):
  55. super().__init__()
  56. base, head_dim = config.rope_theta, config.hidden_size // config.num_attention_heads
  57. inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2) / head_dim))
  58. self.register_buffer("inv_freq", inv_freq, persistent=False) # (dim // 2)
  59. @torch.no_grad()
  60. def forward(self, x, position_ids):
  61. B, T = position_ids.shape
  62. inv_freq_expanded = self.inv_freq[None, :, None].expand(B, -1, 1) # (B, dim // 2, 1)
  63. position_ids_expanded = position_ids[:, None, :].float() # (B, 1, T)
  64. # (B, dim // 2, T) -- transpose --> (B, T, dim // 2)
  65. freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
  66. emb = torch.cat((freqs, freqs), dim=-1) # (B, T, dim)
  67. cos = emb.cos()
  68. sin = emb.sin()
  69. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  70. def rotate_half(x):
  71. """Rotates half the hidden dims of the input."""
  72. x1 = x[..., : x.shape[-1] // 2]
  73. x2 = x[..., x.shape[-1] // 2 :]
  74. return torch.cat((-x2, x1), dim=-1)
  75. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  76. cos = cos.unsqueeze(unsqueeze_dim)
  77. sin = sin.unsqueeze(unsqueeze_dim)
  78. q_embed = (q * cos) + (rotate_half(q) * sin)
  79. k_embed = (k * cos) + (rotate_half(k) * sin)
  80. return q_embed, k_embed