ai-station-code/wudingpv/taihuyuan_roof/manet/model/resunet.py

149 lines
4.9 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : resunet
@Author : qiqq
@create_time : 2023/7/20 18:45
"""
import torch
import torch.nn as nn
from wudingpv.taihuyuan_roof.manet.model.resnet import resnet50
from wudingpv.taihuyuan_roof.manet.model.decoder import *
class unetUp(nn.Module):
def __init__(self, in_size, out_size):
super(unetUp, self).__init__()
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
self.cbr = nn.Sequential(
nn.Conv2d(in_size, out_size, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_size),
nn.ReLU(inplace=True)
)
def forward(self, inputs1, inputs2):
outputs = torch.cat([inputs1, self.up(inputs2)], 1)
outputs = self.cbr(outputs)
return outputs
class resUnet(nn.Module):
def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
super(resUnet, self).__init__()
self.nclas = num_classes
self.finnal_channel = 512
self.backbone = resnet50(pretrained=pretrained)
self.decoder = unetDecoder(in_filters=[ 512, 1024, 3072],out_filters=[128, 256, 512])
def forward(self, inputs):
feaureslist = self.backbone(inputs) #2 4 8 16 32
feaureslist=feaureslist[1:]
out = self.decoder(feaureslist)
out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
return out
class resUnet2(nn.Module):
def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
super(resUnet2, self).__init__()
self.nclas = num_classes
self.finnal_channel = 512
self.backbone = resnet50(pretrained=pretrained)
#中间用carb消融的那个3*3卷积降维了
self.decoder = unetDecoder( in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512])
def forward(self, inputs):
feaureslist = self.backbone(inputs) #2 4 8 16 32
feaureslist=feaureslist[1:]
out = self.decoder(feaureslist)
out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
return out
class resUnetcarb(nn.Module):
def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
super(resUnetcarb, self).__init__()
# in_filters = [192, 384, 768], out_filters = [64, 128, 256]
self.nclas = num_classes
self.finnal_channel = 512
self.backbone = resnet50(pretrained=pretrained)
self.decoder = unetCARBDecoder()
def forward(self, inputs):
feaureslist = self.backbone(inputs) #2 4 8 16 32
feaureslist=feaureslist[1:]
out = self.decoder(feaureslist)
out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
return out
class resUnetpamcarb(nn.Module):
def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
super(resUnetpamcarb, self).__init__()
# in_filters = [192, 384, 768], out_filters = [64, 128, 256]
self.nclas = num_classes
self.finnal_channel = 512
self.backbone = resnet50(pretrained=pretrained)
self.decoder = unetpamCARBDecoder()
def forward(self, inputs):
feaureslist = self.backbone(inputs) #2 4 8 16 32
feaureslist=feaureslist[1:]
out = self.decoder(feaureslist)
out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
return out
class resUnetPAM(nn.Module):
def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
super(resUnetPAM, self).__init__()
self.nclas = num_classes
self.backbone = resnet50(pretrained=pretrained)
self.decoder = unetpamDecoder(in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512])
def forward(self, inputs):
feaureslist = self.backbone(inputs) #2 4 8 16 32
feaureslist=feaureslist[1:]
out = self.decoder(feaureslist)
out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
return out
# class resUnetpamcarb(nn.Module):
# def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
# super(resUnetpamcarb, self).__init__()
#
# self.nclas = num_classes
# self.finnal_channel = 512
# self.backbone = resnet50(pretrained=pretrained)
# self.decoder = unetpamCARBDecoder()
#
# def forward(self, inputs):
#
# feaureslist = self.backbone(inputs) #2 4 8 16 32
# feaureslist=feaureslist[1:]
# out = self.decoder(feaureslist)
# out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
# return out
if __name__ == '__main__':
indd=torch.rand(2,3,512,512)
modl=resUnet()
out=modl(indd)
print(type(out))
print(out.shape)