renewable_eva/utils/utils.py

64 lines
2.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
from PIL import Image
#---------------------------------------------------------#
# 将图像转换成RGB图像防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
def cvtColor(image):
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
return image
else:
image = image.convert('RGB')
return image
#---------------------------------------------------#
# 对输入图像进行resize
#---------------------------------------------------#
def resize_image(image, size):
iw, ih = image.size
w, h = size
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', size, (128,128,128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
return new_image, nw, nh
#---------------------------------------------------#
# 获得学习率
#---------------------------------------------------#
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
def preprocess_input(image):
image /= 255.0
return image
def show_config(**kwargs):
print('Configurations:')
print('-' * 70)
print('|%25s | %40s|' % ('keys', 'values'))
print('-' * 70)
for key, value in kwargs.items():
print('|%25s | %40s|' % (str(key), str(value)))
print('-' * 70)
def download_weights(backbone, model_dir="./model_data"):
import os
from torch.hub import load_state_dict_from_url
download_urls = {
'mobilenet' : 'https://github.com/bubbliiiing/deeplabv3-plus-pytorch/releases/download/v1.0/mobilenet_v2.pth.tar',
'xception' : 'https://github.com/bubbliiiing/deeplabv3-plus-pytorch/releases/download/v1.0/xception_pytorch_imagenet.pth',
}
url = download_urls[backbone]
if not os.path.exists(model_dir):
os.makedirs(model_dir)
load_state_dict_from_url(url, model_dir)