100 lines
4.1 KiB
Python
100 lines
4.1 KiB
Python
|
import os
|
|||
|
import random
|
|||
|
|
|||
|
import numpy as np
|
|||
|
from PIL import Image
|
|||
|
from tqdm import tqdm
|
|||
|
|
|||
|
#-------------------------------------------------------#
|
|||
|
# 想要增加测试集修改trainval_percent
|
|||
|
# 修改train_percent用于改变验证集的比例 9:1
|
|||
|
#
|
|||
|
# 当前该库将测试集当作验证集使用,不单独划分测试集
|
|||
|
#-------------------------------------------------------#
|
|||
|
trainval_percent = 1
|
|||
|
train_percent = 0.9
|
|||
|
#-------------------------------------------------------#
|
|||
|
# 指向VOC数据集所在的文件夹
|
|||
|
# 默认指向根目录下的VOC数据集
|
|||
|
#-------------------------------------------------------#
|
|||
|
VOCdevkit_path = 'VOCdevkit'
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
random.seed(0)
|
|||
|
print("Generate txt in ImageSets.")
|
|||
|
segfilepath = os.path.join(VOCdevkit_path, 'VOC2007/SegmentationClass')
|
|||
|
saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Segmentation')
|
|||
|
|
|||
|
temp_seg = os.listdir(segfilepath)
|
|||
|
total_seg = []
|
|||
|
for seg in temp_seg:
|
|||
|
if seg.endswith(".png"):
|
|||
|
total_seg.append(seg)
|
|||
|
|
|||
|
num = len(total_seg)
|
|||
|
list = range(num)
|
|||
|
tv = int(num*trainval_percent)
|
|||
|
tr = int(tv*train_percent)
|
|||
|
trainval= random.sample(list,tv)
|
|||
|
train = random.sample(trainval,tr)
|
|||
|
|
|||
|
print("train and val size",tv)
|
|||
|
print("traub suze",tr)
|
|||
|
ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w')
|
|||
|
ftest = open(os.path.join(saveBasePath,'test.txt'), 'w')
|
|||
|
ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w')
|
|||
|
fval = open(os.path.join(saveBasePath,'val.txt'), 'w')
|
|||
|
|
|||
|
for i in list:
|
|||
|
name = total_seg[i][:-4]+'\n'
|
|||
|
if i in trainval:
|
|||
|
ftrainval.write(name)
|
|||
|
if i in train:
|
|||
|
ftrain.write(name)
|
|||
|
else:
|
|||
|
fval.write(name)
|
|||
|
else:
|
|||
|
ftest.write(name)
|
|||
|
|
|||
|
ftrainval.close()
|
|||
|
ftrain.close()
|
|||
|
fval.close()
|
|||
|
ftest.close()
|
|||
|
print("Generate txt in ImageSets done.")
|
|||
|
|
|||
|
print("Check datasets format, this may take a while.")
|
|||
|
print("检查数据集格式是否符合要求,这可能需要一段时间。")
|
|||
|
# classes_nums = np.zeros([256], np.int)
|
|||
|
classes_nums = np.zeros([256], dtype=int) # 使用内置int
|
|||
|
|
|||
|
for i in tqdm(list):
|
|||
|
name = total_seg[i]
|
|||
|
png_file_name = os.path.join(segfilepath, name)
|
|||
|
if not os.path.exists(png_file_name):
|
|||
|
raise ValueError("未检测到标签图片%s,请查看具体路径下文件是否存在以及后缀是否为png。"%(png_file_name))
|
|||
|
|
|||
|
png = np.array(Image.open(png_file_name), np.uint8)
|
|||
|
if len(np.shape(png)) > 2:
|
|||
|
print("标签图片%s的shape为%s,不属于灰度图或者八位彩图,请仔细检查数据集格式。"%(name, str(np.shape(png))))
|
|||
|
#print("标签图片需要为灰度图或者八位彩图,标签的每个像素点的值就是这个像素点所属的种类。"%(name, str(np.shape(png))))
|
|||
|
|
|||
|
classes_nums += np.bincount(np.reshape(png, [-1]), minlength=256)
|
|||
|
|
|||
|
print("打印像素点的值与数量。")
|
|||
|
print('-' * 37)
|
|||
|
print("| %15s | %15s |"%("Key", "Value"))
|
|||
|
print('-' * 37)
|
|||
|
for i in range(256):
|
|||
|
if classes_nums[i] > 0:
|
|||
|
print("| %15s | %15s |"%(str(i), str(classes_nums[i])))
|
|||
|
print('-' * 37)
|
|||
|
|
|||
|
if classes_nums[255] > 0 and classes_nums[0] > 0 and np.sum(classes_nums[1:255]) == 0:
|
|||
|
print("检测到标签中像素点的值仅包含0与255,数据格式有误。")
|
|||
|
print("二分类问题需要将标签修改为背景的像素点值为0,目标的像素点值为1。")
|
|||
|
elif classes_nums[0] > 0 and np.sum(classes_nums[1:]) == 0:
|
|||
|
print("检测到标签中仅仅包含背景像素点,数据格式有误,请仔细检查数据集格式。")
|
|||
|
|
|||
|
print("JPEGImages中的图片应当为.jpg文件、SegmentationClass中的图片应当为.png文件。")
|
|||
|
print("如果格式有误,参考:")
|
|||
|
print("https://github.com/bubbliiiing/segmentation-format-fix")
|