ai-station-code/wudingpv/taihuyuan_pv/compared_experiment/deeplabv3Plus/test/testpv.py

228 lines
9.5 KiB
Python
Raw Normal View History

2025-05-06 11:18:48 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : xaiotu
@Author : qiqq
@create_time : 2023/6/29 8:57
"""
'''专门应对新图预测的'''
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import matplotlib.pyplot as plt
import numpy as np
import copy
import cv2 as cv
from tqdm import tqdm
import os
from torch.utils.data import DataLoader
from math import ceil
import argparse
from torch.utils.data import Dataset
from torchvision import transforms
from taihuyuan_pv.dataloaders import custom_transforms as tr
import torch
import torch.nn.functional as F
import torch.nn as nn
from PIL import Image
from PIL import Image as PILImage
'''voc数据集格式的'''
Image.MAX_IMAGE_PIXELS = None
class mypv_bigmap:
def __init__(self,model_name="tongzhou",):
# self.basedir = r"/home/qiqq/q3dl/applicationdata/beijing/tz/tz8"
# self.basedir = "/home/qiqq/q3dl/test_temple/ground_pv/"
self.basedir = "/home/qiqq/q3dl/datalinan/paper2_groundpv/b1"
# self.basedir = "/home/qiqq/q3dl/test_temple/re/"
self.modelnamedir=model_name
if not os.path.exists(self.modelnamedir):
os.makedirs(self.modelnamedir)
self.predict_dir = self.modelnamedir+"/"+model_name+"_"+"outputs/"
self.fusion_path= self.modelnamedir+"/"+model_name+"_"+"fusion/"
self.model_name=model_name
self.palette = [0, 0, 0, 0, 255, 0, 0, 255, 0]
self.imglist=[]
self.labellist=[]
self.fusin=True
def pad_image(self,img, target_size):
"""Pad an image up to the target size."""
rows_missing = target_size[0] - img.shape[2]
cols_missing = target_size[1] - img.shape[3]
padded_img = np.pad(img, ((0, 0), (0, 0), (0, rows_missing), (0, cols_missing)), 'constant')
return padded_img
def predict_sliding(self,net, image, tile_size, classes):
interp = nn.Upsample(size=tile_size, mode='bilinear', align_corners=True)
image_size = image.shape
overlap = 1 / 3
stride = ceil(tile_size[0] * (1 - overlap)) # ceil() 函数返回数字的上入整数 #512重叠1/3 的话这里就是 步长是32 也就是一个512的窗口滑动342个像素
tile_rows = int(ceil((image_size[2] - tile_size[0]) / stride) + 1) # 6 strided convolution formula
tile_cols = int(
ceil((image_size[3] - tile_size[1]) / stride) + 1) # 11 也就是image_size[3]在image_size[3]这个方向上会滑动11次
# print("Need %i x %i prediction tiles @ stride %i px" % (tile_cols, tile_rows, stride))
full_probs = np.zeros((image_size[2], image_size[3], classes)) # 一个全零的 大图 #初始化全概率矩阵
count_predictions = np.zeros((image_size[2], image_size[3], classes)) # 初始化计数矩阵 shape(2160,3840,3)
tile_counter = 0
for row in range(tile_rows):
for col in range(tile_cols):
x1 = int(col * stride) # 起始位置
y1 = int(row * stride)
x2 = min(x1 + tile_size[1], image_size[3]) # 末位置 莫位置如果超出边界就以边界
y2 = min(y1 + tile_size[0], image_size[2])
x1 = max(int(x2 - tile_size[1]), 0) # for portrait images the x1 underflows sometimes #重新校准起始位置x1 =
y1 = max(int(y2 - tile_size[0]), 0) # for very few rows y1 underflows
img = image[:, :, y1:y2, x1:x2] # 要输入网络的小图
padded_img = self.pad_image(img, tile_size) # padding 确保扣下来的图像为tile_size如果不够的话就padding
# plt.imshow(padded_img)
# plt.show()
tile_counter += 1 ##计数加1
# print("Predicting tile %i" % tile_counter)
with torch.no_grad():
# padded_prediction = net(Variable(torch.from_numpy(padded_img), volatile=True).cuda()) # 预测
padded_prediction =net(torch.from_numpy(padded_img).cuda()) # 预测
if isinstance(padded_prediction, list): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个)
padded_prediction = padded_prediction[0] # 就取第一个
# padded_prediction = padded_prediction[0]
padded_prediction = interp(padded_prediction).cpu().data[0].numpy().transpose(1, 2, 0) # 插值到原大小
prediction = padded_prediction[0:img.shape[2], 0:img.shape[3], :]
count_predictions[y1:y2, x1:x2] += 1 # 这几个位置上的像素预测次数+1 ##窗口区域内的计数矩阵加
full_probs[y1:y2,
x1:x2] += prediction # 把它搞到原图上# #窗口区域内的全概率矩阵叠加预测结果accumulate the predictions also in the overlapping regions
'''
full_probs[y1:y2, x1:x2] += prediction 这个地方的如果涉及到重叠的而区域他是这么处理的
一般情况下不重叠的就是预测了1次如果是重叠的一个像素 它被预测了3次比如这三次的概率以屋顶类为例 0.8 0.7 0.6
那count_predictions矩阵上这个位置的值就是3因为是用了3次
那最后的结果取谁呢取这三个的平均值full_probs上这个像素在三次后的值是 0.8+0.7+0.6 那么 下边的这一句
full_probs /= count_predictions就是取位置上的平均值
'''
# average the predictions in the overlapping regions
full_probs /= count_predictions
return full_probs
#直接调用这个就可以
def get_resut(self,net,titlesize=(512,512),classes=2):
imglist=self.get_imgs()
assert len(self.imglist)!=0
for i in tqdm(self.imglist):
image=i["image"]
name=i["name"]
orininal_w=i["orininal_w"]
orininal_h=i["orininal_h"]
old_img = copy.deepcopy(image)
imagedata=np.expand_dims(np.transpose(self.preprocess_input(np.array(image, np.float32),md=False), (2, 0, 1)), 0)
output=self.predict_sliding(net=net,image=imagedata,tile_size=titlesize,classes=classes)
output = F.softmax(torch.from_numpy(output), dim=2).numpy()
seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
#结果图
if not os.path.exists(self.predict_dir):
os.makedirs(self.predict_dir)
output_im = PILImage.fromarray(seg_pred)
output_im.putpalette(self.palette)
output_im.save(self.predict_dir + name + '.png')
#hunhe
if self.fusin:
if not os.path.exists(self.fusion_path):
os.makedirs(self.fusion_path)
PALETTE = [(0, 0, 0), (255, 0, 0), (0, 255, 0)]
seg_img0 = np.reshape(np.array(PALETTE, np.uint8)[np.reshape(seg_pred, [-1])],
[orininal_h, orininal_w, -1])
image0 = Image.fromarray(np.uint8(seg_img0))
fusion = Image.blend(old_img, image0, 0.4)
fusion.save(self.fusion_path + self.model_name + "_" + name + '.png')
print("预测结果图生成")
def cvtColor(self,image):
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
return image
elif np.shape(image)[2] == 4:
image = image.convert('RGB')
return image
else:
image = image.convert('RGB')
return image
def preprocess_input(self,image, md=False):
mean = (0.231, 0.217, 0.22) # 针对北京的数据集
std = (0.104, 0.086, 0.085)
if md:
image /= 255.0
image -= mean
image /= std
return image
else:
image /= 255.0
return image
#m没有标签的读取通常就是纯预测的
def get_imgs(self,):
imagescomplete=[]
lists1 = [i for i in os.listdir(self.basedir) if (not i.endswith(".xml")) and (not i.endswith("pgw")) and (not i.endswith("prj"))]
for i in lists1:
imagescomplepath=os.path.join(self.basedir,i)
imagescomplete.append(imagescomplepath)
for j in imagescomplete:
# name=j.split(".")[0]
name=j.split("/")[-1].split(".")[0]
image=Image.open(j) #zhuyu
orininal_h = image.size[1]
orininal_w = image.size[0]
image = self.cvtColor(image, )
item={"name":name,"orininal_h":orininal_h,"orininal_w":orininal_w,"image":image}
self.imglist.append(item)
print("共监测到{}张测试图像".format(len(self.imglist)))
return self.imglist
if __name__ == '__main__':
from taihuyuan_pv.mitunet.model.resunet import resUnetpamcarb
model = resUnetpamcarb()
model_path = "/home/qiqq/q3dl/code/applicationproject/taihuyuan_pv/mitunet/train_val_test/checkpoints/2023-10-16-14.34//best.pth"
model_dict = torch.load(model_path)
model.load_state_dict(model_dict['net'])
print("权重加载")
model.eval()
model.cuda()
testss = mypv_bigmap(model_name="respamcbam2PV_version31b1")
testss.get_resut(net=model)