ai-station-code/wudingpv/taihuyuan_roof/compared_experiment/utils/util.py

415 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from collections import defaultdict, deque
import datetime
import time
import torch
import torch.distributed as dist
from taihuyuan_roof.compared_experiment.utils.dice_coefficient_loss import multiclass_dice_coeff, build_target
import errno
import torch.nn.functional as F
import os
import logging
from pathlib import Path
def getLogger(savedir):
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Log等级总开关
formatter = logging.Formatter(fmt="[%(asctime)s|%(filename)s|%(levelname)s] %(message)s",
datefmt="%a %b %d %H:%M:%S %Y")
# StreamHandler
sHandler = logging.StreamHandler()
sHandler.setFormatter(formatter)
logger.addHandler(sHandler)
work_dir= "/"
# FileHandler
Path(savedir).mkdir(parents=True, exist_ok=True)
work_dir = os.path.join(savedir,
time.strftime("%Y-%m-%d-%H.%M", time.localtime())) # 日志文件写入目录
if not os.path.exists(work_dir):
os.makedirs(work_dir)
fHandler = logging.FileHandler(work_dir + '/log.txt', mode='w',encoding="utf-8")
fHandler.setLevel(logging.DEBUG) # 输出到file的log等级的开关
fHandler.setFormatter(formatter) # 定义handler的输出格式
logger.addHandler(fHandler) # 将logger添加到handler里面
return logger
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{value:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class ConfusionMatrix(object):
def __init__(self, num_classes):
self.num_classes = num_classes
self.mat = None
def update(self, a, b):
n = self.num_classes
if self.mat is None:
# 创建混淆矩阵
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
with torch.no_grad():
# 寻找GT中为目标的像素索引
k = (a >= 0) & (a < n)
# 统计像素真实类别a[k]被预测成类别b[k]的个数(这里的做法很巧妙)
inds = n * a[k].to(torch.int64) + b[k]
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
def reset(self):
if self.mat is not None:
self.mat.zero_()
def compute(self):
# 注意混淆矩阵的形式
# 注意这个混淆矩阵都不太统一 有的是1的样子 有的是2的样子 本计算 借鉴的霹雳大佬的代码采用2的形式
'''1.
预测
真实
'''
'''2.
真实
预测
'''
#
# #0.sum代表所有行之间取sum 保留列的结构 1.sum代表所有列之间取sum保留行的结构
h = self.mat.float()
# 计算全局预测准确率(混淆矩阵的对角线为预测正确的个数)
acc_global = torch.diag(h).sum() / h.sum()
# 计算每个类别的准确率
acc = torch.diag(h) / h.sum(1)
# 计算每个类别预测与真实目标的iou
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
'''
22/9/24补充每一类的准确率p同上边的acc和召回率r
p=tp/tp+fp 对角线/h.sum(1)
召回=tp/tp+fn 对角线/h.sum(0)
f1=2 (p*r)/p+r
'''
recall = torch.diag(h) / h.sum(1)
precion = torch.diag(h) / h.sum(0)
f1 = (2 * precion * recall) / (precion + recall)
return acc_global, acc, iu,precion,recall,f1
# def reduce_from_all_processes(self):
# if not torch.distributed.is_available():
# return
# if not torch.distributed.is_initialized():
# return
# torch.distributed.barrier()
# torch.distributed.all_reduce(self.mat)
def re_zhib(self):
acc_global, acc, iu,precion,recall,f1 = self.compute()
miou = iu.mean().item() * 100
meanf1= f1.mean().item() * 100
pv_iou=iu[1:].mean().item() * 100
pv_f1=f1[1:].mean().item() * 100
acc_global = acc_global.item() * 100
acc = [round(i, 1) for i in (acc * 100).tolist()]
iu = [round(i, 1) for i in (iu * 100).tolist()]
precion = [round(i, 1) for i in (precion * 100).tolist()]
recall = [round(i, 1) for i in (recall * 100).tolist()]
f1 = [round(i, 1) for i in (f1 * 100).tolist()]
return acc_global, acc, iu,precion,recall,f1,miou,meanf1,pv_iou,pv_f1
def __str__(self):
acc_global, acc, iu, precion, recall, f1 = self.compute()
return (
'global correct: {:.1f}\n'
'average row correct: {}\n'
'IoU: {}\n'
'Precion: {:.1f}\n'
'Recall: {:.1f}\n'
'f1: {:.1f}\n'
'mean IoU: {:.1f}\n'
).format(
acc_global.item() * 100,
['{:.1f}'.format(i) for i in (acc * 100).tolist()],
['{:.1f}'.format(i) for i in (iu * 100).tolist()],
['{:.1f}'.format(i) for i in (precion * 100).tolist()],
['{:.1f}'.format(i) for i in (recall * 100).tolist()],
['{:.1f}'.format(i) for i in (f1 * 100).tolist()],
iu.mean().item() * 100)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
if torch.cuda.is_available():
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}',
'max mem: {memory:.0f}'
])
else:
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {}'.format(header, total_time_str))
class DiceCoefficient(object):
def __init__(self, num_classes: int = 2, ignore_index: int = -100):
self.cumulative_dice = None
self.num_classes = num_classes
self.ignore_index = ignore_index
self.count = None
def update(self, pred, target):
if self.cumulative_dice is None:
self.cumulative_dice = torch.zeros(1, dtype=pred.dtype, device=pred.device)
if self.count is None:
self.count = torch.zeros(1, dtype=pred.dtype, device=pred.device)
# compute the Dice score, ignoring background
pred = F.one_hot(pred.argmax(dim=1), self.num_classes).permute(0, 3, 1, 2).float()
dice_target = build_target(target, self.num_classes, self.ignore_index)
self.cumulative_dice += multiclass_dice_coeff(pred[:, 1:], dice_target[:, 1:], ignore_index=self.ignore_index)
self.count += 1
@property
def value(self):
if self.count == 0:
return 0
else:
return self.cumulative_dice / self.count
def reset(self):
if self.cumulative_dice is not None:
self.cumulative_dice.zero_()
if self.count is not None:
self.count.zeros_()
def reduce_from_all_processes(self):
if not torch.distributed.is_available():
return
if not torch.distributed.is_initialized():
return
torch.distributed.barrier()
torch.distributed.all_reduce(self.cumulative_dice)
torch.distributed.all_reduce(self.count)
def mkdir(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
elif hasattr(args, "rank"):
pass
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
setup_for_distributed(args.rank == 0)