ai-station-code/wudingpv/taihuyuan_pv/mitunet/train_val_test/train.py

393 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : train_module
@Author : root
@create_time : 2022/9/13 22:57
训练文件模板
"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import sys
root_dir='/home/qiqq/q3dl/code/applicationproject/'
sys.path.append(root_dir)
from taihuyuan_pv.utils.index_manager import addLogger,addTensorboard
from taihuyuan_pv.utils.util import getLogger
from taihuyuan_pv.mitunet.model.resunet import resUnetpamcarb
from taihuyuan_pv.mitunet.train_val_test.evaluate_train import Evalue
import argparse
import logging
from pathlib import Path
from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from tensorboardX import SummaryWriter
import time
"由于服务器用不了wandb所以这里删除了用wandb记录的代码"
import os
import numpy as np
import random
import torch
seed = 7
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed) # 为CPU设置随机种子
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
# 加了tensorboard记录
dir_checkpoint = Path('./checkpoints/')
tensorboarddir= Path('./TensorboardLog/')
logsavedir = Path('./logs/')
#在train上的权重[0.5, 2.5, 2.5]
def criterion(inputs, target, loss_weight=None, num_classes: int = 2, dice: bool = True, ignore_index: int = -100):
l_loss = F.cross_entropy(inputs, target, ignore_index=ignore_index, weight=loss_weight)
return l_loss
##cityscpes 我直接改成20类了19类原来的+1类背景原来255位置的
def train_net(net,
device,
resume=False,
isPretrain=False,
epochs: int = 5,
train_batch_size: int = 2,
val_batch_size: int = 2,
learning_rate: float = 1e-3,
save_checkpoint: bool = True,
ignoreindex: int = 100,
num_class = 21,
useDice=False,
valIndex=[],
):
logger = getLogger(logsavedir)
# 1. Create dataset
# myself数据加载
from taihuyuan_pv.dataloaders.datasets import def_taihuyuan
parsertrain = argparse.ArgumentParser()
argstrain = parsertrain.parse_args()
argstrain.resize = (512, 512)
argstrain.crop_size = 480 # 如果不开始数据增强,这个两个也是没用的
argstrain.flip_prob = 0.5
parserval = argparse.ArgumentParser()
argsval = parserval.parse_args()
argsval.resize = (512, 512)
argsval.crop_size = 480 # 在验证的时候没有用 crop和filp 只是为了不报错
argsval.flip_prob = 0.5
imgshape_base = None
train_d = def_taihuyuan.datasets_pvtaihuyuan(argstrain, split='trainbat', isAug=False)
val_d = def_taihuyuan.datasets_pvtaihuyuan(argsval, split='valbat', )
n_val = val_d.__len__()
n_train = train_d.__len__()
train_loader = DataLoader(train_d, batch_size=train_batch_size, shuffle=True, num_workers=4, pin_memory=False,drop_last=False)
val_loader = DataLoader(val_d, shuffle=False, num_workers=4, pin_memory=False, batch_size=val_batch_size,drop_last=False)
# 初始化tensorboard
# Path(tensorboarddir).mkdir(parents=True, exist_ok=True)
workTensorboard_dir = os.path.join(tensorboarddir,
time.strftime("%Y-%m-%d-%H.%M", time.localtime())) # 日志文件写入目录
if not os.path.exists(workTensorboard_dir):
os.makedirs(workTensorboard_dir)
workcheckpoint_dir = os.path.join(dir_checkpoint,
time.strftime("%Y-%m-%d-%H.%M", time.localtime())) # 日志文件写入目录
if not os.path.exists(workcheckpoint_dir):
os.makedirs(workcheckpoint_dir)
writer = SummaryWriter(workTensorboard_dir)
description="优化器:Adam,学习率:LinoPolyScheduler,\n" \
"其他策略:无数据增强,无类加权,归一化方式/255\n" \
"网络resUnetpamcarb \n" \
"数据集原来来太湖源的pv\n " \
"保存最好的权重的指标是pv的iou,\n"
logger.info(f'''Starting training:
"Description:" {description}
Epochs: {epochs}
train_Batch size: {train_batch_size}
val_Batch size: {val_batch_size}
Learning rate: {learning_rate}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_checkpoint}
Device: {device.type}
''')
net = net.to(device)
from taihuyuan_pv.schedulers.polyscheduler import LinoPolyScheduler
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
# optimizer = optim.SGD(net.parameters(),lr=learning_rate,momentum=0.9,weight_decay=1e-4)
scheduler = LinoPolyScheduler(optimizer, epochs=epochs ,steps_per_epoch=len(train_loader),min_lr=0, )
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.98, )
# scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=5,T_mult=2,eta_min=1e-6)
# we = np.array([1,3,16], np.float32)
# we = torch.from_numpy(we).to(device)
start_epoch = 1
# #是否有预训练
# if isPretrain:
# model_dict = net.state_dict()
# pretrained_dict = torch.load(isPretrain, map_location=device)
# load_key, no_load_key, temp_dict = [], [], {}
# for k, v in pretrained_dict.items():
# if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
# temp_dict[k] = v
# load_key.append(k)
# else:
# no_load_key.append(k)
# model_dict.update(temp_dict)
# net.load_state_dict(model_dict)
# print("train中的预训练加载成功")
# logger.info(f'Model loaded from {isPretrain}')
# else:#没有预训练使用pyroch的一些权重初始化方法
# strweight='normal'
# weights_init(net,init_type=strweight)
# logging.info(f'没有预训练权重,{strweight}权重初始化完成')
# # #是否使用断点训练
if resume:
path_checkpoint = resume # 断点路径
checkpoint = torch.load(path_checkpoint) # 加载断点
net.load_state_dict(checkpoint['net']) # 加载模型可学习参数
# 加载优化器参数
start_epoch = checkpoint['epoch'] +1 # 设置开始的epoch
scheduler.load_state_dict(checkpoint['lr_schedule']) # 加载lr_scheduler
print("重启成功")
else:
start_epoch= 1
# 5. Begin training
epochs_score = [] # 记录每个epoh的miou
best_miou = 0.0 # 记录最好的那个
earlyStop_count=0
time1 = time.time()
start = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time1))
print("训练{}个epoch的开始时间为:{}".format(epochs,start))
logger.info("训练{}个epoch的开始时间为:{}".format(epochs,start))
log = addLogger(logger)
tenb = addTensorboard(writer)
# early_stopping = EarlyStopping(patience=100)
for epoch in range(start_epoch, epochs + 1):
logger.info("train|epoch:{epoch}\t".format(epoch=epoch))
current_miou = 0.0
total_train_loss = 0
net.train()
print('Start Train')
time1_1 = time.time()
with tqdm(total=len(train_loader), desc=f'Epoch {epoch}/{epochs}', unit='batch') as pbar:
for iteration, batch in enumerate(train_loader): # batchsize=2 一共1487.5个batch
images = batch['image']
true_masks = batch['label']
# true_masks =true_masks-1#这一条是针对有的8位彩色图索引从1开始
images = images.to(device=device, dtype=torch.float32)
true_masks = true_masks.to(device=device, dtype=torch.long)
# out0,out1 = net(images, use128aux=True,state="train")
out = net(images)
loss = 0.0
if useDice:
loss = criterion(out, true_masks, loss_weight=None, num_classes=num_class, dice= True, ignore_index = ignoreindex)
else:
if isinstance(out, list) or isinstance(out, tuple): # 可能有多个输出(这里把辅助解码头的也输出的所以是多个)
loss0 = criterion(out[0], true_masks, loss_weight=None, num_classes=num_class, dice=False,
ignore_index=ignoreindex)
loss1 = criterion(out[1], true_masks, loss_weight=None, num_classes=num_class, dice=False,
ignore_index=ignoreindex)
loss=loss0+0.4*loss1
else:
loss0 = criterion(out, true_masks, loss_weight=None, num_classes=num_class, dice=False,
ignore_index=ignoreindex)
loss = loss0
'''1.loss 2.梯度清零3.反向传播。backward 4optomizer更新.'''
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.update() # 用来记录一次更新多少的
total_train_loss += loss.item()
pbar.set_postfix(**{'loss (batch)': loss.item()})
writer.add_scalar("train_total_loss", total_train_loss / (iteration + 1), epoch)
time1_2 = time.time()
one_epoch_time = round((time1_2-time1_1)/60,3) #按照分钟来算
logger.info("训练一个epoch大概:{} min".format(one_epoch_time))
print("Finish Train")
print('start Validation')
isevalue = True
if isevalue == True:
val_score =0
current_lr= optimizer.param_groups[0]['lr'] #当前的学习率
evalueModel=Evalue(net,val_loader,device,num_class,isResize=imgshape_base,ignore_index=ignoreindex,weight=None)
for i in valIndex:
#这个地方其实是默认了只有IouMiouP可能会出现数组
if i =="IouMiouP":
indextring = [ "acc_global", "acc", "iu","precion","recall","f1","miou","meanf1","pv_iou","pv_f1"]
result=eval('evalueModel.'+'evalue_'+i)()
for j in range(len(indextring)):
log.logger_index(result[j],indextring[j])
if indextring[j] == "acc_global" or indextring[j] == "miou" or indextring[j] == "meanf1" or indextring[j] =="pv_iou"or indextring[j] =="pv_f1":
if indextring[j] == "pv_iou":
val_score = result[j]
current_miou=val_score
epochs_score.append(val_score)
tenb.writer_singleindex(indextring[j],index=result[j],epoch=epoch)
else:
tenb.writer_classindex(indextring[j],classindexs=result[j],epoch=epoch)
else:
result = eval('evalueModel.' + 'evalue_' + i)()
log.logger_index(result, i)
tenb.writer_singleindex(i, index=result, epoch=epoch)
writer.add_scalar("lr", current_lr, epoch)
logger.info("当前epoch{}的lr:{}".format(epoch, current_lr))
scheduler.step()
# 保存最好的miou和最新的
if save_checkpoint:
checkpoint = {
"net": net.state_dict(),
'optimizer': optimizer.state_dict(),
"epoch": epoch,
'lr_schedule': scheduler.state_dict()
}
torch.save(checkpoint, os.path.join(workcheckpoint_dir, 'last.pth')) # 保存最新的
# 保存最好的
if current_miou >= best_miou:
best_miou = current_miou
torch.save(checkpoint, os.path.join(workcheckpoint_dir, 'best.pth'))
writer.add_scalar('best_epoch_index', epochs_score.index(max(epochs_score)) + 1, epoch) #记录bestepoch的变化情况
# 这个地方是按照迭代来调整学习率的
print('Finish Validation')
# 达到早停止条件时early_stop会被置为True
# early_stopping(val_score)
# if early_stopping.early_stop:
# print("Early stopping")
# time2 = time.time()
# end = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time2))
# logger.info("训练提前结束:训练{}个epoch的结束时间为:{}".format(epoch, end))
# break
time2 = time.time()
end = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time2))
print("训练结束:训练{}个epoch的结束时间为:{}".format(epochs, end))
logger.info("训练结束:训练{}个epoch的结束时间为:{}".format(epochs, end))
writer.close()
logger.info("训练完成")
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=200, help='Number of epochs')
parser.add_argument('--train_batch-size', '-tb', dest='train_batch_size', metavar='TB', type=int, default=8,
help='Train_Batch size')
parser.add_argument('--val_batch-size', '-vb', dest='val_batch_size', metavar='VB', type=int, default=1,
help='Val_Batch size')
parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-3,
help='Learning rate', dest='lr')
parser.add_argument('--load', '-f', type=str,
default="/home/qiqq/q3dl/code/pretrain_weight/pretrained/resnet/resnet50-0676ba61.pth",
help='Load model from a .pth file') # 有没有预训练。。
# parser.add_argument('--load', '-f', type=str,
# default="/home/qiqq/q3dl/code/applicationproject/taihuyuan_pv/mitunet/train_val_test/checkpoints/2023-09-19-17.34/best.pth",
# help='Load model from a .pth file') # 有没有预训练。。
parser.add_argument('--ignore_index', '-i', type=int, dest='ignore_index', default=255,
help='ignore index defult 100')
parser.add_argument('--origin_shape', action='store_true', default=(512,512), help='原始输入尺寸')
parser.add_argument('--resume', '-r', type=str, default="", help='is use Resume')
parser.add_argument('--useDice', '-ud', type=str, default=False, help='训练的时候是否使用dice')
parser.add_argument('--valIndex', '-vI', type=str, default=["Valloss","IouMiouP"], help='评价指标要使用哪些,注意IouMiouP= acc_global, acc, iu,precion,recall,f1,miou,并且建议旨在2分类的时候用dice否则会出错')
parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
return parser.parse_args()
if __name__ == '__main__':
# 有点乱啊,
args = get_args() # 管着非网络的参数的
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
# net igin_shape,pretrain_backbone=args.load)
net= resUnetpamcarb()
try:
train_net(net=net,
device = device,
resume=args.resume,
epochs=args.epochs,
isPretrain=args.load,
train_batch_size=args.train_batch_size,
val_batch_size=args.val_batch_size,
learning_rate=args.lr,
ignoreindex=args.ignore_index,
num_class=args.classes,
useDice=args.useDice,
valIndex=args.valIndex
)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
raise