import torch from collections import Counter def check_labels(data_loader): all_labels = [] for batch in data_loader: images, masks = batch all_labels.extend(masks.view(-1).tolist()) label_counts = Counter(all_labels) return label_counts # 定义你的数据加载器 train_loader = ... # 替换为你的训练数据加载器 val_loader = ... # 替换为你的验证数据加载器 # 检查训练数据集中的标签 train_label_counts = check_labels(train_loader) print("Training dataset label distribution:", train_label_counts) # 检查验证数据集中的标签 val_label_counts = check_labels(val_loader) print("Validation dataset label distribution:", val_label_counts)