Tan_pytorch_segmentation/pytorch_segmentation/PV_FuseDisNet/train.py

148 lines
5.3 KiB
Python
Raw Normal View History

2025-05-19 20:48:24 +08:00
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from tools.cfg import py2cfg
import os
import torch
from torch import nn
import numpy as np
import argparse
from pathlib import Path
from tools.metric import Evaluator
from pytorch_lightning.loggers import CSVLogger
import random
# 设置随机数生成器的种子,以确保实验结果的可重复性
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
# 解析命令行参数,获取配置文件的路径
def get_args():
parser = argparse.ArgumentParser()
arg = parser.add_argument
arg("-c", "--config_path", type=Path, help="Path to the config.", required=True)
return parser.parse_args()
# 自定义训练类,继承自 PyTorch Lightning 模块
class Supervision_Train(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
self.net = config.net
self.loss = config.loss
self.metrics_train = Evaluator(num_class=config.num_classes)
self.metrics_val = Evaluator(num_class=config.num_classes)
def forward(self, x):
seg_pre = self.net(x)
return seg_pre
def training_step(self, batch, batch_idx):
img, mask = batch['img'], batch['gt_semantic_seg']
prediction = self.net(img)
loss = self.loss(prediction, mask)
if self.config.use_aux_loss:
pre_mask = nn.Softmax(dim=1)(prediction[0])
else:
pre_mask = nn.Softmax(dim=1)(prediction)
pre_mask = pre_mask.argmax(dim=1)
for i in range(mask.shape[0]):
self.metrics_train.add_batch(mask[i].cpu().numpy(), pre_mask[i].cpu().numpy())
return {"loss": loss}
# 在每个训练纪元结束时计算并打印评估指标
def on_train_epoch_end(self):
mIoU = np.nanmean(self.metrics_train.Intersection_over_Union()[:-1])
F1 = np.nanmean(self.metrics_train.F1()[:-1])
OA = np.nanmean(self.metrics_train.OA())
iou_per_class = self.metrics_train.Intersection_over_Union()
eval_value = {'mIoU': mIoU, 'F1': F1, 'OA': OA}
print('train:', eval_value)
iou_value = {}
for class_name, iou in zip(self.config.classes, iou_per_class):
iou_value[class_name] = iou
print(iou_value)
self.metrics_train.reset()
log_dict = {'train_mIoU': mIoU, 'train_F1': F1, 'train_OA': OA}
self.log_dict(log_dict, prog_bar=True)
# 验证步骤,处理每个验证批次的数据,计算损失并更新评估指标
def validation_step(self, batch, batch_idx):
img, mask = batch['img'], batch['gt_semantic_seg']
prediction = self.forward(img)
pre_mask = nn.Softmax(dim=1)(prediction)
pre_mask = pre_mask.argmax(dim=1)
for i in range(mask.shape[0]):
self.metrics_val.add_batch(mask[i].cpu().numpy(), pre_mask[i].cpu().numpy())
loss_val = self.loss(prediction, mask)
return {"loss_val": loss_val}
# 在每个验证纪元结束时计算并打印评估指标
def on_validation_epoch_end(self):
mIoU = np.nanmean(self.metrics_val.Intersection_over_Union()[:-1])
F1 = np.nanmean(self.metrics_val.F1()[:-1])
OA = np.nanmean(self.metrics_val.OA())
iou_per_class = self.metrics_val.Intersection_over_Union()
eval_value = {'mIoU': mIoU, 'F1': F1, 'OA': OA}
print('val:', eval_value)
iou_value = {}
for class_name, iou in zip(self.config.classes, iou_per_class):
iou_value[class_name] = iou
print(iou_value)
self.metrics_val.reset()
log_dict = {'val_mIoU': mIoU, 'val_F1': F1, 'val_OA': OA}
self.log_dict(log_dict, prog_bar=True)
# 配置优化器和学习率调度器
def configure_optimizers(self):
optimizer = self.config.optimizer
lr_scheduler = self.config.lr_scheduler
return [optimizer], [lr_scheduler]
def train_dataloader(self):
return self.config.train_loader
def val_dataloader(self):
return self.config.val_loader
# 主函数,用于执行整个训练流程
def main():
args = get_args()
config = py2cfg(args.config_path)
seed_everything(42)
checkpoint_callback = ModelCheckpoint(save_top_k=config.save_top_k, monitor=config.monitor,
save_last=config.save_last, mode=config.monitor_mode,
dirpath=config.weights_path,
filename=config.weights_name)
logger = CSVLogger('lightning_logs', name=config.log_name)
model = Supervision_Train(config)
if config.pretrained_ckpt_path:
model = Supervision_Train.load_from_checkpoint(config.pretrained_ckpt_path, config=config)
trainer = pl.Trainer(devices=config.gpus, max_epochs=config.max_epoch, accelerator='auto',
check_val_every_n_epoch=config.check_val_every_n_epoch,
callbacks=[checkpoint_callback], strategy='auto',
logger=logger)
trainer.fit(model=model, ckpt_path=config.resume_ckpt_path)
if __name__ == "__main__":
main()