ai-station-code/wudingpv/post_processing/roofpv_generate.py

141 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : roofpv_gender
@Author : qiqq
@create_time : 2023/7/2 16:28
"""
#根据屋頂和光伏结果结果合成多目标图
from collections import namedtuple #namedtuple创建一个和tuple类似的对象而且对象拥有可访问的属性
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
Cls = namedtuple('cls', ['name', 'id', 'color'])
Clss = [
Cls('c0', 0, (0, 0, 0)),
Cls('c1', 1, (0, 255, 0)),
Cls('c2', 1, (255, 0, 0)),
]
def get_putpalette(Clss, color_other=[0, 0, 0]):
'''
灰度图转8bit彩色图
:param Clss:颜色映射表
:param color_other:其余颜色设置
:return:
'''
putpalette = []
for cls in Clss:
putpalette += list(cls.color)
putpalette += color_other * (255 - len(Clss))
return putpalette
def Convert_get_color8bit(gt):
'''
灰度图转8bit彩色图
:param grays_path: 灰度图文件路径
:param colors_path: 彩色图文件路径
:return:
'''
bin_colormap = get_putpalette(Clss)
gt = Image.fromarray(gt)
gt = gt.convert("P")
gt.putpalette(bin_colormap)
return gt
bin_colormap = get_putpalette(Clss)
PALETTE = [(0, 0, 0), (255, 0, 0), (0, 255, 0),(0,0,255)]
palette = [0, 0, 0, 255, 0, 0, 0, 255, 0,0, 0, 255]
def roofpvgener(roofresult,pvresult,size):
'''
大概的思路:
屋顶数据集r0 背景 r1屋顶
光伏数据集p0 背景 p1光伏
生成一张空白图,逐个像素的操作
r0 p0 -->背景 0 黑色
r1 p0 --->屋顶 1(应该是空白屋顶了) #红色
r0 p1 -->非屋顶光伏 2 #绿色
r1 p1--- 屋顶光伏 3 #蓝色
'''
roof_nroof_pv = np.zeros(size)
# 使用numpy的条件判断语句进行逐元素的判断和赋值
roof_nroof_pv[(pvresult == 0) & (roofresult == 0)] = 0 #黑色
roof_nroof_pv[(pvresult == 0) & (roofresult == 1)] = 1 #红色不带光伏的屋顶
roof_nroof_pv[(pvresult == 1) & (roofresult == 0)] = 2
roof_nroof_pv[(pvresult == 1) & (roofresult == 1)] = 3
return roof_nroof_pv
if __name__ == '__main__':
##读取
model_name="manet"
pv_path="/"#光伏分割结果
roof_path="/"#屋顶分割结果
imagedir="/"##原图路径
rnrpv_outdir="./"+model_name+"/"+model_name+"_output"
rnrpv_fusindir="./"+model_name+"/"+model_name+"_fusion"
useFusion=False
pv_list= os.listdir(pv_path)
roof_list= os.listdir(roof_path)
assert len(pv_list)==len(roof_list)
if not os.path.exists(model_name):
os.makedirs(model_name)
if not os.path.exists(rnrpv_outdir):
os.makedirs(rnrpv_outdir)
for i in tqdm(pv_list):
imagename= i.split(".")[0]
pvimage = Image.open(os.path.join(pv_path,i))
roofimage=Image.open(os.path.join(roof_path,i))
pv =np.array(pvimage)
roof= np.array(roofimage)
orininal_h, orininal_w=pv.shape
result =roofpvgener(roof,pv,size=(orininal_h, orininal_w))
result=np.uint8(result)
result = Image.fromarray(result).convert('P') # 原来的
result.putpalette(palette)
result.save(os.path.join(rnrpv_outdir, imagename + ".png"))
if useFusion:
if not os.path.exists(rnrpv_fusindir):
os.makedirs(rnrpv_fusindir)
sourceimage =Image.open(os.path.join(imagedir,imagename+".png"))
result = np.reshape(np.array(PALETTE, np.uint8)[np.reshape(result, [-1])],
[orininal_h, orininal_w, -1])
image = Image.fromarray(np.uint8(result))
fusion = Image.blend(sourceimage, image, 0.3)
fusion.save(os.path.join(rnrpv_fusindir,imagename+".png"))
print("完成")