#!/usr/bin/env python # -*- coding: utf-8 -*- """ @project: @File : EarlyStop @Author : qiqq @create_time : 2022/11/5 11:39 """ import numpy as np import torch import os import numpy as np import torch import os class EarlyStopping: """Early stops the training if validation loss doesn't improve after a given patience.""" def __init__(self, save_path=None, patience=10, verbose=False, delta=0): """ Args: save_path : 模型保存文件夹 patience (int): How long to wait after last time validation loss improved. Default: 7 verbose (bool): If True, prints a message for each validation loss improvement. Default: False delta (float): Minimum change in the monitored quantity to qualify as an improvement. Default: 0 """ self.save_path = save_path self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False self.val_loss_min = np.Inf self.delta = delta def __call__(self, val_score): score = val_score if self.best_score is None: self.best_score = score # self.save_checkpoint(val_loss, model) elif score < self.best_score + self.delta: self.counter += 1 print(f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True else: self.best_score = score # self.save_checkpoint(val_loss, model) self.counter = 0 # # def save_checkpoint(self, val_loss, model): # '''Saves model when validation loss decrease.''' # if self.verbose: # print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') # path = os.path.join(self.save_path, 'best_network.pth') # torch.save(model.state_dict(), path) # 这里会存储迄今最优模型的参数 # self.val_loss_min = val_loss # # class EarlyStopping: # """Early stops the training if validation loss doesn't improve after a given patience.""" # def __init__(self, save_path, patience=7, verbose=False, delta=0): # """ # Args: # save_path : 模型保存文件夹 # patience (int): How long to wait after last time validation loss improved. # Default: 7 # verbose (bool): If True, prints a message for each validation loss improvement. # Default: False # delta (float): Minimum change in the monitored quantity to qualify as an improvement. # Default: 0 # """ # self.save_path = save_path # self.patience = patience # self.verbose = verbose # self.counter = 0 # self.best_score = None # self.early_stop = False # self.val_loss_min = np.Inf # self.delta = delta # # def __call__(self, val_loss, model): # # score = -val_loss # # if self.best_score is None: # self.best_score = score # self.save_checkpoint(val_loss, model) # elif score < self.best_score + self.delta: # self.counter += 1 # print(f'EarlyStopping counter: {self.counter} out of {self.patience}') # if self.counter >= self.patience: # self.early_stop = True # else: # self.best_score = score # self.save_checkpoint(val_loss, model) # self.counter = 0 # # def save_checkpoint(self, val_loss, model): # '''Saves model when validation loss decrease.''' # if self.verbose: # print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') # path = os.path.join(self.save_path, 'best_network.pth') # torch.save(model.state_dict(), path) # 这里会存储迄今最优模型的参数 # self.val_loss_min = val_loss