qwen2_5.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import os
  2. import math
  3. import json
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from types import SimpleNamespace
  8. from transformers import AutoModelForCausalLM
  9. from .basics import RMSNorm, SwiGLU, apply_rotary_pos_emb, _compute_rope_params, repeat
  10. class CasualGroupQueryAttention(nn.Module):
  11. def __init__(self, n_embd, n_head, n_group, proj_bias=False):
  12. super(CasualGroupQueryAttention, self).__init__()
  13. self.n_embd = n_embd
  14. self.n_head = n_head
  15. self.n_group = n_group
  16. self.hs = n_embd // n_head
  17. self.q_proj = nn.Linear(n_embd, n_embd)
  18. self.kv_proj = nn.Linear(n_embd, 2 * self.n_group * self.hs)
  19. self.o_proj = nn.Linear(n_embd, n_embd, bias=proj_bias)
  20. def forward(self, x, cos=None, sin=None):
  21. B, T, D = x.shape
  22. q = self.q_proj(x).view(B, T, self.n_head, self.hs).transpose(1, 2) # (B, N, T, H)
  23. k, v = self.kv_proj(x).chunk(2, dim=-1) # (B, T, N_G * H)
  24. k = k.view(B, T, self.n_group, self.hs).transpose(1, 2) # (B, N_G, T, H)
  25. v = v.view(B, T, self.n_group, self.hs).transpose(1, 2)
  26. k = repeat(k, self.n_head // self.n_group)
  27. v = repeat(v, self.n_head // self.n_group)
  28. if cos is not None and sin is not None:
  29. q, k = apply_rotary_pos_emb(q, k, cos, sin)
  30. attn = q @ k.transpose(-1,-2) / math.sqrt(self.hs)
  31. mask = torch.tril(torch.ones(T, T)).view(1, 1, T, T).to(attn.device)
  32. attn = attn.masked_fill(mask[:,:,:T,:T] == 0, float('-inf'))
  33. attn = F.softmax(attn, dim=-1) # (B, N, T, T)
  34. o = attn @ v # (B, N, T, H)
  35. o = o.transpose(1,2).contiguous().view(B, T, D)
  36. o = self.o_proj(o)
  37. return o
  38. class Block(nn.Module):
  39. def __init__(self, n_embd, n_head, n_group, immediate_dim):
  40. super(Block, self).__init__()
  41. self.attn = CasualGroupQueryAttention(n_embd, n_head, n_group)
  42. self.ffn = SwiGLU(n_embd, immediate_dim)
  43. self.ln1 = RMSNorm(n_embd)
  44. self.ln2 = RMSNorm(n_embd)
  45. def forward(self, x, cos=None, sin=None):
  46. x = x + self.attn(self.ln1(x), cos, sin) # pre-norm
  47. x = x + self.ffn(self.ln2(x))
  48. return x
  49. class Qwen2_5(nn.Module):
  50. def __init__(self, config_dict):
  51. super(Qwen2_5, self).__init__()
  52. config = SimpleNamespace(**config_dict)
  53. self.block_size = config.max_position_embeddings
  54. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
  55. self.layers = nn.ModuleList(
  56. [Block(n_embd=config.hidden_size, n_head=config.num_attention_heads, n_group=config.num_key_value_heads, immediate_dim=config.intermediate_size) for _ in range(config.num_hidden_layers)]
  57. )
  58. self.last_norm = RMSNorm(embed_dim=config.hidden_size)
  59. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  60. if config.tie_word_embeddings:
  61. self.embed_tokens.weight = self.lm_head.weight
  62. cos_cached, sin_cached = _compute_rope_params(config)
  63. self.register_buffer('cos_cached', cos_cached)
  64. self.register_buffer('sin_cached', sin_cached)
  65. @torch.no_grad()
  66. def generate(self, input_ids, max_new_tokens, eos_token_id=None, temperature=1.0, top_k=50):
  67. idx = input_ids
  68. for _ in range(max_new_tokens):
  69. idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
  70. logits, _ = self(idx_cond)
  71. logits = logits[:, -1, :] / temperature
  72. if top_k is not None:
  73. v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
  74. logits[logits < v[:, [-1]]] = -float('Inf')
  75. probs = F.softmax(logits, dim=-1)
  76. idx_next = torch.multinomial(probs, num_samples=1)
  77. idx = torch.cat((idx, idx_next), dim=1)
  78. if eos_token_id is not None:
  79. if idx_next.item() == eos_token_id:
  80. break
  81. return idx
  82. def forward(self, input_ids):
  83. B, T = input_ids.shape
  84. x = self.embed_tokens(input_ids) # (B, T, D)
  85. cos = self.cos_cached[:T, :].unsqueeze(0).expand(B, -1, -1)
  86. sin = self.sin_cached[:T, :].unsqueeze(0).expand(B, -1, -1)
  87. for layer in self.layers:
  88. x = layer(x, cos, sin)
  89. x = self.last_norm(x)
  90. logits = self.lm_head(x)
  91. return logits, None
  92. @classmethod
  93. def from_config(cls, config): # <class 'omegaconf.dictconfig.DictConfig'>
  94. model = cls(config)
  95. total_params = sum(p.numel() for p in model.parameters())
  96. print(f"Total parameters: {total_params:,}")
  97. return model
  98. @classmethod
  99. def from_pretrained(cls, model_path):
  100. config_path = os.path.join(model_path, "config.json")
  101. with open(config_path, "r") as f:
  102. config = json.load(f)
  103. model = Qwen2_5(config)
  104. sd = model.state_dict()
  105. sd_keys = sd.keys()
  106. model_hf = AutoModelForCausalLM.from_pretrained(
  107. model_path,
  108. torch_dtype="auto",
  109. device_map="auto"
  110. )
  111. sd_hf = model_hf.state_dict()
  112. key_map = {'embed_tokens': 'embed_tokens', 'attn': 'self_attn', 'q_proj': 'q_proj', 'o_proj': 'o_proj', 'ffn': 'mlp',
  113. 'ln1': 'input_layernorm', 'ln2': 'post_attention_layernorm', 'last_norm': 'norm'}
  114. def to_hf_key(key):
  115. components = key.split('.')
  116. for i, c in enumerate(components):
  117. if c in key_map.keys():
  118. components[i] = key_map[c]
  119. if not key == 'lm_head.weight':
  120. key = 'model.' + '.'.join(components)
  121. return key
  122. for key in sd_keys:
  123. if key in ['cos_cached', 'sin_cached']:
  124. continue
  125. hf_key = to_hf_key(key)
  126. if 'kv_proj' in hf_key:
  127. hf_key_k, hf_key_v = hf_key.replace('kv_proj', 'k_proj'), hf_key.replace('kv_proj', 'v_proj')
  128. sd[key].copy_(torch.concat((sd_hf[hf_key_k], sd_hf[hf_key_v]), dim=0))
  129. else:
  130. # print("=" * 20)
  131. # print(key, hf_key)
  132. # print(sd[key].shape, sd_hf[hf_key].shape)
  133. sd[key].copy_(sd_hf[hf_key])
  134. return model