#!/usr/bin/env python # -*- coding: utf-8 -*- """ @project: @File : xaiotu @Author : qiqq @create_time : 2023/6/29 8:57 """ '''专门应对新图预测的''' import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" import matplotlib.pyplot as plt import numpy as np import copy import cv2 as cv from tqdm import tqdm import os from torch.utils.data import DataLoader from math import ceil import argparse from torch.utils.data import Dataset from torchvision import transforms from taihuyuan_pv.dataloaders import custom_transforms as tr import torch import torch.nn.functional as F import torch.nn as nn from PIL import Image from PIL import Image as PILImage '''voc数据集格式的''' Image.MAX_IMAGE_PIXELS = None class mypv_bigmap: def __init__(self,model_name="tongzhou",): # self.basedir = r"/home/qiqq/q3dl/applicationdata/beijing/tz/tz8" # self.basedir = "/home/qiqq/q3dl/test_temple/ground_pv/" self.basedir = "/home/qiqq/q3dl/datalinan/paper2_groundpv/b1" # self.basedir = "/home/qiqq/q3dl/test_temple/re/" self.modelnamedir=model_name if not os.path.exists(self.modelnamedir): os.makedirs(self.modelnamedir) self.predict_dir = self.modelnamedir+"/"+model_name+"_"+"outputs/" self.fusion_path= self.modelnamedir+"/"+model_name+"_"+"fusion/" self.model_name=model_name self.palette = [0, 0, 0, 0, 255, 0, 0, 255, 0] self.imglist=[] self.labellist=[] self.fusin=True def pad_image(self,img, target_size): """Pad an image up to the target size.""" rows_missing = target_size[0] - img.shape[2] cols_missing = target_size[1] - img.shape[3] padded_img = np.pad(img, ((0, 0), (0, 0), (0, rows_missing), (0, cols_missing)), 'constant') return padded_img def predict_sliding(self,net, image, tile_size, classes): interp = nn.Upsample(size=tile_size, mode='bilinear', align_corners=True) image_size = image.shape overlap = 1 / 3 stride = ceil(tile_size[0] * (1 - overlap)) # ceil() 函数返回数字的上入整数 #512重叠1/3 的话这里就是 步长是32 也就是一个512的窗口滑动342个像素 tile_rows = int(ceil((image_size[2] - tile_size[0]) / stride) + 1) # 6 strided convolution formula tile_cols = int( ceil((image_size[3] - tile_size[1]) / stride) + 1) # 11 也就是image_size[3]在image_size[3]这个方向上会滑动11次? # print("Need %i x %i prediction tiles @ stride %i px" % (tile_cols, tile_rows, stride)) full_probs = np.zeros((image_size[2], image_size[3], classes)) # 一个全零的 大图 #初始化全概率矩阵 count_predictions = np.zeros((image_size[2], image_size[3], classes)) # 初始化计数矩阵 shape(2160,3840,3) tile_counter = 0 for row in range(tile_rows): for col in range(tile_cols): x1 = int(col * stride) # 起始位置 y1 = int(row * stride) x2 = min(x1 + tile_size[1], image_size[3]) # 末位置 莫位置如果超出边界就以边界 y2 = min(y1 + tile_size[0], image_size[2]) x1 = max(int(x2 - tile_size[1]), 0) # for portrait images the x1 underflows sometimes #重新校准起始位置x1 = y1 = max(int(y2 - tile_size[0]), 0) # for very few rows y1 underflows img = image[:, :, y1:y2, x1:x2] # 要输入网络的小图 padded_img = self.pad_image(img, tile_size) # padding 确保扣下来的图像为tile_size如果不够的话就padding # plt.imshow(padded_img) # plt.show() tile_counter += 1 ##计数加1 # print("Predicting tile %i" % tile_counter) with torch.no_grad(): # padded_prediction = net(Variable(torch.from_numpy(padded_img), volatile=True).cuda()) # 预测 padded_prediction =net(torch.from_numpy(padded_img).cuda()) # 预测 if isinstance(padded_prediction, list): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个) padded_prediction = padded_prediction[0] # 就取第一个 # padded_prediction = padded_prediction[0] padded_prediction = interp(padded_prediction).cpu().data[0].numpy().transpose(1, 2, 0) # 插值到原大小 prediction = padded_prediction[0:img.shape[2], 0:img.shape[3], :] count_predictions[y1:y2, x1:x2] += 1 # 这几个位置上的像素预测次数+1 ##窗口区域内的计数矩阵加 full_probs[y1:y2, x1:x2] += prediction # 把它搞到原图上# #窗口区域内的全概率矩阵叠加预测结果accumulate the predictions also in the overlapping regions ''' full_probs[y1:y2, x1:x2] += prediction 这个地方的如果涉及到重叠的而区域他是这么处理的 一般情况下不重叠的就是预测了1次,如果是重叠的一个像素 它被预测了3次比如这三次的概率(以屋顶类为例)是 0.8 0.7 0.6 那count_predictions矩阵上这个位置的值就是3因为是用了3次 那最后的结果取谁呢?取这三个的平均值full_probs上这个像素在三次后的值是 0.8+0.7+0.6 那么 下边的这一句 full_probs /= count_predictions就是取位置上的平均值 ''' # average the predictions in the overlapping regions full_probs /= count_predictions return full_probs #直接调用这个就可以 def get_resut(self,net,titlesize=(512,512),classes=2): imglist=self.get_imgs() assert len(self.imglist)!=0 for i in tqdm(self.imglist): image=i["image"] name=i["name"] orininal_w=i["orininal_w"] orininal_h=i["orininal_h"] old_img = copy.deepcopy(image) imagedata=np.expand_dims(np.transpose(self.preprocess_input(np.array(image, np.float32),md=False), (2, 0, 1)), 0) output=self.predict_sliding(net=net,image=imagedata,tile_size=titlesize,classes=classes) output = F.softmax(torch.from_numpy(output), dim=2).numpy() seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8) #结果图 if not os.path.exists(self.predict_dir): os.makedirs(self.predict_dir) output_im = PILImage.fromarray(seg_pred) output_im.putpalette(self.palette) output_im.save(self.predict_dir + name + '.png') #hunhe if self.fusin: if not os.path.exists(self.fusion_path): os.makedirs(self.fusion_path) PALETTE = [(0, 0, 0), (255, 0, 0), (0, 255, 0)] seg_img0 = np.reshape(np.array(PALETTE, np.uint8)[np.reshape(seg_pred, [-1])], [orininal_h, orininal_w, -1]) image0 = Image.fromarray(np.uint8(seg_img0)) fusion = Image.blend(old_img, image0, 0.4) fusion.save(self.fusion_path + self.model_name + "_" + name + '.png') print("预测结果图生成") def cvtColor(self,image): if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: return image elif np.shape(image)[2] == 4: image = image.convert('RGB') return image else: image = image.convert('RGB') return image def preprocess_input(self,image, md=False): mean = (0.231, 0.217, 0.22) # 针对北京的数据集 std = (0.104, 0.086, 0.085) if md: image /= 255.0 image -= mean image /= std return image else: image /= 255.0 return image #m没有标签的读取,通常就是纯预测的 def get_imgs(self,): imagescomplete=[] lists1 = [i for i in os.listdir(self.basedir) if (not i.endswith(".xml")) and (not i.endswith("pgw")) and (not i.endswith("prj"))] for i in lists1: imagescomplepath=os.path.join(self.basedir,i) imagescomplete.append(imagescomplepath) for j in imagescomplete: # name=j.split(".")[0] name=j.split("/")[-1].split(".")[0] image=Image.open(j) #zhuyu orininal_h = image.size[1] orininal_w = image.size[0] image = self.cvtColor(image, ) item={"name":name,"orininal_h":orininal_h,"orininal_w":orininal_w,"image":image} self.imglist.append(item) print("共监测到{}张测试图像".format(len(self.imglist))) return self.imglist if __name__ == '__main__': from taihuyuan_pv.mitunet.model.resunet import resUnetpamcarb model = resUnetpamcarb() model_path = "/home/qiqq/q3dl/code/applicationproject/taihuyuan_pv/mitunet/train_val_test/checkpoints/2023-10-16-14.34//best.pth" model_dict = torch.load(model_path) model.load_state_dict(model_dict['net']) print("权重加载") model.eval() model.cuda() testss = mypv_bigmap(model_name="respamcbam2PV_version31b1") testss.get_resut(net=model)