464 lines
19 KiB
Python
464 lines
19 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
@project:
|
||
@File : xaiotu
|
||
@Author : qiqq
|
||
@create_time : 2023/6/29 8:57
|
||
"""
|
||
'''专门应对新图预测的'''
|
||
import os
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||
import numpy as np
|
||
import copy
|
||
import cv2 as cv
|
||
from tqdm import tqdm
|
||
import os
|
||
from math import ceil
|
||
|
||
import torch
|
||
import torch.nn.functional as F
|
||
import torch.nn as nn
|
||
from PIL import Image
|
||
from PIL import Image as PILImage
|
||
'''voc数据集格式的'''
|
||
|
||
|
||
|
||
class predictandeval():
|
||
def __init__(self, model_name="xx"):
|
||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
||
self.predict_dir = model_name + "_" + "outputs/"
|
||
self.fusion_path = model_name + "_" + "fusion/"
|
||
self.model_name = model_name
|
||
self.palette = [0, 0, 0, 0, 255, 0, 0, 255, 0]
|
||
|
||
self.fusin = True
|
||
|
||
def get_confusion_matrix(self, gt_label, pred_label, class_num):
|
||
"""
|
||
Calcute the confusion matrix by given label and pred
|
||
:param gt_label: the ground truth label
|
||
:param pred_label: the pred label
|
||
:param class_num: the nunber of class
|
||
:return: the confusion matrix
|
||
"""
|
||
index = (gt_label * class_num + pred_label).astype('int32')
|
||
label_count = np.bincount(index)
|
||
confusion_matrix = np.zeros((class_num, class_num))
|
||
|
||
for i_label in range(class_num):
|
||
for i_pred_label in range(class_num):
|
||
cur_index = i_label * class_num + i_pred_label
|
||
if cur_index < len(label_count):
|
||
confusion_matrix[i_label, i_pred_label] = label_count[cur_index]
|
||
|
||
return confusion_matrix
|
||
|
||
|
||
# 重写 voc格式的val或者test
|
||
def get_images(self, ):
|
||
# basedir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
|
||
# split = 'val'
|
||
# _splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
|
||
# imglist = []
|
||
# with open(os.path.join(os.path.join(_splits_dir, split + '.txt')), "r") as f:
|
||
# lines = f.read().splitlines()
|
||
# for ii, line in enumerate(lines):
|
||
# name = line
|
||
# _imagepath = os.path.join(basedir, 'images', line + ".png")
|
||
#
|
||
# assert os.path.isfile(_imagepath)
|
||
#
|
||
# image = Image.open(_imagepath)
|
||
#
|
||
# orininal_h = image.size[1]
|
||
# orininal_w = image.size[0]
|
||
# item = {"name": name, "orininal_h": orininal_h, "orininal_w": orininal_w, "image": image}
|
||
# imglist.append(item)
|
||
# print("共监测到{}张原始图像和标签".format(len(imglist)))
|
||
# return imglist
|
||
|
||
basedir = "/home/qiqq/q3dl/test_temple/ground_pv/"
|
||
lines=os.listdir(basedir)
|
||
imglist = []
|
||
|
||
for ii, line in enumerate(lines):
|
||
name = line.split(".")[0]
|
||
_imagepath = os.path.join(basedir, line )
|
||
|
||
assert os.path.isfile(_imagepath)
|
||
|
||
image = Image.open(_imagepath)
|
||
|
||
orininal_h = image.size[1]
|
||
orininal_w = image.size[0]
|
||
item = {"name": name, "orininal_h": orininal_h, "orininal_w": orininal_w, "image": image}
|
||
imglist.append(item)
|
||
print("共监测到{}张原始图像和标签".format(len(imglist)))
|
||
return imglist
|
||
|
||
|
||
def get_labels(self, ):
|
||
# basedir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
|
||
# split = 'val'
|
||
# _splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
|
||
# labellist = []
|
||
# with open(os.path.join(os.path.join(_splits_dir, split + '.txt')), "r") as f:
|
||
# lines = f.read().splitlines()
|
||
# for ii, line in enumerate(lines):
|
||
# name = line
|
||
# _labelpath = os.path.join(basedir, 'labels', line + ".png")
|
||
#
|
||
# assert os.path.isfile(_labelpath)
|
||
#
|
||
# label = Image.open(_labelpath)
|
||
# item = {"name": name, "label": label}
|
||
# labellist.append(item)
|
||
# print("共监测到{}张标签".format(len(labellist)))
|
||
# return labellist
|
||
basedir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
|
||
split = 'val'
|
||
_splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
|
||
labellist = []
|
||
with open(os.path.join(os.path.join(_splits_dir, split + '.txt')), "r") as f:
|
||
lines = f.read().splitlines()
|
||
for ii, line in enumerate(lines):
|
||
name = line
|
||
_labelpath = os.path.join(basedir, 'labels', line + ".png")
|
||
|
||
assert os.path.isfile(_labelpath)
|
||
|
||
label = Image.open(_labelpath)
|
||
item = {"name": name, "label": label}
|
||
labellist.append(item)
|
||
print("共监测到{}张标签".format(len(labellist)))
|
||
return labellist
|
||
|
||
|
||
|
||
def get_result(self, net):
|
||
|
||
imglist = self.get_images()
|
||
assert len(imglist) != 0
|
||
for i in tqdm(imglist):
|
||
image = i["image"]
|
||
name = i["name"]
|
||
orininal_w = i["orininal_w"]
|
||
orininal_h = i["orininal_h"]
|
||
old_img = copy.deepcopy(image)
|
||
imaged = cv.resize(np.array(image), dsize=(512, 512), interpolation=cv.INTER_LINEAR)
|
||
image_data = np.expand_dims(
|
||
np.transpose(self.preprocess_input(np.array(imaged, np.float32), md=False), (2, 0, 1)), 0)
|
||
|
||
with torch.no_grad():
|
||
images = torch.from_numpy(image_data)
|
||
image = images.to(device="cuda", dtype=torch.float32)
|
||
model = net.to(device="cuda")
|
||
out = model(image)
|
||
if isinstance(out, list) or isinstance(out, tuple): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个)
|
||
out = out[0] # 就取第一个
|
||
out = out[0] # 去掉batch
|
||
|
||
pr = F.softmax(out.permute(1, 2, 0), dim=-1).cpu().numpy()
|
||
result = pr.argmax(axis=-1)
|
||
# 结果图
|
||
if not os.path.exists(self.predict_dir):
|
||
os.makedirs(self.predict_dir)
|
||
output_im = Image.fromarray(np.uint8(result)).convert('P')
|
||
|
||
output_im.putpalette(self.palette)
|
||
output_im.save(self.predict_dir + name + '.png')
|
||
|
||
# hunhe
|
||
if self.fusin:
|
||
if not os.path.exists(self.fusion_path):
|
||
os.makedirs(self.fusion_path)
|
||
PALETTE = [(0, 0, 0), (255, 0, 0), (0, 255, 0)]
|
||
|
||
seg_img0 = np.reshape(np.array(PALETTE, np.uint8)[np.reshape(result, [-1])],
|
||
[orininal_h, orininal_w, -1])
|
||
|
||
image0 = Image.fromarray(np.uint8(seg_img0))
|
||
|
||
fusion = Image.blend(old_img, image0, 0.4)
|
||
|
||
fusion.save(self.fusion_path + self.model_name + "_" + name + '.png')
|
||
|
||
def preprocess_input(self, image, md=False):
|
||
mean = (0.231, 0.217, 0.22) # 针对北京的数据集
|
||
std = (0.104, 0.086, 0.085)
|
||
if md:
|
||
image /= 255.0
|
||
image -= mean
|
||
image /= std
|
||
return image
|
||
else:
|
||
image /= 255.0
|
||
return image
|
||
|
||
def compute_evalue(self, predict_dir, ):
|
||
labellist = self.get_labels() # 存的额是没转换成nmpy的png
|
||
predictresultlist = os.listdir(predict_dir)
|
||
assert len(labellist) == len(predictresultlist)
|
||
num_classes = 2
|
||
|
||
confusion_matrix = np.zeros((num_classes, num_classes))
|
||
|
||
for i in tqdm(labellist):
|
||
name = i["name"]
|
||
seg_gt = np.array(i["label"])
|
||
seg_pred = np.array(Image.open(os.path.join(predict_dir, name + ".png")))
|
||
#
|
||
ignore_index = seg_gt != 255 #
|
||
seg_gt = seg_gt[ignore_index]
|
||
seg_pred = seg_pred[ignore_index]
|
||
|
||
confusion_matrix += self.get_confusion_matrix(seg_gt, seg_pred, num_classes)
|
||
|
||
pos = confusion_matrix.sum(1) # 得到的每个数都是每个类别真实的像素点数量 (相当于tp+fn)
|
||
res = confusion_matrix.sum(0) # 得到的每个数都是被预测为这个类别的像素点数量 (tp+fp)
|
||
tp = np.diag(confusion_matrix)
|
||
|
||
IU_array = (tp / np.maximum(1.0, pos + res - tp))
|
||
|
||
mean_IU = IU_array.mean()
|
||
pv_iou = IU_array[1:].mean()
|
||
precion = tp / res
|
||
|
||
recall = tp / pos
|
||
|
||
f1 = (2 * precion * recall) / (precion + recall)
|
||
mf1 = f1.mean()
|
||
acc_global = tp.sum() / pos.sum()
|
||
|
||
pv_f1 = f1[1:].mean()
|
||
pv_precion = precion[1:].mean()
|
||
pv_recall = recall[1:].mean()
|
||
|
||
print("测试结果")
|
||
print("acc_global:", round(acc_global, 4))
|
||
print("IU_array:", np.round(IU_array, 4))
|
||
print("precion:", np.round(precion, 4))
|
||
print("recall:", np.round(recall, 4))
|
||
print("f1:", np.round(f1, 4))
|
||
print("miou:", round(mean_IU, 4))
|
||
print("pv_iou:", round(pv_iou, 4))
|
||
print("pv_precion:", round(pv_precion, 4))
|
||
print("pv_recall:", round(pv_recall, 4))
|
||
print("pv_f1:", round(pv_f1, 4))
|
||
|
||
return pv_iou, pv_f1, pv_precion, pv_recall, acc_global
|
||
class mypv_bigmap:
|
||
|
||
def __init__(self,model_name="tongzhou",):
|
||
# self.basedir = r"/home/qiqq/q3dl/applicationdata/beijing/tz/tz8"
|
||
self.basedir = "/home/qiqq/q3dl/test_temple/ground_pv/"
|
||
# self.basedir = "/home/qiqq/q3dl/test_temple/re/"
|
||
|
||
self.modelnamedir=model_name
|
||
|
||
if not os.path.exists(self.modelnamedir):
|
||
os.makedirs(self.modelnamedir)
|
||
self.predict_dir = self.modelnamedir+"/"+model_name+"_"+"outputs/"
|
||
self.fusion_path= self.modelnamedir+"/"+model_name+"_"+"fusion/"
|
||
self.model_name=model_name
|
||
self.palette = [0, 0, 0, 0, 255, 0, 0, 255, 0]
|
||
self.imglist=[]
|
||
self.labellist=[]
|
||
|
||
self.fusin=True
|
||
|
||
def pad_image(self,img, target_size):
|
||
"""Pad an image up to the target size."""
|
||
rows_missing = target_size[0] - img.shape[2]
|
||
cols_missing = target_size[1] - img.shape[3]
|
||
padded_img = np.pad(img, ((0, 0), (0, 0), (0, rows_missing), (0, cols_missing)), 'constant')
|
||
return padded_img
|
||
|
||
def predict_sliding(self,net, image, tile_size, classes):
|
||
interp = nn.Upsample(size=tile_size, mode='bilinear', align_corners=True)
|
||
image_size = image.shape
|
||
overlap = 1 / 3
|
||
|
||
stride = ceil(tile_size[0] * (1 - overlap)) # ceil() 函数返回数字的上入整数 #512重叠1/3 的话这里就是 步长是32 也就是一个512的窗口滑动342个像素
|
||
tile_rows = int(ceil((image_size[2] - tile_size[0]) / stride) + 1) # 6 strided convolution formula
|
||
tile_cols = int(
|
||
ceil((image_size[3] - tile_size[1]) / stride) + 1) # 11 也就是image_size[3]在image_size[3]这个方向上会滑动11次?
|
||
# print("Need %i x %i prediction tiles @ stride %i px" % (tile_cols, tile_rows, stride))
|
||
full_probs = np.zeros((image_size[2], image_size[3], classes)) # 一个全零的 大图 #初始化全概率矩阵
|
||
count_predictions = np.zeros((image_size[2], image_size[3], classes)) # 初始化计数矩阵 shape(2160,3840,3)
|
||
tile_counter = 0
|
||
|
||
for row in range(tile_rows):
|
||
for col in range(tile_cols):
|
||
x1 = int(col * stride) # 起始位置
|
||
y1 = int(row * stride)
|
||
x2 = min(x1 + tile_size[1], image_size[3]) # 末位置 莫位置如果超出边界就以边界
|
||
y2 = min(y1 + tile_size[0], image_size[2])
|
||
x1 = max(int(x2 - tile_size[1]), 0) # for portrait images the x1 underflows sometimes #重新校准起始位置x1 =
|
||
y1 = max(int(y2 - tile_size[0]), 0) # for very few rows y1 underflows
|
||
|
||
img = image[:, :, y1:y2, x1:x2] # 要输入网络的小图
|
||
padded_img = self.pad_image(img, tile_size) # padding 确保扣下来的图像为tile_size如果不够的话就padding
|
||
# plt.imshow(padded_img)
|
||
# plt.show()
|
||
tile_counter += 1 ##计数加1
|
||
# print("Predicting tile %i" % tile_counter)
|
||
with torch.no_grad():
|
||
# padded_prediction = net(Variable(torch.from_numpy(padded_img), volatile=True).cuda()) # 预测
|
||
padded_prediction =net(torch.from_numpy(padded_img).cuda()) # 预测
|
||
if isinstance(padded_prediction, list): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个)
|
||
padded_prediction = padded_prediction[0] # 就取第一个
|
||
# padded_prediction = padded_prediction[0]
|
||
|
||
padded_prediction = interp(padded_prediction).cpu().data[0].numpy().transpose(1, 2, 0) # 插值到原大小
|
||
prediction = padded_prediction[0:img.shape[2], 0:img.shape[3], :]
|
||
count_predictions[y1:y2, x1:x2] += 1 # 这几个位置上的像素预测次数+1 ##窗口区域内的计数矩阵加
|
||
full_probs[y1:y2,
|
||
x1:x2] += prediction # 把它搞到原图上# #窗口区域内的全概率矩阵叠加预测结果accumulate the predictions also in the overlapping regions
|
||
'''
|
||
full_probs[y1:y2, x1:x2] += prediction 这个地方的如果涉及到重叠的而区域他是这么处理的
|
||
一般情况下不重叠的就是预测了1次,如果是重叠的一个像素 它被预测了3次比如这三次的概率(以屋顶类为例)是 0.8 0.7 0.6
|
||
那count_predictions矩阵上这个位置的值就是3因为是用了3次
|
||
那最后的结果取谁呢?取这三个的平均值full_probs上这个像素在三次后的值是 0.8+0.7+0.6 那么 下边的这一句
|
||
full_probs /= count_predictions就是取位置上的平均值
|
||
'''
|
||
# average the predictions in the overlapping regions
|
||
full_probs /= count_predictions
|
||
|
||
return full_probs
|
||
#直接调用这个就可以
|
||
def get_resut(self,net,titlesize=(512,512),classes=2):
|
||
imglist=self.get_imgs()
|
||
assert len(self.imglist)!=0
|
||
for i in tqdm(self.imglist):
|
||
image=i["image"]
|
||
name=i["name"]
|
||
orininal_w=i["orininal_w"]
|
||
orininal_h=i["orininal_h"]
|
||
old_img = copy.deepcopy(image)
|
||
imagedata=np.expand_dims(np.transpose(self.preprocess_input(np.array(image, np.float32),md=False), (2, 0, 1)), 0)
|
||
output=self.predict_sliding(net=net,image=imagedata,tile_size=titlesize,classes=classes)
|
||
output = F.softmax(torch.from_numpy(output), dim=2).numpy()
|
||
|
||
seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
|
||
|
||
#结果图
|
||
|
||
|
||
if not os.path.exists(self.predict_dir):
|
||
os.makedirs(self.predict_dir)
|
||
output_im = PILImage.fromarray(seg_pred)
|
||
output_im.putpalette(self.palette)
|
||
output_im.save(self.predict_dir + name + '.png')
|
||
|
||
#hunhe
|
||
if self.fusin:
|
||
if not os.path.exists(self.fusion_path):
|
||
os.makedirs(self.fusion_path)
|
||
PALETTE = [(0, 0, 0), (255, 0, 0), (0, 255, 0)]
|
||
|
||
seg_img0 = np.reshape(np.array(PALETTE, np.uint8)[np.reshape(seg_pred, [-1])],
|
||
[orininal_h, orininal_w, -1])
|
||
|
||
image0 = Image.fromarray(np.uint8(seg_img0))
|
||
|
||
fusion = Image.blend(old_img, image0, 0.4)
|
||
|
||
fusion.save(self.fusion_path + self.model_name + "_" + name + '.png')
|
||
|
||
print("预测结果图生成")
|
||
|
||
|
||
|
||
|
||
|
||
def cvtColor(self,image):
|
||
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
|
||
return image
|
||
elif np.shape(image)[2] == 4:
|
||
image = image.convert('RGB')
|
||
return image
|
||
else:
|
||
image = image.convert('RGB')
|
||
return image
|
||
|
||
def preprocess_input(self,image, md=False):
|
||
mean = (0.231, 0.217, 0.22) # 针对北京的数据集
|
||
std = (0.104, 0.086, 0.085)
|
||
if md:
|
||
image /= 255.0
|
||
image -= mean
|
||
image /= std
|
||
return image
|
||
else:
|
||
image /= 255.0
|
||
return image
|
||
|
||
#m没有标签的读取,通常就是纯预测的
|
||
def get_imgs(self,):
|
||
|
||
#全部模式
|
||
# lists1=os.listdir(self.basedir) #tz1tz2
|
||
# imagescomplete=[]
|
||
# for i in lists1: #tz1
|
||
# chidpath=os.path.join(self.basedir,i)#
|
||
# imageslist =[i for i in os.listdir(chidpath) if (not i.endswith(".xml")) and (not i.endswith("pgw"))]
|
||
# for k in imageslist:
|
||
# imagescomplepath=os.path.join(self.basedir,j,k)
|
||
# imagescomplete.append(imagescomplepath)
|
||
#单个模式
|
||
imagescomplete=[]
|
||
lists1 = [i for i in os.listdir(self.basedir) if (not i.endswith(".xml")) and (not i.endswith("pgw")) and (not i.endswith("prj"))]
|
||
for i in lists1:
|
||
imagescomplepath=os.path.join(self.basedir,i)
|
||
imagescomplete.append(imagescomplepath)
|
||
for j in imagescomplete:
|
||
# name=j.split(".")[0]
|
||
name=j.split("/")[-1].split(".")[0]
|
||
image=Image.open(j) #zhuyu
|
||
orininal_h = image.size[1]
|
||
orininal_w = image.size[0]
|
||
image = self.cvtColor(image, )
|
||
item={"name":name,"orininal_h":orininal_h,"orininal_w":orininal_w,"image":image}
|
||
self.imglist.append(item)
|
||
|
||
print("共监测到{}张测试图像".format(len(self.imglist)))
|
||
return self.imglist
|
||
|
||
|
||
|
||
if __name__ == '__main__':
|
||
from taihuyuan_pv.compared_experiment.deeplabv3Plus.model.modeling import deeplabv3plus_resnet50
|
||
|
||
model = deeplabv3plus_resnet50(num_classes=2, output_stride=8, pretrained_backbone=False)
|
||
|
||
model_path = "/home/qiqq/q3dl/code/applicationproject/taihuyuan_pv/deeplabv3Plus/train_val_test/checkpoints/2023-06-30-22.06/best.pth"
|
||
# 2023-06-30-22.06
|
||
model_dict = torch.load(model_path)
|
||
model.load_state_dict(model_dict['net'])
|
||
print("权重加载")
|
||
model.eval()
|
||
model.cuda()
|
||
|
||
testss = mypv_bigmap(model_name="d3pgp1_pv")
|
||
testss.get_resut(net=model)
|
||
# predictoutpath = "./PVReunetpamcarb2_outputs/"
|
||
# pv_iou, pv_f1, pv_precion, pv_recall, acc_global = testss.compute_evalue(predictoutpath)
|
||
#
|
||
# resultsavepath = "./PVReunetpamcarb2.txt"
|
||
# with open(resultsavepath, "a") as f:
|
||
# f.write("测试结果\n")
|
||
# f.write(f"acc_global:{round(acc_global, 4)}\n")
|
||
# f.write(f"pv_iou:{round(pv_iou, 4)}\n")
|
||
# f.write(f"pv_f1:{round(pv_f1, 4)}\n")
|
||
# f.write(f"pv_precion:{round(pv_precion, 4)}\n")
|
||
# f.write(f"pv_recall:{round(pv_recall, 4)}\n")
|
||
# print("写入完成")
|
||
|
||
|
||
|
||
|