import numpy as np from tqdm import tqdm import torch from torch.cuda.amp import autocast as autocast from sklearn.metrics import confusion_matrix from utils import save_imgs def train_one_epoch(train_loader, model, criterion, optimizer, scheduler, epoch, step, logger, config, writer): ''' train model for one epoch ''' # switch to train mode model.train() loss_list = [] for iter, data in enumerate(train_loader): step += iter optimizer.zero_grad() images, targets = data images, targets = images.cuda(non_blocking=True).float(), targets.cuda(non_blocking=True).float() out = model(images) loss = criterion(out, targets) loss.backward() optimizer.step() loss_list.append(loss.item()) now_lr = optimizer.state_dict()['param_groups'][0]['lr'] writer.add_scalar('loss', loss, global_step=step) if iter % config.print_interval == 0: log_info = f'train: epoch {epoch}, iter:{iter}, loss: {np.mean(loss_list):.4f}, lr: {now_lr}' print(log_info) logger.info(log_info) scheduler.step() return step def val_one_epoch(test_loader, model, criterion, epoch, logger, config): # switch to evaluate mode model.eval() preds = [] gts = [] loss_list = [] with torch.no_grad(): for data in tqdm(test_loader): img, msk = data img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float() out = model(img) loss = criterion(out, msk) loss_list.append(loss.item()) gts.append(msk.squeeze(1).cpu().detach().numpy()) if type(out) is tuple: out = out[0] out = out.squeeze(1).cpu().detach().numpy() preds.append(out) if epoch % config.val_interval == 0: preds = np.array(preds).reshape(-1) gts = np.array(gts).reshape(-1) y_pre = np.where(preds>=config.threshold, 1, 0) y_true = np.where(gts>=0.5, 1, 0) confusion = confusion_matrix(y_true, y_pre) TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1] accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0 sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0 specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0 f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0 miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0 log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}, miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \ specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}' print(log_info) logger.info(log_info) else: log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}' print(log_info) logger.info(log_info) return np.mean(loss_list) def test_one_epoch(test_loader, model, criterion, logger, config, test_data_name=None): # switch to evaluate mode model.eval() preds = [] gts = [] loss_list = [] with torch.no_grad(): for i, data in enumerate(tqdm(test_loader)): img, msk = data img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float() out = model(img) loss = criterion(out, msk) loss_list.append(loss.item()) msk = msk.squeeze(1).cpu().detach().numpy() gts.append(msk) if type(out) is tuple: out = out[0] out = out.squeeze(1).cpu().detach().numpy() preds.append(out) if i % config.save_interval == 0: save_imgs(img, msk, out, i, config.work_dir + 'outputs/', config.datasets, config.threshold, test_data_name=test_data_name) preds = np.array(preds).reshape(-1) gts = np.array(gts).reshape(-1) y_pre = np.where(preds>=config.threshold, 1, 0) y_true = np.where(gts>=0.5, 1, 0) confusion = confusion_matrix(y_true, y_pre) TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1] accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0 sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0 specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0 f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0 miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0 if test_data_name is not None: log_info = f'test_datasets_name: {test_data_name}' print(log_info) logger.info(log_info) log_info = f'test of best model, loss: {np.mean(loss_list):.4f},miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \ specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}' print(log_info) logger.info(log_info) return np.mean(loss_list)