ai-station-code/wudingpv/taihuyuan_roof/manet/train_val_test/train.py

391 lines
16 KiB
Python
Raw Normal View History

2025-05-06 11:18:48 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : train_module
@Author : root
@create_time : 2022/9/13 22:57
训练文件模板
"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import sys
root_dir='/home/qiqq/q3dl/codeR/applicationprojectR/'
sys.path.append(root_dir)
from taihuyuan_roof.utils.index_manager import addLogger,addTensorboard
from taihuyuan_roof.utils.util import getLogger
from taihuyuan_roof.manet.model.resunet import resUnetpamcarb
from taihuyuan_roof.manet.train_val_test.evaluate_train import Evalue
import argparse
import logging
from pathlib import Path
from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from tensorboardX import SummaryWriter
import time
"由于服务器用不了wandb所以这里删除了用wandb记录的代码"
import os
import numpy as np
import random
import torch
seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed) # 为CPU设置随机种子
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
# 加了tensorboard记录
dir_checkpoint = Path('./checkpoints/')
tensorboarddir= Path('./TensorboardLog/')
logsavedir = Path('./logs/')
#在train上的权重[0.5, 2.5, 2.5]
def criterion(inputs, target, loss_weight=None, num_classes: int = 2, dice: bool = True, ignore_index: int = -100):
l_loss = F.cross_entropy(inputs, target, ignore_index=ignore_index, weight=loss_weight)
return l_loss
##cityscpes 我直接改成20类了19类原来的+1类背景原来255位置的
def train_net(net,
device,
resume=False,
isPretrain=False,
epochs: int = 5,
train_batch_size: int = 2,
val_batch_size: int = 2,
learning_rate: float = 1e-3,
save_checkpoint: bool = True,
ignoreindex: int = 100,
num_class = 21,
useDice=False,
valIndex=[],
):
logger = getLogger(logsavedir)
# 1. Create dataset
# myself数据加载
from taihuyuan_roof.dataloaders.datasets import def_taihuyuan
parsertrain = argparse.ArgumentParser()
argstrain = parsertrain.parse_args()
argstrain.resize = (512, 512)
argstrain.crop_size = 480 # 如果不开始数据增强,这个两个也是没用的
argstrain.flip_prob = 0.5
parserval = argparse.ArgumentParser()
argsval = parserval.parse_args()
argsval.resize = (512, 512)
argsval.crop_size = 480 # 在验证的时候没有用 crop和filp 只是为了不报错
argsval.flip_prob = 0.5
imgshape_base = None
train_d = def_taihuyuan.datasets_rooftaihuyuan(argstrain, split='trainbat', isAug=True)
val_d = def_taihuyuan.datasets_rooftaihuyuan(argsval, split='valbat', )
n_val = val_d.__len__()
n_train = train_d.__len__()
train_loader = DataLoader(train_d, batch_size=train_batch_size, shuffle=True, num_workers=4, pin_memory=False,
drop_last=False)
val_loader = DataLoader(val_d, shuffle=False, num_workers=4, pin_memory=False, batch_size=val_batch_size,
drop_last=False)
workTensorboard_dir = os.path.join(tensorboarddir,
time.strftime("%Y-%m-%d-%H.%M", time.localtime())) # 日志文件写入目录
if not os.path.exists(workTensorboard_dir):
os.makedirs(workTensorboard_dir)
workcheckpoint_dir = os.path.join(dir_checkpoint,
time.strftime("%Y-%m-%d-%H.%M", time.localtime())) # 日志文件写入目录
if not os.path.exists(workcheckpoint_dir):
os.makedirs(workcheckpoint_dir)
writer = SummaryWriter(workTensorboard_dir)
description="优化器:Adam,学习率:LinoPolyScheduler,\n" \
"其他策略:无数据增强,无类加权,归一化方式/255\n" \
"res50的预训练\n" \
"网络res+Unetpam+carb(MANET) \n" \
"32倍数下次采样。 \n" \
"数据集太湖源roof\n " \
"保存最好的权重的指标是roof的iou,\n"
logger.info(f'''Starting training:
"Description:" {description}
Epochs: {epochs}
train_Batch size: {train_batch_size}
val_Batch size: {val_batch_size}
Learning rate: {learning_rate}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_checkpoint}
Device: {device.type}
''')
net = net.to(device)
from taihuyuan_roof.schedulers.polyscheduler import LinoPolyScheduler
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
# optimizer = optim.SGD(net.parameters(),lr=learning_rate,momentum=0.9,weight_decay=1e-4)
scheduler = LinoPolyScheduler(optimizer, epochs=epochs ,steps_per_epoch=len(train_loader),min_lr=0, )
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.98, )
# scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=5,T_mult=2,eta_min=1e-6)
# we = np.array([1,3,16], np.float32)
# we = torch.from_numpy(we).to(device)
start_epoch = 1
# #是否有预训练
# if isPretrain:
# model_dict = net.state_dict()
# pretrained_dict = torch.load(isPretrain, map_location=device)
# load_key, no_load_key, temp_dict = [], [], {}
# for k, v in pretrained_dict.items():
# if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
# temp_dict[k] = v
# load_key.append(k)
# else:
# no_load_key.append(k)
# model_dict.update(temp_dict)
# net.load_state_dict(model_dict)
# print("train中的预训练加载成功")
# logger.info(f'Model loaded from {isPretrain}')
# else:#没有预训练使用pyroch的一些权重初始化方法
# strweight='normal'
# weights_init(net,init_type=strweight)
# logging.info(f'没有预训练权重,{strweight}权重初始化完成')
# # #是否使用断点训练
if resume:
path_checkpoint = resume # 断点路径
checkpoint = torch.load(path_checkpoint) # 加载断点
net.load_state_dict(checkpoint['net']) # 加载模型可学习参数
# 加载优化器参数
start_epoch = checkpoint['epoch'] +1 # 设置开始的epoch
scheduler.load_state_dict(checkpoint['lr_schedule']) # 加载lr_scheduler
print("重启成功")
else:
start_epoch= 1
# 5. Begin training
epochs_score = [] # 记录每个epoh的miou
best_miou = 0.0 # 记录最好的那个
earlyStop_count=0
time1 = time.time()
start = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time1))
print("训练{}个epoch的开始时间为:{}".format(epochs,start))
logger.info("训练{}个epoch的开始时间为:{}".format(epochs,start))
log = addLogger(logger)
tenb = addTensorboard(writer)
# early_stopping = EarlyStopping(patience=100)
for epoch in range(start_epoch, epochs + 1):
logger.info("train|epoch:{epoch}\t".format(epoch=epoch))
current_miou = 0.0
total_train_loss = 0
net.train()
print('Start Train')
time1_1 = time.time()
with tqdm(total=len(train_loader), desc=f'Epoch {epoch}/{epochs}', unit='batch') as pbar:
for iteration, batch in enumerate(train_loader): # batchsize=2 一共1487.5个batch
images = batch['image']
true_masks = batch['label']
# true_masks =true_masks-1#这一条是针对有的8位彩色图索引从1开始
images = images.to(device=device, dtype=torch.float32)
true_masks = true_masks.to(device=device, dtype=torch.long)
# out0,out1 = net(images, use128aux=True,state="train")
out = net(images)
loss = 0.0
if useDice:
loss = criterion(out, true_masks, loss_weight=None, num_classes=num_class, dice= True, ignore_index = ignoreindex)
else:
if isinstance(out, list) or isinstance(out, tuple): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个)
loss0 = criterion(out[0], true_masks, loss_weight=None, num_classes=num_class, dice=False,
ignore_index=ignoreindex)
loss1 = criterion(out[1], true_masks, loss_weight=None, num_classes=num_class, dice=False,
ignore_index=ignoreindex)
loss=loss0+0.4*loss1
else:
loss0 = criterion(out, true_masks, loss_weight=None, num_classes=num_class, dice=False,
ignore_index=ignoreindex)
loss = loss0
'''1.loss 2.梯度清零3.反向传播。backward 4optomizer更新.'''
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.update() # 用来记录一次更新多少的
total_train_loss += loss.item()
pbar.set_postfix(**{'loss (batch)': loss.item()})
writer.add_scalar("train_total_loss", total_train_loss / (iteration + 1), epoch)
time1_2 = time.time()
one_epoch_time = round((time1_2-time1_1)/60,3) #按照分钟来算
logger.info("训练一个epoch大概:{} min".format(one_epoch_time))
print("Finish Train")
print('start Validation')
isevalue = True
if isevalue == True:
val_score =0
current_lr= optimizer.param_groups[0]['lr'] #当前的学习率
evalueModel=Evalue(net,val_loader,device,num_class,isResize=imgshape_base,ignore_index=ignoreindex,weight=None)
for i in valIndex:
#这个地方其实是默认了只有IouMiouP可能会出现数组
if i =="IouMiouP":
indextring = [ "acc_global", "acc", "iu","precion","recall","f1","miou","meanf1","roof_iou","roof_f1"]
result=eval('evalueModel.'+'evalue_'+i)()
for j in range(len(indextring)):
log.logger_index(result[j],indextring[j])
if indextring[j] == "acc_global" or indextring[j] == "miou" or indextring[j] == "meanf1" or indextring[j] =="roof_iou"or indextring[j] =="roof_f1":
if indextring[j] == "roof_iou":
val_score = result[j]
current_miou=val_score
epochs_score.append(val_score)
tenb.writer_singleindex(indextring[j],index=result[j],epoch=epoch)
else:
tenb.writer_classindex(indextring[j],classindexs=result[j],epoch=epoch)
else:
result = eval('evalueModel.' + 'evalue_' + i)()
log.logger_index(result, i)
tenb.writer_singleindex(i, index=result, epoch=epoch)
writer.add_scalar("lr", current_lr, epoch)
logger.info("当前epoch{}的lr:{}".format(epoch, current_lr))
scheduler.step()
# 保存最好的miou和最新的
if save_checkpoint:
checkpoint = {
"net": net.state_dict(),
'optimizer': optimizer.state_dict(),
"epoch": epoch,
'lr_schedule': scheduler.state_dict()
}
torch.save(checkpoint, os.path.join(workcheckpoint_dir, 'last.pth')) # 保存最新的
# 保存最好的
if current_miou >= best_miou:
best_miou = current_miou
torch.save(checkpoint, os.path.join(workcheckpoint_dir, 'best.pth'))
writer.add_scalar('best_epoch_index', epochs_score.index(max(epochs_score)) + 1, epoch) #记录bestepoch的变化情况
# 这个地方是按照迭代来调整学习率的
print('Finish Validation')
# 达到早停止条件时early_stop会被置为True
# early_stopping(val_score)
# if early_stopping.early_stop:
# print("Early stopping")
# time2 = time.time()
# end = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time2))
# logger.info("训练提前结束:训练{}个epoch的结束时间为:{}".format(epoch, end))
# break
time2 = time.time()
end = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time2))
print("训练结束:训练{}个epoch的结束时间为:{}".format(epochs, end))
logger.info("训练结束:训练{}个epoch的结束时间为:{}".format(epochs, end))
writer.close()
logger.info("训练完成")
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=200, help='Number of epochs')
parser.add_argument('--train_batch-size', '-tb', dest='train_batch_size', metavar='TB', type=int, default=8,
help='Train_Batch size')
parser.add_argument('--val_batch-size', '-vb', dest='val_batch_size', metavar='VB', type=int, default=1,
help='Val_Batch size')
parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-3,
help='Learning rate', dest='lr')
parser.add_argument('--load', '-f', type=str,
default="/home/qiqq/q3dl/code/pretrain_weight/pretrained/resnet/resnet50-0676ba61.pth",
help='Load model from a .pth file') # 有没有预训练。。
# parser.add_argument('--load', '-f', type=str,
# default="/home/qiqq/q3dl/code/applicationproject/taihuyuan_roof/manet/train_val_test/checkpoints/2023-09-20-14.07/best.pth",
# help='Load model from a .pth file') # 有没有预训练。。
parser.add_argument('--ignore_index', '-i', type=int, dest='ignore_index', default=255,
help='ignore index defult 100')
parser.add_argument('--origin_shape', action='store_true', default=(512,512), help='原始输入尺寸')
parser.add_argument('--resume', '-r', type=str, default="", help='is use Resume')
parser.add_argument('--useDice', '-ud', type=str, default=False, help='训练的时候是否使用dice')
parser.add_argument('--valIndex', '-vI', type=str, default=["Valloss","IouMiouP"], help='评价指标要使用哪些,注意IouMiouP= acc_global, acc, iu,precion,recall,f1,miou,并且建议旨在2分类的时候用dice否则会出错')
parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
return parser.parse_args()
if __name__ == '__main__':
# 有点乱啊,
args = get_args() # 管着非网络的参数的
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
# net igin_shape,pretrain_backbone=args.load)
net= resUnetpamcarb()
try:
train_net(net=net,
device = device,
resume=args.resume,
epochs=args.epochs,
isPretrain=args.load,
train_batch_size=args.train_batch_size,
val_batch_size=args.val_batch_size,
learning_rate=args.lr,
ignoreindex=args.ignore_index,
num_class=args.classes,
useDice=args.useDice,
valIndex=args.valIndex
)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
raise