391 lines
16 KiB
Python
391 lines
16 KiB
Python
|
#!/usr/bin/env python
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
@project:
|
|||
|
@File : train_module
|
|||
|
@Author : root
|
|||
|
@create_time : 2022/9/13 22:57
|
|||
|
训练文件模板
|
|||
|
"""
|
|||
|
import os
|
|||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|||
|
import sys
|
|||
|
root_dir='/home/qiqq/q3dl/codeR/applicationprojectR/'
|
|||
|
sys.path.append(root_dir)
|
|||
|
from taihuyuan_roof.utils.index_manager import addLogger,addTensorboard
|
|||
|
from taihuyuan_roof.utils.util import getLogger
|
|||
|
|
|||
|
|
|||
|
from taihuyuan_roof.manet.model.resunet import resUnetpamcarb
|
|||
|
from taihuyuan_roof.manet.train_val_test.evaluate_train import Evalue
|
|||
|
|
|||
|
|
|||
|
import argparse
|
|||
|
import logging
|
|||
|
|
|||
|
from pathlib import Path
|
|||
|
from torch import nn
|
|||
|
import torch.nn.functional as F
|
|||
|
from torch import optim
|
|||
|
from torch.utils.data import DataLoader
|
|||
|
from tqdm import tqdm
|
|||
|
from tensorboardX import SummaryWriter
|
|||
|
|
|||
|
import time
|
|||
|
|
|||
|
|
|||
|
|
|||
|
"由于服务器用不了wandb所以这里删除了用wandb记录的代码"
|
|||
|
import os
|
|||
|
import numpy as np
|
|||
|
import random
|
|||
|
import torch
|
|||
|
seed = 123
|
|||
|
random.seed(seed)
|
|||
|
np.random.seed(seed)
|
|||
|
torch.manual_seed(seed) # 为CPU设置随机种子
|
|||
|
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
|
|||
|
|
|||
|
|
|||
|
# 加了tensorboard记录
|
|||
|
dir_checkpoint = Path('./checkpoints/')
|
|||
|
tensorboarddir= Path('./TensorboardLog/')
|
|||
|
logsavedir = Path('./logs/')
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
#在train上的权重[0.5, 2.5, 2.5]
|
|||
|
def criterion(inputs, target, loss_weight=None, num_classes: int = 2, dice: bool = True, ignore_index: int = -100):
|
|||
|
|
|||
|
l_loss = F.cross_entropy(inputs, target, ignore_index=ignore_index, weight=loss_weight)
|
|||
|
|
|||
|
|
|||
|
return l_loss
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
##cityscpes 我直接改成20类了19类原来的+1类背景(原来255位置的)
|
|||
|
def train_net(net,
|
|||
|
device,
|
|||
|
resume=False,
|
|||
|
isPretrain=False,
|
|||
|
epochs: int = 5,
|
|||
|
train_batch_size: int = 2,
|
|||
|
val_batch_size: int = 2,
|
|||
|
learning_rate: float = 1e-3,
|
|||
|
save_checkpoint: bool = True,
|
|||
|
ignoreindex: int = 100,
|
|||
|
num_class = 21,
|
|||
|
useDice=False,
|
|||
|
valIndex=[],
|
|||
|
|
|||
|
|
|||
|
):
|
|||
|
logger = getLogger(logsavedir)
|
|||
|
# 1. Create dataset
|
|||
|
# myself数据加载
|
|||
|
|
|||
|
from taihuyuan_roof.dataloaders.datasets import def_taihuyuan
|
|||
|
parsertrain = argparse.ArgumentParser()
|
|||
|
argstrain = parsertrain.parse_args()
|
|||
|
argstrain.resize = (512, 512)
|
|||
|
argstrain.crop_size = 480 # 如果不开始数据增强,这个两个也是没用的
|
|||
|
argstrain.flip_prob = 0.5
|
|||
|
|
|||
|
parserval = argparse.ArgumentParser()
|
|||
|
argsval = parserval.parse_args()
|
|||
|
argsval.resize = (512, 512)
|
|||
|
argsval.crop_size = 480 # 在验证的时候没有用 crop和filp 只是为了不报错
|
|||
|
argsval.flip_prob = 0.5
|
|||
|
imgshape_base = None
|
|||
|
|
|||
|
train_d = def_taihuyuan.datasets_rooftaihuyuan(argstrain, split='trainbat', isAug=True)
|
|||
|
val_d = def_taihuyuan.datasets_rooftaihuyuan(argsval, split='valbat', )
|
|||
|
|
|||
|
n_val = val_d.__len__()
|
|||
|
n_train = train_d.__len__()
|
|||
|
|
|||
|
train_loader = DataLoader(train_d, batch_size=train_batch_size, shuffle=True, num_workers=4, pin_memory=False,
|
|||
|
drop_last=False)
|
|||
|
val_loader = DataLoader(val_d, shuffle=False, num_workers=4, pin_memory=False, batch_size=val_batch_size,
|
|||
|
drop_last=False)
|
|||
|
|
|||
|
|
|||
|
workTensorboard_dir = os.path.join(tensorboarddir,
|
|||
|
time.strftime("%Y-%m-%d-%H.%M", time.localtime())) # 日志文件写入目录
|
|||
|
if not os.path.exists(workTensorboard_dir):
|
|||
|
os.makedirs(workTensorboard_dir)
|
|||
|
|
|||
|
workcheckpoint_dir = os.path.join(dir_checkpoint,
|
|||
|
time.strftime("%Y-%m-%d-%H.%M", time.localtime())) # 日志文件写入目录
|
|||
|
if not os.path.exists(workcheckpoint_dir):
|
|||
|
os.makedirs(workcheckpoint_dir)
|
|||
|
|
|||
|
writer = SummaryWriter(workTensorboard_dir)
|
|||
|
description="优化器:Adam,学习率:LinoPolyScheduler,\n" \
|
|||
|
"其他策略:无数据增强,无类加权,归一化方式/255\n" \
|
|||
|
"res50的预训练\n" \
|
|||
|
"网络:res+Unetpam+carb(MANET) \n" \
|
|||
|
"32倍数下次采样。 \n" \
|
|||
|
"数据集:太湖源roof\n " \
|
|||
|
"保存最好的权重的指标是roof的iou,\n"
|
|||
|
|
|||
|
|
|||
|
logger.info(f'''Starting training:
|
|||
|
"Description:" {description}
|
|||
|
Epochs: {epochs}
|
|||
|
train_Batch size: {train_batch_size}
|
|||
|
val_Batch size: {val_batch_size}
|
|||
|
Learning rate: {learning_rate}
|
|||
|
Training size: {n_train}
|
|||
|
Validation size: {n_val}
|
|||
|
Checkpoints: {save_checkpoint}
|
|||
|
Device: {device.type}
|
|||
|
|
|||
|
|
|||
|
''')
|
|||
|
net = net.to(device)
|
|||
|
from taihuyuan_roof.schedulers.polyscheduler import LinoPolyScheduler
|
|||
|
|
|||
|
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
|
|||
|
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
|
|||
|
# optimizer = optim.SGD(net.parameters(),lr=learning_rate,momentum=0.9,weight_decay=1e-4)
|
|||
|
scheduler = LinoPolyScheduler(optimizer, epochs=epochs ,steps_per_epoch=len(train_loader),min_lr=0, )
|
|||
|
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.98, )
|
|||
|
# scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=5,T_mult=2,eta_min=1e-6)
|
|||
|
|
|||
|
# we = np.array([1,3,16], np.float32)
|
|||
|
# we = torch.from_numpy(we).to(device)
|
|||
|
|
|||
|
|
|||
|
start_epoch = 1
|
|||
|
# #是否有预训练
|
|||
|
# if isPretrain:
|
|||
|
# model_dict = net.state_dict()
|
|||
|
# pretrained_dict = torch.load(isPretrain, map_location=device)
|
|||
|
# load_key, no_load_key, temp_dict = [], [], {}
|
|||
|
# for k, v in pretrained_dict.items():
|
|||
|
# if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
|
|||
|
# temp_dict[k] = v
|
|||
|
# load_key.append(k)
|
|||
|
# else:
|
|||
|
# no_load_key.append(k)
|
|||
|
# model_dict.update(temp_dict)
|
|||
|
# net.load_state_dict(model_dict)
|
|||
|
# print("train中的预训练加载成功")
|
|||
|
# logger.info(f'Model loaded from {isPretrain}')
|
|||
|
# else:#没有预训练使用pyroch的一些权重初始化方法
|
|||
|
# strweight='normal'
|
|||
|
# weights_init(net,init_type=strweight)
|
|||
|
# logging.info(f'没有预训练权重,{strweight}权重初始化完成')
|
|||
|
# # #是否使用断点训练
|
|||
|
if resume:
|
|||
|
path_checkpoint = resume # 断点路径
|
|||
|
checkpoint = torch.load(path_checkpoint) # 加载断点
|
|||
|
net.load_state_dict(checkpoint['net']) # 加载模型可学习参数
|
|||
|
# 加载优化器参数
|
|||
|
start_epoch = checkpoint['epoch'] +1 # 设置开始的epoch
|
|||
|
scheduler.load_state_dict(checkpoint['lr_schedule']) # 加载lr_scheduler
|
|||
|
print("重启成功")
|
|||
|
else:
|
|||
|
start_epoch= 1
|
|||
|
|
|||
|
|
|||
|
# 5. Begin training
|
|||
|
epochs_score = [] # 记录每个epoh的miou
|
|||
|
best_miou = 0.0 # 记录最好的那个
|
|||
|
earlyStop_count=0
|
|||
|
|
|||
|
time1 = time.time()
|
|||
|
start = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time1))
|
|||
|
print("训练{}个epoch的开始时间为:{}".format(epochs,start))
|
|||
|
|
|||
|
logger.info("训练{}个epoch的开始时间为:{}".format(epochs,start))
|
|||
|
|
|||
|
log = addLogger(logger)
|
|||
|
tenb = addTensorboard(writer)
|
|||
|
|
|||
|
# early_stopping = EarlyStopping(patience=100)
|
|||
|
|
|||
|
for epoch in range(start_epoch, epochs + 1):
|
|||
|
logger.info("train|epoch:{epoch}\t".format(epoch=epoch))
|
|||
|
current_miou = 0.0
|
|||
|
total_train_loss = 0
|
|||
|
net.train()
|
|||
|
print('Start Train')
|
|||
|
time1_1 = time.time()
|
|||
|
|
|||
|
with tqdm(total=len(train_loader), desc=f'Epoch {epoch}/{epochs}', unit='batch') as pbar:
|
|||
|
for iteration, batch in enumerate(train_loader): # batchsize=2 一共1487.5个batch
|
|||
|
images = batch['image']
|
|||
|
true_masks = batch['label']
|
|||
|
# true_masks =true_masks-1#这一条是针对有的8位彩色图索引从1开始
|
|||
|
images = images.to(device=device, dtype=torch.float32)
|
|||
|
true_masks = true_masks.to(device=device, dtype=torch.long)
|
|||
|
|
|||
|
# out0,out1 = net(images, use128aux=True,state="train")
|
|||
|
out = net(images)
|
|||
|
loss = 0.0
|
|||
|
if useDice:
|
|||
|
loss = criterion(out, true_masks, loss_weight=None, num_classes=num_class, dice= True, ignore_index = ignoreindex)
|
|||
|
|
|||
|
else:
|
|||
|
|
|||
|
if isinstance(out, list) or isinstance(out, tuple): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个)
|
|||
|
|
|||
|
loss0 = criterion(out[0], true_masks, loss_weight=None, num_classes=num_class, dice=False,
|
|||
|
ignore_index=ignoreindex)
|
|||
|
|
|||
|
loss1 = criterion(out[1], true_masks, loss_weight=None, num_classes=num_class, dice=False,
|
|||
|
ignore_index=ignoreindex)
|
|||
|
|
|||
|
loss=loss0+0.4*loss1
|
|||
|
else:
|
|||
|
loss0 = criterion(out, true_masks, loss_weight=None, num_classes=num_class, dice=False,
|
|||
|
ignore_index=ignoreindex)
|
|||
|
loss = loss0
|
|||
|
|
|||
|
|
|||
|
'''1.loss 2.梯度清零,3.反向传播。backward 4optomizer更新.'''
|
|||
|
|
|||
|
optimizer.zero_grad()
|
|||
|
loss.backward()
|
|||
|
optimizer.step()
|
|||
|
pbar.update() # 用来记录一次更新多少的
|
|||
|
total_train_loss += loss.item()
|
|||
|
|
|||
|
pbar.set_postfix(**{'loss (batch)': loss.item()})
|
|||
|
|
|||
|
|
|||
|
writer.add_scalar("train_total_loss", total_train_loss / (iteration + 1), epoch)
|
|||
|
time1_2 = time.time()
|
|||
|
one_epoch_time = round((time1_2-time1_1)/60,3) #按照分钟来算
|
|||
|
logger.info("训练一个epoch大概:{} min".format(one_epoch_time))
|
|||
|
print("Finish Train")
|
|||
|
print('start Validation')
|
|||
|
isevalue = True
|
|||
|
|
|||
|
if isevalue == True:
|
|||
|
val_score =0
|
|||
|
current_lr= optimizer.param_groups[0]['lr'] #当前的学习率
|
|||
|
evalueModel=Evalue(net,val_loader,device,num_class,isResize=imgshape_base,ignore_index=ignoreindex,weight=None)
|
|||
|
for i in valIndex:
|
|||
|
#这个地方其实是默认了只有IouMiouP可能会出现数组
|
|||
|
if i =="IouMiouP":
|
|||
|
indextring = [ "acc_global", "acc", "iu","precion","recall","f1","miou","meanf1","roof_iou","roof_f1"]
|
|||
|
result=eval('evalueModel.'+'evalue_'+i)()
|
|||
|
for j in range(len(indextring)):
|
|||
|
log.logger_index(result[j],indextring[j])
|
|||
|
if indextring[j] == "acc_global" or indextring[j] == "miou" or indextring[j] == "meanf1" or indextring[j] =="roof_iou"or indextring[j] =="roof_f1":
|
|||
|
if indextring[j] == "roof_iou":
|
|||
|
val_score = result[j]
|
|||
|
current_miou=val_score
|
|||
|
epochs_score.append(val_score)
|
|||
|
tenb.writer_singleindex(indextring[j],index=result[j],epoch=epoch)
|
|||
|
|
|||
|
else:
|
|||
|
tenb.writer_classindex(indextring[j],classindexs=result[j],epoch=epoch)
|
|||
|
else:
|
|||
|
result = eval('evalueModel.' + 'evalue_' + i)()
|
|||
|
|
|||
|
log.logger_index(result, i)
|
|||
|
tenb.writer_singleindex(i, index=result, epoch=epoch)
|
|||
|
|
|||
|
writer.add_scalar("lr", current_lr, epoch)
|
|||
|
logger.info("当前epoch{}的lr:{}".format(epoch, current_lr))
|
|||
|
scheduler.step()
|
|||
|
# 保存最好的miou和最新的
|
|||
|
if save_checkpoint:
|
|||
|
checkpoint = {
|
|||
|
"net": net.state_dict(),
|
|||
|
'optimizer': optimizer.state_dict(),
|
|||
|
"epoch": epoch,
|
|||
|
'lr_schedule': scheduler.state_dict()
|
|||
|
}
|
|||
|
torch.save(checkpoint, os.path.join(workcheckpoint_dir, 'last.pth')) # 保存最新的
|
|||
|
# 保存最好的
|
|||
|
if current_miou >= best_miou:
|
|||
|
best_miou = current_miou
|
|||
|
torch.save(checkpoint, os.path.join(workcheckpoint_dir, 'best.pth'))
|
|||
|
writer.add_scalar('best_epoch_index', epochs_score.index(max(epochs_score)) + 1, epoch) #记录bestepoch的变化情况
|
|||
|
# 这个地方是按照迭代来调整学习率的
|
|||
|
print('Finish Validation')
|
|||
|
# 达到早停止条件时,early_stop会被置为True
|
|||
|
# early_stopping(val_score)
|
|||
|
# if early_stopping.early_stop:
|
|||
|
# print("Early stopping")
|
|||
|
# time2 = time.time()
|
|||
|
# end = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time2))
|
|||
|
# logger.info("训练提前结束:训练{}个epoch的结束时间为:{}".format(epoch, end))
|
|||
|
# break
|
|||
|
|
|||
|
time2 = time.time()
|
|||
|
end = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time2))
|
|||
|
print("训练结束:训练{}个epoch的结束时间为:{}".format(epochs, end))
|
|||
|
logger.info("训练结束:训练{}个epoch的结束时间为:{}".format(epochs, end))
|
|||
|
writer.close()
|
|||
|
logger.info("训练完成")
|
|||
|
|
|||
|
|
|||
|
def get_args():
|
|||
|
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
|
|||
|
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=200, help='Number of epochs')
|
|||
|
parser.add_argument('--train_batch-size', '-tb', dest='train_batch_size', metavar='TB', type=int, default=8,
|
|||
|
help='Train_Batch size')
|
|||
|
parser.add_argument('--val_batch-size', '-vb', dest='val_batch_size', metavar='VB', type=int, default=1,
|
|||
|
help='Val_Batch size')
|
|||
|
parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-3,
|
|||
|
help='Learning rate', dest='lr')
|
|||
|
parser.add_argument('--load', '-f', type=str,
|
|||
|
default="/home/qiqq/q3dl/code/pretrain_weight/pretrained/resnet/resnet50-0676ba61.pth",
|
|||
|
help='Load model from a .pth file') # 有没有预训练。。
|
|||
|
# parser.add_argument('--load', '-f', type=str,
|
|||
|
# default="/home/qiqq/q3dl/code/applicationproject/taihuyuan_roof/manet/train_val_test/checkpoints/2023-09-20-14.07/best.pth",
|
|||
|
# help='Load model from a .pth file') # 有没有预训练。。
|
|||
|
parser.add_argument('--ignore_index', '-i', type=int, dest='ignore_index', default=255,
|
|||
|
help='ignore index defult 100')
|
|||
|
parser.add_argument('--origin_shape', action='store_true', default=(512,512), help='原始输入尺寸')
|
|||
|
parser.add_argument('--resume', '-r', type=str, default="", help='is use Resume')
|
|||
|
parser.add_argument('--useDice', '-ud', type=str, default=False, help='训练的时候是否使用dice')
|
|||
|
parser.add_argument('--valIndex', '-vI', type=str, default=["Valloss","IouMiouP"], help='评价指标要使用哪些,注意IouMiouP= acc_global, acc, iu,precion,recall,f1,miou,并且建议旨在2分类的时候用dice否则会出错')
|
|||
|
parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
|
|||
|
|
|||
|
return parser.parse_args()
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
|
|||
|
# 有点乱啊,
|
|||
|
args = get_args() # 管着非网络的参数的
|
|||
|
|
|||
|
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
|||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|||
|
logging.info(f'Using device {device}')
|
|||
|
|
|||
|
# net igin_shape,pretrain_backbone=args.load)
|
|||
|
net= resUnetpamcarb()
|
|||
|
try:
|
|||
|
train_net(net=net,
|
|||
|
device = device,
|
|||
|
resume=args.resume,
|
|||
|
epochs=args.epochs,
|
|||
|
isPretrain=args.load,
|
|||
|
train_batch_size=args.train_batch_size,
|
|||
|
val_batch_size=args.val_batch_size,
|
|||
|
learning_rate=args.lr,
|
|||
|
ignoreindex=args.ignore_index,
|
|||
|
num_class=args.classes,
|
|||
|
useDice=args.useDice,
|
|||
|
valIndex=args.valIndex
|
|||
|
|
|||
|
)
|
|||
|
except KeyboardInterrupt:
|
|||
|
torch.save(net.state_dict(), 'INTERRUPTED.pth')
|
|||
|
logging.info('Saved interrupt')
|
|||
|
raise
|