ai-station-code/wudingpv/taihuyuan_pv/mitunet/TestAndPredict/wholelinan/predict_wholelinan.py

233 lines
9.4 KiB
Python
Raw Normal View History

2025-05-06 11:18:48 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : predict
@Author : qiqq
@create_time : 2023/4/24 20:08
单纯的生成预测结果图
单纯的的单通道图和合成的fusion
"""
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : predict
@Author : qiqq
@create_time : 2023/3/27 16:29
所有测试集都在一个文件夹中
"""
import argparse
import scipy
from scipy import ndimage
import cv2
import numpy as np
import sys
import copy
import random
import json
from tqdm import tqdm
import torch
from torch.autograd import Variable
import torchvision.models as models
import torch.nn.functional as F
from torch.utils import data
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
from collections import OrderedDict
import os
import scipy.ndimage as nd
from math import ceil
from PIL import Image as PILImage
import torch.nn as nn
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
'''专门为预测整个临安建立的
/home/qiqq/q3dl/test_temple/wholelinan/
分了t1 t2 t3
'''
class mypv_bigmap:
def __init__(self,model_name="tongzhou",):
self.basedir = "/home/qiqq/q3dl/test_temple/wholelinan/"
self.modelnamedir=model_name
if not os.path.exists(self.modelnamedir):
os.makedirs(self.modelnamedir)
self.model_name=model_name
self.palette = [0, 0, 0, 0, 255, 0, 0, 255, 0]
self.imglist=[]
self.labellist=[]
self.fusin=True
def pad_image(self,img, target_size):
"""Pad an image up to the target size."""
rows_missing = target_size[0] - img.shape[2]
cols_missing = target_size[1] - img.shape[3]
padded_img = np.pad(img, ((0, 0), (0, 0), (0, rows_missing), (0, cols_missing)), 'constant')
return padded_img
def predict_sliding(self,net, image, tile_size, classes):
interp = nn.Upsample(size=tile_size, mode='bilinear', align_corners=True)
image_size = image.shape
overlap = 1 / 3
stride = ceil(tile_size[0] * (1 - overlap)) # ceil() 函数返回数字的上入整数 #512重叠1/3 的话这里就是 步长是32 也就是一个512的窗口滑动342个像素
tile_rows = int(ceil((image_size[2] - tile_size[0]) / stride) + 1) # 6 strided convolution formula
tile_cols = int(
ceil((image_size[3] - tile_size[1]) / stride) + 1) # 11 也就是image_size[3]在image_size[3]这个方向上会滑动11次
# print("Need %i x %i prediction tiles @ stride %i px" % (tile_cols, tile_rows, stride))
full_probs = np.zeros((image_size[2], image_size[3], classes)) # 一个全零的 大图 #初始化全概率矩阵
count_predictions = np.zeros((image_size[2], image_size[3], classes)) # 初始化计数矩阵 shape(2160,3840,3)
tile_counter = 0
for row in range(tile_rows):
for col in range(tile_cols):
x1 = int(col * stride) # 起始位置
y1 = int(row * stride)
x2 = min(x1 + tile_size[1], image_size[3]) # 末位置 莫位置如果超出边界就以边界
y2 = min(y1 + tile_size[0], image_size[2])
x1 = max(int(x2 - tile_size[1]), 0) # for portrait images the x1 underflows sometimes #重新校准起始位置x1 =
y1 = max(int(y2 - tile_size[0]), 0) # for very few rows y1 underflows
img = image[:, :, y1:y2, x1:x2] # 要输入网络的小图
padded_img = self.pad_image(img, tile_size) # padding 确保扣下来的图像为tile_size如果不够的话就padding
# plt.imshow(padded_img)
# plt.show()
tile_counter += 1 ##计数加1
# print("Predicting tile %i" % tile_counter)
with torch.no_grad():
# padded_prediction = net(Variable(torch.from_numpy(padded_img), volatile=True).cuda()) # 预测
padded_prediction =net(torch.from_numpy(padded_img).cuda()) # 预测
if isinstance(padded_prediction, list): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个)
padded_prediction = padded_prediction[0] # 就取第一个
# padded_prediction = padded_prediction[0]
padded_prediction = interp(padded_prediction).cpu().data[0].numpy().transpose(1, 2, 0) # 插值到原大小
prediction = padded_prediction[0:img.shape[2], 0:img.shape[3], :]
count_predictions[y1:y2, x1:x2] += 1 # 这几个位置上的像素预测次数+1 ##窗口区域内的计数矩阵加
full_probs[y1:y2,
x1:x2] += prediction # 把它搞到原图上# #窗口区域内的全概率矩阵叠加预测结果accumulate the predictions also in the overlapping regions
'''
full_probs[y1:y2, x1:x2] += prediction 这个地方的如果涉及到重叠的而区域他是这么处理的
一般情况下不重叠的就是预测了1次如果是重叠的一个像素 它被预测了3次比如这三次的概率以屋顶类为例 0.8 0.7 0.6
那count_predictions矩阵上这个位置的值就是3因为是用了3次
那最后的结果取谁呢取这三个的平均值full_probs上这个像素在三次后的值是 0.8+0.7+0.6 那么 下边的这一句
full_probs /= count_predictions就是取位置上的平均值
'''
# average the predictions in the overlapping regions
full_probs /= count_predictions
return full_probs
#直接调用这个就可以
def get_resut(self,net,titlesize=(512,512),classes=2):
d0 = os.listdir(self.basedir)
hacc=["gaohong","jinbei"]
for i in d0:#t1 t2 T3
d1 = os.path.join(self.basedir, i)
d2 = os.listdir(d1) #taiyang gaohong
for j in d2: #
if j in hacc:
continue
# fpath=os.path.join(self.basedir ,i, j)
fpath=os.path.join(d1,j)
print("没问题过了",j)
self.predict_dir=os.path.join(self.model_name,j)
if not os.path.exists(self.predict_dir):
os.makedirs(self.predict_dir)
imglist=self.get_imgs(fpath)
print("正在预测", j," ","一共",len(self.imglist))
assert len(self.imglist)!=0
for i in tqdm(self.imglist):
image=i["image"]
name=i["name"]
orininal_w=i["orininal_w"]
orininal_h=i["orininal_h"]
old_img = copy.deepcopy(image)
imagedata=np.expand_dims(np.transpose(self.preprocess_input(np.array(image, np.float32),md=False), (2, 0, 1)), 0)
output=self.predict_sliding(net=net,image=imagedata,tile_size=titlesize,classes=classes)
output = F.softmax(torch.from_numpy(output), dim=2).numpy()
seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
output_im = PILImage.fromarray(seg_pred)
output_im.putpalette(self.palette)
output_im.save(os.path.join(self.predict_dir , name + '.png'))
print("预测结果图生成")
def cvtColor(self,image):
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
return image
elif np.shape(image)[2] == 4:
image = image.convert('RGB')
return image
else:
image = image.convert('RGB')
return image
def preprocess_input(self,image, md=False):
mean = (0.231, 0.217, 0.22) # 针对北京的数据集
std = (0.104, 0.086, 0.085)
if md:
image /= 255.0
image -= mean
image /= std
return image
else:
image /= 255.0
return image
#m没有标签的读取通常就是纯预测的
def get_imgs(self,fpath):
self.imglist=[]
#单个模式
imagescomplete=[]
fpth=os.path.join(fpath,"images")
lists1 = [i for i in os.listdir(fpth) if i.endswith("png")]
for i in lists1:
imagescomplepath=os.path.join(fpth,i)
imagescomplete.append(imagescomplepath)
for j in imagescomplete:
# name=j.split(".")[0]
name=j.split("/")[-1].split(".")[0]
image=Image.open(j) #zhuyu
orininal_h = image.size[1]
orininal_w = image.size[0]
image = self.cvtColor(image, )
item={"name":name,"orininal_h":orininal_h,"orininal_w":orininal_w,"image":image}
self.imglist.append(item)
print("共监测到{}张测试图像".format(len(self.imglist)))
return self.imglist
if __name__ == '__main__':
from taihuyuan_pv.mitunet.model.resunet import resUnetpamcarb
model = resUnetpamcarb()
model_path ="best.pth"
model_dict = torch.load(model_path)
model.load_state_dict(model_dict['net'])
print("权重加载")
model.eval()
model.cuda()
mypvtest=mypv_bigmap(model_name="manet_wholelinan_pv")
mypvtest.get_resut(model)