376 lines
13 KiB
Python
376 lines
13 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
@project:
|
||
@File : xaiotu
|
||
@Author : qiqq
|
||
@create_time : 2023/6/29 8:57
|
||
"""
|
||
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import copy
|
||
import cv2 as cv
|
||
from tqdm import tqdm
|
||
import os
|
||
from torch.utils.data import DataLoader
|
||
|
||
import argparse
|
||
from torch.utils.data import Dataset
|
||
|
||
from torchvision import transforms
|
||
from taihuyuan_pv.dataloaders import custom_transforms as tr
|
||
import torch
|
||
import torch.nn.functional as F
|
||
|
||
from PIL import Image
|
||
|
||
|
||
|
||
'''voc数据集格式的'''
|
||
|
||
|
||
|
||
class datasets_pvtaihuyuan(Dataset):
|
||
def __init__(self,
|
||
args,
|
||
split='val',
|
||
isAug=False
|
||
):
|
||
super(datasets_pvtaihuyuan, self).__init__()
|
||
|
||
self.args = args
|
||
self.resize = args.resize
|
||
self.crop_size = args.crop_size # 是单数比如256,
|
||
self.flip_prob = args.flip_prob # 是0-1
|
||
|
||
self.isAug = isAug
|
||
self._base_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
|
||
self._image_dir = os.path.join(self._base_dir, 'images')
|
||
self._cat_dir = os.path.join(self._base_dir, 'labels')
|
||
|
||
if isinstance(split, str):
|
||
self.split = [split]
|
||
else:
|
||
split.sort()
|
||
self.split = split
|
||
|
||
_splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
|
||
|
||
self.im_ids = []
|
||
self.images = []
|
||
self.categories = []
|
||
|
||
for splt in self.split:
|
||
with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f:
|
||
lines = f.read().splitlines()
|
||
|
||
for ii, line in enumerate(lines):
|
||
_image = os.path.join(self._image_dir, line + ".png") # 注意这个地方有的图片可能是jpg有的可能是png自己看着改
|
||
_cat = os.path.join(self._cat_dir, line + ".png") # 注意格式
|
||
assert os.path.isfile(_image)
|
||
assert os.path.isfile(_cat)
|
||
self.im_ids.append(line)
|
||
self.images.append(_image)
|
||
self.categories.append(_cat)
|
||
|
||
assert (len(self.images) == len(self.categories))
|
||
|
||
# Display stats
|
||
print('Number of images in {}: {:d}'.format(split, len(self.images)))
|
||
|
||
def __len__(self):
|
||
return len(self.images)
|
||
|
||
def __getitem__(self, index):
|
||
_img, _target,_name = self._make_img_gt_point_pair(index)
|
||
sample = {'image': _img, 'label': _target,"name":_name}
|
||
|
||
for split in self.split:
|
||
|
||
if split == 'val':
|
||
return self.transform_val(sample)
|
||
|
||
def _make_img_gt_point_pair(self, index):
|
||
_img = Image.open(self.images[index]).convert('RGB')
|
||
_target = Image.open(self.categories[index])
|
||
name=self.images[index].split(".")[0]
|
||
|
||
return _img, _target,name
|
||
|
||
|
||
|
||
def transform_val(self, sample):
|
||
composed_transforms = transforms.Compose([
|
||
tr.Resize(self.args.resize),
|
||
|
||
tr.Normalize_simple(),
|
||
|
||
tr.ToTensor()])
|
||
return composed_transforms(sample)
|
||
|
||
|
||
def __str__(self):
|
||
return 'linan(split=' + str(self.split) + ')'
|
||
#
|
||
|
||
|
||
class predictandeval():
|
||
def __init__(self,model_name="xx"):
|
||
self.device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
||
self.predict_dir = model_name + "_" + "outputs/"
|
||
self.fusion_path = model_name + "_" + "fusion/"
|
||
self.model_name = model_name
|
||
self.palette = [0, 0, 0, 0, 255, 0, 0, 255, 0]
|
||
|
||
self.fusin = True
|
||
|
||
def get_confusion_matrix(self, gt_label, pred_label, class_num):
|
||
"""
|
||
Calcute the confusion matrix by given label and pred
|
||
:param gt_label: the ground truth label
|
||
:param pred_label: the pred label
|
||
:param class_num: the nunber of class
|
||
:return: the confusion matrix
|
||
"""
|
||
index = (gt_label * class_num + pred_label).astype('int32')
|
||
label_count = np.bincount(index)
|
||
confusion_matrix = np.zeros((class_num, class_num))
|
||
|
||
for i_label in range(class_num):
|
||
for i_pred_label in range(class_num):
|
||
cur_index = i_label * class_num + i_pred_label
|
||
if cur_index < len(label_count):
|
||
confusion_matrix[i_label, i_pred_label] = label_count[cur_index]
|
||
|
||
return confusion_matrix
|
||
|
||
#延用训练数据集格式
|
||
# def get_datasets(self, ):
|
||
# parser = argparse.ArgumentParser()
|
||
# args = parser.parse_args()
|
||
# args.resize = (512, 512)
|
||
# args.crop_size = 480 # 在验证的时候没有用 crop和filp 只是为了不报错
|
||
# args.flip_prob = 0.5
|
||
# datasets = datasets_pvtaihuyuan(args, split='val', isAug=False)
|
||
# print(len(datasets))
|
||
# return datasets
|
||
|
||
#重写 voc格式的val或者test
|
||
def get_images(self, ):
|
||
basedir= "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
|
||
split = 'val'
|
||
_splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
|
||
imglist = []
|
||
with open(os.path.join(os.path.join(_splits_dir, split + '.txt')), "r") as f:
|
||
lines = f.read().splitlines()
|
||
for ii, line in enumerate(lines):
|
||
name =line
|
||
_imagepath=os.path.join(basedir, 'images',line + ".png")
|
||
|
||
assert os.path.isfile(_imagepath)
|
||
|
||
image = Image.open(_imagepath)
|
||
|
||
orininal_h = image.size[1]
|
||
orininal_w = image.size[0]
|
||
item = {"name": name, "orininal_h": orininal_h, "orininal_w": orininal_w, "image": image}
|
||
imglist.append(item)
|
||
print("共监测到{}张原始图像和标签".format(len(imglist)))
|
||
return imglist
|
||
|
||
def get_labels(self, ):
|
||
basedir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
|
||
split = 'val'
|
||
_splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
|
||
labellist = []
|
||
with open(os.path.join(os.path.join(_splits_dir, split + '.txt')), "r") as f:
|
||
lines = f.read().splitlines()
|
||
for ii, line in enumerate(lines):
|
||
name = line
|
||
_labelpath = os.path.join(basedir, 'labels', line + ".png")
|
||
|
||
assert os.path.isfile(_labelpath)
|
||
|
||
label = Image.open(_labelpath)
|
||
item = {"name": name, "label": label}
|
||
labellist.append(item)
|
||
print("共监测到{}张标签".format(len(labellist)))
|
||
return labellist
|
||
|
||
# def get_result1(self,net):
|
||
#
|
||
# datase = self.get_datasets1()
|
||
# dataloader = DataLoader(datase, batch_size=1, shuffle=False, num_workers=0)
|
||
# for ii, sample in enumerate(dataloader):
|
||
# image, mask_true = sample['image'], sample['label']
|
||
# name=sample["name"]
|
||
# image = image.to(device=self.device, dtype=torch.float32)
|
||
# orininal_h, orininal_w =image.shape[2:]
|
||
# with torch.no_grad():
|
||
# out = net(image) #
|
||
# if isinstance(out, list) or isinstance(out, tuple): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个)
|
||
# out = out[0] # 就取第一个
|
||
# pr = F.softmax(out.permute(1, 2, 0), dim=-1).cpu().numpy()
|
||
# result = pr.argmax(axis=-1)
|
||
# if not os.path.exists(self.predict_dir):
|
||
# os.makedirs(self.predict_dir)
|
||
# output_im = Image.fromarray(np.uint8(result)).convert('P')
|
||
#
|
||
# output_im.putpalette(self.palette)
|
||
# output_im.save(self.predict_dir + name + '.png')
|
||
#
|
||
# # hunhe #没有fusion
|
||
|
||
def get_result(self,net):
|
||
|
||
imglist = self.get_images()
|
||
assert len(imglist) != 0
|
||
for i in tqdm(imglist):
|
||
image = i["image"]
|
||
name = i["name"]
|
||
orininal_w = i["orininal_w"]
|
||
orininal_h = i["orininal_h"]
|
||
old_img = copy.deepcopy(image)
|
||
imaged = cv.resize(np.array(image), dsize=(512, 512), interpolation=cv.INTER_LINEAR)
|
||
image_data = np.expand_dims(
|
||
np.transpose(self.preprocess_input(np.array(imaged, np.float32), md=False), (2, 0, 1)), 0)
|
||
|
||
with torch.no_grad():
|
||
images = torch.from_numpy(image_data)
|
||
image = images.to(device="cuda", dtype=torch.float32)
|
||
model = net.to(device="cuda")
|
||
out = model(image)
|
||
if isinstance(out, list) or isinstance(out, tuple): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个)
|
||
out = out[0] # 就取第一个
|
||
out = out[0] #去掉batch
|
||
|
||
pr = F.softmax(out.permute(1, 2, 0), dim=-1).cpu().numpy()
|
||
result = pr.argmax(axis=-1)
|
||
# 结果图
|
||
if not os.path.exists(self.predict_dir):
|
||
os.makedirs(self.predict_dir)
|
||
output_im = Image.fromarray(np.uint8(result)).convert('P')
|
||
|
||
output_im.putpalette(self.palette)
|
||
output_im.save(self.predict_dir + name + '.png')
|
||
|
||
# hunhe
|
||
if self.fusin:
|
||
if not os.path.exists(self.fusion_path):
|
||
os.makedirs(self.fusion_path)
|
||
PALETTE = [(0, 0, 0), (255, 0, 0), (0, 255, 0)]
|
||
|
||
seg_img0 = np.reshape(np.array(PALETTE, np.uint8)[np.reshape(result, [-1])],
|
||
[orininal_h, orininal_w, -1])
|
||
|
||
image0 = Image.fromarray(np.uint8(seg_img0))
|
||
|
||
fusion = Image.blend(old_img, image0, 0.4)
|
||
|
||
fusion.save(self.fusion_path + self.model_name + "_" + name + '.png')
|
||
|
||
def preprocess_input(self,image, md=False):
|
||
mean = (0.231, 0.217, 0.22) # 针对北京的数据集
|
||
std = (0.104, 0.086, 0.085)
|
||
if md:
|
||
image /= 255.0
|
||
image -= mean
|
||
image /= std
|
||
return image
|
||
else:
|
||
image /= 255.0
|
||
return image
|
||
|
||
def compute_evalue(self, predict_dir, ):
|
||
labellist = self.get_labels() # 存的额是没转换成nmpy的png
|
||
predictresultlist = os.listdir(predict_dir)
|
||
assert len(labellist) == len(predictresultlist)
|
||
num_classes = 2
|
||
|
||
confusion_matrix = np.zeros((num_classes, num_classes))
|
||
|
||
for i in tqdm(labellist):
|
||
name = i["name"]
|
||
seg_gt = np.array(i["label"])
|
||
seg_pred = np.array(Image.open(os.path.join(predict_dir, name + ".png")))
|
||
#
|
||
ignore_index = seg_gt != 255 #
|
||
seg_gt = seg_gt[ignore_index]
|
||
seg_pred = seg_pred[ignore_index]
|
||
|
||
confusion_matrix += self.get_confusion_matrix(seg_gt, seg_pred, num_classes)
|
||
|
||
pos = confusion_matrix.sum(1) # 得到的每个数都是每个类别真实的像素点数量 (相当于tp+fn)
|
||
res = confusion_matrix.sum(0)# 得到的每个数都是被预测为这个类别的像素点数量 (tp+fp)
|
||
tp = np.diag(confusion_matrix)
|
||
|
||
IU_array = (tp / np.maximum(1.0, pos + res - tp))
|
||
|
||
mean_IU = IU_array.mean()
|
||
pv_iou = IU_array[1:].mean()
|
||
precion = tp / res
|
||
|
||
recall = tp / pos
|
||
|
||
f1 = (2 * precion * recall) / (precion + recall)
|
||
mf1 = f1.mean()
|
||
acc_global = tp.sum() / pos.sum()
|
||
|
||
pv_f1 = f1[1:].mean()
|
||
pv_precion = precion[1:].mean()
|
||
pv_recall = recall[1:].mean()
|
||
|
||
print("测试结果")
|
||
print("acc_global:", round(acc_global,4))
|
||
print("IU_array:", np.round(IU_array,4))
|
||
print("precion:", np.round(precion,4))
|
||
print("recall:", np.round(recall,4))
|
||
print("f1:", np.round(f1,4))
|
||
print("miou:", round(mean_IU,4))
|
||
print("pv_iou:", round(pv_iou,4))
|
||
print("pv_precion:", round(pv_precion,4))
|
||
print("pv_recall:", round(pv_recall,4))
|
||
print("pv_f1:", round(pv_f1,4))
|
||
|
||
|
||
|
||
|
||
return pv_iou, pv_f1,pv_precion,pv_recall, acc_global
|
||
|
||
|
||
|
||
|
||
if __name__ == '__main__':
|
||
from taihuyuan_pv.compared_experiment.mySETR.model.setr import MSETR_naive
|
||
|
||
|
||
model = MSETR_naive(pretrained="/home/qiqq/q3dl/code/pretrain_weight/pretrained/vit_checkpoint/setrvit/vit_base_p16_224-4e355ebd.pth")
|
||
|
||
model_path = "/home/qiqq/q3dl/code/applicationproject/taihuyuan_pv/compared_experiment/mySETR/train_val_test/checkpoints/2023-08-08-19.46/best.pth"
|
||
model_dict = torch.load(model_path)
|
||
model.load_state_dict(model_dict['net'])
|
||
print("权重加载")
|
||
model.eval()
|
||
model.cuda()
|
||
|
||
testss = predictandeval(model_name="setrthypv")
|
||
# testss.get_result(net=model)
|
||
predictoutpath = "/home/qiqq/q3dl/code/applicationproject/taihuyuan_pv/compared_experiment/mySETR/test/setrthypv_outputs/"
|
||
pv_iou, pv_f1, pv_precion, pv_recall, acc_global = testss.compute_evalue(predictoutpath)
|
||
|
||
resultsavepath = "./setrthypv.txt"
|
||
with open(resultsavepath, "a") as f:
|
||
f.write("测试结果\n")
|
||
f.write(f"acc_global:{round(acc_global, 4)}\n")
|
||
f.write(f"pv_iou:{round(pv_iou, 4)}\n")
|
||
f.write(f"pv_f1:{round(pv_f1, 4)}\n")
|
||
f.write(f"pv_precion:{round(pv_precion, 4)}\n")
|
||
f.write(f"pv_recall:{round(pv_recall, 4)}\n")
|
||
print("写入完成")
|
||
|
||
|
||
|
||
|