93 lines
3.4 KiB
Python
93 lines
3.4 KiB
Python
|
import random
|
|||
|
|
|||
|
import numpy as np
|
|||
|
import torch
|
|||
|
from PIL import Image
|
|||
|
|
|||
|
#---------------------------------------------------------#
|
|||
|
# 将图像转换成RGB图像,防止灰度图在预测时报错。
|
|||
|
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
|||
|
#---------------------------------------------------------#
|
|||
|
def cvtColor(image):
|
|||
|
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
|
|||
|
return image
|
|||
|
else:
|
|||
|
image = image.convert('RGB')
|
|||
|
return image
|
|||
|
|
|||
|
#---------------------------------------------------#
|
|||
|
# 对输入图像进行resize
|
|||
|
#---------------------------------------------------#
|
|||
|
def resize_image(image, size):
|
|||
|
iw, ih = image.size
|
|||
|
w, h = size
|
|||
|
|
|||
|
scale = min(w/iw, h/ih)
|
|||
|
nw = int(iw*scale)
|
|||
|
nh = int(ih*scale)
|
|||
|
|
|||
|
image = image.resize((nw,nh), Image.BICUBIC)
|
|||
|
new_image = Image.new('RGB', size, (128,128,128))
|
|||
|
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
|
|||
|
|
|||
|
return new_image, nw, nh # 颜色会变
|
|||
|
|
|||
|
#---------------------------------------------------#
|
|||
|
# 获得学习率
|
|||
|
#---------------------------------------------------#
|
|||
|
def get_lr(optimizer):
|
|||
|
for param_group in optimizer.param_groups:
|
|||
|
return param_group['lr']
|
|||
|
|
|||
|
#---------------------------------------------------#
|
|||
|
# 设置种子
|
|||
|
#---------------------------------------------------#
|
|||
|
def seed_everything(seed=11):
|
|||
|
random.seed(seed)
|
|||
|
np.random.seed(seed)
|
|||
|
torch.manual_seed(seed)
|
|||
|
torch.cuda.manual_seed(seed)
|
|||
|
torch.cuda.manual_seed_all(seed)
|
|||
|
torch.backends.cudnn.deterministic = True
|
|||
|
torch.backends.cudnn.benchmark = False
|
|||
|
|
|||
|
#---------------------------------------------------#
|
|||
|
# 设置Dataloader的种子
|
|||
|
#---------------------------------------------------#
|
|||
|
def worker_init_fn(worker_id, rank, seed):
|
|||
|
worker_seed = rank + seed
|
|||
|
random.seed(worker_seed)
|
|||
|
np.random.seed(worker_seed)
|
|||
|
torch.manual_seed(worker_seed)
|
|||
|
|
|||
|
def preprocess_input(image):
|
|||
|
image -= np.array([123.675, 116.28, 103.53], np.float32)
|
|||
|
image /= np.array([58.395, 57.12, 57.375], np.float32)
|
|||
|
return image
|
|||
|
|
|||
|
def show_config(**kwargs):
|
|||
|
print('Configurations:')
|
|||
|
print('-' * 70)
|
|||
|
print('|%25s | %40s|' % ('keys', 'values'))
|
|||
|
print('-' * 70)
|
|||
|
for key, value in kwargs.items():
|
|||
|
print('|%25s | %40s|' % (str(key), str(value)))
|
|||
|
print('-' * 70)
|
|||
|
|
|||
|
def download_weights(phi, model_dir="./model_data"):
|
|||
|
import os
|
|||
|
from torch.hub import load_state_dict_from_url
|
|||
|
|
|||
|
download_urls = {
|
|||
|
'b0' : "https://github.com/bubbliiiing/segformer-pytorch/releases/download/v1.0/segformer_b0_backbone_weights.pth",
|
|||
|
'b1' : "https://github.com/bubbliiiing/segformer-pytorch/releases/download/v1.0/segformer_b1_backbone_weights.pth",
|
|||
|
'b2' : "https://github.com/bubbliiiing/segformer-pytorch/releases/download/v1.0/segformer_b2_backbone_weights.pth",
|
|||
|
'b3' : "https://github.com/bubbliiiing/segformer-pytorch/releases/download/v1.0/segformer_b3_backbone_weights.pth",
|
|||
|
'b4' : "https://github.com/bubbliiiing/segformer-pytorch/releases/download/v1.0/segformer_b4_backbone_weights.pth",
|
|||
|
'b5' : "https://github.com/bubbliiiing/segformer-pytorch/releases/download/v1.0/segformer_b5_backbone_weights.pth",
|
|||
|
}
|
|||
|
url = download_urls[phi]
|
|||
|
|
|||
|
if not os.path.exists(model_dir):
|
|||
|
os.makedirs(model_dir)
|
|||
|
load_state_dict_from_url(url, model_dir)
|