ai_platform_cv/segmentation/segment_pred.py

59 lines
1.6 KiB
Python

# -*-coding:utf-8-*-
from logzero import logger
import albumentations as albu
import torch
import segmentation_models_pytorch as smp
DEVICE = 'cpu'
ENCODER = 'se_resnext50_32x4d'
ENCODER_WEIGHTS = 'imagenet'
# ---------------------------------------------------------------
### 加载数据
def get_validation_augmentation():
"""调整图像使得图片的分辨率长宽能被32整除"""
test_transform = [
albu.PadIfNeeded(1024, 1024)
]
return albu.Compose(test_transform)
def to_tensor(x, **kwargs):
return x.transpose(2, 0, 1).astype('float32')
def get_preprocessing(preprocessing_fn):
"""进行图像预处理操作
Args:
preprocessing_fn (callbale): 数据规范化的函数
(针对每种预训练的神经网络)
Return:
transform: albumentations.Compose
"""
_transform = [
albu.Lambda(image=preprocessing_fn),
albu.Lambda(image=to_tensor),
]
return albu.Compose(_transform)
def run_seg(img, best_model):
# 测试集
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
augmentator = get_validation_augmentation()
preprocessor = get_preprocessing(preprocessing_fn)
# ---------------------------------------------------------------
img = augmentator(image=img)['image']
img = preprocessor(image=img)['image']
logger.info(f"img.shape {img.shape}")
x_tensor = torch.from_numpy(img).to(DEVICE).unsqueeze(0)
pr_mask = best_model.predict(x_tensor)
pr_mask = (pr_mask.squeeze().cpu().numpy().round())
return (pr_mask - 1) * (-220)
if __name__ == '__main__':
pass