ai-station-code/wudingpv/taihuyuan_pv/mitunet/model/resunet.py

188 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : resunet
@Author : qiqq
@create_time : 2023/7/20 18:45
"""
import torch
import torch.nn as nn
from taihuyuan_pv.mitunet.model.resnet import resnet50
from taihuyuan_pv.mitunet.model.decoder import *
class unetUp(nn.Module):
def __init__(self, in_size, out_size):
super(unetUp, self).__init__()
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
self.cbr = nn.Sequential(
nn.Conv2d(in_size, out_size, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_size),
nn.ReLU(inplace=True)
)
def forward(self, inputs1, inputs2):
outputs = torch.cat([inputs1, self.up(inputs2)], 1)
outputs = self.cbr(outputs)
return outputs
class resUnet(nn.Module):
def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
super(resUnet, self).__init__()
self.nclas = num_classes
self.finnal_channel = 512
self.backbone = resnet50(pretrained=pretrained)
# #j最后一层加了个trans 2048-512 解码器用了两个3*3深度cbr
# self.decoder = unetDecoder(in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512])
self.decoder = unetDecoder(in_filters=[ 512, 1024, 3072],out_filters=[128, 256, 512])
def forward(self, inputs):
feaureslist = self.backbone(inputs) #2 4 8 16 32
feaureslist=feaureslist[1:]
out = self.decoder(feaureslist)
out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
return out
#
#
class resUnetPAM(nn.Module):
def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
super(resUnetPAM, self).__init__()
self.nclas = num_classes
self.backbone = resnet50(pretrained=pretrained)
self.decoder = unetpamDecoder(in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512])
def forward(self, inputs):
feaureslist = self.backbone(inputs) #2 4 8 16 32
feaureslist=feaureslist[1:]
out = self.decoder(feaureslist)
out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
return out
class resUnetpamcarb(nn.Module):
def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
super(resUnetpamcarb, self).__init__()
self.nclas = num_classes
self.finnal_channel = 512
self.backbone = resnet50(pretrained=pretrained)
self.decoder = unetpamCARBDecoder()
def forward(self, inputs):
feaureslist = self.backbone(inputs) #2 4 8 16 32
feaureslist=feaureslist[1:]
out = self.decoder(feaureslist)
out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
return out
class resUnetaspp(nn.Module):
def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
super(resUnetaspp, self).__init__()
self.nclas = num_classes
self.backbone = resnet50(pretrained=pretrained)
##aspp出来是256
'''
256 512 1024 2048_256
用一个1*1变成128 256 512
in_filters=[ 448, 640, 768],out_filters=[128, 320, 384]
'''
self.decoder = unetasppDecoder( in_filters=[ 448, 640, 768],out_filters=[128, 320, 384] )
def forward(self, inputs):
feaureslist = self.backbone(inputs) #2 4 8 16 32
feaureslist=feaureslist[1:]
out = self.decoder(feaureslist)
out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
return out
# class resUnetpamcam(nn.Module):
# def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
# super(resUnetpamcam, self).__init__()
#
# self.nclas = num_classes
# self.finnal_channel = 512
# self.backbone = resnet50(pretrained=pretrained)
# self.decoder = unetpamcamDecoder()
#
# def forward(self, inputs):
#
# feaureslist = self.backbone(inputs) #2 4 8 16 32
# feaureslist=feaureslist[1:]
# out = self.decoder(feaureslist)
# out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
# return out
#
#
#
# # 1/4 1/8 1/16 1/32 记作 1 2 3 4
# class resUnetpamcarb_4(nn.Module):
# def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
# super(resUnetpamcarb_4, self).__init__()
#
# self.nclas = num_classes
# self.finnal_channel = 512
# self.backbone = resnet50(pretrained=pretrained)
# self.decoder = unetpamDecoderzuhe()
#
# def forward(self, inputs):
#
# feaureslist = self.backbone(inputs) #2 4 8 16 32
# feaureslist=feaureslist[1:]
# out = self.decoder(feaureslist)
# out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
# return out
#
#################resunet大调参版本2, 再不行我就不活了#####################################
class resUnet_version2(nn.Module):
def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
super(resUnet_version2, self).__init__()
self.nclas = num_classes
self.finnal_channel = 512
self.backbone = resnet50(pretrained=pretrained)
self.decoder = unetDecoder2(in_filters = [384, 768, 1536], out_filters = [128, 256, 512])
def forward(self, inputs):
feaureslist = self.backbone(inputs) #2 4 8 16 32
feaureslist=feaureslist[1:]
out = self.decoder(feaureslist)
out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
return out
#
#############################################
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
indd=torch.rand(2,3,512,512).to(device)
modl=resUnet_version2().to(device)
out=modl(indd)
print(type(out))
print(out.shape)