from asyncio.log import logger import os import numpy as np import cv2 import matplotlib.pyplot as plt import albumentations as albu import torch import segmentation_models_pytorch as smp from torch.utils.data import DataLoader from torch.utils.data import Dataset as BaseDataset from PIL import Image DEVICE = 'cpu' ENCODER = 'se_resnext50_32x4d' ENCODER_WEIGHTS = 'imagenet' # --------------------------------------------------------------- ### 加载数据 def get_validation_augmentation(): """调整图像使得图片的分辨率长宽能被32整除""" test_transform = [ albu.PadIfNeeded(256, 256) ] return albu.Compose(test_transform) def to_tensor(x, **kwargs): return x.transpose(2, 0, 1).astype('float32') def get_preprocessing(preprocessing_fn): """进行图像预处理操作 Args: preprocessing_fn (callbale): 数据规范化的函数 (针对每种预训练的神经网络) Return: transform: albumentations.Compose """ _transform = [ albu.Lambda(image=preprocessing_fn), albu.Lambda(image=to_tensor), ] return albu.Compose(_transform) def run_seg(img, best_model): # 测试集 preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS) augmentator = get_validation_augmentation() preprocessor = get_preprocessing(preprocessing_fn) # --------------------------------------------------------------- img = augmentator(image=img)['image'] img = preprocessor(image=img)['image'] # 加载最佳模型 x_tensor = torch.from_numpy(img).to(DEVICE).unsqueeze(0) pr_mask = best_model.predict(x_tensor) pr_mask = (pr_mask.squeeze().cpu().numpy().round()) return (pr_mask - 1) * (-220) if __name__ == '__main__': best_model = torch.load('/home/zhaojh/workspace/computer_vision/segmentation/models/best_model_pvgc.pth', map_location=DEVICE) input_img = cv2.imread('/home/zhaojh/datasets/photovoltaic/PV03/PV03_Ground_Cropland/test/PV03_316626_1211836.bmp')