Idealisan

PyTorch模型训练模板框架

import numpy as np
import shutil
from torch import nn
import torch

class EarlyStopping():
    def __init__(self,patience=10,verbose=False,delta=0,checkpoint_path=None):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.checkpoint_path = checkpoint_path
        self.best_weights = None

    def __call__(self,val_loss,model):
        # print("val_loss={:.6f}".format(val_loss))
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss,model,self.checkpoint_path)
        elif score < self.best_score+self.delta:
            self.counter+=1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter>=self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss,model,self.checkpoint_path)
            self.counter = 0
    def save_checkpoint(self,val_loss,model,path):
        if self.verbose:
            print(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        self.best_weights = model.state_dict()
        torch.save(model.state_dict(), path)
        self.val_loss_min = val_loss
    
    def restore_best_model(self,model):
        model.load_state_dict(self.best_weights)
        print("restore best model from in memory checkpoint.")
        return model

class Framework:
    def __init__(self):
        self.model = None
        self.train_loss = []
        self.train_epochs_loss = []
        self.val_loss = []
        self.val_epochs_loss = []
        self.train_acc = []
        self.val_acc = []
        self.train_loader = None
        self.val_loader = None
        self.test_loader = None
        self.optimizer = None
        self.criterion = None
        self.epochs = 0
        self.device = None
        self.test_output = []
        self.test_label = []

    def fit(self, model, train_loader, val_loader, epochs, optimizer, criterion, device='cpu', init_weight=False,early_stopping=None,lr_adjust=None):
        '''
        model: model to train
        train_loader: train data loader
        val_loader: validation data loader
        epochs: number of epochs
        optimizer: optimizer
        criterion: loss function
        device: cpu or cuda:0 or cuda:1
        init_weight: whether to initialize the weight of the model
        early_stopping: early stopping object
        lr_adjust: learning rate adjust dict

        '''
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.epochs = epochs
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device
        if init_weight:
            def __init_weights(m):
                if isinstance(m, nn.Linear):
                    torch.nn.init.xavier_uniform(m.weight)
                    m.bias.data.fill_(0.01)
            self.model.apply(__init_weights)

        model.to(self.device)

        for epoch in range(epochs):
            screen_width = shutil.get_terminal_size().columns
            model.train()
            train_epoch_loss = []
            for idx, (data_x, data_y) in enumerate(train_loader, 0):
                data_x = data_x.to(device)
                data_y = data_y.to(device)
                outputs = model(data_x)
                optimizer.zero_grad()
                loss = criterion(outputs, data_y)
                loss.backward()
                optimizer.step()
                train_epoch_loss.append(loss.item())
                self.train_loss.append(loss.item())

                print("epoch={}/{},{}/{} batches of train, loss={:.6f}".format(
                    epoch, epochs, idx, len(train_loader), loss.item()).ljust(screen_width), end='\r')

            self.train_epochs_loss.append(np.average(train_epoch_loss))

            # =====================valid============================
            model.eval()
            valid_epoch_loss = []
            for idx, (data_x, data_y) in enumerate(val_loader, 0):
                data_x = data_x.to(device)
                data_y = data_y.to(device)
                outputs = model(data_x)
                loss = criterion(outputs, data_y)
                valid_epoch_loss.append(loss.item())
                self.val_loss.append(loss.item())
                print("epoch={}/{},{}/{} batches of val, loss={:.6f}".format(
                    epoch, epochs, idx, len(val_loader), loss.item()).ljust(screen_width), end='\r')
            self.val_epochs_loss.append(np.average(valid_epoch_loss))
            print("epoch={}/{}, epoch train loss={:.6f} , ".format(epoch, epochs, np.average(train_epoch_loss)).ljust(screen_width).ljust(50).strip(),
                  "epoch val loss={:.6f}".format(np.average(valid_epoch_loss)).ljust(screen_width).strip())
            #==================early stopping======================
            if early_stopping is not None:
                early_stopping(self.val_epochs_loss[-1],model=model)
                if early_stopping.early_stop:
                    print("Early stopping")
                    break
            #====================adjust lr========================
            if lr_adjust is not None:
                if epoch in lr_adjust.keys():
                    lr = lr_adjust[epoch]
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Updating learning rate to {}'.format(lr))        
        return {"train_loss": self.train_loss, "train_epochs_loss": self.train_epochs_loss, "val_loss": self.val_loss, "val_epochs_loss": self.val_epochs_loss}

    def predict(self, model, test_loader, device='cpu'):
        model.eval()
        self.device = device
        model.to(self.device)
        self.test_loader = test_loader
        for idx, (data_x, data_y) in enumerate(test_loader, 0):
            data_x = data_x.to(self.device)
            data_y = data_y.to(self.device)
            outputs = model(data_x)
            self.test_output.append(outputs.cpu().detach().numpy())
            self.test_label.append(data_y.cpu().detach().numpy())
            print("predicting {}/{} batches of test".format(idx, len(test_loader)
                                                            ).ljust(shutil.get_terminal_size().columns), end='\r')
        return {"test_output": self.test_output, "test_label": self.test_label}

    def single_predict(self, model, data_x, device='cpu'):
        model.eval()
        self.device = device
        model.to(self.device)
        data_x = data_x.to(self.device)
        outputs = model(data_x)
        return outputs.cpu().detach().numpy()


def accuracy(y_batch, y_hat_batch):
    y_hat_batch = np.argmax(y_hat_batch, axis=-1)
    return np.sum(y_batch == y_hat_batch) / y_batch.shape[0]
    
分类

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注