ai-station-code/wudingpv/taihuyuan_pv/compared_experiment/mySwinUnet/test/xaiotu.py

407 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : xaiotu
@Author : qiqq
@create_time : 2023/6/29 8:57
"""
import matplotlib.pyplot as plt
import numpy as np
import copy
import cv2 as cv
from tqdm import tqdm
import os
from torch.utils.data import DataLoader
import argparse
from torch.utils.data import Dataset
from torchvision import transforms
from taihuyuan_pv.dataloaders import custom_transforms as tr
import torch
import torch.nn.functional as F
from PIL import Image
'''voc数据集格式的'''
class datasets_pvtaihuyuan(Dataset):
def __init__(self,
args,
split='val',
isAug=False
):
super(datasets_pvtaihuyuan, self).__init__()
self.args = args
self.resize = args.resize
self.crop_size = args.crop_size # 是单数比如256
self.flip_prob = args.flip_prob # 是0-1
self.isAug = isAug
self._base_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
self._image_dir = os.path.join(self._base_dir, 'images')
self._cat_dir = os.path.join(self._base_dir, 'labels')
if isinstance(split, str):
self.split = [split]
else:
split.sort()
self.split = split
_splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
self.im_ids = []
self.images = []
self.categories = []
for splt in self.split:
with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f:
lines = f.read().splitlines()
for ii, line in enumerate(lines):
_image = os.path.join(self._image_dir, line + ".png") # 注意这个地方有的图片可能是jpg有的可能是png自己看着改
_cat = os.path.join(self._cat_dir, line + ".png") # 注意格式
assert os.path.isfile(_image)
assert os.path.isfile(_cat)
self.im_ids.append(line)
self.images.append(_image)
self.categories.append(_cat)
assert (len(self.images) == len(self.categories))
# Display stats
print('Number of images in {}: {:d}'.format(split, len(self.images)))
def __len__(self):
return len(self.images)
def __getitem__(self, index):
_img, _target, _name = self._make_img_gt_point_pair(index)
sample = {'image': _img, 'label': _target, "name": _name}
for split in self.split:
if split == 'val':
return self.transform_val(sample)
def _make_img_gt_point_pair(self, index):
_img = Image.open(self.images[index]).convert('RGB')
_target = Image.open(self.categories[index])
name = self.images[index].split(".")[0]
return _img, _target, name
def transform_val(self, sample):
composed_transforms = transforms.Compose([
tr.Resize(self.args.resize),
tr.Normalize_simple(),
tr.ToTensor()])
return composed_transforms(sample)
def __str__(self):
return 'linan(split=' + str(self.split) + ')'
#
class predictandeval():
def __init__(self, model_name="xx"):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.predict_dir = model_name + "_" + "outputs/"
self.fusion_path = model_name + "_" + "fusion/"
self.model_name = model_name
self.palette = [0, 0, 0, 0, 255, 0, 0, 255, 0]
self.fusin = True
def get_confusion_matrix(self, gt_label, pred_label, class_num):
"""
Calcute the confusion matrix by given label and pred
:param gt_label: the ground truth label
:param pred_label: the pred label
:param class_num: the nunber of class
:return: the confusion matrix
"""
index = (gt_label * class_num + pred_label).astype('int32')
label_count = np.bincount(index)
confusion_matrix = np.zeros((class_num, class_num))
for i_label in range(class_num):
for i_pred_label in range(class_num):
cur_index = i_label * class_num + i_pred_label
if cur_index < len(label_count):
confusion_matrix[i_label, i_pred_label] = label_count[cur_index]
return confusion_matrix
# 延用训练数据集格式
# def get_datasets(self, ):
# parser = argparse.ArgumentParser()
# args = parser.parse_args()
# args.resize = (512, 512)
# args.crop_size = 480 # 在验证的时候没有用 crop和filp 只是为了不报错
# args.flip_prob = 0.5
# datasets = datasets_pvtaihuyuan(args, split='val', isAug=False)
# print(len(datasets))
# return datasets
# 重写 voc格式的val或者test
def get_images(self, ):
basedir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
split = 'val'
_splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
imglist = []
with open(os.path.join(os.path.join(_splits_dir, split + '.txt')), "r") as f:
lines = f.read().splitlines()
for ii, line in enumerate(lines):
name = line
_imagepath = os.path.join(basedir, 'images', line + ".png")
assert os.path.isfile(_imagepath)
image = Image.open(_imagepath)
orininal_h = image.size[1]
orininal_w = image.size[0]
item = {"name": name, "orininal_h": orininal_h, "orininal_w": orininal_w, "image": image}
imglist.append(item)
print("共监测到{}张原始图像和标签".format(len(imglist)))
return imglist
def get_labels(self, ):
basedir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
split = 'val'
_splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
labellist = []
with open(os.path.join(os.path.join(_splits_dir, split + '.txt')), "r") as f:
lines = f.read().splitlines()
for ii, line in enumerate(lines):
name = line
_labelpath = os.path.join(basedir, 'labels', line + ".png")
assert os.path.isfile(_labelpath)
label = Image.open(_labelpath)
item = {"name": name, "label": label}
labellist.append(item)
print("共监测到{}张标签".format(len(labellist)))
return labellist
# def get_result1(self,net):
#
# datase = self.get_datasets1()
# dataloader = DataLoader(datase, batch_size=1, shuffle=False, num_workers=0)
# for ii, sample in enumerate(dataloader):
# image, mask_true = sample['image'], sample['label']
# name=sample["name"]
# image = image.to(device=self.device, dtype=torch.float32)
# orininal_h, orininal_w =image.shape[2:]
# with torch.no_grad():
# out = net(image) #
# if isinstance(out, list) or isinstance(out, tuple): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个)
# out = out[0] # 就取第一个
# pr = F.softmax(out.permute(1, 2, 0), dim=-1).cpu().numpy()
# result = pr.argmax(axis=-1)
# if not os.path.exists(self.predict_dir):
# os.makedirs(self.predict_dir)
# output_im = Image.fromarray(np.uint8(result)).convert('P')
#
# output_im.putpalette(self.palette)
# output_im.save(self.predict_dir + name + '.png')
#
# # hunhe #没有fusion
def get_result(self, net):
imglist = self.get_images()
assert len(imglist) != 0
for i in tqdm(imglist):
image = i["image"]
name = i["name"]
orininal_w = i["orininal_w"]
orininal_h = i["orininal_h"]
old_img = copy.deepcopy(image)
imaged = cv.resize(np.array(image), dsize=(512, 512), interpolation=cv.INTER_LINEAR)
image_data = np.expand_dims(
np.transpose(self.preprocess_input(np.array(imaged, np.float32), md=False), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
image = images.to(device="cuda", dtype=torch.float32)
model = net.to(device="cuda")
out = model(image)
if isinstance(out, list) or isinstance(out, tuple): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个)
out = out[0] # 就取第一个
out = out[0] # 去掉batch
pr = F.softmax(out.permute(1, 2, 0), dim=-1).cpu().numpy()
result = pr.argmax(axis=-1)
# 结果图
if not os.path.exists(self.predict_dir):
os.makedirs(self.predict_dir)
output_im = Image.fromarray(np.uint8(result)).convert('P')
output_im.putpalette(self.palette)
output_im.save(self.predict_dir + name + '.png')
# hunhe
if self.fusin:
if not os.path.exists(self.fusion_path):
os.makedirs(self.fusion_path)
PALETTE = [(0, 0, 0), (255, 0, 0), (0, 255, 0)]
seg_img0 = np.reshape(np.array(PALETTE, np.uint8)[np.reshape(result, [-1])],
[orininal_h, orininal_w, -1])
image0 = Image.fromarray(np.uint8(seg_img0))
fusion = Image.blend(old_img, image0, 0.4)
fusion.save(self.fusion_path + self.model_name + "_" + name + '.png')
def preprocess_input(self, image, md=False):
mean = (0.231, 0.217, 0.22) # 针对北京的数据集
std = (0.104, 0.086, 0.085)
if md:
image /= 255.0
image -= mean
image /= std
return image
else:
image /= 255.0
return image
def compute_evalue(self, predict_dir, ):
labellist = self.get_labels() # 存的额是没转换成nmpy的png
predictresultlist = os.listdir(predict_dir)
assert len(labellist) == len(predictresultlist)
num_classes = 2
confusion_matrix = np.zeros((num_classes, num_classes))
for i in tqdm(labellist):
name = i["name"]
seg_gt = np.array(i["label"])
seg_pred = np.array(Image.open(os.path.join(predict_dir, name + ".png")))
#
ignore_index = seg_gt != 255 #
seg_gt = seg_gt[ignore_index]
seg_pred = seg_pred[ignore_index]
confusion_matrix += self.get_confusion_matrix(seg_gt, seg_pred, num_classes)
pos = confusion_matrix.sum(1) # 得到的每个数都是每个类别真实的像素点数量 (相当于tp+fn)
res = confusion_matrix.sum(0) # 得到的每个数都是被预测为这个类别的像素点数量 (tp+fp)
tp = np.diag(confusion_matrix)
IU_array = (tp / np.maximum(1.0, pos + res - tp))
mean_IU = IU_array.mean()
pv_iou = IU_array[1:].mean()
precion = tp / res
recall = tp / pos
f1 = (2 * precion * recall) / (precion + recall)
mf1 = f1.mean()
acc_global = tp.sum() / pos.sum()
pv_f1 = f1[1:].mean()
pv_precion = precion[1:].mean()
pv_recall = recall[1:].mean()
print("测试结果")
print("acc_global:", round(acc_global, 4))
print("IU_array:", np.round(IU_array, 4))
print("precion:", np.round(precion, 4))
print("recall:", np.round(recall, 4))
print("f1:", np.round(f1, 4))
print("miou:", round(mean_IU, 4))
print("pv_iou:", round(pv_iou, 4))
print("pv_precion:", round(pv_precion, 4))
print("pv_recall:", round(pv_recall, 4))
print("pv_f1:", round(pv_f1, 4))
return pv_iou, pv_f1, pv_precion, pv_recall, acc_global
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=200, help='Number of epochs')
parser.add_argument('--train_batch-size', '-tb', dest='train_batch_size', metavar='TB', type=int, default=8,
help='Train_Batch size')
parser.add_argument('--val_batch-size', '-vb', dest='val_batch_size', metavar='VB', type=int, default=1,
help='Val_Batch size')
parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-3,
help='Learning rate', dest='lr')
parser.add_argument('--load', '-f', type=str,
default="/home/qiqq/q3dl/code/pretrain_weight/pretrained/resnet/resnet50-0676ba61.pth",
help='Load model from a .pth file') # 有没有预训练。。
parser.add_argument('--ignore_index', '-i', type=int, dest='ignore_index', default=255,
help='ignore index defult 100')
parser.add_argument('--origin_shape', action='store_true', default=(512,512), help='原始输入尺寸')
parser.add_argument('--resume', '-r', type=str, default="", help='is use Resume')
parser.add_argument('--useDice', '-ud', type=str, default=False, help='训练的时候是否使用dice')
parser.add_argument('--valIndex', '-vI', type=str, default=["Valloss","IouMiouP"], help='评价指标要使用哪些,注意IouMiouP= acc_global, acc, iu,precion,recall,f1,miou,并且建议旨在2分类的时候用dice否则会出错')
parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
parser.add_argument('--num_classes', type=int,
default=2, help='output channel of network')
parser.add_argument('--img_size', type=int,
default=512, help='input patch size of network input')
parser.add_argument('--cfg', type=str,
default="/home/qiqq/q3dl/code/rooftoprecognition/pv_recognition/compared_experiment/mySwinUnet/configs/swin_tiny_patch4_window7_224_lite.yaml",
metavar="FILE", help='path to config file', )
parser.add_argument(
"--opts",
help="Modify config options by adding 'KEY VALUE' pairs. ",
default=None,
nargs='+',
)
return parser.parse_args()
if __name__ == '__main__':
from taihuyuan_pv.compared_experiment.mySwinUnet.model.vision_transformer import SwinUnet
from taihuyuan_pv.compared_experiment.mySwinUnet.configs.config import get_config
args = get_args()
config = get_config(args)
model = SwinUnet(config, img_size=args.img_size, num_classes=args.num_classes)
model_path = "/home/qiqq/q3dl/code/applicationproject/taihuyuan_pv/compared_experiment/mySwinUnet/train_val_test/checkpoints/2023-08-09-00.08/best.pth"
model_dict = torch.load(model_path)
model.load_state_dict(model_dict['net'])
print("权重加载")
model.eval()
model.cuda()
testss = predictandeval(model_name="swinunethypv")
testss.get_result(net=model)
predictoutpath = "/home/qiqq/q3dl/code/applicationproject/taihuyuan_pv/compared_experiment/mySwinUnet/test/swinunethypv_outputs/"
pv_iou, pv_f1, pv_precion, pv_recall, acc_global = testss.compute_evalue(predictoutpath)
resultsavepath = "./swinunethypv.txt"
with open(resultsavepath, "a") as f:
f.write("测试结果\n")
f.write(f"acc_global:{round(acc_global, 4)}\n")
f.write(f"pv_iou:{round(pv_iou, 4)}\n")
f.write(f"pv_f1:{round(pv_f1, 4)}\n")
f.write(f"pv_precion:{round(pv_precion, 4)}\n")
f.write(f"pv_recall:{round(pv_recall, 4)}\n")
print("写入完成")