| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- import os
- import sys
- import mlflow
- import torch
- import torch.nn.functional as F
- import pathlib
- import numpy as np
- from tqdm import tqdm
- from hydra.core.hydra_config import HydraConfig
- from .adapters import *
- def get_memmap_dataset(path, dtype=np.int32):
- arr = np.memmap(path, dtype=dtype, mode="r") # 单列token id序列
- return arr
- def get_batch(memmap_arr, batch_size, context_length):
- N = len(memmap_arr)
- ix = np.random.randint(0, N-context_length-1, size=(batch_size,))
- x = np.stack([memmap_arr[i:i+context_length] for i in ix])
- y = np.stack([memmap_arr[i+1:i+context_length+1] for i in ix])
- return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
- def memmap_val_iterator(memmap_arr, batch_size, context_length):
- N = len(memmap_arr)
- nb = (N-context_length-1)//batch_size
- for bi in range(nb):
- base = bi*batch_size
- x = np.stack([memmap_arr[i:i+context_length] for i in range(base, base+batch_size)])
- y = np.stack([memmap_arr[i+1:i+context_length+1] for i in range(base, base+batch_size)])
- yield torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
- def train(model, device, args):
- os.makedirs(args.save_path, exist_ok=True)
- # 2. 加载数据集
- train_data = get_memmap_dataset(args.train_data_path)
- val_data = get_memmap_dataset(args.val_data_path)
- # 3. 构建优化器
- optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
- # 4. 恢复断点
- start_iter = 0
- if args.resume_checkpoint:
- print(f"Resuming from checkpoint {args.resume_checkpoint}")
- resume_ckpt_path = pathlib.Path(HydraConfig.get().runtime.output_dir) / f"{args.save_path}/ckpt_iter{args.resume_checkpoint}.pt"
- start_iter = run_load_checkpoint(resume_ckpt_path, model, optimizer)
- print(f"Resumed at iteration {start_iter} from path {resume_ckpt_path}")
- # 5. 训练loop
- pbar = tqdm(range(start_iter, args.train_steps), desc="Training", leave=False)
- for iteration in pbar:
- model.train()
- x, y = get_batch(train_data, args.batch_size, args.context_length)
- x, y = x.to(device), y.to(device)
-
- logits, _ = model(x)
- loss = F.cross_entropy(
- logits.reshape(-1, logits.shape[-1]),
- y.reshape(-1)
- )
-
- optimizer.zero_grad()
- loss.backward()
- run_gradient_clipping(model.parameters(), args.clip_grad_norm)
-
- # 更新学习率
- lr = run_get_lr_cosine_schedule(
- iteration, args.lr, args.min_lr, args.warmup_iters, args.cosine_iters
- )
- for param_group in optimizer.param_groups:
- param_group['lr'] = lr
- optimizer.step()
- pbar.set_postfix(loss=loss.item(), lr=lr)
- mlflow.log_metric("loss", loss.item(), step=iteration)
- mlflow.log_metric("lr", lr, step=iteration)
- # 验证
- if (iteration+1) % args.val_interval == 0:
- model.eval()
- with torch.no_grad():
- val_losses = []
- count = 0
- for x_val, y_val in memmap_val_iterator(val_data, args.batch_size, args.context_length):
- x_val, y_val = x_val.to(device), y_val.to(device)
- val_logits, _ = model(x_val)
- val_loss = F.cross_entropy(
- val_logits.reshape(-1, val_logits.shape[-1]),
- y_val.reshape(-1)
- )
- val_losses.append(val_loss.item())
- count += 1
- if count >= args.val_batches:
- break
- val_loss_mean = np.mean(val_losses)
- mlflow.log_metric("val_loss", val_loss_mean, step=iteration)
- print(f"iter {iteration+1:05d}: VALID loss = {val_loss_mean:.4f}")
-
- # 保存
- if (iteration+1) % args.save_interval == 0:
- ckpt_name = os.path.join(args.save_path, f"ckpt_iter{iteration+1}.pt")
- run_save_checkpoint(model, optimizer, iteration+1, ckpt_name)
- print(f"Checkpoint saved to {ckpt_name}")
|