26 lines
719 B
Python
26 lines
719 B
Python
|
import torch
|
||
|
from collections import Counter
|
||
|
|
||
|
|
||
|
def check_labels(data_loader):
|
||
|
all_labels = []
|
||
|
for batch in data_loader:
|
||
|
images, masks = batch
|
||
|
all_labels.extend(masks.view(-1).tolist())
|
||
|
|
||
|
label_counts = Counter(all_labels)
|
||
|
return label_counts
|
||
|
|
||
|
|
||
|
# 定义你的数据加载器
|
||
|
train_loader = ... # 替换为你的训练数据加载器
|
||
|
val_loader = ... # 替换为你的验证数据加载器
|
||
|
|
||
|
# 检查训练数据集中的标签
|
||
|
train_label_counts = check_labels(train_loader)
|
||
|
print("Training dataset label distribution:", train_label_counts)
|
||
|
|
||
|
# 检查验证数据集中的标签
|
||
|
val_label_counts = check_labels(val_loader)
|
||
|
print("Validation dataset label distribution:", val_label_counts)
|