renewable_eva/utils/utils.py

64 lines
2.1 KiB
Python
Raw Normal View History

2024-06-18 10:59:34 +08:00
import numpy as np
from PIL import Image
#---------------------------------------------------------#
# 将图像转换成RGB图像防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
def cvtColor(image):
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
return image
else:
image = image.convert('RGB')
return image
#---------------------------------------------------#
# 对输入图像进行resize
#---------------------------------------------------#
def resize_image(image, size):
iw, ih = image.size
w, h = size
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', size, (128,128,128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
return new_image, nw, nh
#---------------------------------------------------#
# 获得学习率
#---------------------------------------------------#
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
def preprocess_input(image):
image /= 255.0
return image
def show_config(**kwargs):
print('Configurations:')
print('-' * 70)
print('|%25s | %40s|' % ('keys', 'values'))
print('-' * 70)
for key, value in kwargs.items():
print('|%25s | %40s|' % (str(key), str(value)))
print('-' * 70)
def download_weights(backbone, model_dir="./model_data"):
import os
from torch.hub import load_state_dict_from_url
download_urls = {
'mobilenet' : 'https://github.com/bubbliiiing/deeplabv3-plus-pytorch/releases/download/v1.0/mobilenet_v2.pth.tar',
'xception' : 'https://github.com/bubbliiiing/deeplabv3-plus-pytorch/releases/download/v1.0/xception_pytorch_imagenet.pth',
}
url = download_urls[backbone]
if not os.path.exists(model_dir):
os.makedirs(model_dir)
load_state_dict_from_url(url, model_dir)