ai-station-code/wudingpv/test_pv.py

143 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import matplotlib.pyplot as plt
import os
import sys
# print(sys.path)
# # 获取当前脚本的父目录
current_dir = os.path.dirname(os.path.abspath(__file__))
print(current_dir)
# # module_path = os.path.join(current_dir, '../my_module')
# # 将模块路径添加到 sys.path
# sys.path.append("d:\\project\\ai_station\\wudingpv")
from taihuyuan_pv.dataloaders import custom_transforms as tr
import torch
import torch.nn.functional as F
from taihuyuan_pv.mitunet.model.resunet import resUnetpamcarb as pv_resUnetpamcarb
from taihuyuan_roof.manet.model.resunet import resUnetpamcarb as roof_resUnetpamcarb
from predictandeval_util import segmentation
"""util"""
# 测试只有一张512*512的光伏识别
def single_image_pv(testss,model):
path_file = "D:\\project\\ai_station\\wudingpv\\tmp\\7b2ca8e5-e228-498b-ade1-d72515993b40\\jinbei_4-2_71.png"
testss.start_segmentation(path_file,net = model)
# 测试,大图的光伏识别
def big_image_pv(testss,model):
# testss.create_tmp_path(current_dir)
path_file = "D:\\project\\ai_station\\wudingpv\\tmp\\9c804d27-5076-4356-ac0e-69d933ba2170\\taihuyuan_7-8.png"
"""图片分割"""
testss.cut_big_image(path_file)
"""这里是存储分割结果的地方"""
ori_path = 'D:\\project\\ai_station\\wudingpv\\tmp\\9c804d27-5076-4356-ac0e-69d933ba2170\\ori'
testss.folder_segment(ori_path,net=model)
"""图片合并"""
testss.merge_pic(path_file)
# 测试,屋顶识别,标准图
def single_image_roof(testss,model):
path_file = "D:\\project\\ai_station\\wudingpv\\tmp\\18fb2a4d-7a45-4f2f-9176-11e0e3b3ffee\\jinbei_4-2_71.png"
testss.start_segmentation(path_file,net = model)
# 测试,大图的光伏识别
def big_image_roof(testss,model):
# testss.create_tmp_path(current_dir)
path_file = "D:\\project\\ai_station\\wudingpv\\tmp\\99f06853-4788-4608-8884-2a3b7bc33768\\taihuyuan_7-8.png"
"""图片分割"""
testss.cut_big_image(path_file)
"""这里是存储分割结果的地方"""
ori_path = 'D:\\project\\ai_station\\wudingpv\\tmp\\99f06853-4788-4608-8884-2a3b7bc33768\\ori'
testss.folder_segment(ori_path,net=model)
"""图片合并"""
testss.merge_pic(path_file)
# 测试,大图的屋顶光伏
def big_image_roofpv(testss,model_pv,model_roof):
# testss.create_tmp_path(current_dir)
path_file = "D:\\project\\ai_station\\wudingpv\\tmp\\99f06853-4788-4608-8884-2a3b7bc32131\\taihuyuan_7-8.png"
"""图片分割"""
testss.cut_big_image(path_file)
"""这里是存储分割结果的地方"""
ori_path = 'D:\\project\\ai_station\\wudingpv\\tmp\\99f06853-4788-4608-8884-2a3b7bc32131\\ori'
testss.folder_segment_all(ori_path,net=model_roof,mode='roof')
testss.folder_segment_all(ori_path,net=model_pv,mode='pv')
"""图片合并"""
path_list = testss.merge_pic_all(path_file)
binary_path = testss.merge_binary(path_list)
testss.merge_final([path_file,binary_path])
if __name__ == '__main__':
# pv 调用
model_pv = pv_resUnetpamcarb()
model_path_pv = os.path.join(current_dir,'models/pv_best.pth')
model_dict_pv = torch.load(model_path_pv, map_location=torch.device('cpu'))
model_pv.load_state_dict(model_dict_pv['net'])
print("权重加载")
model_pv.eval()
model_pv.cuda()
model_name = "pv"
testss_pv = segmentation(model_name="pv")
single_image_pv(testss_pv,model_pv)
big_image_pv(testss_pv,model_pv)
# # roof 调用
model_roof = roof_resUnetpamcarb()
model_path_roof = os.path.join(current_dir,'models/roof_best.pth')
model_dict_roof = torch.load(model_path_roof, map_location=torch.device('cpu'))
model_roof.load_state_dict(model_dict_roof['net'])
print("权重加载")
model_roof.eval()
model_roof.cuda()
#model.cpu()
model_name = "roof"
testss_roof = segmentation(model_name="roof")
# single_image_roof(testss_roof,model_roof)
big_image_roof(testss_roof,model_roof)
# 屋顶光伏识别
# model_pv = pv_resUnetpamcarb()
# model_path_pv = os.path.join(current_dir,'models/pv_best.pth')
# model_dict_pv = torch.load(model_path_pv, map_location=torch.device('cpu'))
# model_pv.load_state_dict(model_dict_pv['net'])
# print("权重加载")
# model_pv.eval()
# model_pv.cuda()
# model_name = "pv"
# testss_pv = segmentation(model_name="pv")
# model_roof = roof_resUnetpamcarb()
# model_path_roof = os.path.join(current_dir,'models/roof_best.pth')
# model_dict_roof = torch.load(model_path_roof, map_location=torch.device('cpu'))
# model_roof.load_state_dict(model_dict_roof['net'])
# print("权重加载")
# model_roof.eval()
# model_roof.cuda()
# model_name = "roof"
# testss_roof = segmentation(model_name="roof")
# big_image_roofpv(testss_pv,model_pv,model_roof)