522 lines
17 KiB
Python
522 lines
17 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
import torch.backends.cudnn as cudnn
|
||
|
import torchvision.transforms.functional as TF
|
||
|
import numpy as np
|
||
|
import os
|
||
|
import math
|
||
|
import random
|
||
|
import logging
|
||
|
import logging.handlers
|
||
|
from matplotlib import pyplot as plt
|
||
|
|
||
|
from scipy.ndimage import zoom
|
||
|
import SimpleITK as sitk
|
||
|
from medpy import metric
|
||
|
|
||
|
|
||
|
def set_seed(seed):
|
||
|
# for hash
|
||
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
||
|
# for python and numpy
|
||
|
random.seed(seed)
|
||
|
np.random.seed(seed)
|
||
|
# for cpu gpu
|
||
|
torch.manual_seed(seed)
|
||
|
torch.cuda.manual_seed(seed)
|
||
|
torch.cuda.manual_seed_all(seed)
|
||
|
# for cudnn
|
||
|
cudnn.benchmark = False
|
||
|
cudnn.deterministic = True
|
||
|
|
||
|
|
||
|
def get_logger(name, log_dir):
|
||
|
'''
|
||
|
Args:
|
||
|
name(str): name of logger
|
||
|
log_dir(str): path of log
|
||
|
'''
|
||
|
|
||
|
if not os.path.exists(log_dir):
|
||
|
os.makedirs(log_dir)
|
||
|
|
||
|
logger = logging.getLogger(name)
|
||
|
logger.setLevel(logging.INFO)
|
||
|
|
||
|
info_name = os.path.join(log_dir, '{}.info.log'.format(name))
|
||
|
info_handler = logging.handlers.TimedRotatingFileHandler(info_name,
|
||
|
when='D',
|
||
|
encoding='utf-8')
|
||
|
info_handler.setLevel(logging.INFO)
|
||
|
|
||
|
formatter = logging.Formatter('%(asctime)s - %(message)s',
|
||
|
datefmt='%Y-%m-%d %H:%M:%S')
|
||
|
|
||
|
info_handler.setFormatter(formatter)
|
||
|
|
||
|
logger.addHandler(info_handler)
|
||
|
|
||
|
return logger
|
||
|
|
||
|
|
||
|
def log_config_info(config, logger):
|
||
|
config_dict = config.__dict__
|
||
|
log_info = f'#----------Config info----------#'
|
||
|
logger.info(log_info)
|
||
|
for k, v in config_dict.items():
|
||
|
if k[0] == '_':
|
||
|
continue
|
||
|
else:
|
||
|
log_info = f'{k}: {v},'
|
||
|
logger.info(log_info)
|
||
|
|
||
|
|
||
|
|
||
|
def get_optimizer(config, model):
|
||
|
assert config.opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!'
|
||
|
|
||
|
if config.opt == 'Adadelta':
|
||
|
return torch.optim.Adadelta(
|
||
|
model.parameters(),
|
||
|
lr = config.lr,
|
||
|
rho = config.rho,
|
||
|
eps = config.eps,
|
||
|
weight_decay = config.weight_decay
|
||
|
)
|
||
|
elif config.opt == 'Adagrad':
|
||
|
return torch.optim.Adagrad(
|
||
|
model.parameters(),
|
||
|
lr = config.lr,
|
||
|
lr_decay = config.lr_decay,
|
||
|
eps = config.eps,
|
||
|
weight_decay = config.weight_decay
|
||
|
)
|
||
|
elif config.opt == 'Adam':
|
||
|
return torch.optim.Adam(
|
||
|
model.parameters(),
|
||
|
lr = config.lr,
|
||
|
betas = config.betas,
|
||
|
eps = config.eps,
|
||
|
weight_decay = config.weight_decay,
|
||
|
amsgrad = config.amsgrad
|
||
|
)
|
||
|
elif config.opt == 'AdamW':
|
||
|
return torch.optim.AdamW(
|
||
|
model.parameters(),
|
||
|
lr = config.lr,
|
||
|
betas = config.betas,
|
||
|
eps = config.eps,
|
||
|
weight_decay = config.weight_decay,
|
||
|
amsgrad = config.amsgrad
|
||
|
)
|
||
|
elif config.opt == 'Adamax':
|
||
|
return torch.optim.Adamax(
|
||
|
model.parameters(),
|
||
|
lr = config.lr,
|
||
|
betas = config.betas,
|
||
|
eps = config.eps,
|
||
|
weight_decay = config.weight_decay
|
||
|
)
|
||
|
elif config.opt == 'ASGD':
|
||
|
return torch.optim.ASGD(
|
||
|
model.parameters(),
|
||
|
lr = config.lr,
|
||
|
lambd = config.lambd,
|
||
|
alpha = config.alpha,
|
||
|
t0 = config.t0,
|
||
|
weight_decay = config.weight_decay
|
||
|
)
|
||
|
elif config.opt == 'RMSprop':
|
||
|
return torch.optim.RMSprop(
|
||
|
model.parameters(),
|
||
|
lr = config.lr,
|
||
|
momentum = config.momentum,
|
||
|
alpha = config.alpha,
|
||
|
eps = config.eps,
|
||
|
centered = config.centered,
|
||
|
weight_decay = config.weight_decay
|
||
|
)
|
||
|
elif config.opt == 'Rprop':
|
||
|
return torch.optim.Rprop(
|
||
|
model.parameters(),
|
||
|
lr = config.lr,
|
||
|
etas = config.etas,
|
||
|
step_sizes = config.step_sizes,
|
||
|
)
|
||
|
elif config.opt == 'SGD':
|
||
|
return torch.optim.SGD(
|
||
|
model.parameters(),
|
||
|
lr = config.lr,
|
||
|
momentum = config.momentum,
|
||
|
weight_decay = config.weight_decay,
|
||
|
dampening = config.dampening,
|
||
|
nesterov = config.nesterov
|
||
|
)
|
||
|
else: # default opt is SGD
|
||
|
return torch.optim.SGD(
|
||
|
model.parameters(),
|
||
|
lr = 0.01,
|
||
|
momentum = 0.9,
|
||
|
weight_decay = 0.05,
|
||
|
)
|
||
|
|
||
|
|
||
|
def get_scheduler(config, optimizer):
|
||
|
assert config.sch in ['StepLR', 'MultiStepLR', 'ExponentialLR', 'CosineAnnealingLR', 'ReduceLROnPlateau',
|
||
|
'CosineAnnealingWarmRestarts', 'WP_MultiStepLR', 'WP_CosineLR'], 'Unsupported scheduler!'
|
||
|
if config.sch == 'StepLR':
|
||
|
scheduler = torch.optim.lr_scheduler.StepLR(
|
||
|
optimizer,
|
||
|
step_size = config.step_size,
|
||
|
gamma = config.gamma,
|
||
|
last_epoch = config.last_epoch
|
||
|
)
|
||
|
elif config.sch == 'MultiStepLR':
|
||
|
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||
|
optimizer,
|
||
|
milestones = config.milestones,
|
||
|
gamma = config.gamma,
|
||
|
last_epoch = config.last_epoch
|
||
|
)
|
||
|
elif config.sch == 'ExponentialLR':
|
||
|
scheduler = torch.optim.lr_scheduler.ExponentialLR(
|
||
|
optimizer,
|
||
|
gamma = config.gamma,
|
||
|
last_epoch = config.last_epoch
|
||
|
)
|
||
|
elif config.sch == 'CosineAnnealingLR':
|
||
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||
|
optimizer,
|
||
|
T_max = config.T_max,
|
||
|
eta_min = config.eta_min,
|
||
|
last_epoch = config.last_epoch
|
||
|
)
|
||
|
elif config.sch == 'ReduceLROnPlateau':
|
||
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||
|
optimizer,
|
||
|
mode = config.mode,
|
||
|
factor = config.factor,
|
||
|
patience = config.patience,
|
||
|
threshold = config.threshold,
|
||
|
threshold_mode = config.threshold_mode,
|
||
|
cooldown = config.cooldown,
|
||
|
min_lr = config.min_lr,
|
||
|
eps = config.eps
|
||
|
)
|
||
|
elif config.sch == 'CosineAnnealingWarmRestarts':
|
||
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||
|
optimizer,
|
||
|
T_0 = config.T_0,
|
||
|
T_mult = config.T_mult,
|
||
|
eta_min = config.eta_min,
|
||
|
last_epoch = config.last_epoch
|
||
|
)
|
||
|
elif config.sch == 'WP_MultiStepLR':
|
||
|
lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else config.gamma**len(
|
||
|
[m for m in config.milestones if m <= epoch])
|
||
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func)
|
||
|
elif config.sch == 'WP_CosineLR':
|
||
|
lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else 0.5 * (
|
||
|
math.cos((epoch - config.warm_up_epochs) / (config.epochs - config.warm_up_epochs) * math.pi) + 1)
|
||
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func)
|
||
|
|
||
|
return scheduler
|
||
|
|
||
|
|
||
|
|
||
|
def save_imgs(img, msk, msk_pred, i, save_path, datasets, threshold=0.5, test_data_name=None):
|
||
|
img = img.squeeze(0).permute(1,2,0).detach().cpu().numpy()
|
||
|
img = img / 255. if img.max() > 1.1 else img
|
||
|
if datasets == 'retinal':
|
||
|
msk = np.squeeze(msk, axis=0)
|
||
|
msk_pred = np.squeeze(msk_pred, axis=0)
|
||
|
else:
|
||
|
msk = np.where(np.squeeze(msk, axis=0) > 0.5, 1, 0)
|
||
|
msk_pred = np.where(np.squeeze(msk_pred, axis=0) > threshold, 1, 0)
|
||
|
|
||
|
plt.figure(figsize=(7,15))
|
||
|
|
||
|
plt.subplot(3,1,1)
|
||
|
plt.imshow(img)
|
||
|
plt.axis('off')
|
||
|
|
||
|
plt.subplot(3,1,2)
|
||
|
plt.imshow(msk, cmap= 'gray')
|
||
|
plt.axis('off')
|
||
|
|
||
|
plt.subplot(3,1,3)
|
||
|
plt.imshow(msk_pred, cmap = 'gray')
|
||
|
plt.axis('off')
|
||
|
|
||
|
if test_data_name is not None:
|
||
|
save_path = save_path + test_data_name + '_'
|
||
|
plt.savefig(save_path + str(i) +'.png')
|
||
|
plt.close()
|
||
|
|
||
|
|
||
|
|
||
|
class BCELoss(nn.Module):
|
||
|
def __init__(self):
|
||
|
super(BCELoss, self).__init__()
|
||
|
self.bceloss = nn.BCELoss()
|
||
|
|
||
|
def forward(self, pred, target):
|
||
|
size = pred.size(0)
|
||
|
pred_ = pred.view(size, -1)
|
||
|
target_ = target.view(size, -1)
|
||
|
|
||
|
return self.bceloss(pred_, target_)
|
||
|
|
||
|
|
||
|
class DiceLoss(nn.Module):
|
||
|
def __init__(self):
|
||
|
super(DiceLoss, self).__init__()
|
||
|
|
||
|
def forward(self, pred, target):
|
||
|
smooth = 1
|
||
|
size = pred.size(0)
|
||
|
|
||
|
pred_ = pred.view(size, -1)
|
||
|
target_ = target.view(size, -1)
|
||
|
intersection = pred_ * target_
|
||
|
dice_score = (2 * intersection.sum(1) + smooth)/(pred_.sum(1) + target_.sum(1) + smooth)
|
||
|
dice_loss = 1 - dice_score.sum()/size
|
||
|
|
||
|
return dice_loss
|
||
|
|
||
|
|
||
|
class nDiceLoss(nn.Module):
|
||
|
def __init__(self, n_classes):
|
||
|
super(nDiceLoss, self).__init__()
|
||
|
self.n_classes = n_classes
|
||
|
|
||
|
def _one_hot_encoder(self, input_tensor):
|
||
|
tensor_list = []
|
||
|
for i in range(self.n_classes):
|
||
|
temp_prob = input_tensor == i # * torch.ones_like(input_tensor)
|
||
|
tensor_list.append(temp_prob.unsqueeze(1))
|
||
|
output_tensor = torch.cat(tensor_list, dim=1)
|
||
|
return output_tensor.float()
|
||
|
|
||
|
def _dice_loss(self, score, target):
|
||
|
target = target.float()
|
||
|
smooth = 1e-5
|
||
|
intersect = torch.sum(score * target)
|
||
|
y_sum = torch.sum(target * target)
|
||
|
z_sum = torch.sum(score * score)
|
||
|
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
|
||
|
loss = 1 - loss
|
||
|
return loss
|
||
|
|
||
|
def forward(self, inputs, target, weight=None, softmax=False):
|
||
|
if softmax:
|
||
|
inputs = torch.softmax(inputs, dim=1)
|
||
|
target = self._one_hot_encoder(target)
|
||
|
if weight is None:
|
||
|
weight = [1] * self.n_classes
|
||
|
assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
|
||
|
class_wise_dice = []
|
||
|
loss = 0.0
|
||
|
for i in range(0, self.n_classes):
|
||
|
dice = self._dice_loss(inputs[:, i], target[:, i])
|
||
|
class_wise_dice.append(1.0 - dice.item())
|
||
|
loss += dice * weight[i]
|
||
|
return loss / self.n_classes
|
||
|
|
||
|
|
||
|
class CeDiceLoss(nn.Module):
|
||
|
def __init__(self, num_classes, loss_weight=[0.4, 0.6]):
|
||
|
super(CeDiceLoss, self).__init__()
|
||
|
self.celoss = nn.CrossEntropyLoss()
|
||
|
self.diceloss = nDiceLoss(num_classes)
|
||
|
self.loss_weight = loss_weight
|
||
|
|
||
|
def forward(self, pred, target):
|
||
|
loss_ce = self.celoss(pred, target[:].long())
|
||
|
loss_dice = self.diceloss(pred, target, softmax=True)
|
||
|
loss = self.loss_weight[0] * loss_ce + self.loss_weight[1] * loss_dice
|
||
|
return loss
|
||
|
|
||
|
|
||
|
class BceDiceLoss(nn.Module):
|
||
|
def __init__(self, wb=1, wd=1):
|
||
|
super(BceDiceLoss, self).__init__()
|
||
|
self.bce = BCELoss()
|
||
|
self.dice = DiceLoss()
|
||
|
self.wb = wb
|
||
|
self.wd = wd
|
||
|
|
||
|
def forward(self, pred, target):
|
||
|
bceloss = self.bce(pred, target)
|
||
|
diceloss = self.dice(pred, target)
|
||
|
|
||
|
loss = self.wd * diceloss + self.wb * bceloss
|
||
|
return loss
|
||
|
|
||
|
|
||
|
class GT_BceDiceLoss(nn.Module):
|
||
|
def __init__(self, wb=1, wd=1):
|
||
|
super(GT_BceDiceLoss, self).__init__()
|
||
|
self.bcedice = BceDiceLoss(wb, wd)
|
||
|
|
||
|
def forward(self, gt_pre, out, target):
|
||
|
bcediceloss = self.bcedice(out, target)
|
||
|
gt_pre5, gt_pre4, gt_pre3, gt_pre2, gt_pre1 = gt_pre
|
||
|
gt_loss = self.bcedice(gt_pre5, target) * 0.1 + self.bcedice(gt_pre4, target) * 0.2 + self.bcedice(gt_pre3, target) * 0.3 + self.bcedice(gt_pre2, target) * 0.4 + self.bcedice(gt_pre1, target) * 0.5
|
||
|
return bcediceloss + gt_loss
|
||
|
|
||
|
|
||
|
|
||
|
class myToTensor:
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
def __call__(self, data):
|
||
|
image, mask = data
|
||
|
return torch.tensor(image).permute(2,0,1), torch.tensor(mask).permute(2,0,1)
|
||
|
|
||
|
|
||
|
class myResize:
|
||
|
def __init__(self, size_h=256, size_w=256):
|
||
|
self.size_h = size_h
|
||
|
self.size_w = size_w
|
||
|
def __call__(self, data):
|
||
|
image, mask = data
|
||
|
return TF.resize(image, [self.size_h, self.size_w]), TF.resize(mask, [self.size_h, self.size_w])
|
||
|
|
||
|
|
||
|
class myRandomHorizontalFlip:
|
||
|
def __init__(self, p=0.5):
|
||
|
self.p = p
|
||
|
def __call__(self, data):
|
||
|
image, mask = data
|
||
|
if random.random() < self.p: return TF.hflip(image), TF.hflip(mask)
|
||
|
else: return image, mask
|
||
|
|
||
|
|
||
|
class myRandomVerticalFlip:
|
||
|
def __init__(self, p=0.5):
|
||
|
self.p = p
|
||
|
def __call__(self, data):
|
||
|
image, mask = data
|
||
|
if random.random() < self.p: return TF.vflip(image), TF.vflip(mask)
|
||
|
else: return image, mask
|
||
|
|
||
|
|
||
|
class myRandomRotation:
|
||
|
def __init__(self, p=0.5, degree=[0,360]):
|
||
|
self.angle = random.uniform(degree[0], degree[1])
|
||
|
self.p = p
|
||
|
def __call__(self, data):
|
||
|
image, mask = data
|
||
|
if random.random() < self.p: return TF.rotate(image,self.angle), TF.rotate(mask,self.angle)
|
||
|
else: return image, mask
|
||
|
|
||
|
|
||
|
class myNormalize:
|
||
|
def __init__(self, data_name, train=True):
|
||
|
if data_name == 'isic18':
|
||
|
if train:
|
||
|
self.mean = 157.561
|
||
|
self.std = 26.706
|
||
|
else:
|
||
|
self.mean = 149.034
|
||
|
self.std = 32.022
|
||
|
elif data_name == 'isic17':
|
||
|
if train:
|
||
|
self.mean = 159.922
|
||
|
self.std = 28.871
|
||
|
else:
|
||
|
self.mean = 148.429
|
||
|
self.std = 25.748
|
||
|
elif data_name == 'isic18_82':
|
||
|
if train:
|
||
|
self.mean = 156.2899
|
||
|
self.std = 26.5457
|
||
|
else:
|
||
|
self.mean = 149.8485
|
||
|
self.std = 35.3346
|
||
|
|
||
|
def __call__(self, data):
|
||
|
img, msk = data
|
||
|
img_normalized = (img-self.mean)/self.std
|
||
|
img_normalized = ((img_normalized - np.min(img_normalized))
|
||
|
/ (np.max(img_normalized)-np.min(img_normalized))) * 255.
|
||
|
return img_normalized, msk
|
||
|
|
||
|
|
||
|
|
||
|
from thop import profile ## 导入thop模块
|
||
|
def cal_params_flops(model, size, logger):
|
||
|
input = torch.randn(1, 3, size, size).cuda()
|
||
|
flops, params = profile(model, inputs=(input,))
|
||
|
print('flops',flops/1e9) ## 打印计算量
|
||
|
print('params',params/1e6) ## 打印参数量
|
||
|
|
||
|
total = sum(p.numel() for p in model.parameters())
|
||
|
print("Total params: %.2fM" % (total/1e6))
|
||
|
logger.info(f'flops: {flops/1e9}, params: {params/1e6}, Total params: : {total/1e6:.4f}')
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
def calculate_metric_percase(pred, gt):
|
||
|
pred[pred > 0] = 1
|
||
|
gt[gt > 0] = 1
|
||
|
if pred.sum() > 0 and gt.sum()>0:
|
||
|
dice = metric.binary.dc(pred, gt)
|
||
|
hd95 = metric.binary.hd95(pred, gt)
|
||
|
return dice, hd95
|
||
|
elif pred.sum() > 0 and gt.sum()==0:
|
||
|
return 1, 0
|
||
|
else:
|
||
|
return 0, 0
|
||
|
|
||
|
|
||
|
|
||
|
def test_single_volume(image, label, net, classes, patch_size=[256, 256],
|
||
|
test_save_path=None, case=None, z_spacing=1, val_or_test=False):
|
||
|
image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
|
||
|
if len(image.shape) == 3:
|
||
|
prediction = np.zeros_like(label)
|
||
|
for ind in range(image.shape[0]):
|
||
|
slice = image[ind, :, :]
|
||
|
x, y = slice.shape[0], slice.shape[1]
|
||
|
if x != patch_size[0] or y != patch_size[1]:
|
||
|
slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3) # previous using 0
|
||
|
input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda()
|
||
|
net.eval()
|
||
|
with torch.no_grad():
|
||
|
outputs = net(input)
|
||
|
out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0)
|
||
|
out = out.cpu().detach().numpy()
|
||
|
if x != patch_size[0] or y != patch_size[1]:
|
||
|
pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
|
||
|
else:
|
||
|
pred = out
|
||
|
prediction[ind] = pred
|
||
|
else:
|
||
|
input = torch.from_numpy(image).unsqueeze(
|
||
|
0).unsqueeze(0).float().cuda()
|
||
|
net.eval()
|
||
|
with torch.no_grad():
|
||
|
out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
|
||
|
prediction = out.cpu().detach().numpy()
|
||
|
metric_list = []
|
||
|
for i in range(1, classes):
|
||
|
metric_list.append(calculate_metric_percase(prediction == i, label == i))
|
||
|
|
||
|
if test_save_path is not None and val_or_test is True:
|
||
|
img_itk = sitk.GetImageFromArray(image.astype(np.float32))
|
||
|
prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32))
|
||
|
lab_itk = sitk.GetImageFromArray(label.astype(np.float32))
|
||
|
img_itk.SetSpacing((1, 1, z_spacing))
|
||
|
prd_itk.SetSpacing((1, 1, z_spacing))
|
||
|
lab_itk.SetSpacing((1, 1, z_spacing))
|
||
|
sitk.WriteImage(prd_itk, test_save_path + '/'+case + "_pred.nii.gz")
|
||
|
sitk.WriteImage(img_itk, test_save_path + '/'+ case + "_img.nii.gz")
|
||
|
sitk.WriteImage(lab_itk, test_save_path + '/'+ case + "_gt.nii.gz")
|
||
|
# cv2.imwrite(test_save_path + '/'+case + '.png', prediction*255)
|
||
|
return metric_list
|