ai-station-code/wudingpv/taihuyuan_roof/dataloaders/datasets/def_taihuyuan.py

198 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : def_selfv5
@Author : qiqq
@create_time : 2023/3/22 22:16
在my的基础上aug后的
"""
from __future__ import print_function, division
import os
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from taihuyuan_roof.dataloaders.mypath import Path #注意一下这个path
from torchvision import transforms
from taihuyuan_roof.dataloaders import custom_transforms as tr
import numpy as np
import random
import torch
seed = 7
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed) # 为CPU设置随机种子
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
class datasets_rooftaihuyuan(Dataset):
def __init__(self,
args,
base_dir=Path.db_root_dir("thy_roof"),
split='train',
isAug=False
):
super(datasets_rooftaihuyuan, self).__init__()
self.args = args
self.resize = args.resize
self.crop_size = args.crop_size # 是单数比如256
self.flip_prob = args.flip_prob # 是0-1
self.isAug = isAug
self._base_dir = base_dir
self._image_dir = os.path.join(self._base_dir, 'images1')
self._cat_dir = os.path.join(self._base_dir, 'labels1')
if isinstance(split, str):
self.split = [split]
else:
split.sort()
self.split = split
_splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_roof/"
self.im_ids = []
self.images = []
self.categories = []
for splt in self.split:
with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f:
lines = f.read().splitlines()
for ii, line in enumerate(lines):
_image = os.path.join(self._image_dir, line + ".png") # 注意这个地方有的图片可能是jpg有的可能是png自己看着改
_cat = os.path.join(self._cat_dir, line + ".png") # 注意格式
# print(_image)
# print(_cat)
assert os.path.isfile(_image)
assert os.path.isfile(_cat)
# print("通过")
self.im_ids.append(line)
self.images.append(_image)
self.categories.append(_cat)
assert (len(self.images) == len(self.categories))
# Display stats
print('Number of images in {}: {:d}'.format(split, len(self.images)))
def __len__(self):
return len(self.images)
def __getitem__(self, index):
_img, _target = self._make_img_gt_point_pair(index)
sample = {'image': _img, 'label': _target}
for split in self.split:
if split == "trainbat":
if self.isAug: # 训练的时候使用数据增强验证的时候不需要
return self.AugTrain(sample)
else:
return self.transform_tr(sample)
elif split == 'valbat':
return self.transform_val(sample)
elif split == 'test1':
return self.transform_val(sample)
def _make_img_gt_point_pair(self, index):
_img = Image.open(self.images[index]).convert('RGB')
_target = Image.open(self.categories[index])
return _img, _target
def transform_tr(self, sample):
composed_transforms = transforms.Compose([
tr.Resize(self.args.resize), # 先缩放要不然原图太大了进不去
# tr.Normalize(mean=(0.231, 0.217, 0.22), std=(0.104, 0.086, 0.085)), # 针对whdld数据集的
tr.Normalize_simple(),
tr.ToTensor()])
return composed_transforms(sample)
def AugTrain(self, sample):
'''
这个版本的AugTrain使用的 Zhenchao Jin版本对numpy进行操作
:param sample:
:return:
随机裁剪,随机翻转,随机颜色增强,随机旋转这四个就够了以后别再浪费时间在这上边了
'''
composed_transforms = transforms.Compose([
tr.Tonumpy(),
# tr.RandomCrop(crop_size=self.crop_size),
tr.RandomFlip(),
tr.PhotoMetricDistortion(),
tr.RandomRotation(),
tr.Resize2(self.resize,scale_range=None,keep_ratio=False),
tr.Normalize_simple(),
tr.ToTensor()])
return composed_transforms(sample)
def transform_val(self, sample):
composed_transforms = transforms.Compose([
tr.Resize(self.args.resize),
# tr.Normalize(mean=(0.052, 0.05, 0.05), std=(0.027, 0.026, 0.026)),
tr.Normalize_simple(),
tr.ToTensor()])
return composed_transforms(sample)
def __str__(self):
return 'linan(split=' + str(self.split) + ')'
#
if __name__ == '__main__':
a =np.random.randn(3,3)
# print(a)
from taihuyuan_roof.dataloaders.utils import decode_segmap
from torch.utils.data import DataLoader
# import matplotlib.pyplot as plt
import argparse
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.resize = (512, 512)
args.crop_size = 480 #在验证的时候没有用 crop和filp 只是为了不报错
args.flip_prob = 0.5
voc_train = datasets_rooftaihuyuan(args, split='trainbat',isAug=True)
voc_val = datasets_rooftaihuyuan(args, split='valbat', isAug=False)
print(len(voc_train))
print(len(voc_val))
# for i in voc_train:
# print(type(i))
dataloader = DataLoader(voc_train, batch_size=1, shuffle=False, num_workers=0)
for ii, sample in enumerate(dataloader):
for jj in range(sample["image"].size()[0]):
img = sample['image'].numpy()
gt = sample['label'].numpy()
# gt=gt-1
# print(gt)
tmp = np.array(gt[jj]).astype(np.uint8)
segmap = decode_segmap(tmp, dataset='pascal_customer')
img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
img_tmp *= 255.0
img_tmp = img_tmp.astype(np.uint8)
plt.figure()
plt.title('img')
plt.imshow(img_tmp)
plt.show()
plt.figure()
plt.title('label')
plt.imshow(segmap)
plt.show()
# if ii == 1:
# break