ai-station-code/wudingpv/taihuyuan_pv/compared_experiment/deeplabv3Plus/test/xaiotubu.py

464 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : xaiotu
@Author : qiqq
@create_time : 2023/6/29 8:57
"""
'''专门应对新图预测的'''
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import numpy as np
import copy
import cv2 as cv
from tqdm import tqdm
import os
from math import ceil
import torch
import torch.nn.functional as F
import torch.nn as nn
from PIL import Image
from PIL import Image as PILImage
'''voc数据集格式的'''
class predictandeval():
def __init__(self, model_name="xx"):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.predict_dir = model_name + "_" + "outputs/"
self.fusion_path = model_name + "_" + "fusion/"
self.model_name = model_name
self.palette = [0, 0, 0, 0, 255, 0, 0, 255, 0]
self.fusin = True
def get_confusion_matrix(self, gt_label, pred_label, class_num):
"""
Calcute the confusion matrix by given label and pred
:param gt_label: the ground truth label
:param pred_label: the pred label
:param class_num: the nunber of class
:return: the confusion matrix
"""
index = (gt_label * class_num + pred_label).astype('int32')
label_count = np.bincount(index)
confusion_matrix = np.zeros((class_num, class_num))
for i_label in range(class_num):
for i_pred_label in range(class_num):
cur_index = i_label * class_num + i_pred_label
if cur_index < len(label_count):
confusion_matrix[i_label, i_pred_label] = label_count[cur_index]
return confusion_matrix
# 重写 voc格式的val或者test
def get_images(self, ):
# basedir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
# split = 'val'
# _splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
# imglist = []
# with open(os.path.join(os.path.join(_splits_dir, split + '.txt')), "r") as f:
# lines = f.read().splitlines()
# for ii, line in enumerate(lines):
# name = line
# _imagepath = os.path.join(basedir, 'images', line + ".png")
#
# assert os.path.isfile(_imagepath)
#
# image = Image.open(_imagepath)
#
# orininal_h = image.size[1]
# orininal_w = image.size[0]
# item = {"name": name, "orininal_h": orininal_h, "orininal_w": orininal_w, "image": image}
# imglist.append(item)
# print("共监测到{}张原始图像和标签".format(len(imglist)))
# return imglist
basedir = "/home/qiqq/q3dl/test_temple/ground_pv/"
lines=os.listdir(basedir)
imglist = []
for ii, line in enumerate(lines):
name = line.split(".")[0]
_imagepath = os.path.join(basedir, line )
assert os.path.isfile(_imagepath)
image = Image.open(_imagepath)
orininal_h = image.size[1]
orininal_w = image.size[0]
item = {"name": name, "orininal_h": orininal_h, "orininal_w": orininal_w, "image": image}
imglist.append(item)
print("共监测到{}张原始图像和标签".format(len(imglist)))
return imglist
def get_labels(self, ):
# basedir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
# split = 'val'
# _splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
# labellist = []
# with open(os.path.join(os.path.join(_splits_dir, split + '.txt')), "r") as f:
# lines = f.read().splitlines()
# for ii, line in enumerate(lines):
# name = line
# _labelpath = os.path.join(basedir, 'labels', line + ".png")
#
# assert os.path.isfile(_labelpath)
#
# label = Image.open(_labelpath)
# item = {"name": name, "label": label}
# labellist.append(item)
# print("共监测到{}张标签".format(len(labellist)))
# return labellist
basedir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
split = 'val'
_splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
labellist = []
with open(os.path.join(os.path.join(_splits_dir, split + '.txt')), "r") as f:
lines = f.read().splitlines()
for ii, line in enumerate(lines):
name = line
_labelpath = os.path.join(basedir, 'labels', line + ".png")
assert os.path.isfile(_labelpath)
label = Image.open(_labelpath)
item = {"name": name, "label": label}
labellist.append(item)
print("共监测到{}张标签".format(len(labellist)))
return labellist
def get_result(self, net):
imglist = self.get_images()
assert len(imglist) != 0
for i in tqdm(imglist):
image = i["image"]
name = i["name"]
orininal_w = i["orininal_w"]
orininal_h = i["orininal_h"]
old_img = copy.deepcopy(image)
imaged = cv.resize(np.array(image), dsize=(512, 512), interpolation=cv.INTER_LINEAR)
image_data = np.expand_dims(
np.transpose(self.preprocess_input(np.array(imaged, np.float32), md=False), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
image = images.to(device="cuda", dtype=torch.float32)
model = net.to(device="cuda")
out = model(image)
if isinstance(out, list) or isinstance(out, tuple): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个)
out = out[0] # 就取第一个
out = out[0] # 去掉batch
pr = F.softmax(out.permute(1, 2, 0), dim=-1).cpu().numpy()
result = pr.argmax(axis=-1)
# 结果图
if not os.path.exists(self.predict_dir):
os.makedirs(self.predict_dir)
output_im = Image.fromarray(np.uint8(result)).convert('P')
output_im.putpalette(self.palette)
output_im.save(self.predict_dir + name + '.png')
# hunhe
if self.fusin:
if not os.path.exists(self.fusion_path):
os.makedirs(self.fusion_path)
PALETTE = [(0, 0, 0), (255, 0, 0), (0, 255, 0)]
seg_img0 = np.reshape(np.array(PALETTE, np.uint8)[np.reshape(result, [-1])],
[orininal_h, orininal_w, -1])
image0 = Image.fromarray(np.uint8(seg_img0))
fusion = Image.blend(old_img, image0, 0.4)
fusion.save(self.fusion_path + self.model_name + "_" + name + '.png')
def preprocess_input(self, image, md=False):
mean = (0.231, 0.217, 0.22) # 针对北京的数据集
std = (0.104, 0.086, 0.085)
if md:
image /= 255.0
image -= mean
image /= std
return image
else:
image /= 255.0
return image
def compute_evalue(self, predict_dir, ):
labellist = self.get_labels() # 存的额是没转换成nmpy的png
predictresultlist = os.listdir(predict_dir)
assert len(labellist) == len(predictresultlist)
num_classes = 2
confusion_matrix = np.zeros((num_classes, num_classes))
for i in tqdm(labellist):
name = i["name"]
seg_gt = np.array(i["label"])
seg_pred = np.array(Image.open(os.path.join(predict_dir, name + ".png")))
#
ignore_index = seg_gt != 255 #
seg_gt = seg_gt[ignore_index]
seg_pred = seg_pred[ignore_index]
confusion_matrix += self.get_confusion_matrix(seg_gt, seg_pred, num_classes)
pos = confusion_matrix.sum(1) # 得到的每个数都是每个类别真实的像素点数量 (相当于tp+fn)
res = confusion_matrix.sum(0) # 得到的每个数都是被预测为这个类别的像素点数量 (tp+fp)
tp = np.diag(confusion_matrix)
IU_array = (tp / np.maximum(1.0, pos + res - tp))
mean_IU = IU_array.mean()
pv_iou = IU_array[1:].mean()
precion = tp / res
recall = tp / pos
f1 = (2 * precion * recall) / (precion + recall)
mf1 = f1.mean()
acc_global = tp.sum() / pos.sum()
pv_f1 = f1[1:].mean()
pv_precion = precion[1:].mean()
pv_recall = recall[1:].mean()
print("测试结果")
print("acc_global:", round(acc_global, 4))
print("IU_array:", np.round(IU_array, 4))
print("precion:", np.round(precion, 4))
print("recall:", np.round(recall, 4))
print("f1:", np.round(f1, 4))
print("miou:", round(mean_IU, 4))
print("pv_iou:", round(pv_iou, 4))
print("pv_precion:", round(pv_precion, 4))
print("pv_recall:", round(pv_recall, 4))
print("pv_f1:", round(pv_f1, 4))
return pv_iou, pv_f1, pv_precion, pv_recall, acc_global
class mypv_bigmap:
def __init__(self,model_name="tongzhou",):
# self.basedir = r"/home/qiqq/q3dl/applicationdata/beijing/tz/tz8"
self.basedir = "/home/qiqq/q3dl/test_temple/ground_pv/"
# self.basedir = "/home/qiqq/q3dl/test_temple/re/"
self.modelnamedir=model_name
if not os.path.exists(self.modelnamedir):
os.makedirs(self.modelnamedir)
self.predict_dir = self.modelnamedir+"/"+model_name+"_"+"outputs/"
self.fusion_path= self.modelnamedir+"/"+model_name+"_"+"fusion/"
self.model_name=model_name
self.palette = [0, 0, 0, 0, 255, 0, 0, 255, 0]
self.imglist=[]
self.labellist=[]
self.fusin=True
def pad_image(self,img, target_size):
"""Pad an image up to the target size."""
rows_missing = target_size[0] - img.shape[2]
cols_missing = target_size[1] - img.shape[3]
padded_img = np.pad(img, ((0, 0), (0, 0), (0, rows_missing), (0, cols_missing)), 'constant')
return padded_img
def predict_sliding(self,net, image, tile_size, classes):
interp = nn.Upsample(size=tile_size, mode='bilinear', align_corners=True)
image_size = image.shape
overlap = 1 / 3
stride = ceil(tile_size[0] * (1 - overlap)) # ceil() 函数返回数字的上入整数 #512重叠1/3 的话这里就是 步长是32 也就是一个512的窗口滑动342个像素
tile_rows = int(ceil((image_size[2] - tile_size[0]) / stride) + 1) # 6 strided convolution formula
tile_cols = int(
ceil((image_size[3] - tile_size[1]) / stride) + 1) # 11 也就是image_size[3]在image_size[3]这个方向上会滑动11次
# print("Need %i x %i prediction tiles @ stride %i px" % (tile_cols, tile_rows, stride))
full_probs = np.zeros((image_size[2], image_size[3], classes)) # 一个全零的 大图 #初始化全概率矩阵
count_predictions = np.zeros((image_size[2], image_size[3], classes)) # 初始化计数矩阵 shape(2160,3840,3)
tile_counter = 0
for row in range(tile_rows):
for col in range(tile_cols):
x1 = int(col * stride) # 起始位置
y1 = int(row * stride)
x2 = min(x1 + tile_size[1], image_size[3]) # 末位置 莫位置如果超出边界就以边界
y2 = min(y1 + tile_size[0], image_size[2])
x1 = max(int(x2 - tile_size[1]), 0) # for portrait images the x1 underflows sometimes #重新校准起始位置x1 =
y1 = max(int(y2 - tile_size[0]), 0) # for very few rows y1 underflows
img = image[:, :, y1:y2, x1:x2] # 要输入网络的小图
padded_img = self.pad_image(img, tile_size) # padding 确保扣下来的图像为tile_size如果不够的话就padding
# plt.imshow(padded_img)
# plt.show()
tile_counter += 1 ##计数加1
# print("Predicting tile %i" % tile_counter)
with torch.no_grad():
# padded_prediction = net(Variable(torch.from_numpy(padded_img), volatile=True).cuda()) # 预测
padded_prediction =net(torch.from_numpy(padded_img).cuda()) # 预测
if isinstance(padded_prediction, list): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个)
padded_prediction = padded_prediction[0] # 就取第一个
# padded_prediction = padded_prediction[0]
padded_prediction = interp(padded_prediction).cpu().data[0].numpy().transpose(1, 2, 0) # 插值到原大小
prediction = padded_prediction[0:img.shape[2], 0:img.shape[3], :]
count_predictions[y1:y2, x1:x2] += 1 # 这几个位置上的像素预测次数+1 ##窗口区域内的计数矩阵加
full_probs[y1:y2,
x1:x2] += prediction # 把它搞到原图上# #窗口区域内的全概率矩阵叠加预测结果accumulate the predictions also in the overlapping regions
'''
full_probs[y1:y2, x1:x2] += prediction 这个地方的如果涉及到重叠的而区域他是这么处理的
一般情况下不重叠的就是预测了1次如果是重叠的一个像素 它被预测了3次比如这三次的概率以屋顶类为例是 0.8 0.7 0.6
那count_predictions矩阵上这个位置的值就是3因为是用了3次
那最后的结果取谁呢取这三个的平均值full_probs上这个像素在三次后的值是 0.8+0.7+0.6 那么 下边的这一句
full_probs /= count_predictions就是取位置上的平均值
'''
# average the predictions in the overlapping regions
full_probs /= count_predictions
return full_probs
#直接调用这个就可以
def get_resut(self,net,titlesize=(512,512),classes=2):
imglist=self.get_imgs()
assert len(self.imglist)!=0
for i in tqdm(self.imglist):
image=i["image"]
name=i["name"]
orininal_w=i["orininal_w"]
orininal_h=i["orininal_h"]
old_img = copy.deepcopy(image)
imagedata=np.expand_dims(np.transpose(self.preprocess_input(np.array(image, np.float32),md=False), (2, 0, 1)), 0)
output=self.predict_sliding(net=net,image=imagedata,tile_size=titlesize,classes=classes)
output = F.softmax(torch.from_numpy(output), dim=2).numpy()
seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
#结果图
if not os.path.exists(self.predict_dir):
os.makedirs(self.predict_dir)
output_im = PILImage.fromarray(seg_pred)
output_im.putpalette(self.palette)
output_im.save(self.predict_dir + name + '.png')
#hunhe
if self.fusin:
if not os.path.exists(self.fusion_path):
os.makedirs(self.fusion_path)
PALETTE = [(0, 0, 0), (255, 0, 0), (0, 255, 0)]
seg_img0 = np.reshape(np.array(PALETTE, np.uint8)[np.reshape(seg_pred, [-1])],
[orininal_h, orininal_w, -1])
image0 = Image.fromarray(np.uint8(seg_img0))
fusion = Image.blend(old_img, image0, 0.4)
fusion.save(self.fusion_path + self.model_name + "_" + name + '.png')
print("预测结果图生成")
def cvtColor(self,image):
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
return image
elif np.shape(image)[2] == 4:
image = image.convert('RGB')
return image
else:
image = image.convert('RGB')
return image
def preprocess_input(self,image, md=False):
mean = (0.231, 0.217, 0.22) # 针对北京的数据集
std = (0.104, 0.086, 0.085)
if md:
image /= 255.0
image -= mean
image /= std
return image
else:
image /= 255.0
return image
#m没有标签的读取通常就是纯预测的
def get_imgs(self,):
#全部模式
# lists1=os.listdir(self.basedir) #tz1tz2
# imagescomplete=[]
# for i in lists1: #tz1
# chidpath=os.path.join(self.basedir,i)#
# imageslist =[i for i in os.listdir(chidpath) if (not i.endswith(".xml")) and (not i.endswith("pgw"))]
# for k in imageslist:
# imagescomplepath=os.path.join(self.basedir,j,k)
# imagescomplete.append(imagescomplepath)
#单个模式
imagescomplete=[]
lists1 = [i for i in os.listdir(self.basedir) if (not i.endswith(".xml")) and (not i.endswith("pgw")) and (not i.endswith("prj"))]
for i in lists1:
imagescomplepath=os.path.join(self.basedir,i)
imagescomplete.append(imagescomplepath)
for j in imagescomplete:
# name=j.split(".")[0]
name=j.split("/")[-1].split(".")[0]
image=Image.open(j) #zhuyu
orininal_h = image.size[1]
orininal_w = image.size[0]
image = self.cvtColor(image, )
item={"name":name,"orininal_h":orininal_h,"orininal_w":orininal_w,"image":image}
self.imglist.append(item)
print("共监测到{}张测试图像".format(len(self.imglist)))
return self.imglist
if __name__ == '__main__':
from taihuyuan_pv.compared_experiment.deeplabv3Plus.model.modeling import deeplabv3plus_resnet50
model = deeplabv3plus_resnet50(num_classes=2, output_stride=8, pretrained_backbone=False)
model_path = "/home/qiqq/q3dl/code/applicationproject/taihuyuan_pv/deeplabv3Plus/train_val_test/checkpoints/2023-06-30-22.06/best.pth"
# 2023-06-30-22.06
model_dict = torch.load(model_path)
model.load_state_dict(model_dict['net'])
print("权重加载")
model.eval()
model.cuda()
testss = mypv_bigmap(model_name="d3pgp1_pv")
testss.get_resut(net=model)
# predictoutpath = "./PVReunetpamcarb2_outputs/"
# pv_iou, pv_f1, pv_precion, pv_recall, acc_global = testss.compute_evalue(predictoutpath)
#
# resultsavepath = "./PVReunetpamcarb2.txt"
# with open(resultsavepath, "a") as f:
# f.write("测试结果\n")
# f.write(f"acc_global:{round(acc_global, 4)}\n")
# f.write(f"pv_iou:{round(pv_iou, 4)}\n")
# f.write(f"pv_f1:{round(pv_f1, 4)}\n")
# f.write(f"pv_precion:{round(pv_precion, 4)}\n")
# f.write(f"pv_recall:{round(pv_recall, 4)}\n")
# print("写入完成")