import random import numpy as np import torch from PIL import Image #---------------------------------------------------------# # 将图像转换成RGB图像,防止灰度图在预测时报错。 # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB #---------------------------------------------------------# def cvtColor(image): if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: return image else: image = image.convert('RGB') return image #---------------------------------------------------# # 对输入图像进行resize #---------------------------------------------------# def resize_image(image, size): iw, ih = image.size w, h = size scale = min(w/iw, h/ih) nw = int(iw*scale) nh = int(ih*scale) image = image.resize((nw,nh), Image.BICUBIC) new_image = Image.new('RGB', size, (128,128,128)) new_image.paste(image, ((w-nw)//2, (h-nh)//2)) return new_image, nw, nh # 颜色会变 #---------------------------------------------------# # 获得学习率 #---------------------------------------------------# def get_lr(optimizer): for param_group in optimizer.param_groups: return param_group['lr'] #---------------------------------------------------# # 设置种子 #---------------------------------------------------# def seed_everything(seed=11): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False #---------------------------------------------------# # 设置Dataloader的种子 #---------------------------------------------------# def worker_init_fn(worker_id, rank, seed): worker_seed = rank + seed random.seed(worker_seed) np.random.seed(worker_seed) torch.manual_seed(worker_seed) def preprocess_input(image): image -= np.array([123.675, 116.28, 103.53], np.float32) image /= np.array([58.395, 57.12, 57.375], np.float32) return image def show_config(**kwargs): print('Configurations:') print('-' * 70) print('|%25s | %40s|' % ('keys', 'values')) print('-' * 70) for key, value in kwargs.items(): print('|%25s | %40s|' % (str(key), str(value))) print('-' * 70) def download_weights(phi, model_dir="./model_data"): import os from torch.hub import load_state_dict_from_url download_urls = { 'b0' : "https://github.com/bubbliiiing/segformer-pytorch/releases/download/v1.0/segformer_b0_backbone_weights.pth", 'b1' : "https://github.com/bubbliiiing/segformer-pytorch/releases/download/v1.0/segformer_b1_backbone_weights.pth", 'b2' : "https://github.com/bubbliiiing/segformer-pytorch/releases/download/v1.0/segformer_b2_backbone_weights.pth", 'b3' : "https://github.com/bubbliiiing/segformer-pytorch/releases/download/v1.0/segformer_b3_backbone_weights.pth", 'b4' : "https://github.com/bubbliiiing/segformer-pytorch/releases/download/v1.0/segformer_b4_backbone_weights.pth", 'b5' : "https://github.com/bubbliiiing/segformer-pytorch/releases/download/v1.0/segformer_b5_backbone_weights.pth", } url = download_urls[phi] if not os.path.exists(model_dir): os.makedirs(model_dir) load_state_dict_from_url(url, model_dir)