201 lines
7.9 KiB
Python
201 lines
7.9 KiB
Python
import os
|
||
import time
|
||
import datetime
|
||
|
||
import torch
|
||
|
||
from src import lraspp_mobilenetv3_large
|
||
from train_utils import train_one_epoch, evaluate, create_lr_scheduler
|
||
from my_dataset import VOCSegmentation
|
||
import transforms as T
|
||
|
||
|
||
class SegmentationPresetTrain:
|
||
def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
|
||
min_size = int(0.5 * base_size)
|
||
max_size = int(2.0 * base_size)
|
||
|
||
trans = [T.RandomResize(min_size, max_size)]
|
||
if hflip_prob > 0:
|
||
trans.append(T.RandomHorizontalFlip(hflip_prob))
|
||
trans.extend([
|
||
T.RandomCrop(crop_size),
|
||
T.ToTensor(),
|
||
T.Normalize(mean=mean, std=std),
|
||
])
|
||
self.transforms = T.Compose(trans)
|
||
|
||
def __call__(self, img, target):
|
||
return self.transforms(img, target)
|
||
|
||
|
||
class SegmentationPresetEval:
|
||
def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
|
||
self.transforms = T.Compose([
|
||
T.RandomResize(base_size, base_size),
|
||
T.ToTensor(),
|
||
T.Normalize(mean=mean, std=std),
|
||
])
|
||
|
||
def __call__(self, img, target):
|
||
return self.transforms(img, target)
|
||
|
||
|
||
def get_transform(train):
|
||
base_size = 520
|
||
crop_size = 480
|
||
|
||
return SegmentationPresetTrain(base_size, crop_size) if train else SegmentationPresetEval(base_size)
|
||
|
||
|
||
def create_model(num_classes, pretrain=True):
|
||
model = lraspp_mobilenetv3_large(num_classes=num_classes)
|
||
|
||
if pretrain:
|
||
weights_dict = torch.load("./lraspp_mobilenet_v3_large.pth", map_location='cpu')
|
||
|
||
if num_classes != 21:
|
||
# 官方提供的预训练权重是21类(包括背景)
|
||
# 如果训练自己的数据集,将和类别相关的权重删除,防止权重shape不一致报错
|
||
for k in list(weights_dict.keys()):
|
||
if "low_classifier" in k or "high_classifier" in k:
|
||
del weights_dict[k]
|
||
|
||
missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)
|
||
if len(missing_keys) != 0 or len(unexpected_keys) != 0:
|
||
print("missing_keys: ", missing_keys)
|
||
print("unexpected_keys: ", unexpected_keys)
|
||
|
||
return model
|
||
|
||
|
||
def main(args):
|
||
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||
batch_size = args.batch_size
|
||
# segmentation nun_classes + background
|
||
num_classes = args.num_classes + 1
|
||
|
||
# 用来保存训练以及验证过程中信息
|
||
results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
|
||
|
||
# VOCdevkit -> VOC2012 -> ImageSets -> Segmentation -> train.txt
|
||
train_dataset = VOCSegmentation(args.data_path,
|
||
year="2012",
|
||
transforms=get_transform(train=True),
|
||
txt_name="train.txt")
|
||
|
||
# VOCdevkit -> VOC2012 -> ImageSets -> Segmentation -> val.txt
|
||
val_dataset = VOCSegmentation(args.data_path,
|
||
year="2012",
|
||
transforms=get_transform(train=False),
|
||
txt_name="val.txt")
|
||
|
||
num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
|
||
train_loader = torch.utils.data.DataLoader(train_dataset,
|
||
batch_size=batch_size,
|
||
num_workers=num_workers,
|
||
shuffle=True,
|
||
pin_memory=True,
|
||
collate_fn=train_dataset.collate_fn)
|
||
|
||
val_loader = torch.utils.data.DataLoader(val_dataset,
|
||
batch_size=1,
|
||
num_workers=num_workers,
|
||
pin_memory=True,
|
||
collate_fn=val_dataset.collate_fn)
|
||
|
||
model = create_model(num_classes=num_classes)
|
||
model.to(device)
|
||
|
||
params_to_optimize = [
|
||
{"params": [p for p in model.backbone.parameters() if p.requires_grad]},
|
||
{"params": [p for p in model.classifier.parameters() if p.requires_grad]}
|
||
]
|
||
|
||
optimizer = torch.optim.SGD(
|
||
params_to_optimize,
|
||
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
|
||
)
|
||
|
||
scaler = torch.cuda.amp.GradScaler() if args.amp else None
|
||
|
||
# 创建学习率更新策略,这里是每个step更新一次(不是每个epoch)
|
||
lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs, warmup=True)
|
||
|
||
if args.resume:
|
||
checkpoint = torch.load(args.resume, map_location='cpu')
|
||
model.load_state_dict(checkpoint['model'])
|
||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
||
args.start_epoch = checkpoint['epoch'] + 1
|
||
if args.amp:
|
||
scaler.load_state_dict(checkpoint["scaler"])
|
||
|
||
start_time = time.time()
|
||
for epoch in range(args.start_epoch, args.epochs):
|
||
mean_loss, lr = train_one_epoch(model, optimizer, train_loader, device, epoch,
|
||
lr_scheduler=lr_scheduler, print_freq=args.print_freq, scaler=scaler)
|
||
|
||
confmat = evaluate(model, val_loader, device=device, num_classes=num_classes)
|
||
val_info = str(confmat)
|
||
print(val_info)
|
||
# write into txt
|
||
with open(results_file, "a") as f:
|
||
# 记录每个epoch对应的train_loss、lr以及验证集各指标
|
||
train_info = f"[epoch: {epoch}]\n" \
|
||
f"train_loss: {mean_loss:.4f}\n" \
|
||
f"lr: {lr:.6f}\n"
|
||
f.write(train_info + val_info + "\n\n")
|
||
|
||
save_file = {"model": model.state_dict(),
|
||
"optimizer": optimizer.state_dict(),
|
||
"lr_scheduler": lr_scheduler.state_dict(),
|
||
"epoch": epoch,
|
||
"args": args}
|
||
if args.amp:
|
||
save_file["scaler"] = scaler.state_dict()
|
||
torch.save(save_file, "save_weights/model_{}.pth".format(epoch))
|
||
|
||
total_time = time.time() - start_time
|
||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||
print("training time {}".format(total_time_str))
|
||
|
||
|
||
def parse_args():
|
||
import argparse
|
||
parser = argparse.ArgumentParser(description="pytorch lraspp training")
|
||
|
||
parser.add_argument("--data-path", default="/data/", help="VOCdevkit root")
|
||
parser.add_argument("--num-classes", default=20, type=int)
|
||
parser.add_argument("--device", default="cuda", help="training device")
|
||
parser.add_argument("-b", "--batch-size", default=4, type=int)
|
||
parser.add_argument("--epochs", default=30, type=int, metavar="N",
|
||
help="number of total epochs to train")
|
||
|
||
parser.add_argument('--lr', default=0.0001, type=float, help='initial learning rate')
|
||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||
help='momentum')
|
||
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
||
metavar='W', help='weight decay (default: 1e-4)',
|
||
dest='weight_decay')
|
||
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
|
||
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
||
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
||
help='start epoch')
|
||
# Mixed precision training parameters
|
||
parser.add_argument("--amp", default=False, type=bool,
|
||
help="Use torch.cuda.amp for mixed precision training")
|
||
|
||
args = parser.parse_args()
|
||
|
||
return args
|
||
|
||
|
||
if __name__ == '__main__':
|
||
args = parse_args()
|
||
|
||
if not os.path.exists("./save_weights"):
|
||
os.mkdir("./save_weights")
|
||
|
||
main(args)
|