Tan_pytorch_segmentation/pytorch_segmentation/PV_VM-UNet-main/engine.py

161 lines
5.6 KiB
Python
Raw Permalink Normal View History

2025-05-19 20:48:24 +08:00
import numpy as np
from tqdm import tqdm
import torch
from torch.cuda.amp import autocast as autocast
from sklearn.metrics import confusion_matrix
from utils import save_imgs
def train_one_epoch(train_loader,
model,
criterion,
optimizer,
scheduler,
epoch,
step,
logger,
config,
writer):
'''
train model for one epoch
'''
# switch to train mode
model.train()
loss_list = []
for iter, data in enumerate(train_loader):
step += iter
optimizer.zero_grad()
images, targets = data
images, targets = images.cuda(non_blocking=True).float(), targets.cuda(non_blocking=True).float()
out = model(images)
loss = criterion(out, targets)
loss.backward()
optimizer.step()
loss_list.append(loss.item())
now_lr = optimizer.state_dict()['param_groups'][0]['lr']
writer.add_scalar('loss', loss, global_step=step)
if iter % config.print_interval == 0:
log_info = f'train: epoch {epoch}, iter:{iter}, loss: {np.mean(loss_list):.4f}, lr: {now_lr}'
print(log_info)
logger.info(log_info)
scheduler.step()
return step
def val_one_epoch(test_loader,
model,
criterion,
epoch,
logger,
config):
# switch to evaluate mode
model.eval()
preds = []
gts = []
loss_list = []
with torch.no_grad():
for data in tqdm(test_loader):
img, msk = data
img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float()
out = model(img)
loss = criterion(out, msk)
loss_list.append(loss.item())
gts.append(msk.squeeze(1).cpu().detach().numpy())
if type(out) is tuple:
out = out[0]
out = out.squeeze(1).cpu().detach().numpy()
preds.append(out)
if epoch % config.val_interval == 0:
preds = np.array(preds).reshape(-1)
gts = np.array(gts).reshape(-1)
y_pre = np.where(preds>=config.threshold, 1, 0)
y_true = np.where(gts>=0.5, 1, 0)
confusion = confusion_matrix(y_true, y_pre)
TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1]
accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0
sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0
specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0
f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0
miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0
log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}, miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \
specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}'
print(log_info)
logger.info(log_info)
else:
log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}'
print(log_info)
logger.info(log_info)
return np.mean(loss_list)
def test_one_epoch(test_loader,
model,
criterion,
logger,
config,
test_data_name=None):
# switch to evaluate mode
model.eval()
preds = []
gts = []
loss_list = []
with torch.no_grad():
for i, data in enumerate(tqdm(test_loader)):
img, msk = data
img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float()
out = model(img)
loss = criterion(out, msk)
loss_list.append(loss.item())
msk = msk.squeeze(1).cpu().detach().numpy()
gts.append(msk)
if type(out) is tuple:
out = out[0]
out = out.squeeze(1).cpu().detach().numpy()
preds.append(out)
if i % config.save_interval == 0:
save_imgs(img, msk, out, i, config.work_dir + 'outputs/', config.datasets, config.threshold, test_data_name=test_data_name)
preds = np.array(preds).reshape(-1)
gts = np.array(gts).reshape(-1)
y_pre = np.where(preds>=config.threshold, 1, 0)
y_true = np.where(gts>=0.5, 1, 0)
confusion = confusion_matrix(y_true, y_pre)
TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1]
accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0
sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0
specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0
f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0
miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0
if test_data_name is not None:
log_info = f'test_datasets_name: {test_data_name}'
print(log_info)
logger.info(log_info)
log_info = f'test of best model, loss: {np.mean(loss_list):.4f},miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \
specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}'
print(log_info)
logger.info(log_info)
return np.mean(loss_list)