# -*-coding:utf-8-*- from logzero import logger import albumentations as albu import torch import segmentation_models_pytorch as smp DEVICE = 'cpu' ENCODER = 'se_resnext50_32x4d' ENCODER_WEIGHTS = 'imagenet' # --------------------------------------------------------------- ### 加载数据 def get_validation_augmentation(): """调整图像使得图片的分辨率长宽能被32整除""" test_transform = [ albu.PadIfNeeded(1024, 1024) ] return albu.Compose(test_transform) def to_tensor(x, **kwargs): return x.transpose(2, 0, 1).astype('float32') def get_preprocessing(preprocessing_fn): """进行图像预处理操作 Args: preprocessing_fn (callbale): 数据规范化的函数 (针对每种预训练的神经网络) Return: transform: albumentations.Compose """ _transform = [ albu.Lambda(image=preprocessing_fn), albu.Lambda(image=to_tensor), ] return albu.Compose(_transform) def run_seg(img, best_model): # 测试集 preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS) augmentator = get_validation_augmentation() preprocessor = get_preprocessing(preprocessing_fn) # --------------------------------------------------------------- img = augmentator(image=img)['image'] img = preprocessor(image=img)['image'] logger.info(f"img.shape {img.shape}") x_tensor = torch.from_numpy(img).to(DEVICE).unsqueeze(0) pr_mask = best_model.predict(x_tensor) pr_mask = (pr_mask.squeeze().cpu().numpy().round()) return (pr_mask - 1) * (-220) if __name__ == '__main__': pass