233 lines
9.4 KiB
Python
233 lines
9.4 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
@project:
|
||
@File : predict
|
||
@Author : qiqq
|
||
@create_time : 2023/4/24 20:08
|
||
|
||
单纯的生成预测结果图
|
||
单纯的的单通道图和合成的fusion
|
||
"""
|
||
|
||
|
||
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
@project:
|
||
@File : predict
|
||
@Author : qiqq
|
||
@create_time : 2023/3/27 16:29
|
||
所有测试集都在一个文件夹中
|
||
"""
|
||
import argparse
|
||
import scipy
|
||
from scipy import ndimage
|
||
import cv2
|
||
|
||
import numpy as np
|
||
import sys
|
||
import copy
|
||
import random
|
||
import json
|
||
from tqdm import tqdm
|
||
import torch
|
||
from torch.autograd import Variable
|
||
import torchvision.models as models
|
||
import torch.nn.functional as F
|
||
from torch.utils import data
|
||
|
||
from PIL import Image
|
||
Image.MAX_IMAGE_PIXELS = None
|
||
from collections import OrderedDict
|
||
import os
|
||
import scipy.ndimage as nd
|
||
from math import ceil
|
||
from PIL import Image as PILImage
|
||
|
||
import torch.nn as nn
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||
'''专门为预测整个临安建立的
|
||
/home/qiqq/q3dl/test_temple/wholelinan/
|
||
分了t1 t2 t3
|
||
'''
|
||
class mypv_bigmap:
|
||
def __init__(self,model_name="tongzhou",):
|
||
|
||
self.basedir = "/home/qiqq/q3dl/test_temple/wholelinan/"
|
||
|
||
self.modelnamedir=model_name
|
||
|
||
if not os.path.exists(self.modelnamedir):
|
||
os.makedirs(self.modelnamedir)
|
||
|
||
self.model_name=model_name
|
||
self.palette = [0, 0, 0, 0, 255, 0, 0, 255, 0]
|
||
self.imglist=[]
|
||
self.labellist=[]
|
||
|
||
self.fusin=True
|
||
|
||
def pad_image(self,img, target_size):
|
||
"""Pad an image up to the target size."""
|
||
rows_missing = target_size[0] - img.shape[2]
|
||
cols_missing = target_size[1] - img.shape[3]
|
||
padded_img = np.pad(img, ((0, 0), (0, 0), (0, rows_missing), (0, cols_missing)), 'constant')
|
||
return padded_img
|
||
|
||
def predict_sliding(self,net, image, tile_size, classes):
|
||
interp = nn.Upsample(size=tile_size, mode='bilinear', align_corners=True)
|
||
image_size = image.shape
|
||
overlap = 1 / 3
|
||
|
||
stride = ceil(tile_size[0] * (1 - overlap)) # ceil() 函数返回数字的上入整数 #512重叠1/3 的话这里就是 步长是32 也就是一个512的窗口滑动342个像素
|
||
tile_rows = int(ceil((image_size[2] - tile_size[0]) / stride) + 1) # 6 strided convolution formula
|
||
tile_cols = int(
|
||
ceil((image_size[3] - tile_size[1]) / stride) + 1) # 11 也就是image_size[3]在image_size[3]这个方向上会滑动11次?
|
||
# print("Need %i x %i prediction tiles @ stride %i px" % (tile_cols, tile_rows, stride))
|
||
full_probs = np.zeros((image_size[2], image_size[3], classes)) # 一个全零的 大图 #初始化全概率矩阵
|
||
count_predictions = np.zeros((image_size[2], image_size[3], classes)) # 初始化计数矩阵 shape(2160,3840,3)
|
||
tile_counter = 0
|
||
|
||
for row in range(tile_rows):
|
||
for col in range(tile_cols):
|
||
x1 = int(col * stride) # 起始位置
|
||
y1 = int(row * stride)
|
||
x2 = min(x1 + tile_size[1], image_size[3]) # 末位置 莫位置如果超出边界就以边界
|
||
y2 = min(y1 + tile_size[0], image_size[2])
|
||
x1 = max(int(x2 - tile_size[1]), 0) # for portrait images the x1 underflows sometimes #重新校准起始位置x1 =
|
||
y1 = max(int(y2 - tile_size[0]), 0) # for very few rows y1 underflows
|
||
|
||
img = image[:, :, y1:y2, x1:x2] # 要输入网络的小图
|
||
padded_img = self.pad_image(img, tile_size) # padding 确保扣下来的图像为tile_size如果不够的话就padding
|
||
# plt.imshow(padded_img)
|
||
# plt.show()
|
||
tile_counter += 1 ##计数加1
|
||
# print("Predicting tile %i" % tile_counter)
|
||
with torch.no_grad():
|
||
# padded_prediction = net(Variable(torch.from_numpy(padded_img), volatile=True).cuda()) # 预测
|
||
padded_prediction =net(torch.from_numpy(padded_img).cuda()) # 预测
|
||
if isinstance(padded_prediction, list): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个)
|
||
padded_prediction = padded_prediction[0] # 就取第一个
|
||
# padded_prediction = padded_prediction[0]
|
||
|
||
padded_prediction = interp(padded_prediction).cpu().data[0].numpy().transpose(1, 2, 0) # 插值到原大小
|
||
prediction = padded_prediction[0:img.shape[2], 0:img.shape[3], :]
|
||
count_predictions[y1:y2, x1:x2] += 1 # 这几个位置上的像素预测次数+1 ##窗口区域内的计数矩阵加
|
||
full_probs[y1:y2,
|
||
x1:x2] += prediction # 把它搞到原图上# #窗口区域内的全概率矩阵叠加预测结果accumulate the predictions also in the overlapping regions
|
||
'''
|
||
full_probs[y1:y2, x1:x2] += prediction 这个地方的如果涉及到重叠的而区域他是这么处理的
|
||
一般情况下不重叠的就是预测了1次,如果是重叠的一个像素 它被预测了3次比如这三次的概率(以屋顶类为例)是 0.8 0.7 0.6
|
||
那count_predictions矩阵上这个位置的值就是3因为是用了3次
|
||
那最后的结果取谁呢?取这三个的平均值full_probs上这个像素在三次后的值是 0.8+0.7+0.6 那么 下边的这一句
|
||
full_probs /= count_predictions就是取位置上的平均值
|
||
'''
|
||
# average the predictions in the overlapping regions
|
||
full_probs /= count_predictions
|
||
|
||
return full_probs
|
||
#直接调用这个就可以
|
||
def get_resut(self,net,titlesize=(512,512),classes=2):
|
||
d0 = os.listdir(self.basedir)
|
||
hacc=["gaohong","jinbei"]
|
||
for i in d0:#t1 t2 T3
|
||
d1 = os.path.join(self.basedir, i)
|
||
d2 = os.listdir(d1) #taiyang gaohong
|
||
for j in d2: #
|
||
if j in hacc:
|
||
continue
|
||
# fpath=os.path.join(self.basedir ,i, j)
|
||
fpath=os.path.join(d1,j)
|
||
print("没问题过了",j)
|
||
|
||
self.predict_dir=os.path.join(self.model_name,j)
|
||
if not os.path.exists(self.predict_dir):
|
||
os.makedirs(self.predict_dir)
|
||
|
||
imglist=self.get_imgs(fpath)
|
||
print("正在预测", j," ","一共",len(self.imglist))
|
||
assert len(self.imglist)!=0
|
||
for i in tqdm(self.imglist):
|
||
image=i["image"]
|
||
name=i["name"]
|
||
orininal_w=i["orininal_w"]
|
||
orininal_h=i["orininal_h"]
|
||
old_img = copy.deepcopy(image)
|
||
imagedata=np.expand_dims(np.transpose(self.preprocess_input(np.array(image, np.float32),md=False), (2, 0, 1)), 0)
|
||
output=self.predict_sliding(net=net,image=imagedata,tile_size=titlesize,classes=classes)
|
||
output = F.softmax(torch.from_numpy(output), dim=2).numpy()
|
||
seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
|
||
|
||
output_im = PILImage.fromarray(seg_pred)
|
||
output_im.putpalette(self.palette)
|
||
output_im.save(os.path.join(self.predict_dir , name + '.png'))
|
||
|
||
print("预测结果图生成")
|
||
def cvtColor(self,image):
|
||
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
|
||
return image
|
||
elif np.shape(image)[2] == 4:
|
||
image = image.convert('RGB')
|
||
return image
|
||
else:
|
||
image = image.convert('RGB')
|
||
return image
|
||
|
||
def preprocess_input(self,image, md=False):
|
||
mean = (0.231, 0.217, 0.22) # 针对北京的数据集
|
||
std = (0.104, 0.086, 0.085)
|
||
if md:
|
||
image /= 255.0
|
||
image -= mean
|
||
image /= std
|
||
return image
|
||
else:
|
||
image /= 255.0
|
||
return image
|
||
|
||
#m没有标签的读取,通常就是纯预测的
|
||
def get_imgs(self,fpath):
|
||
self.imglist=[]
|
||
|
||
#单个模式
|
||
imagescomplete=[]
|
||
fpth=os.path.join(fpath,"images")
|
||
lists1 = [i for i in os.listdir(fpth) if i.endswith("png")]
|
||
for i in lists1:
|
||
imagescomplepath=os.path.join(fpth,i)
|
||
imagescomplete.append(imagescomplepath)
|
||
for j in imagescomplete:
|
||
# name=j.split(".")[0]
|
||
name=j.split("/")[-1].split(".")[0]
|
||
image=Image.open(j) #zhuyu
|
||
orininal_h = image.size[1]
|
||
orininal_w = image.size[0]
|
||
image = self.cvtColor(image, )
|
||
item={"name":name,"orininal_h":orininal_h,"orininal_w":orininal_w,"image":image}
|
||
self.imglist.append(item)
|
||
|
||
print("共监测到{}张测试图像".format(len(self.imglist)))
|
||
return self.imglist
|
||
|
||
|
||
|
||
|
||
|
||
|
||
if __name__ == '__main__':
|
||
from taihuyuan_pv.mitunet.model.resunet import resUnetpamcarb
|
||
|
||
model = resUnetpamcarb()
|
||
|
||
model_path ="best.pth"
|
||
model_dict = torch.load(model_path)
|
||
model.load_state_dict(model_dict['net'])
|
||
print("权重加载")
|
||
model.eval()
|
||
model.cuda()
|
||
|
||
mypvtest=mypv_bigmap(model_name="manet_wholelinan_pv")
|
||
mypvtest.get_resut(model)
|
||
|