| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- import os
- import math
- import json
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from types import SimpleNamespace
- from transformers import AutoModelForCausalLM
- from .basics import RMSNorm, SwiGLU, apply_rotary_pos_emb, _compute_rope_params, repeat
- class CasualGroupQueryAttention(nn.Module):
- def __init__(self, n_embd, n_head, n_group, proj_bias=False):
- super(CasualGroupQueryAttention, self).__init__()
- self.n_embd = n_embd
- self.n_head = n_head
- self.n_group = n_group
- self.hs = n_embd // n_head
- self.q_proj = nn.Linear(n_embd, n_embd)
- self.kv_proj = nn.Linear(n_embd, 2 * self.n_group * self.hs)
- self.o_proj = nn.Linear(n_embd, n_embd, bias=proj_bias)
- def forward(self, x, cos=None, sin=None):
- B, T, D = x.shape
- q = self.q_proj(x).view(B, T, self.n_head, self.hs).transpose(1, 2) # (B, N, T, H)
- k, v = self.kv_proj(x).chunk(2, dim=-1) # (B, T, N_G * H)
- k = k.view(B, T, self.n_group, self.hs).transpose(1, 2) # (B, N_G, T, H)
- v = v.view(B, T, self.n_group, self.hs).transpose(1, 2)
-
- k = repeat(k, self.n_head // self.n_group)
- v = repeat(v, self.n_head // self.n_group)
- if cos is not None and sin is not None:
- q, k = apply_rotary_pos_emb(q, k, cos, sin)
-
- attn = q @ k.transpose(-1,-2) / math.sqrt(self.hs)
- mask = torch.tril(torch.ones(T, T)).view(1, 1, T, T).to(attn.device)
- attn = attn.masked_fill(mask[:,:,:T,:T] == 0, float('-inf'))
- attn = F.softmax(attn, dim=-1) # (B, N, T, T)
- o = attn @ v # (B, N, T, H)
- o = o.transpose(1,2).contiguous().view(B, T, D)
- o = self.o_proj(o)
-
- return o
- class Block(nn.Module):
- def __init__(self, n_embd, n_head, n_group, immediate_dim):
- super(Block, self).__init__()
- self.attn = CasualGroupQueryAttention(n_embd, n_head, n_group)
- self.ffn = SwiGLU(n_embd, immediate_dim)
- self.ln1 = RMSNorm(n_embd)
- self.ln2 = RMSNorm(n_embd)
- def forward(self, x, cos=None, sin=None):
- x = x + self.attn(self.ln1(x), cos, sin) # pre-norm
- x = x + self.ffn(self.ln2(x))
- return x
- class Qwen2_5(nn.Module):
- def __init__(self, config_dict):
- super(Qwen2_5, self).__init__()
- config = SimpleNamespace(**config_dict)
- self.block_size = config.max_position_embeddings
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
- self.layers = nn.ModuleList(
- [Block(n_embd=config.hidden_size, n_head=config.num_attention_heads, n_group=config.num_key_value_heads, immediate_dim=config.intermediate_size) for _ in range(config.num_hidden_layers)]
- )
- self.last_norm = RMSNorm(embed_dim=config.hidden_size)
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- if config.tie_word_embeddings:
- self.embed_tokens.weight = self.lm_head.weight
- cos_cached, sin_cached = _compute_rope_params(config)
- self.register_buffer('cos_cached', cos_cached)
- self.register_buffer('sin_cached', sin_cached)
- @torch.no_grad()
- def generate(self, input_ids, max_new_tokens, eos_token_id=None, temperature=1.0, top_k=50):
- idx = input_ids
- for _ in range(max_new_tokens):
- idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
- logits, _ = self(idx_cond)
- logits = logits[:, -1, :] / temperature
- if top_k is not None:
- v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
- logits[logits < v[:, [-1]]] = -float('Inf')
- probs = F.softmax(logits, dim=-1)
- idx_next = torch.multinomial(probs, num_samples=1)
- idx = torch.cat((idx, idx_next), dim=1)
-
- if eos_token_id is not None:
- if idx_next.item() == eos_token_id:
- break
- return idx
- def forward(self, input_ids):
- B, T = input_ids.shape
- x = self.embed_tokens(input_ids) # (B, T, D)
- cos = self.cos_cached[:T, :].unsqueeze(0).expand(B, -1, -1)
- sin = self.sin_cached[:T, :].unsqueeze(0).expand(B, -1, -1)
- for layer in self.layers:
- x = layer(x, cos, sin)
-
- x = self.last_norm(x)
- logits = self.lm_head(x)
- return logits, None
- @classmethod
- def from_config(cls, config): # <class 'omegaconf.dictconfig.DictConfig'>
- model = cls(config)
- total_params = sum(p.numel() for p in model.parameters())
- print(f"Total parameters: {total_params:,}")
- return model
- @classmethod
- def from_pretrained(cls, model_path):
- config_path = os.path.join(model_path, "config.json")
- with open(config_path, "r") as f:
- config = json.load(f)
-
- model = Qwen2_5(config)
- sd = model.state_dict()
- sd_keys = sd.keys()
- model_hf = AutoModelForCausalLM.from_pretrained(
- model_path,
- torch_dtype="auto",
- device_map="auto"
- )
- sd_hf = model_hf.state_dict()
- key_map = {'embed_tokens': 'embed_tokens', 'attn': 'self_attn', 'q_proj': 'q_proj', 'o_proj': 'o_proj', 'ffn': 'mlp',
- 'ln1': 'input_layernorm', 'ln2': 'post_attention_layernorm', 'last_norm': 'norm'}
- def to_hf_key(key):
- components = key.split('.')
- for i, c in enumerate(components):
- if c in key_map.keys():
- components[i] = key_map[c]
- if not key == 'lm_head.weight':
- key = 'model.' + '.'.join(components)
-
- return key
- for key in sd_keys:
- if key in ['cos_cached', 'sin_cached']:
- continue
- hf_key = to_hf_key(key)
- if 'kv_proj' in hf_key:
- hf_key_k, hf_key_v = hf_key.replace('kv_proj', 'k_proj'), hf_key.replace('kv_proj', 'v_proj')
- sd[key].copy_(torch.concat((sd_hf[hf_key_k], sd_hf[hf_key_v]), dim=0))
- else:
- # print("=" * 20)
- # print(key, hf_key)
- # print(sd[key].shape, sd_hf[hf_key].shape)
- sd[key].copy_(sd_hf[hf_key])
-
- return model
|