Tan_pytorch_segmentation/pytorch_segmentation/PV_FuseDisNet/确定标签值.py

26 lines
719 B
Python
Raw Permalink Normal View History

2025-05-19 20:48:24 +08:00
import torch
from collections import Counter
def check_labels(data_loader):
all_labels = []
for batch in data_loader:
images, masks = batch
all_labels.extend(masks.view(-1).tolist())
label_counts = Counter(all_labels)
return label_counts
# 定义你的数据加载器
train_loader = ... # 替换为你的训练数据加载器
val_loader = ... # 替换为你的验证数据加载器
# 检查训练数据集中的标签
train_label_counts = check_labels(train_loader)
print("Training dataset label distribution:", train_label_counts)
# 检查验证数据集中的标签
val_label_counts = check_labels(val_loader)
print("Validation dataset label distribution:", val_label_counts)