import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint from tools.cfg import py2cfg import os import torch from torch import nn import numpy as np import argparse from pathlib import Path from tools.metric import Evaluator from pytorch_lightning.loggers import CSVLogger import random # 设置随机数生成器的种子,以确保实验结果的可重复性 def seed_everything(seed): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True # 解析命令行参数,获取配置文件的路径 def get_args(): parser = argparse.ArgumentParser() arg = parser.add_argument arg("-c", "--config_path", type=Path, help="Path to the config.", required=True) return parser.parse_args() # 自定义训练类,继承自 PyTorch Lightning 模块 class Supervision_Train(pl.LightningModule): def __init__(self, config): super().__init__() self.config = config self.net = config.net self.loss = config.loss self.metrics_train = Evaluator(num_class=config.num_classes) self.metrics_val = Evaluator(num_class=config.num_classes) def forward(self, x): seg_pre = self.net(x) return seg_pre def training_step(self, batch, batch_idx): img, mask = batch['img'], batch['gt_semantic_seg'] prediction = self.net(img) loss = self.loss(prediction, mask) if self.config.use_aux_loss: pre_mask = nn.Softmax(dim=1)(prediction[0]) else: pre_mask = nn.Softmax(dim=1)(prediction) pre_mask = pre_mask.argmax(dim=1) for i in range(mask.shape[0]): self.metrics_train.add_batch(mask[i].cpu().numpy(), pre_mask[i].cpu().numpy()) return {"loss": loss} # 在每个训练纪元结束时计算并打印评估指标 def on_train_epoch_end(self): mIoU = np.nanmean(self.metrics_train.Intersection_over_Union()[:-1]) F1 = np.nanmean(self.metrics_train.F1()[:-1]) OA = np.nanmean(self.metrics_train.OA()) iou_per_class = self.metrics_train.Intersection_over_Union() eval_value = {'mIoU': mIoU, 'F1': F1, 'OA': OA} print('train:', eval_value) iou_value = {} for class_name, iou in zip(self.config.classes, iou_per_class): iou_value[class_name] = iou print(iou_value) self.metrics_train.reset() log_dict = {'train_mIoU': mIoU, 'train_F1': F1, 'train_OA': OA} self.log_dict(log_dict, prog_bar=True) # 验证步骤,处理每个验证批次的数据,计算损失并更新评估指标 def validation_step(self, batch, batch_idx): img, mask = batch['img'], batch['gt_semantic_seg'] prediction = self.forward(img) pre_mask = nn.Softmax(dim=1)(prediction) pre_mask = pre_mask.argmax(dim=1) for i in range(mask.shape[0]): self.metrics_val.add_batch(mask[i].cpu().numpy(), pre_mask[i].cpu().numpy()) loss_val = self.loss(prediction, mask) return {"loss_val": loss_val} # 在每个验证纪元结束时计算并打印评估指标 def on_validation_epoch_end(self): mIoU = np.nanmean(self.metrics_val.Intersection_over_Union()[:-1]) F1 = np.nanmean(self.metrics_val.F1()[:-1]) OA = np.nanmean(self.metrics_val.OA()) iou_per_class = self.metrics_val.Intersection_over_Union() eval_value = {'mIoU': mIoU, 'F1': F1, 'OA': OA} print('val:', eval_value) iou_value = {} for class_name, iou in zip(self.config.classes, iou_per_class): iou_value[class_name] = iou print(iou_value) self.metrics_val.reset() log_dict = {'val_mIoU': mIoU, 'val_F1': F1, 'val_OA': OA} self.log_dict(log_dict, prog_bar=True) # 配置优化器和学习率调度器 def configure_optimizers(self): optimizer = self.config.optimizer lr_scheduler = self.config.lr_scheduler return [optimizer], [lr_scheduler] def train_dataloader(self): return self.config.train_loader def val_dataloader(self): return self.config.val_loader # 主函数,用于执行整个训练流程 def main(): args = get_args() config = py2cfg(args.config_path) seed_everything(42) checkpoint_callback = ModelCheckpoint(save_top_k=config.save_top_k, monitor=config.monitor, save_last=config.save_last, mode=config.monitor_mode, dirpath=config.weights_path, filename=config.weights_name) logger = CSVLogger('lightning_logs', name=config.log_name) model = Supervision_Train(config) if config.pretrained_ckpt_path: model = Supervision_Train.load_from_checkpoint(config.pretrained_ckpt_path, config=config) trainer = pl.Trainer(devices=config.gpus, max_epochs=config.max_epoch, accelerator='auto', check_val_every_n_epoch=config.check_val_every_n_epoch, callbacks=[checkpoint_callback], strategy='auto', logger=logger) trainer.fit(model=model, ckpt_path=config.resume_ckpt_path) if __name__ == "__main__": main()