198 lines
6.5 KiB
Python
198 lines
6.5 KiB
Python
|
#!/usr/bin/env python
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
@project:
|
|||
|
@File : def_selfv5
|
|||
|
@Author : qiqq
|
|||
|
@create_time : 2023/3/22 22:16
|
|||
|
在my的基础上aug后的
|
|||
|
"""
|
|||
|
from __future__ import print_function, division
|
|||
|
import os
|
|||
|
from PIL import Image
|
|||
|
import matplotlib.pyplot as plt
|
|||
|
from torch.utils.data import Dataset
|
|||
|
from taihuyuan_roof.dataloaders.mypath import Path #注意一下这个path
|
|||
|
from torchvision import transforms
|
|||
|
from taihuyuan_roof.dataloaders import custom_transforms as tr
|
|||
|
|
|||
|
import numpy as np
|
|||
|
import random
|
|||
|
import torch
|
|||
|
|
|||
|
seed = 7
|
|||
|
random.seed(seed)
|
|||
|
np.random.seed(seed)
|
|||
|
torch.manual_seed(seed) # 为CPU设置随机种子
|
|||
|
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
|
|||
|
|
|||
|
class datasets_rooftaihuyuan(Dataset):
|
|||
|
def __init__(self,
|
|||
|
args,
|
|||
|
base_dir=Path.db_root_dir("thy_roof"),
|
|||
|
split='train',
|
|||
|
isAug=False
|
|||
|
):
|
|||
|
super(datasets_rooftaihuyuan, self).__init__()
|
|||
|
|
|||
|
self.args = args
|
|||
|
self.resize = args.resize
|
|||
|
self.crop_size = args.crop_size # 是单数比如256,
|
|||
|
self.flip_prob = args.flip_prob # 是0-1
|
|||
|
|
|||
|
self.isAug = isAug
|
|||
|
self._base_dir = base_dir
|
|||
|
self._image_dir = os.path.join(self._base_dir, 'images1')
|
|||
|
self._cat_dir = os.path.join(self._base_dir, 'labels1')
|
|||
|
|
|||
|
if isinstance(split, str):
|
|||
|
self.split = [split]
|
|||
|
else:
|
|||
|
split.sort()
|
|||
|
self.split = split
|
|||
|
|
|||
|
_splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_roof/"
|
|||
|
|
|||
|
self.im_ids = []
|
|||
|
self.images = []
|
|||
|
self.categories = []
|
|||
|
|
|||
|
for splt in self.split:
|
|||
|
with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f:
|
|||
|
lines = f.read().splitlines()
|
|||
|
|
|||
|
for ii, line in enumerate(lines):
|
|||
|
_image = os.path.join(self._image_dir, line + ".png") # 注意这个地方有的图片可能是jpg有的可能是png自己看着改
|
|||
|
_cat = os.path.join(self._cat_dir, line + ".png") # 注意格式
|
|||
|
# print(_image)
|
|||
|
# print(_cat)
|
|||
|
assert os.path.isfile(_image)
|
|||
|
assert os.path.isfile(_cat)
|
|||
|
# print("通过")
|
|||
|
self.im_ids.append(line)
|
|||
|
self.images.append(_image)
|
|||
|
self.categories.append(_cat)
|
|||
|
|
|||
|
assert (len(self.images) == len(self.categories))
|
|||
|
|
|||
|
# Display stats
|
|||
|
print('Number of images in {}: {:d}'.format(split, len(self.images)))
|
|||
|
|
|||
|
def __len__(self):
|
|||
|
return len(self.images)
|
|||
|
|
|||
|
def __getitem__(self, index):
|
|||
|
_img, _target = self._make_img_gt_point_pair(index)
|
|||
|
sample = {'image': _img, 'label': _target}
|
|||
|
|
|||
|
for split in self.split:
|
|||
|
if split == "trainbat":
|
|||
|
if self.isAug: # 训练的时候使用数据增强验证的时候不需要
|
|||
|
return self.AugTrain(sample)
|
|||
|
else:
|
|||
|
return self.transform_tr(sample)
|
|||
|
elif split == 'valbat':
|
|||
|
return self.transform_val(sample)
|
|||
|
elif split == 'test1':
|
|||
|
return self.transform_val(sample)
|
|||
|
|
|||
|
def _make_img_gt_point_pair(self, index):
|
|||
|
_img = Image.open(self.images[index]).convert('RGB')
|
|||
|
_target = Image.open(self.categories[index])
|
|||
|
|
|||
|
return _img, _target
|
|||
|
|
|||
|
|
|||
|
def transform_tr(self, sample):
|
|||
|
composed_transforms = transforms.Compose([
|
|||
|
tr.Resize(self.args.resize), # 先缩放要不然原图太大了进不去
|
|||
|
# tr.Normalize(mean=(0.231, 0.217, 0.22), std=(0.104, 0.086, 0.085)), # 针对whdld数据集的
|
|||
|
tr.Normalize_simple(),
|
|||
|
tr.ToTensor()])
|
|||
|
|
|||
|
return composed_transforms(sample)
|
|||
|
|
|||
|
|
|||
|
def AugTrain(self, sample):
|
|||
|
'''
|
|||
|
这个版本的AugTrain使用的 Zhenchao Jin版本对numpy进行操作
|
|||
|
:param sample:
|
|||
|
:return:
|
|||
|
|
|||
|
随机裁剪,随机翻转,随机颜色增强,随机旋转这四个就够了以后别再浪费时间在这上边了
|
|||
|
'''
|
|||
|
composed_transforms = transforms.Compose([
|
|||
|
tr.Tonumpy(),
|
|||
|
# tr.RandomCrop(crop_size=self.crop_size),
|
|||
|
tr.RandomFlip(),
|
|||
|
tr.PhotoMetricDistortion(),
|
|||
|
tr.RandomRotation(),
|
|||
|
|
|||
|
tr.Resize2(self.resize,scale_range=None,keep_ratio=False),
|
|||
|
tr.Normalize_simple(),
|
|||
|
tr.ToTensor()])
|
|||
|
return composed_transforms(sample)
|
|||
|
|
|||
|
def transform_val(self, sample):
|
|||
|
composed_transforms = transforms.Compose([
|
|||
|
tr.Resize(self.args.resize),
|
|||
|
# tr.Normalize(mean=(0.052, 0.05, 0.05), std=(0.027, 0.026, 0.026)),
|
|||
|
tr.Normalize_simple(),
|
|||
|
|
|||
|
tr.ToTensor()])
|
|||
|
return composed_transforms(sample)
|
|||
|
|
|||
|
|
|||
|
def __str__(self):
|
|||
|
return 'linan(split=' + str(self.split) + ')'
|
|||
|
#
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
a =np.random.randn(3,3)
|
|||
|
# print(a)
|
|||
|
|
|||
|
|
|||
|
from taihuyuan_roof.dataloaders.utils import decode_segmap
|
|||
|
from torch.utils.data import DataLoader
|
|||
|
# import matplotlib.pyplot as plt
|
|||
|
import argparse
|
|||
|
|
|||
|
parser = argparse.ArgumentParser()
|
|||
|
args = parser.parse_args()
|
|||
|
args.resize = (512, 512)
|
|||
|
args.crop_size = 480 #在验证的时候没有用 crop和filp 只是为了不报错
|
|||
|
args.flip_prob = 0.5
|
|||
|
|
|||
|
voc_train = datasets_rooftaihuyuan(args, split='trainbat',isAug=True)
|
|||
|
voc_val = datasets_rooftaihuyuan(args, split='valbat', isAug=False)
|
|||
|
print(len(voc_train))
|
|||
|
print(len(voc_val))
|
|||
|
# for i in voc_train:
|
|||
|
# print(type(i))
|
|||
|
|
|||
|
dataloader = DataLoader(voc_train, batch_size=1, shuffle=False, num_workers=0)
|
|||
|
|
|||
|
for ii, sample in enumerate(dataloader):
|
|||
|
for jj in range(sample["image"].size()[0]):
|
|||
|
img = sample['image'].numpy()
|
|||
|
gt = sample['label'].numpy()
|
|||
|
# gt=gt-1
|
|||
|
# print(gt)
|
|||
|
tmp = np.array(gt[jj]).astype(np.uint8)
|
|||
|
segmap = decode_segmap(tmp, dataset='pascal_customer')
|
|||
|
img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
|
|||
|
img_tmp *= 255.0
|
|||
|
img_tmp = img_tmp.astype(np.uint8)
|
|||
|
plt.figure()
|
|||
|
plt.title('img')
|
|||
|
plt.imshow(img_tmp)
|
|||
|
plt.show()
|
|||
|
plt.figure()
|
|||
|
plt.title('label')
|
|||
|
plt.imshow(segmap)
|
|||
|
plt.show()
|
|||
|
|
|||
|
# if ii == 1:
|
|||
|
# break
|
|||
|
|