148 lines
5.3 KiB
Python
148 lines
5.3 KiB
Python
|
import pytorch_lightning as pl
|
||
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
||
|
from tools.cfg import py2cfg
|
||
|
import os
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
import numpy as np
|
||
|
import argparse
|
||
|
from pathlib import Path
|
||
|
from tools.metric import Evaluator
|
||
|
from pytorch_lightning.loggers import CSVLogger
|
||
|
import random
|
||
|
|
||
|
|
||
|
# 设置随机数生成器的种子,以确保实验结果的可重复性
|
||
|
def seed_everything(seed):
|
||
|
random.seed(seed)
|
||
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
||
|
np.random.seed(seed)
|
||
|
torch.manual_seed(seed)
|
||
|
torch.cuda.manual_seed(seed)
|
||
|
torch.backends.cudnn.deterministic = True
|
||
|
torch.backends.cudnn.benchmark = True
|
||
|
|
||
|
# 解析命令行参数,获取配置文件的路径
|
||
|
def get_args():
|
||
|
parser = argparse.ArgumentParser()
|
||
|
arg = parser.add_argument
|
||
|
arg("-c", "--config_path", type=Path, help="Path to the config.", required=True)
|
||
|
return parser.parse_args()
|
||
|
|
||
|
# 自定义训练类,继承自 PyTorch Lightning 模块
|
||
|
class Supervision_Train(pl.LightningModule):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
self.net = config.net
|
||
|
|
||
|
self.loss = config.loss
|
||
|
|
||
|
self.metrics_train = Evaluator(num_class=config.num_classes)
|
||
|
self.metrics_val = Evaluator(num_class=config.num_classes)
|
||
|
|
||
|
def forward(self, x):
|
||
|
seg_pre = self.net(x)
|
||
|
return seg_pre
|
||
|
|
||
|
def training_step(self, batch, batch_idx):
|
||
|
img, mask = batch['img'], batch['gt_semantic_seg']
|
||
|
prediction = self.net(img)
|
||
|
loss = self.loss(prediction, mask)
|
||
|
|
||
|
if self.config.use_aux_loss:
|
||
|
pre_mask = nn.Softmax(dim=1)(prediction[0])
|
||
|
else:
|
||
|
pre_mask = nn.Softmax(dim=1)(prediction)
|
||
|
|
||
|
pre_mask = pre_mask.argmax(dim=1)
|
||
|
for i in range(mask.shape[0]):
|
||
|
self.metrics_train.add_batch(mask[i].cpu().numpy(), pre_mask[i].cpu().numpy())
|
||
|
|
||
|
return {"loss": loss}
|
||
|
|
||
|
# 在每个训练纪元结束时计算并打印评估指标
|
||
|
def on_train_epoch_end(self):
|
||
|
mIoU = np.nanmean(self.metrics_train.Intersection_over_Union()[:-1])
|
||
|
F1 = np.nanmean(self.metrics_train.F1()[:-1])
|
||
|
OA = np.nanmean(self.metrics_train.OA())
|
||
|
iou_per_class = self.metrics_train.Intersection_over_Union()
|
||
|
eval_value = {'mIoU': mIoU, 'F1': F1, 'OA': OA}
|
||
|
print('train:', eval_value)
|
||
|
|
||
|
iou_value = {}
|
||
|
for class_name, iou in zip(self.config.classes, iou_per_class):
|
||
|
iou_value[class_name] = iou
|
||
|
print(iou_value)
|
||
|
self.metrics_train.reset()
|
||
|
log_dict = {'train_mIoU': mIoU, 'train_F1': F1, 'train_OA': OA}
|
||
|
self.log_dict(log_dict, prog_bar=True)
|
||
|
|
||
|
# 验证步骤,处理每个验证批次的数据,计算损失并更新评估指标
|
||
|
def validation_step(self, batch, batch_idx):
|
||
|
img, mask = batch['img'], batch['gt_semantic_seg']
|
||
|
prediction = self.forward(img)
|
||
|
pre_mask = nn.Softmax(dim=1)(prediction)
|
||
|
pre_mask = pre_mask.argmax(dim=1)
|
||
|
for i in range(mask.shape[0]):
|
||
|
self.metrics_val.add_batch(mask[i].cpu().numpy(), pre_mask[i].cpu().numpy())
|
||
|
|
||
|
loss_val = self.loss(prediction, mask)
|
||
|
return {"loss_val": loss_val}
|
||
|
|
||
|
# 在每个验证纪元结束时计算并打印评估指标
|
||
|
def on_validation_epoch_end(self):
|
||
|
mIoU = np.nanmean(self.metrics_val.Intersection_over_Union()[:-1])
|
||
|
F1 = np.nanmean(self.metrics_val.F1()[:-1])
|
||
|
OA = np.nanmean(self.metrics_val.OA())
|
||
|
iou_per_class = self.metrics_val.Intersection_over_Union()
|
||
|
|
||
|
eval_value = {'mIoU': mIoU, 'F1': F1, 'OA': OA}
|
||
|
print('val:', eval_value)
|
||
|
iou_value = {}
|
||
|
for class_name, iou in zip(self.config.classes, iou_per_class):
|
||
|
iou_value[class_name] = iou
|
||
|
print(iou_value)
|
||
|
|
||
|
self.metrics_val.reset()
|
||
|
log_dict = {'val_mIoU': mIoU, 'val_F1': F1, 'val_OA': OA}
|
||
|
self.log_dict(log_dict, prog_bar=True)
|
||
|
|
||
|
# 配置优化器和学习率调度器
|
||
|
def configure_optimizers(self):
|
||
|
optimizer = self.config.optimizer
|
||
|
lr_scheduler = self.config.lr_scheduler
|
||
|
|
||
|
return [optimizer], [lr_scheduler]
|
||
|
|
||
|
def train_dataloader(self):
|
||
|
return self.config.train_loader
|
||
|
|
||
|
def val_dataloader(self):
|
||
|
return self.config.val_loader
|
||
|
|
||
|
# 主函数,用于执行整个训练流程
|
||
|
def main():
|
||
|
args = get_args()
|
||
|
config = py2cfg(args.config_path)
|
||
|
seed_everything(42)
|
||
|
|
||
|
checkpoint_callback = ModelCheckpoint(save_top_k=config.save_top_k, monitor=config.monitor,
|
||
|
save_last=config.save_last, mode=config.monitor_mode,
|
||
|
dirpath=config.weights_path,
|
||
|
filename=config.weights_name)
|
||
|
logger = CSVLogger('lightning_logs', name=config.log_name)
|
||
|
|
||
|
model = Supervision_Train(config)
|
||
|
if config.pretrained_ckpt_path:
|
||
|
model = Supervision_Train.load_from_checkpoint(config.pretrained_ckpt_path, config=config)
|
||
|
|
||
|
trainer = pl.Trainer(devices=config.gpus, max_epochs=config.max_epoch, accelerator='auto',
|
||
|
check_val_every_n_epoch=config.check_val_every_n_epoch,
|
||
|
callbacks=[checkpoint_callback], strategy='auto',
|
||
|
logger=logger)
|
||
|
trainer.fit(model=model, ckpt_path=config.resume_ckpt_path)
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|