ai-station-code/wudingpv/taihuyuan_pv/dataloaders/custom_transforms.py

817 lines
32 KiB
Python
Raw Normal View History

2025-05-06 11:18:48 +08:00
import torch
import random
import numpy as np
import torchvision.transforms as transforms
#定义了一些数据增强的方法
from PIL import Image, ImageOps, ImageFilter
from scipy.ndimage.interpolation import shift
from scipy.ndimage.morphology import distance_transform_edt
import numpy as np
import torchvision.transforms as transforms
#定义了一些数据增强的方法
from PIL import Image, ImageOps, ImageFilter,ImageEnhance
#一定要设置随机种子
import random
import torch
seed = 7
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class RelaxedBoundary(object):
"""
Boundary Relaxation
俺也不知道干啥的就是生成边缘的01输入的是一张Image open的image
返回的是一个tensor
"""
def __init__(self, ignore_id, num_classes):
self.ignore_id = ignore_id
self.num_classes = num_classes
def new_one_hot_converter(self, a):
ncols = self.num_classes + 1
out = np.zeros((a.size, ncols), dtype=np.uint8)
out[np.arange(a.size), a.ravel()] = 1
out.shape = a.shape + (ncols,)
return out
def __call__(self, img):
img_arr = np.array(img)
img_arr[img_arr == self.ignore_id] = self.num_classes # 让忽略的物体的边缘都等于numclass
one_hot = 0
border = 1
for i in range(-border, border + 1):
for j in range(-border, border + 1):
shifted = shift(img_arr, (i, j), cval=self.num_classes)
one_hot += self.new_one_hot_converter(shifted)
one_hot[one_hot > 1] = 1
one_hot = np.moveaxis(one_hot, -1, 0)
return one_hot
# return torch.from_numpy(one_hot).byte()
def onehot_to_binary_edges(mask, radius, num_classes):
"""
Converts a segmentation mask (K,H,W) to a binary edgemap (H,W)
"""
if radius < 0:
return mask
# We need to pad the borders for boundary conditions
mask_pad = np.pad(mask, ((0, 0), (1, 1), (1, 1)), mode='constant', constant_values=0)
edgemap = np.zeros(mask.shape[1:])
for i in range(num_classes):
dist = distance_transform_edt(mask_pad[i, :]) + distance_transform_edt(1.0 - mask_pad[i, :])
dist = dist[1:-1, 1:-1]
dist[dist > radius] = 0
edgemap += dist
edgemap = np.expand_dims(edgemap, axis=0)
edgemap = (edgemap > 0).astype(np.uint8)
return edgemap
def generate_edge_map(mask,num_classes=3):
'''mask是一张np.array的数组'''
gener_edge = RelaxedBoundary(ignore_id=255,num_classes=num_classes) # c, h, w
_edgemap=gener_edge(mask)
_edgemap = _edgemap[:-1, :, :]
_edgemap = onehot_to_binary_edges(_edgemap, 2, num_classes)
# edgemap = torch.from_numpy(_edgemap).float()
return _edgemap
'''自定义一些数据增强的方法'''
class Normalize(object): #归一化处理 根据数据集的均值和方差归一化
"""Normalize a tensor image with mean and standard deviation.
Args:
mean (tuple): means for each channel.
std (tuple): standard deviations for each channel.
"""
def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
self.mean = mean
self.std = std
def __call__(self, sample):
img = sample['image']
mask = sample['label']
img = np.array(img).astype(np.float32)
mask = np.array(mask).astype(np.float32)
img /= 255.0
img -= self.mean
img /= self.std
return {'image': img,
'label': mask}
class Normalize_simple(object): #简单归一化处理
"""Normalize a tensor image with/255
Args:
"""
def __init__(self,use_edge=False):
self.num = 255.0
self.use_edge=use_edge
def __call__(self, sample):
if self.use_edge:
img = sample['image']
mask = sample['label']
edge = sample['edge']
img = np.array(img).astype(np.float32)
mask = np.array(mask).astype(np.float32)
edge = np.array(edge).astype(np.float32)
img /= 255.0
return {'image': img,
'label': mask,
'edge': edge,
}
else:
img = sample['image']
mask = sample['label']
img = np.array(img).astype(np.float32)
mask = np.array(mask).astype(np.float32)
img /= 255.0
return {'image': img,
'label': mask}
class Resize(object):
'''太大了进不去'''
def __init__(self, resizeshape=(512,512)):
self.resizeshape = resizeshape
def __call__(self, sample):
img = sample['image']
mask = sample['label']
img =img.resize(self.resizeshape,Image.BILINEAR) #原图用双线性插值
mask = mask.resize(self.resizeshape, Image.NEAREST)#标签图用最近邻,要不然就乱了
return {'image': img,
'label': mask}
class ResizeforValTest(object): #注意测试miu的时候gt是打死都不能动的只能缩放预测图 #允许你缩放原图,因为可能因为显存不够盛不下 但是gt打死不能动
'''太大了进不去'''
def __init__(self, resizeshape=(1024,512)):
self.resizeshape = resizeshape
def __call__(self, sample):
img = sample['image']
mask = sample['label']
img =img.resize(self.resizeshape,Image.BILINEAR) #原图用双线性插值
# mask = mask.resize(self.resizeshape, Image.NEAREST)#标签图用最近邻,要不然就乱了
return {'image': img,
'label': mask}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __init__(self,use_edge=False):
self.use_edge=use_edge
def __call__(self, sample):
if self.use_edge:
# swap color axis because
# numpy image: H x W x C
# torch image: C X H X W
img = sample['image']
mask = sample['label']
edge = sample['edge']
img = np.array(img).astype(np.float32).transpose((2, 0, 1))
mask = np.array(mask).astype(np.float32)
img = torch.from_numpy(img).float()
mask = torch.from_numpy(mask).float()
edge = torch.from_numpy(edge).float()
return {'image': img,
'label': mask,
'edge':edge
}
else:
# swap color axis because
# numpy image: H x W x C
# torch image: C X H X W
img = sample['image']
mask = sample['label']
img = np.array(img).astype(np.float32).transpose((2, 0, 1))
mask = np.array(mask).astype(np.float32)
# mask = (np.array(mask)/100).astype(np.float32) #针对rispac数据集中的标签
img = torch.from_numpy(img).float()
mask = torch.from_numpy(mask).float()
return {'image': img,
'label': mask}
class RandomHorizontalFlip(object):#这是真的随机翻转,有的翻转有的不反转 keyong
def __call__(self, sample):
img = sample['image']
mask = sample['label']
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
return {'image': img,
'label': mask}
#随机色度增强
class Enhance_Color(object):
def __call__(self, sample):
img = sample['image']
mask = sample['label']
enh_col = ImageEnhance.Color(img)
color= np.random.uniform(0.4,2.6) #返回a,b之间的随机浮点数,控制图像的增强程度。变量factor为1将返回原始图像的拷贝factor值越小颜色越少亮度对比度等更多的价值。对变量facotr没有限制。
img_colored = enh_col.enhance(color)
return {'image': img_colored,
'label': mask}
#随机对比度增强
class Enhance_contrasted(object):
def __call__(self, sample):
img = sample['image']
mask = sample['label']
enh_con = ImageEnhance.Color(img)
contrast = np.random.uniform(0.6,1.6)
img_contrasted = enh_con.enhance(contrast)
return {'image': img_contrasted,
'label': mask}
#随机锐度增强
class Enhance_sharped(object):
def __call__(self, sample):
img = sample['image']
mask = sample['label']
enh_sha = ImageEnhance.Sharpness(img)
sharpness = np.random.uniform(0.4, 4)
image_sharped = enh_sha.enhance(sharpness)
return {'image': image_sharped,
'label': mask}
class RandomRotate(object): #随机 随机旋转率旋转 有的旋转有的不旋转,转转率也不一样
def __init__(self, degree):
self.degree = degree
def __call__(self, sample):
img = sample['image']
mask = sample['label']
if random.random() < 0.5:
#这玩意旋转完了还好吗
rotate_degree = random.uniform(-1*self.degree, self.degree)
img = img.rotate(rotate_degree, Image.BILINEAR)
mask = mask.rotate(rotate_degree, Image.NEAREST)
return {'image': img,
'label': mask}
class RandomGaussianBlur(object): #随机高斯模糊,真的随机
def __call__(self, sample):
img = sample['image']
mask = sample['label']
if random.random() < 0.5:
img = img.filter(ImageFilter.GaussianBlur(
radius=random.random()))
return {'image': img,
'label': mask}
#缩放尺寸原图2048*1024太大了
class FixedResize(object):#固定尺寸
def __init__(self, size):
self.size = (size, size) # size: (h, w)
def __call__(self, sample):
img = sample['image']
mask = sample['label']
assert img.size == mask.size
img = img.resize(self.size, Image.BILINEAR)
mask = mask.resize(self.size, Image.NEAREST)
return {'image': img,
'label': mask}
class RandomCropResize(object):
""" #随机一块区域裁剪裁剪出一块区域后再resize成原图大小
Randomly crop and resize the given PIL image with a probability of 0.5
"""
def __init__(self, crop_area):
'''
:param crop_area: area to be cropped (this is the max value and we select between o and crop area
'''
self.cw = crop_area #裁剪区域大小 裁剪区域必须小于原图大小
self.ch = crop_area
def __call__(self,sample):
img = sample['image']
mask = sample['label']
if random.random() < 0.5:
h, w = img.shape[:2]
x1 = random.randint(0, self.ch)
y1 = random.randint(0, self.cw)
img_crop = img[y1:h-y1, x1:w-x1]
label_crop = mask[y1:h-y1, x1:w-x1]
img_crop = img_crop.resize( (w, h),interpolation=Image.BILINEAR)
label_crop = label_crop.resize( (w,h), interpolation=Image.NEAREST)
return {'image': img_crop,
'label': label_crop}
else:
return {'image': img,
'label': mask}
class FixScaleCrop(object):#固定尺寸裁剪中心裁剪从标标准准的的中心裁剪裁剪大小是crop_size
def __init__(self, crop_size):#他裁剪出来后貌似不进行resize或者填充了
self.crop_size = crop_size
def __call__(self, sample):
img = sample['image']
mask = sample['label']
w, h = img.size
if w > h:
oh = self.crop_size
ow = int(1.0 * w * oh / h)
else:
ow = self.crop_size
oh = int(1.0 * h * ow / w)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# center crop
w, h = img.size
x1 = int(round((w - self.crop_size) / 2.))
y1 = int(round((h - self.crop_size) / 2.))
img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) #坐上右下点的坐标,然后呢 裁剪完成之后不pading成统一的尺寸了吗
mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
return {'image': img,
'label': mask}
class RandomFixScaleCropMy(object): #随机中心裁剪裁剪出来的那一块再填充成原图大小 基于上边的那个给填充
def __init__(self, crop_size,fill=0):
self.crop_size = crop_size
self.fill=fill
def __call__(self, sample):
img = sample['image']
mask = sample['label']
w, h = img.size
#前提cropsize要小于resize的大小
# center crop
if random.random() < 0.5:
x1 = int(round((w - self.crop_size) / 2.))
y1 = int(round((h - self.crop_size) / 2.))
img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) # 坐上右下点的坐标,然后呢 裁剪完成之后不pading成统一的尺寸了吗
mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
#裁剪玩成之后再补回原图(缩放后的原图)大小 #比如原图2048 1024 缩放后的原图是1024 512
padw=w-self.crop_size-x1
padh=h-self.crop_size-y1
img = ImageOps.expand(img, border=(x1, y1, padw, padh), fill=0) # 左,上,右,下
mask = ImageOps.expand(mask, border=(x1, y1, padw, padh), fill=self.fill)
return {'image': img,
'label': mask}
else:
return {'image': img,
'label': mask}
class RandomScaleCrop(object):
#随机尺寸裁剪
def __init__(self, base_size, crop_size, fill=0):
'''
没怎么明白
裁出来还是crop_size大小的
base_size是是干嘛的?
先使用这个basesize缩放
然后对缩放的图片进行裁剪出crop size大小
:param base_size:
:param crop_size: 裁剪的
:param fill:
'''
self.base_size = base_size
self.crop_size = crop_size
self.fill = fill
def __call__(self, sample):
img = sample['image']
mask = sample['label']
# random scale (short edge)
short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
w, h = img.size
if h > w:
ow = short_size
oh = int(1.0 * h * ow / w)
else:
oh = short_size
ow = int(1.0 * w * oh / h)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# pad crop
if short_size < self.crop_size:
padh = self.crop_size - oh if oh < self.crop_size else 0
padw = self.crop_size - ow if ow < self.crop_size else 0
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill)
# random crop crop_size
w, h = img.size
x1 = random.randint(0, w - self.crop_size)
y1 = random.randint(0, h - self.crop_size)
img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
return {'image': img,
'label': mask}
####################################################
################################################################
'''
下边这些的使用前提你的img和label都转换成了numpy但是还没有进行归一化
'''
'''
Function:
Define the transforms for data augmentations
Author:
Zhenchao Jin
'''
import cv2
import torch
import numpy as np
import torch.nn.functional as F
'''Resize'''
class Resize2(object):
#固定尺寸resize或者按照scale_range随机resize大小
def __init__(self, output_size, scale_range=(0.5, 2.0), img_interpolation='bilinear', seg_interpolation='nearest', keep_ratio=True, min_size=None):
# set attribute
'''
:param output_size: 输出的size
:param scale_range: 缩放的范围
:param img_interpolation:
:param seg_interpolation:
:param keep_ratio: 是否保持原比例缩放
:param min_size:
当scale_range=None,keep_ratio=False时就直接粗暴的resize成想要的大小
当scale_range=None,keep_ratio=True时不失真的缩放 比如原图(501, 750, 3) outsize=512-- (342, 512, 3)
当scale_range=[0.5,2],keep_ratio=True时按照这个随机比例不失真的缩放
当scale_range=[0.5,2],keep_ratio=False时按照这个随机比例直接缩放
'''
self.output_size = output_size
if isinstance(output_size, int): self.output_size = (output_size, output_size)
self.scale_range = scale_range
self.img_interpolation = img_interpolation
self.seg_interpolation = seg_interpolation
self.keep_ratio = keep_ratio
self.min_size = min_size
# interpolation to cv2 interpolation
self.interpolation_dict = {
'nearest': cv2.INTER_NEAREST,
'bilinear': cv2.INTER_LINEAR,
'bicubic': cv2.INTER_CUBIC,
'area': cv2.INTER_AREA,
'lanczos': cv2.INTER_LANCZOS4
}
'''call'''
def __call__(self, sample):
# parse
image, segmentation = sample['image'].copy(), sample['label'].copy()
if self.scale_range is not None:
rand_scale = np.random.random_sample() * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0]
output_size = int(self.output_size[0] * rand_scale), int(self.output_size[1] * rand_scale)
else:
output_size = self.output_size[0], self.output_size[1]
# resize image and segmentation
if self.keep_ratio:
scale_factor = min(max(output_size) / max(image.shape[:2]), min(output_size) / min(image.shape[:2]))
dsize = int(image.shape[1] * scale_factor + 0.5), int(image.shape[0] * scale_factor + 0.5)
if self.min_size is not None and min(dsize) < self.min_size:
scale_factor = self.min_size / min(image.shape[:2])
dsize = int(image.shape[1] * scale_factor + 0.5), int(image.shape[0] * scale_factor + 0.5)
image = cv2.resize(image, dsize=dsize, interpolation=self.interpolation_dict[self.img_interpolation])
segmentation = cv2.resize(segmentation, dsize=dsize, interpolation=self.interpolation_dict[self.seg_interpolation])
else:
if image.shape[0] > image.shape[1]:
dsize = min(output_size), max(output_size)
else:
dsize = max(output_size), min(output_size)
image = cv2.resize(image, dsize=dsize, interpolation=self.interpolation_dict[self.img_interpolation])
segmentation = cv2.resize(segmentation, dsize=dsize, interpolation=self.interpolation_dict[self.seg_interpolation])
# update and return sample
sample['image'], sample['label'] = image, segmentation
return sample
class Tonumpy(object):
def __init__(self,):
pass
def __call__(self, sample):
img = sample['image']
mask = sample['label']
img = np.array(img)
mask = np.array(mask)
return {'image': img,
'label': mask}
'''RandomCrop'''
class RandomCrop(object):
#原来是没有概率,对每张图都进行随机裁剪
#现在我给他改成对每张图片以一定的概率进行随机裁剪,一般就默认0.5
#prob 0就是不进行
'''
随机裁剪具体代码就别读了它大概的原理
比如你要输入网络的尺寸为512*512
你的原图是500750
这个随机裁剪比如随机裁剪的尺寸是crop_size=256*256
那么将会从原图500,750随机裁剪出一块256*256的区域然后再resize成网络输入的大小
注意这个class出来就是crop_size大小的区域然后根据网络要求的输入在把这个crop的区域resize成512
这个代码是在crop size的长和高都不超过范围的的情况下裁剪出来是crop size大小的
比如原图500750 cropsize=256 那就会随机从原图中裁剪出256*256
如果crop size超过原图了 比如原图500,750 你的crop size为800*800那就直接返回原图500750
如果只有一个尺寸超了范围比如crop size 700,700那么对应的那一边就会返回原图最小的那一边 500,700)
'''
def __init__(self, crop_size, crop_prob = 0.5,ignore_index=255, one_category_max_ratio=0.75):
self.crop_size = crop_size
self.crop_prob = crop_prob
if isinstance(crop_size, int): self.crop_size = (crop_size, crop_size)
self.ignore_index = ignore_index
self.one_category_max_ratio = one_category_max_ratio
'''call'''
def __call__(self, sample):
# avoid the cropped image is filled by only one category
if np.random.rand() > self.crop_prob: return sample
for _ in range(10):
# --parse
image, segmentation = sample['image'].copy(), sample['label'].copy()
h_ori, w_ori = image.shape[:2]
h_out, w_out = min(self.crop_size[0], h_ori), min(self.crop_size[1], w_ori)
# --random crop
top, left = np.random.randint(0, h_ori - h_out + 1), np.random.randint(0, w_ori - w_out + 1)
image = image[top: top + h_out, left: left + w_out]
segmentation = segmentation[top: top + h_out, left: left + w_out]
# --judge
labels, counts = np.unique(segmentation, return_counts=True)
counts = counts[labels != self.ignore_index]
if len(counts) > 1 and np.max(counts) / np.sum(counts) < self.one_category_max_ratio: break
# update and return sample
if len(counts) == 0: return sample
sample['image'], sample['label'] = image, segmentation
return sample
'''RandomFlip'''
class RandomFlip(object):
#以一定的额概率随机翻转
def __init__(self, flip_prob=0.5, fix_ann_pairs=None):
self.flip_prob = flip_prob
self.fix_ann_pairs = fix_ann_pairs
'''call'''
def __call__(self, sample):
if np.random.rand() > self.flip_prob: return sample
image, segmentation = sample['image'].copy(), sample['label'].copy()
image, segmentation = np.flip(image, axis=1), np.flip(segmentation, axis=1)
if self.fix_ann_pairs:
for (pair_a, pair_b) in self.fix_ann_pairs:
pair_a_pos = np.where(segmentation == pair_a)
pair_b_pos = np.where(segmentation == pair_b)
segmentation[pair_a_pos[0], pair_a_pos[1]] = pair_b
segmentation[pair_b_pos[0], pair_b_pos[1]] = pair_a
sample['image'], sample['label'] = image, segmentation
return sample
'''PhotoMetricDistortion'''
class PhotoMetricDistortion(object):
'''一些颜色变换'''
def __init__(self, brightness_delta=32, contrast_range=(0.5, 1.5), saturation_range=(0.5, 1.5), hue_delta=18):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
'''call'''
def __call__(self, sample):
image = sample['image'].copy()
image = self.brightness(image)
mode = np.random.randint(2)
if mode == 1: image = self.contrast(image)
image = self.saturation(image)
image = self.hue(image)
if mode == 0: image = self.contrast(image)
sample['image'] = image
return sample
'''brightness distortion亮度'''
def brightness(self, image):
if not np.random.randint(2): return image
return self.convert(image, beta=np.random.uniform(-self.brightness_delta, self.brightness_delta))
'''contrast distortion对比度'''
def contrast(self, image):
if not np.random.randint(2): return image
return self.convert(image, alpha=np.random.uniform(self.contrast_lower, self.contrast_upper))
'''rgb2hsvRGB颜色空间转hsv颜色空间'''
def rgb2hsv(self, image):
return cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
'''hsv2rgbhsv转rgb'''
def hsv2rgb(self, image):
return cv2.cvtColor(image, cv2.COLOR_HSV2RGB)
'''saturation distortion饱和度'''
def saturation(self, image):
if not np.random.randint(2): return image
image = self.rgb2hsv(image)
image[..., 1] = self.convert(image[..., 1], alpha=np.random.uniform(self.saturation_lower, self.saturation_upper))
image = self.hsv2rgb(image)
return image
'''hue distortion色调'''
def hue(self, image):
if not np.random.randint(2): return image
image = self.rgb2hsv(image)
image[..., 0] = (image[..., 0].astype(int) + np.random.randint(-self.hue_delta, self.hue_delta)) % 180
image = self.hsv2rgb(image)
return image
'''multiple with alpha and add beat with clip'''
def convert(self, image, alpha=1, beta=0):
image = image.astype(np.float32) * alpha + beta
image = np.clip(image, 0, 255)
return image.astype(np.uint8)
'''RandomRotation'''
class RandomRotation(object):
'''随机旋转'''
'''每一张图有rotation_prob的概率会进行angle_upper度的旋转'''
def __init__(self, angle_upper=30, rotation_prob=0.5, img_fill_value=0.0, seg_fill_value=255, img_interpolation='bicubic', seg_interpolation='nearest'):
# set attributes
'''
:param angle_upper: 旋转角度
:param rotation_prob: 旋转概率
:param img_fill_value: 原图旋转后用什么像素值填充
:param seg_fill_value: 标签图旋转后用什么像素值填充
:param img_interpolation:原图的插值方式
:param seg_interpolation: 标签图的插值方式
'''
self.angle_upper = angle_upper
self.rotation_prob = rotation_prob
self.img_fill_value = img_fill_value
self.seg_fill_value = seg_fill_value
self.img_interpolation = img_interpolation
self.seg_interpolation = seg_interpolation
# interpolation to cv2 interpolation
self.interpolation_dict = {
'nearest': cv2.INTER_NEAREST,
'bilinear': cv2.INTER_LINEAR,
'bicubic': cv2.INTER_CUBIC,
'area': cv2.INTER_AREA,
'lanczos': cv2.INTER_LANCZOS4
}
'''call'''
def __call__(self, sample):
if np.random.rand() > self.rotation_prob: return sample
image, segmentation = sample['image'].copy(), sample['label'].copy()
h_ori, w_ori = image.shape[:2]
rand_angle = np.random.randint(-self.angle_upper, self.angle_upper)
matrix = cv2.getRotationMatrix2D(center=(w_ori / 2, h_ori / 2), angle=rand_angle, scale=1)
image = cv2.warpAffine(image, matrix, (w_ori, h_ori), flags=self.interpolation_dict[self.img_interpolation], borderValue=self.img_fill_value)
segmentation = cv2.warpAffine(segmentation, matrix, (w_ori, h_ori), flags=self.interpolation_dict[self.seg_interpolation], borderValue=self.seg_fill_value)
sample['image'], sample['label'] = image, segmentation
return sample
'''Padding'''
class Padding(object):
def __init__(self, output_size, data_type='numpy', img_fill_value=0, seg_fill_value=255, output_size_auto_adaptive=True):
self.output_size = output_size
if isinstance(output_size, int): self.output_size = (output_size, output_size)
assert data_type in ['numpy', 'tensor'], 'unsupport data type %s' % data_type
self.data_type = data_type
self.img_fill_value = img_fill_value
self.seg_fill_value = seg_fill_value
self.output_size_auto_adaptive = output_size_auto_adaptive
'''call'''
def __call__(self, sample):
output_size = self.output_size[0], self.output_size[1]
if self.output_size_auto_adaptive:
if self.data_type == 'numpy':
h_ori, w_ori = sample['image'].shape[:2]
else:
h_ori, w_ori = sample['image'].shape[1:]
h_out, w_out = output_size
if (h_ori > w_ori and h_out < w_out) or (h_ori < w_ori and h_out > w_out):
output_size = (w_out, h_out)
if self.data_type == 'numpy':
image, segmentation, edge = sample['image'].copy(), sample['segmentation'].copy(), sample['edge'].copy()
h_ori, w_ori = image.shape[:2]
top = (output_size[0] - h_ori) // 2
bottom = output_size[0] - h_ori - top
left = (output_size[1] - w_ori) // 2
right = output_size[1] - w_ori - left
image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=[self.img_fill_value, self.img_fill_value, self.img_fill_value])
segmentation = cv2.copyMakeBorder(segmentation, top, bottom, left, right, cv2.BORDER_CONSTANT, value=[self.seg_fill_value])
edge = cv2.copyMakeBorder(edge, top, bottom, left, right, cv2.BORDER_CONSTANT, value=[self.seg_fill_value])
sample['image'], sample['segmentation'], sample['edge'] = image, segmentation, edge
else:
image, segmentation, edge = sample['image'], sample['segmentation'], sample['edge']
h_ori, w_ori = image.shape[1:]
top = (output_size[0] - h_ori) // 2
bottom = output_size[0] - h_ori - top
left = (output_size[1] - w_ori) // 2
right = output_size[1] - w_ori - left
image = F.pad(image, pad=(left, right, top, bottom), value=self.img_fill_value)
segmentation = F.pad(segmentation, pad=(left, right, top, bottom), value=self.seg_fill_value)
edge = F.pad(edge, pad=(left, right, top, bottom), value=self.seg_fill_value)
sample['image'], sample['segmentation'], sample['edge'] = image, segmentation, edge
return sample
################################################################
class pad_imagetrain(object):
'''如果原图原来就比网络输入的要小除了resize 这一种是保留原尺寸padding成网络大小'''
def __init__(self, target_size=(512, 512)):
self.target_size = target_size
def __call__(self, sample):
img = sample['image']
mask = sample['label']
img = np.array(img)
mask = np.array(mask)
#这里的img和sample还不是numpy
rows_missing = self.target_size[0] - img.shape[0]
cols_missing = self.target_size[1] - img.shape[1]
padded_img = np.pad(img, ((rows_missing//2,rows_missing//2), (cols_missing//2, cols_missing//2),(0,0)),
'constant',constant_values=(0,0))
padded_label =np.pad(mask, ((rows_missing//2,rows_missing//2), (cols_missing//2, cols_missing//2),),
'constant',constant_values=(255,255))
return {'image': padded_img,
'label': padded_label}