#!/usr/bin/env python # -*- coding: utf-8 -*- """ @project: @File : predict @Author : qiqq @create_time : 2023/4/24 20:08 单纯的生成预测结果图 单纯的的单通道图和合成的fusion """ #!/usr/bin/env python # -*- coding: utf-8 -*- """ @project: @File : predict @Author : qiqq @create_time : 2023/3/27 16:29 所有测试集都在一个文件夹中 """ import argparse import scipy from scipy import ndimage import cv2 import numpy as np import sys import copy import random import json from tqdm import tqdm import torch from torch.autograd import Variable import torchvision.models as models import torch.nn.functional as F from torch.utils import data from PIL import Image Image.MAX_IMAGE_PIXELS = None from collections import OrderedDict import os import scipy.ndimage as nd from math import ceil from PIL import Image as PILImage import torch.nn as nn os.environ["CUDA_VISIBLE_DEVICES"] = "1" '''专门为预测整个临安建立的 /home/qiqq/q3dl/test_temple/wholelinan/ 分了t1 t2 t3 ''' class mypv_bigmap: def __init__(self,model_name="tongzhou",): self.basedir = "/home/qiqq/q3dl/test_temple/wholelinan/" self.modelnamedir=model_name if not os.path.exists(self.modelnamedir): os.makedirs(self.modelnamedir) self.model_name=model_name self.palette = [0, 0, 0, 0, 255, 0, 0, 255, 0] self.imglist=[] self.labellist=[] self.fusin=True def pad_image(self,img, target_size): """Pad an image up to the target size.""" rows_missing = target_size[0] - img.shape[2] cols_missing = target_size[1] - img.shape[3] padded_img = np.pad(img, ((0, 0), (0, 0), (0, rows_missing), (0, cols_missing)), 'constant') return padded_img def predict_sliding(self,net, image, tile_size, classes): interp = nn.Upsample(size=tile_size, mode='bilinear', align_corners=True) image_size = image.shape overlap = 1 / 3 stride = ceil(tile_size[0] * (1 - overlap)) # ceil() 函数返回数字的上入整数 #512重叠1/3 的话这里就是 步长是32 也就是一个512的窗口滑动342个像素 tile_rows = int(ceil((image_size[2] - tile_size[0]) / stride) + 1) # 6 strided convolution formula tile_cols = int( ceil((image_size[3] - tile_size[1]) / stride) + 1) # 11 也就是image_size[3]在image_size[3]这个方向上会滑动11次? # print("Need %i x %i prediction tiles @ stride %i px" % (tile_cols, tile_rows, stride)) full_probs = np.zeros((image_size[2], image_size[3], classes)) # 一个全零的 大图 #初始化全概率矩阵 count_predictions = np.zeros((image_size[2], image_size[3], classes)) # 初始化计数矩阵 shape(2160,3840,3) tile_counter = 0 for row in range(tile_rows): for col in range(tile_cols): x1 = int(col * stride) # 起始位置 y1 = int(row * stride) x2 = min(x1 + tile_size[1], image_size[3]) # 末位置 莫位置如果超出边界就以边界 y2 = min(y1 + tile_size[0], image_size[2]) x1 = max(int(x2 - tile_size[1]), 0) # for portrait images the x1 underflows sometimes #重新校准起始位置x1 = y1 = max(int(y2 - tile_size[0]), 0) # for very few rows y1 underflows img = image[:, :, y1:y2, x1:x2] # 要输入网络的小图 padded_img = self.pad_image(img, tile_size) # padding 确保扣下来的图像为tile_size如果不够的话就padding # plt.imshow(padded_img) # plt.show() tile_counter += 1 ##计数加1 # print("Predicting tile %i" % tile_counter) with torch.no_grad(): # padded_prediction = net(Variable(torch.from_numpy(padded_img), volatile=True).cuda()) # 预测 padded_prediction =net(torch.from_numpy(padded_img).cuda()) # 预测 if isinstance(padded_prediction, list): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个) padded_prediction = padded_prediction[0] # 就取第一个 # padded_prediction = padded_prediction[0] padded_prediction = interp(padded_prediction).cpu().data[0].numpy().transpose(1, 2, 0) # 插值到原大小 prediction = padded_prediction[0:img.shape[2], 0:img.shape[3], :] count_predictions[y1:y2, x1:x2] += 1 # 这几个位置上的像素预测次数+1 ##窗口区域内的计数矩阵加 full_probs[y1:y2, x1:x2] += prediction # 把它搞到原图上# #窗口区域内的全概率矩阵叠加预测结果accumulate the predictions also in the overlapping regions ''' full_probs[y1:y2, x1:x2] += prediction 这个地方的如果涉及到重叠的而区域他是这么处理的 一般情况下不重叠的就是预测了1次,如果是重叠的一个像素 它被预测了3次比如这三次的概率(以屋顶类为例)是 0.8 0.7 0.6 那count_predictions矩阵上这个位置的值就是3因为是用了3次 那最后的结果取谁呢?取这三个的平均值full_probs上这个像素在三次后的值是 0.8+0.7+0.6 那么 下边的这一句 full_probs /= count_predictions就是取位置上的平均值 ''' # average the predictions in the overlapping regions full_probs /= count_predictions return full_probs #直接调用这个就可以 def get_resut(self,net,titlesize=(512,512),classes=2): d0 = os.listdir(self.basedir) hacc=["gaohong","jinbei"] for i in d0:#t1 t2 T3 d1 = os.path.join(self.basedir, i) d2 = os.listdir(d1) #taiyang gaohong for j in d2: # if j in hacc: continue # fpath=os.path.join(self.basedir ,i, j) fpath=os.path.join(d1,j) print("没问题过了",j) self.predict_dir=os.path.join(self.model_name,j) if not os.path.exists(self.predict_dir): os.makedirs(self.predict_dir) imglist=self.get_imgs(fpath) print("正在预测", j," ","一共",len(self.imglist)) assert len(self.imglist)!=0 for i in tqdm(self.imglist): image=i["image"] name=i["name"] orininal_w=i["orininal_w"] orininal_h=i["orininal_h"] old_img = copy.deepcopy(image) imagedata=np.expand_dims(np.transpose(self.preprocess_input(np.array(image, np.float32),md=False), (2, 0, 1)), 0) output=self.predict_sliding(net=net,image=imagedata,tile_size=titlesize,classes=classes) output = F.softmax(torch.from_numpy(output), dim=2).numpy() seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8) output_im = PILImage.fromarray(seg_pred) output_im.putpalette(self.palette) output_im.save(os.path.join(self.predict_dir , name + '.png')) print("预测结果图生成") def cvtColor(self,image): if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: return image elif np.shape(image)[2] == 4: image = image.convert('RGB') return image else: image = image.convert('RGB') return image def preprocess_input(self,image, md=False): mean = (0.231, 0.217, 0.22) # 针对北京的数据集 std = (0.104, 0.086, 0.085) if md: image /= 255.0 image -= mean image /= std return image else: image /= 255.0 return image #m没有标签的读取,通常就是纯预测的 def get_imgs(self,fpath): self.imglist=[] #单个模式 imagescomplete=[] fpth=os.path.join(fpath,"images") lists1 = [i for i in os.listdir(fpth) if i.endswith("png")] for i in lists1: imagescomplepath=os.path.join(fpth,i) imagescomplete.append(imagescomplepath) for j in imagescomplete: # name=j.split(".")[0] name=j.split("/")[-1].split(".")[0] image=Image.open(j) #zhuyu orininal_h = image.size[1] orininal_w = image.size[0] image = self.cvtColor(image, ) item={"name":name,"orininal_h":orininal_h,"orininal_w":orininal_w,"image":image} self.imglist.append(item) print("共监测到{}张测试图像".format(len(self.imglist))) return self.imglist if __name__ == '__main__': from taihuyuan_pv.mitunet.model.resunet import resUnetpamcarb model = resUnetpamcarb() model_path ="best.pth" model_dict = torch.load(model_path) model.load_state_dict(model_dict['net']) print("权重加载") model.eval() model.cuda() mypvtest=mypv_bigmap(model_name="manet_wholelinan_pv") mypvtest.get_resut(model)