1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
| def train_bert(net, criterion, opti, lr, lr_scheduler, train_loader, val_loader, epochs, iters_to_accumulate):
best_loss = np.Inf best_ep = 1 nb_iterations = len(train_loader) print_every = nb_iterations // 5 iters = [] train_losses = [] val_losses = []
scaler = GradScaler()
for ep in range(epochs):
net.train() running_loss = 0.0 for it, (seq, attn_masks, token_type_ids, labels) in enumerate(tqdm(train_loader)):
seq, attn_masks, token_type_ids, labels = \ seq.to(device), attn_masks.to(device), token_type_ids.to(device), labels.to(device) with autocast(): logits = net(seq, attn_masks, token_type_ids)
loss = criterion(logits.squeeze(-1), labels.float()) loss = loss / iters_to_accumulate
scaler.scale(loss).backward()
if (it + 1) % iters_to_accumulate == 0: scaler.step(opti) scaler.update() lr_scheduler.step() opti.zero_grad()
running_loss += loss.item()
if (it + 1) % print_every == 0: print() print(f"Iteration {it+1}/{nb_iterations} of epoch {ep+1} complete. \ Loss : {running_loss / print_every} ")
running_loss = 0.0
val_loss = evaluate_loss(net, device, criterion, val_loader) print() print(f"Epoch {ep+1} complete! Validation Loss : {val_loss}")
if val_loss < best_loss: print("Best validation loss improved from {} to {}".format(best_loss, val_loss)) print() net_copy = copy.deepcopy(net) best_loss = val_loss best_ep = ep + 1
path_to_model=f'models/{bert_model}_lr_{lr}_val_loss_{round(best_loss, 5)}_ep_{best_ep}.pt' torch.save(net_copy.state_dict(), path_to_model) print("The model has been saved in {}".format(path_to_model))
del loss torch.cuda.empty_cache() def evaluate_loss(net, device, criterion, dataloader): """ 评估输出 """ net.eval()
mean_loss = 0 count = 0
with torch.no_grad(): for it, (seq, attn_masks, token_type_ids, labels) in enumerate(tqdm(dataloader)): seq, attn_masks, token_type_ids, labels = \ seq.to(device), attn_masks.to(device), token_type_ids.to(device), labels.to(device) logits = net(seq, attn_masks, token_type_ids) mean_loss += criterion(logits.squeeze(-1), labels.float()).item() count += 1
return mean_loss / count
|