202 lines
6.6 KiB
Python
202 lines
6.6 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
@project:
|
||
@File : def_selfv5
|
||
@Author : qiqq
|
||
@create_time : 2023/3/22 22:16
|
||
在my的基础上aug后的
|
||
"""
|
||
from __future__ import print_function, division
|
||
import os
|
||
from PIL import Image
|
||
import matplotlib.pyplot as plt
|
||
from torch.utils.data import Dataset
|
||
from taihuyuan_pv.dataloaders.mypath import Path #注意一下这个path
|
||
from torchvision import transforms
|
||
from taihuyuan_pv.dataloaders import custom_transforms as tr
|
||
|
||
import numpy as np
|
||
import random
|
||
import torch
|
||
|
||
seed = 7
|
||
random.seed(seed)
|
||
np.random.seed(seed)
|
||
torch.manual_seed(seed) # 为CPU设置随机种子
|
||
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
|
||
|
||
|
||
class datasets_pvtaihuyuan(Dataset):
|
||
def __init__(self,
|
||
args,
|
||
base_dir=Path.db_root_dir("thy_pv"),
|
||
split='train',
|
||
isAug=False
|
||
):
|
||
super(datasets_pvtaihuyuan, self).__init__()
|
||
|
||
self.args = args
|
||
self.resize = args.resize
|
||
self.crop_size = args.crop_size # 是单数比如256,
|
||
self.flip_prob = args.flip_prob # 是0-1
|
||
|
||
self.isAug = isAug
|
||
self._base_dir = base_dir
|
||
self._image_dir = os.path.join(self._base_dir, 'images')
|
||
self._cat_dir = os.path.join(self._base_dir, 'labels')
|
||
|
||
if isinstance(split, str):
|
||
self.split = [split]
|
||
else:
|
||
split.sort()
|
||
self.split = split
|
||
|
||
_splits_dir = "/home/qiqq/q3dl/datalinan/taihuyuan_pv/"
|
||
|
||
self.im_ids = []
|
||
self.images = []
|
||
self.categories = []
|
||
|
||
for splt in self.split:
|
||
with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f:
|
||
lines = f.read().splitlines()
|
||
|
||
for ii, line in enumerate(lines):
|
||
_image = os.path.join(self._image_dir, line + ".png") # 注意这个地方有的图片可能是jpg有的可能是png自己看着改
|
||
_cat = os.path.join(self._cat_dir, line + ".png") # 注意格式
|
||
# print(_image)
|
||
# if line=="tuankou_11-6_2_aug5":
|
||
# print("ddd")
|
||
assert os.path.isfile(_image)
|
||
# print("ok",_image)
|
||
assert os.path.isfile(_cat)
|
||
# print("ok",_cat)
|
||
self.im_ids.append(line)
|
||
self.images.append(_image)
|
||
self.categories.append(_cat)
|
||
|
||
assert (len(self.images) == len(self.categories))
|
||
|
||
# Display stats
|
||
print('Number of images in {}: {:d}'.format(split, len(self.images)))
|
||
|
||
def __len__(self):
|
||
return len(self.images)
|
||
|
||
def __getitem__(self, index):
|
||
|
||
_img, _target = self._make_img_gt_point_pair(index)
|
||
sample = {'image': _img, 'label': _target}
|
||
|
||
for split in self.split:
|
||
if split == "trainbat":
|
||
if self.isAug: # 训练的时候使用数据增强验证的时候不需要
|
||
return self.AugTrain(sample)
|
||
else:
|
||
return self.transform_tr(sample)
|
||
elif split == 'valbat':
|
||
return self.transform_val(sample)
|
||
elif split == 'test1':
|
||
return self.transform_val(sample)
|
||
|
||
def _make_img_gt_point_pair(self, index):
|
||
_img = Image.open(self.images[index]).convert('RGB')
|
||
_target = Image.open(self.categories[index])
|
||
|
||
return _img, _target
|
||
|
||
|
||
def transform_tr(self, sample):
|
||
composed_transforms = transforms.Compose([
|
||
tr.Resize(self.args.resize), # 先缩放要不然原图太大了进不去
|
||
# tr.Normalize(mean=(0.231, 0.217, 0.22), std=(0.104, 0.086, 0.085)), # 针对whdld数据集的
|
||
tr.Normalize_simple(),
|
||
tr.ToTensor()])
|
||
|
||
return composed_transforms(sample)
|
||
|
||
|
||
def AugTrain(self, sample):
|
||
'''
|
||
这个版本的AugTrain使用的 Zhenchao Jin版本对numpy进行操作
|
||
:param sample:
|
||
:return:
|
||
|
||
随机裁剪,随机翻转,随机颜色增强,随机旋转这四个就够了以后别再浪费时间在这上边了
|
||
'''
|
||
composed_transforms = transforms.Compose([
|
||
tr.Tonumpy(),
|
||
# tr.RandomCrop(crop_size=self.crop_size),
|
||
tr.RandomFlip(),
|
||
tr.PhotoMetricDistortion(),
|
||
tr.RandomRotation(),
|
||
|
||
tr.Resize2(self.resize,scale_range=None,keep_ratio=False),
|
||
tr.Normalize_simple(),
|
||
tr.ToTensor()])
|
||
return composed_transforms(sample)
|
||
|
||
def transform_val(self, sample):
|
||
composed_transforms = transforms.Compose([
|
||
tr.Resize(self.args.resize),
|
||
# tr.Normalize(mean=(0.052, 0.05, 0.05), std=(0.027, 0.026, 0.026)),
|
||
tr.Normalize_simple(),
|
||
|
||
tr.ToTensor()])
|
||
return composed_transforms(sample)
|
||
|
||
|
||
def __str__(self):
|
||
return 'linan(split=' + str(self.split) + ')'
|
||
#
|
||
|
||
if __name__ == '__main__':
|
||
a =np.random.randn(3,3)
|
||
# print(a)
|
||
|
||
|
||
from taihuyuan_pv.dataloaders.utils import decode_segmap
|
||
from torch.utils.data import DataLoader
|
||
# import matplotlib.pyplot as plt
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser()
|
||
args = parser.parse_args()
|
||
args.resize = (512, 512)
|
||
args.crop_size = 480 #在验证的时候没有用 crop和filp 只是为了不报错
|
||
args.flip_prob = 0.5
|
||
|
||
voc_train = datasets_pvtaihuyuan(args, split='trainbat',isAug=False)
|
||
voc_val = datasets_pvtaihuyuan(args, split='valbat', isAug=False)
|
||
print(len(voc_train))
|
||
print(len(voc_val))
|
||
# for i in voc_train:
|
||
# print(type(i))
|
||
|
||
dataloader = DataLoader(voc_train, batch_size=1, shuffle=False, num_workers=0)
|
||
|
||
for ii, sample in enumerate(dataloader):
|
||
for jj in range(sample["image"].size()[0]):
|
||
img = sample['image'].numpy()
|
||
gt = sample['label'].numpy()
|
||
# gt=gt-1
|
||
# print(gt)
|
||
tmp = np.array(gt[jj]).astype(np.uint8)
|
||
segmap = decode_segmap(tmp, dataset='pascal_customer')
|
||
img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
|
||
img_tmp *= 255.0
|
||
img_tmp = img_tmp.astype(np.uint8)
|
||
plt.figure()
|
||
plt.title('img')
|
||
plt.imshow(img_tmp)
|
||
plt.show()
|
||
plt.figure()
|
||
plt.title('label')
|
||
plt.imshow(segmap)
|
||
plt.show()
|
||
|
||
# if ii == 1:
|
||
# break
|
||
|