415 lines
14 KiB
Python
415 lines
14 KiB
Python
|
|
|||
|
from collections import defaultdict, deque
|
|||
|
import datetime
|
|||
|
import time
|
|||
|
import torch
|
|||
|
import torch.distributed as dist
|
|||
|
from taihuyuan_pv.compared_experiment.utils.dice_coefficient_loss import multiclass_dice_coeff, build_target
|
|||
|
import errno
|
|||
|
import torch.nn.functional as F
|
|||
|
import os
|
|||
|
|
|||
|
import logging
|
|||
|
from pathlib import Path
|
|||
|
|
|||
|
|
|||
|
def getLogger(savedir):
|
|||
|
logger = logging.getLogger()
|
|||
|
logger.setLevel(logging.INFO) # Log等级总开关
|
|||
|
formatter = logging.Formatter(fmt="[%(asctime)s|%(filename)s|%(levelname)s] %(message)s",
|
|||
|
datefmt="%a %b %d %H:%M:%S %Y")
|
|||
|
# StreamHandler
|
|||
|
sHandler = logging.StreamHandler()
|
|||
|
sHandler.setFormatter(formatter)
|
|||
|
logger.addHandler(sHandler)
|
|||
|
work_dir= "/"
|
|||
|
|
|||
|
# FileHandler
|
|||
|
Path(savedir).mkdir(parents=True, exist_ok=True)
|
|||
|
work_dir = os.path.join(savedir,
|
|||
|
time.strftime("%Y-%m-%d-%H.%M", time.localtime())) # 日志文件写入目录
|
|||
|
if not os.path.exists(work_dir):
|
|||
|
os.makedirs(work_dir)
|
|||
|
fHandler = logging.FileHandler(work_dir + '/log.txt', mode='w',encoding="utf-8")
|
|||
|
fHandler.setLevel(logging.DEBUG) # 输出到file的log等级的开关
|
|||
|
fHandler.setFormatter(formatter) # 定义handler的输出格式
|
|||
|
logger.addHandler(fHandler) # 将logger添加到handler里面
|
|||
|
|
|||
|
return logger
|
|||
|
|
|||
|
class SmoothedValue(object):
|
|||
|
"""Track a series of values and provide access to smoothed values over a
|
|||
|
window or the global series average.
|
|||
|
"""
|
|||
|
|
|||
|
def __init__(self, window_size=20, fmt=None):
|
|||
|
if fmt is None:
|
|||
|
fmt = "{value:.4f} ({global_avg:.4f})"
|
|||
|
self.deque = deque(maxlen=window_size)
|
|||
|
self.total = 0.0
|
|||
|
self.count = 0
|
|||
|
self.fmt = fmt
|
|||
|
|
|||
|
def update(self, value, n=1):
|
|||
|
self.deque.append(value)
|
|||
|
self.count += n
|
|||
|
self.total += value * n
|
|||
|
|
|||
|
def synchronize_between_processes(self):
|
|||
|
"""
|
|||
|
Warning: does not synchronize the deque!
|
|||
|
"""
|
|||
|
if not is_dist_avail_and_initialized():
|
|||
|
return
|
|||
|
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
|||
|
dist.barrier()
|
|||
|
dist.all_reduce(t)
|
|||
|
t = t.tolist()
|
|||
|
self.count = int(t[0])
|
|||
|
self.total = t[1]
|
|||
|
|
|||
|
@property
|
|||
|
def median(self):
|
|||
|
d = torch.tensor(list(self.deque))
|
|||
|
return d.median().item()
|
|||
|
|
|||
|
@property
|
|||
|
def avg(self):
|
|||
|
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
|||
|
return d.mean().item()
|
|||
|
|
|||
|
@property
|
|||
|
def global_avg(self):
|
|||
|
return self.total / self.count
|
|||
|
|
|||
|
@property
|
|||
|
def max(self):
|
|||
|
return max(self.deque)
|
|||
|
|
|||
|
@property
|
|||
|
def value(self):
|
|||
|
return self.deque[-1]
|
|||
|
|
|||
|
def __str__(self):
|
|||
|
return self.fmt.format(
|
|||
|
median=self.median,
|
|||
|
avg=self.avg,
|
|||
|
global_avg=self.global_avg,
|
|||
|
max=self.max,
|
|||
|
value=self.value)
|
|||
|
|
|||
|
class ConfusionMatrix(object):
|
|||
|
def __init__(self, num_classes):
|
|||
|
self.num_classes = num_classes
|
|||
|
self.mat = None
|
|||
|
|
|||
|
def update(self, a, b):
|
|||
|
n = self.num_classes
|
|||
|
if self.mat is None:
|
|||
|
# 创建混淆矩阵
|
|||
|
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
|||
|
with torch.no_grad():
|
|||
|
# 寻找GT中为目标的像素索引
|
|||
|
k = (a >= 0) & (a < n)
|
|||
|
# 统计像素真实类别a[k]被预测成类别b[k]的个数(这里的做法很巧妙)
|
|||
|
inds = n * a[k].to(torch.int64) + b[k]
|
|||
|
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
|||
|
|
|||
|
def reset(self):
|
|||
|
if self.mat is not None:
|
|||
|
self.mat.zero_()
|
|||
|
|
|||
|
def compute(self):
|
|||
|
# 注意混淆矩阵的形式
|
|||
|
# 注意这个混淆矩阵都不太统一 有的是1的样子 有的是2的样子 本计算 借鉴的霹雳大佬的代码采用2的形式
|
|||
|
'''1.
|
|||
|
预测
|
|||
|
真实
|
|||
|
|
|||
|
'''
|
|||
|
|
|||
|
'''2.
|
|||
|
真实
|
|||
|
预测
|
|||
|
'''
|
|||
|
#
|
|||
|
# #0.sum代表所有行之间取sum 保留列的结构 1.sum代表所有列之间取sum保留行的结构
|
|||
|
|
|||
|
h = self.mat.float()
|
|||
|
# 计算全局预测准确率(混淆矩阵的对角线为预测正确的个数)
|
|||
|
acc_global = torch.diag(h).sum() / h.sum()
|
|||
|
# 计算每个类别的准确率
|
|||
|
acc = torch.diag(h) / h.sum(1)
|
|||
|
# 计算每个类别预测与真实目标的iou
|
|||
|
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
|
|||
|
'''
|
|||
|
22/9/24补充:每一类的准确率(p)同上边的acc和召回率r
|
|||
|
p=tp/tp+fp 对角线/h.sum(1)
|
|||
|
召回=tp/tp+fn 对角线/h.sum(0)
|
|||
|
f1=2 (p*r)/p+r
|
|||
|
'''
|
|||
|
recall = torch.diag(h) / h.sum(1)
|
|||
|
precion = torch.diag(h) / h.sum(0)
|
|||
|
f1 = (2 * precion * recall) / (precion + recall)
|
|||
|
|
|||
|
return acc_global, acc, iu,precion,recall,f1
|
|||
|
|
|||
|
# def reduce_from_all_processes(self):
|
|||
|
# if not torch.distributed.is_available():
|
|||
|
# return
|
|||
|
# if not torch.distributed.is_initialized():
|
|||
|
# return
|
|||
|
# torch.distributed.barrier()
|
|||
|
# torch.distributed.all_reduce(self.mat)
|
|||
|
|
|||
|
def re_zhib(self):
|
|||
|
acc_global, acc, iu,precion,recall,f1 = self.compute()
|
|||
|
miou = iu.mean().item() * 100
|
|||
|
meanf1= f1.mean().item() * 100
|
|||
|
pv_iou=iu[1:].mean().item() * 100
|
|||
|
pv_f1=f1[1:].mean().item() * 100
|
|||
|
|
|||
|
acc_global = acc_global.item() * 100
|
|||
|
acc = [round(i, 1) for i in (acc * 100).tolist()]
|
|||
|
iu = [round(i, 1) for i in (iu * 100).tolist()]
|
|||
|
|
|||
|
precion = [round(i, 1) for i in (precion * 100).tolist()]
|
|||
|
recall = [round(i, 1) for i in (recall * 100).tolist()]
|
|||
|
f1 = [round(i, 1) for i in (f1 * 100).tolist()]
|
|||
|
|
|||
|
return acc_global, acc, iu,precion,recall,f1,miou,meanf1,pv_iou,pv_f1
|
|||
|
|
|||
|
def __str__(self):
|
|||
|
acc_global, acc, iu, precion, recall, f1 = self.compute()
|
|||
|
return (
|
|||
|
'global correct: {:.1f}\n'
|
|||
|
'average row correct: {}\n'
|
|||
|
'IoU: {}\n'
|
|||
|
'Precion: {:.1f}\n'
|
|||
|
'Recall: {:.1f}\n'
|
|||
|
'f1: {:.1f}\n'
|
|||
|
'mean IoU: {:.1f}\n'
|
|||
|
|
|||
|
|
|||
|
|
|||
|
).format(
|
|||
|
acc_global.item() * 100,
|
|||
|
['{:.1f}'.format(i) for i in (acc * 100).tolist()],
|
|||
|
['{:.1f}'.format(i) for i in (iu * 100).tolist()],
|
|||
|
['{:.1f}'.format(i) for i in (precion * 100).tolist()],
|
|||
|
['{:.1f}'.format(i) for i in (recall * 100).tolist()],
|
|||
|
['{:.1f}'.format(i) for i in (f1 * 100).tolist()],
|
|||
|
|
|||
|
iu.mean().item() * 100)
|
|||
|
|
|||
|
class MetricLogger(object):
|
|||
|
def __init__(self, delimiter="\t"):
|
|||
|
self.meters = defaultdict(SmoothedValue)
|
|||
|
self.delimiter = delimiter
|
|||
|
|
|||
|
def update(self, **kwargs):
|
|||
|
for k, v in kwargs.items():
|
|||
|
if isinstance(v, torch.Tensor):
|
|||
|
v = v.item()
|
|||
|
assert isinstance(v, (float, int))
|
|||
|
self.meters[k].update(v)
|
|||
|
|
|||
|
def __getattr__(self, attr):
|
|||
|
if attr in self.meters:
|
|||
|
return self.meters[attr]
|
|||
|
if attr in self.__dict__:
|
|||
|
return self.__dict__[attr]
|
|||
|
raise AttributeError("'{}' object has no attribute '{}'".format(
|
|||
|
type(self).__name__, attr))
|
|||
|
|
|||
|
def __str__(self):
|
|||
|
loss_str = []
|
|||
|
for name, meter in self.meters.items():
|
|||
|
loss_str.append(
|
|||
|
"{}: {}".format(name, str(meter))
|
|||
|
)
|
|||
|
return self.delimiter.join(loss_str)
|
|||
|
|
|||
|
def synchronize_between_processes(self):
|
|||
|
for meter in self.meters.values():
|
|||
|
meter.synchronize_between_processes()
|
|||
|
|
|||
|
def add_meter(self, name, meter):
|
|||
|
self.meters[name] = meter
|
|||
|
|
|||
|
def log_every(self, iterable, print_freq, header=None):
|
|||
|
i = 0
|
|||
|
if not header:
|
|||
|
header = ''
|
|||
|
start_time = time.time()
|
|||
|
end = time.time()
|
|||
|
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
|||
|
data_time = SmoothedValue(fmt='{avg:.4f}')
|
|||
|
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
|||
|
if torch.cuda.is_available():
|
|||
|
log_msg = self.delimiter.join([
|
|||
|
header,
|
|||
|
'[{0' + space_fmt + '}/{1}]',
|
|||
|
'eta: {eta}',
|
|||
|
'{meters}',
|
|||
|
'time: {time}',
|
|||
|
'data: {data}',
|
|||
|
'max mem: {memory:.0f}'
|
|||
|
])
|
|||
|
else:
|
|||
|
log_msg = self.delimiter.join([
|
|||
|
header,
|
|||
|
'[{0' + space_fmt + '}/{1}]',
|
|||
|
'eta: {eta}',
|
|||
|
'{meters}',
|
|||
|
'time: {time}',
|
|||
|
'data: {data}'
|
|||
|
])
|
|||
|
MB = 1024.0 * 1024.0
|
|||
|
for obj in iterable:
|
|||
|
data_time.update(time.time() - end)
|
|||
|
yield obj
|
|||
|
iter_time.update(time.time() - end)
|
|||
|
if i % print_freq == 0:
|
|||
|
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
|||
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
|||
|
if torch.cuda.is_available():
|
|||
|
print(log_msg.format(
|
|||
|
i, len(iterable), eta=eta_string,
|
|||
|
meters=str(self),
|
|||
|
time=str(iter_time), data=str(data_time),
|
|||
|
memory=torch.cuda.max_memory_allocated() / MB))
|
|||
|
else:
|
|||
|
print(log_msg.format(
|
|||
|
i, len(iterable), eta=eta_string,
|
|||
|
meters=str(self),
|
|||
|
time=str(iter_time), data=str(data_time)))
|
|||
|
i += 1
|
|||
|
end = time.time()
|
|||
|
total_time = time.time() - start_time
|
|||
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
|||
|
print('{} Total time: {}'.format(header, total_time_str))
|
|||
|
|
|||
|
class DiceCoefficient(object):
|
|||
|
def __init__(self, num_classes: int = 2, ignore_index: int = -100):
|
|||
|
self.cumulative_dice = None
|
|||
|
self.num_classes = num_classes
|
|||
|
self.ignore_index = ignore_index
|
|||
|
self.count = None
|
|||
|
|
|||
|
def update(self, pred, target):
|
|||
|
if self.cumulative_dice is None:
|
|||
|
self.cumulative_dice = torch.zeros(1, dtype=pred.dtype, device=pred.device)
|
|||
|
if self.count is None:
|
|||
|
self.count = torch.zeros(1, dtype=pred.dtype, device=pred.device)
|
|||
|
# compute the Dice score, ignoring background
|
|||
|
pred = F.one_hot(pred.argmax(dim=1), self.num_classes).permute(0, 3, 1, 2).float()
|
|||
|
dice_target = build_target(target, self.num_classes, self.ignore_index)
|
|||
|
self.cumulative_dice += multiclass_dice_coeff(pred[:, 1:], dice_target[:, 1:], ignore_index=self.ignore_index)
|
|||
|
self.count += 1
|
|||
|
|
|||
|
@property
|
|||
|
def value(self):
|
|||
|
if self.count == 0:
|
|||
|
return 0
|
|||
|
else:
|
|||
|
return self.cumulative_dice / self.count
|
|||
|
|
|||
|
def reset(self):
|
|||
|
if self.cumulative_dice is not None:
|
|||
|
self.cumulative_dice.zero_()
|
|||
|
|
|||
|
if self.count is not None:
|
|||
|
self.count.zeros_()
|
|||
|
|
|||
|
def reduce_from_all_processes(self):
|
|||
|
if not torch.distributed.is_available():
|
|||
|
return
|
|||
|
if not torch.distributed.is_initialized():
|
|||
|
return
|
|||
|
torch.distributed.barrier()
|
|||
|
torch.distributed.all_reduce(self.cumulative_dice)
|
|||
|
torch.distributed.all_reduce(self.count)
|
|||
|
|
|||
|
|
|||
|
def mkdir(path):
|
|||
|
try:
|
|||
|
os.makedirs(path)
|
|||
|
except OSError as e:
|
|||
|
if e.errno != errno.EEXIST:
|
|||
|
raise
|
|||
|
|
|||
|
|
|||
|
def setup_for_distributed(is_master):
|
|||
|
"""
|
|||
|
This function disables printing when not in master process
|
|||
|
"""
|
|||
|
import builtins as __builtin__
|
|||
|
builtin_print = __builtin__.print
|
|||
|
|
|||
|
def print(*args, **kwargs):
|
|||
|
force = kwargs.pop('force', False)
|
|||
|
if is_master or force:
|
|||
|
builtin_print(*args, **kwargs)
|
|||
|
|
|||
|
__builtin__.print = print
|
|||
|
|
|||
|
|
|||
|
def is_dist_avail_and_initialized():
|
|||
|
if not dist.is_available():
|
|||
|
return False
|
|||
|
if not dist.is_initialized():
|
|||
|
return False
|
|||
|
return True
|
|||
|
|
|||
|
|
|||
|
def get_world_size():
|
|||
|
if not is_dist_avail_and_initialized():
|
|||
|
return 1
|
|||
|
return dist.get_world_size()
|
|||
|
|
|||
|
|
|||
|
def get_rank():
|
|||
|
if not is_dist_avail_and_initialized():
|
|||
|
return 0
|
|||
|
return dist.get_rank()
|
|||
|
|
|||
|
|
|||
|
def is_main_process():
|
|||
|
return get_rank() == 0
|
|||
|
|
|||
|
|
|||
|
def save_on_master(*args, **kwargs):
|
|||
|
if is_main_process():
|
|||
|
torch.save(*args, **kwargs)
|
|||
|
|
|||
|
|
|||
|
def init_distributed_mode(args):
|
|||
|
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
|||
|
args.rank = int(os.environ["RANK"])
|
|||
|
args.world_size = int(os.environ['WORLD_SIZE'])
|
|||
|
args.gpu = int(os.environ['LOCAL_RANK'])
|
|||
|
elif 'SLURM_PROCID' in os.environ:
|
|||
|
args.rank = int(os.environ['SLURM_PROCID'])
|
|||
|
args.gpu = args.rank % torch.cuda.device_count()
|
|||
|
elif hasattr(args, "rank"):
|
|||
|
pass
|
|||
|
else:
|
|||
|
print('Not using distributed mode')
|
|||
|
args.distributed = False
|
|||
|
return
|
|||
|
|
|||
|
args.distributed = True
|
|||
|
|
|||
|
torch.cuda.set_device(args.gpu)
|
|||
|
args.dist_backend = 'nccl'
|
|||
|
print('| distributed init (rank {}): {}'.format(
|
|||
|
args.rank, args.dist_url), flush=True)
|
|||
|
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
|||
|
world_size=args.world_size, rank=args.rank)
|
|||
|
setup_for_distributed(args.rank == 0)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|