ai-station-code/dimaoshibie/segformer.py

236 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import colorsys
import copy
import time
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
import os
from dimaoshibie.nets.segformer import SegFormer
from dimaoshibie.utils.utils import cvtColor, preprocess_input, resize_image, show_config
file_name = __file__
path = os.path.dirname(file_name)
class SegFormer_Segmentation(object):
# 默认配置参数
_defaults = {
"model_path": path + "/logs/best_epoch_weights_voc_12000.pth", # 模型权重路径
"num_classes": 10 + 1, # 类别数(包括背景)
"phi": "b0", # 模型规模b0-b5
"input_shape": [512, 512], # 输入图像尺寸
"mix_type": 1, # 可视化方式0-混合原图1-仅分割图2-仅目标区域
"cuda": True, # 是否使用GPU
}
def __init__(self, **kwargs):
# 更新默认配置
self.__dict__.update(self._defaults)
for name, value in kwargs.items():
setattr(self, name, value)
# 颜色设置
self.colors = [
(0, 0, 0), # Background (黑色)
(252, 250, 205), # Cropland (淡黄色)
(0, 123, 79), # Forest (深绿色)
(157, 221, 106), # Grass (浅绿色)
(77, 208, 159), # Shrub (浅蓝绿色)
(111, 208, 242), # Wetland (浅蓝色)
(10, 78, 151), # Water (深蓝色)
(92, 106, 55), # Tundra (土黄色)
(155, 36, 22), # Impervious surface (红色)
(205, 205, 205), # Bareland (灰色)
(211, 242, 255) # Ice/snow (浅天蓝色)
]
# 如果类别数小于等于11使用预设颜色否则动态生成颜色
if self.num_classes <= 11: # 10个标签 + 背景
self.colors = self.colors[:self.num_classes]
else:
hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
# 初始化模型
self.generate()
show_config(**self._defaults)
def generate(self, onnx=False):
"""
加载模型权重并初始化模型
:param onnx: 是否用于导出ONNX模型
"""
self.net = SegFormer(num_classes=self.num_classes, phi=self.phi, pretrained=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.net.load_state_dict(torch.load(self.model_path, map_location=device))
self.net = self.net.eval()
print('{} model, and classes loaded.'.format(self.model_path))
if not onnx:
if self.cuda:
self.net = nn.DataParallel(self.net)
self.net = self.net.cuda()
def detect_image(self, image, count=False, name_classes=None):
"""
对单张图像进行预测
:param image: 输入图像PIL格式
:param count: 是否进行像素点计数
:param name_classes: 类别名称列表
:return: 预测结果图像、计数字典、类别像素数数组
"""
# 将图像转换为RGB格式
image = cvtColor(image)
old_img = copy.deepcopy(image) # 备份原图
orininal_h = np.array(image).shape[0] # 原图高度
orininal_w = np.array(image).shape[1] # 原图宽度
# 图像预处理:调整大小并归一化
image_data, nw, nh = resize_image(image, (self.input_shape[1], self.input_shape[0]))
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
with torch.no_grad():
# 将图像数据转换为Tensor
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
# 模型预测
pr = self.net(images)[0]
pr = F.softmax(pr.permute(1, 2, 0), dim=-1).cpu().numpy() # 转换为概率图
pr = pr[int((self.input_shape[0] - nh) // 2): int((self.input_shape[0] - nh) // 2 + nh),
int((self.input_shape[1] - nw) // 2): int((self.input_shape[1] - nw) // 2 + nw)] # 裁剪填充区域
pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation=cv2.INTER_LINEAR) # 恢复原图大小
pr = pr.argmax(axis=-1) # 获取每个像素的预测类别
# 计数功能
count_dict = {}
classes_nums = np.zeros([self.num_classes])
if count:
total_points_num = orininal_h * orininal_w # 总像素数
for i in range(self.num_classes):
num = np.sum(pr == i) # 统计每个类别的像素数
ratio = num / total_points_num * 100 # 计算比例
count_dict[name_classes[i]] = num # 保存到字典
classes_nums[i] = num # 保存到数组
# 可视化
if self.mix_type == 0:
# 混合原图和分割图
seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
image = Image.fromarray(np.uint8(seg_img))
image = Image.blend(old_img, image, 1.0)
elif self.mix_type == 1:
# 仅显示分割图
seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
image = Image.fromarray(np.uint8(seg_img))
elif self.mix_type == 2:
# 仅显示目标区域
seg_img = (np.expand_dims(pr != 0, -1) * np.array(old_img, np.float32)).astype('uint8')
image = Image.fromarray(np.uint8(seg_img))
return image, count_dict, classes_nums
def get_FPS(self, image, test_interval):
"""
计算模型的FPS每秒帧数
:param image: 测试图像
:param test_interval: 测试次数
:return: 平均每帧耗时
"""
image = cvtColor(image)
image_data, nw, nh = resize_image(image, (self.input_shape[1], self.input_shape[0]))
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
pr = self.net(images)[0]
pr = F.softmax(pr.permute(1, 2, 0), dim=-1).cpu().numpy().argmax(axis=-1)
pr = pr[int((self.input_shape[0] - nh) // 2): int((self.input_shape[0] - nh) // 2 + nh),
int((self.input_shape[1] - nw) // 2): int((self.input_shape[1] - nw) // 2 + nw)]
t1 = time.time()
for _ in range(test_interval):
with torch.no_grad():
pr = self.net(images)[0]
pr = F.softmax(pr.permute(1, 2, 0), dim=-1).cpu().numpy().argmax(axis=-1)
pr = pr[int((self.input_shape[0] - nh) // 2): int((self.input_shape[0] - nh) // 2 + nh),
int((self.input_shape[1] - nw) // 2): int((self.input_shape[1] - nw) // 2 + nw)]
t2 = time.time()
tact_time = (t2 - t1) / test_interval # 平均每帧耗时
return tact_time
def convert_to_onnx(self, simplify, model_path):
"""
将模型导出为ONNX格式
:param simplify: 是否简化模型
:param model_path: ONNX模型保存路径
"""
import onnx
self.generate(onnx=True)
im = torch.zeros(1, 3, *self.input_shape).to('cpu') # 创建输入张量
input_layer_names = ["images"]
output_layer_names = ["output"]
print(f'Starting export with onnx {onnx.__version__}.')
torch.onnx.export(self.net,
im,
f=model_path,
verbose=False,
opset_version=12,
training=torch.onnx.TrainingMode.EVAL,
do_constant_folding=True,
input_names=input_layer_names,
output_names=output_layer_names,
dynamic_axes=None)
model_onnx = onnx.load(model_path) # 加载ONNX模型
onnx.checker.check_model(model_onnx) # 检查模型
if simplify:
import onnxsim
print(f'Simplifying with onnx-simplifier {onnxsim.__version__}.')
model_onnx, check = onnxsim.simplify(
model_onnx,
dynamic_input_shape=False,
input_shapes=None)
assert check, 'assert check failed'
onnx.save(model_onnx, model_path) # 保存简化后的模型
print('Onnx model save as {}'.format(model_path))
def get_miou_png(self, image):
"""
获取用于计算mIoU的预测结果图像
:param image: 输入图像
:return: 预测结果图像PIL格式
"""
image = cvtColor(image)
orininal_h = np.array(image).shape[0]
orininal_w = np.array(image).shape[1]
image_data, nw, nh = resize_image(image, (self.input_shape[1], self.input_shape[0]))
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
pr = self.net(images)[0]
pr = F.softmax(pr.permute(1, 2, 0), dim=-1).cpu().numpy()
pr = pr[int((self.input_shape[0] - nh) // 2): int((self.input_shape[0] - nh) // 2 + nh),
int((self.input_shape[1] - nw) // 2): int((self.input_shape[1] - nw) // 2 + nw)]
pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation=cv2.INTER_LINEAR)
pr = pr.argmax(axis=-1)
image = Image.fromarray(np.uint8(pr)) # 转换为PIL图像
return image