59 lines
1.6 KiB
Python
59 lines
1.6 KiB
Python
# -*-coding:utf-8-*-
|
|
from logzero import logger
|
|
import albumentations as albu
|
|
import torch
|
|
import segmentation_models_pytorch as smp
|
|
|
|
DEVICE = 'cpu'
|
|
ENCODER = 'se_resnext50_32x4d'
|
|
ENCODER_WEIGHTS = 'imagenet'
|
|
|
|
# ---------------------------------------------------------------
|
|
### 加载数据
|
|
|
|
def get_validation_augmentation():
|
|
"""调整图像使得图片的分辨率长宽能被32整除"""
|
|
test_transform = [
|
|
albu.PadIfNeeded(1024, 1024)
|
|
]
|
|
return albu.Compose(test_transform)
|
|
|
|
|
|
def to_tensor(x, **kwargs):
|
|
return x.transpose(2, 0, 1).astype('float32')
|
|
|
|
|
|
def get_preprocessing(preprocessing_fn):
|
|
"""进行图像预处理操作
|
|
|
|
Args:
|
|
preprocessing_fn (callbale): 数据规范化的函数
|
|
(针对每种预训练的神经网络)
|
|
Return:
|
|
transform: albumentations.Compose
|
|
"""
|
|
|
|
_transform = [
|
|
albu.Lambda(image=preprocessing_fn),
|
|
albu.Lambda(image=to_tensor),
|
|
]
|
|
return albu.Compose(_transform)
|
|
|
|
|
|
def run_seg(img, best_model):
|
|
# 测试集
|
|
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
|
|
augmentator = get_validation_augmentation()
|
|
preprocessor = get_preprocessing(preprocessing_fn)
|
|
# ---------------------------------------------------------------
|
|
img = augmentator(image=img)['image']
|
|
img = preprocessor(image=img)['image']
|
|
logger.info(f"img.shape {img.shape}")
|
|
x_tensor = torch.from_numpy(img).to(DEVICE).unsqueeze(0)
|
|
pr_mask = best_model.predict(x_tensor)
|
|
pr_mask = (pr_mask.squeeze().cpu().numpy().round())
|
|
return (pr_mask - 1) * (-220)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pass |