188 lines
6.1 KiB
Python
188 lines
6.1 KiB
Python
|
#!/usr/bin/env python
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
@project:
|
|||
|
@File : resunet
|
|||
|
@Author : qiqq
|
|||
|
@create_time : 2023/7/20 18:45
|
|||
|
"""
|
|||
|
import torch
|
|||
|
import torch.nn as nn
|
|||
|
|
|||
|
from taihuyuan_pv.mitunet.model.resnet import resnet50
|
|||
|
from taihuyuan_pv.mitunet.model.decoder import *
|
|||
|
|
|||
|
|
|||
|
|
|||
|
class unetUp(nn.Module):
|
|||
|
def __init__(self, in_size, out_size):
|
|||
|
super(unetUp, self).__init__()
|
|||
|
|
|||
|
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
|
|||
|
|
|||
|
self.cbr = nn.Sequential(
|
|||
|
nn.Conv2d(in_size, out_size, 3, 1, 1, bias=False),
|
|||
|
nn.BatchNorm2d(out_size),
|
|||
|
nn.ReLU(inplace=True)
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def forward(self, inputs1, inputs2):
|
|||
|
outputs = torch.cat([inputs1, self.up(inputs2)], 1)
|
|||
|
outputs = self.cbr(outputs)
|
|||
|
|
|||
|
return outputs
|
|||
|
|
|||
|
|
|||
|
class resUnet(nn.Module):
|
|||
|
def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
|
|||
|
super(resUnet, self).__init__()
|
|||
|
|
|||
|
self.nclas = num_classes
|
|||
|
self.finnal_channel = 512
|
|||
|
self.backbone = resnet50(pretrained=pretrained)
|
|||
|
# #j最后一层加了个trans (2048-512) 解码器用了两个3*3深度cbr
|
|||
|
# self.decoder = unetDecoder(in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512])
|
|||
|
self.decoder = unetDecoder(in_filters=[ 512, 1024, 3072],out_filters=[128, 256, 512])
|
|||
|
|
|||
|
def forward(self, inputs):
|
|||
|
|
|||
|
feaureslist = self.backbone(inputs) #2 4 8 16 32
|
|||
|
feaureslist=feaureslist[1:]
|
|||
|
out = self.decoder(feaureslist)
|
|||
|
out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
|
|||
|
return out
|
|||
|
|
|||
|
#
|
|||
|
#
|
|||
|
class resUnetPAM(nn.Module):
|
|||
|
def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
|
|||
|
super(resUnetPAM, self).__init__()
|
|||
|
|
|||
|
self.nclas = num_classes
|
|||
|
self.backbone = resnet50(pretrained=pretrained)
|
|||
|
self.decoder = unetpamDecoder(in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512])
|
|||
|
|
|||
|
def forward(self, inputs):
|
|||
|
|
|||
|
feaureslist = self.backbone(inputs) #2 4 8 16 32
|
|||
|
feaureslist=feaureslist[1:]
|
|||
|
out = self.decoder(feaureslist)
|
|||
|
out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
|
|||
|
return out
|
|||
|
|
|||
|
class resUnetpamcarb(nn.Module):
|
|||
|
def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
|
|||
|
super(resUnetpamcarb, self).__init__()
|
|||
|
self.nclas = num_classes
|
|||
|
self.finnal_channel = 512
|
|||
|
self.backbone = resnet50(pretrained=pretrained)
|
|||
|
self.decoder = unetpamCARBDecoder()
|
|||
|
|
|||
|
def forward(self, inputs):
|
|||
|
|
|||
|
feaureslist = self.backbone(inputs) #2 4 8 16 32
|
|||
|
feaureslist=feaureslist[1:]
|
|||
|
out = self.decoder(feaureslist)
|
|||
|
out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
|
|||
|
return out
|
|||
|
|
|||
|
|
|||
|
|
|||
|
class resUnetaspp(nn.Module):
|
|||
|
def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
|
|||
|
super(resUnetaspp, self).__init__()
|
|||
|
|
|||
|
self.nclas = num_classes
|
|||
|
self.backbone = resnet50(pretrained=pretrained)
|
|||
|
##aspp出来是256
|
|||
|
'''
|
|||
|
256 512 1024 2048_256
|
|||
|
用一个1*1变成128 256 512
|
|||
|
|
|||
|
in_filters=[ 448, 640, 768],out_filters=[128, 320, 384]
|
|||
|
'''
|
|||
|
self.decoder = unetasppDecoder( in_filters=[ 448, 640, 768],out_filters=[128, 320, 384] )
|
|||
|
|
|||
|
def forward(self, inputs):
|
|||
|
|
|||
|
feaureslist = self.backbone(inputs) #2 4 8 16 32
|
|||
|
feaureslist=feaureslist[1:]
|
|||
|
out = self.decoder(feaureslist)
|
|||
|
out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
|
|||
|
return out
|
|||
|
|
|||
|
# class resUnetpamcam(nn.Module):
|
|||
|
# def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
|
|||
|
# super(resUnetpamcam, self).__init__()
|
|||
|
#
|
|||
|
# self.nclas = num_classes
|
|||
|
# self.finnal_channel = 512
|
|||
|
# self.backbone = resnet50(pretrained=pretrained)
|
|||
|
# self.decoder = unetpamcamDecoder()
|
|||
|
#
|
|||
|
# def forward(self, inputs):
|
|||
|
#
|
|||
|
# feaureslist = self.backbone(inputs) #2 4 8 16 32
|
|||
|
# feaureslist=feaureslist[1:]
|
|||
|
# out = self.decoder(feaureslist)
|
|||
|
# out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
|
|||
|
# return out
|
|||
|
#
|
|||
|
#
|
|||
|
#
|
|||
|
# # 1/4 1/8 1/16 1/32 记作 1 2 3 4
|
|||
|
# class resUnetpamcarb_4(nn.Module):
|
|||
|
# def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
|
|||
|
# super(resUnetpamcarb_4, self).__init__()
|
|||
|
#
|
|||
|
# self.nclas = num_classes
|
|||
|
# self.finnal_channel = 512
|
|||
|
# self.backbone = resnet50(pretrained=pretrained)
|
|||
|
# self.decoder = unetpamDecoderzuhe()
|
|||
|
#
|
|||
|
# def forward(self, inputs):
|
|||
|
#
|
|||
|
# feaureslist = self.backbone(inputs) #2 4 8 16 32
|
|||
|
# feaureslist=feaureslist[1:]
|
|||
|
# out = self.decoder(feaureslist)
|
|||
|
# out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
|
|||
|
# return out
|
|||
|
#
|
|||
|
|
|||
|
|
|||
|
#################resunet大调参版本2, 再不行我就不活了#####################################
|
|||
|
|
|||
|
|
|||
|
class resUnet_version2(nn.Module):
|
|||
|
def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'):
|
|||
|
super(resUnet_version2, self).__init__()
|
|||
|
|
|||
|
self.nclas = num_classes
|
|||
|
self.finnal_channel = 512
|
|||
|
self.backbone = resnet50(pretrained=pretrained)
|
|||
|
|
|||
|
self.decoder = unetDecoder2(in_filters = [384, 768, 1536], out_filters = [128, 256, 512])
|
|||
|
|
|||
|
def forward(self, inputs):
|
|||
|
feaureslist = self.backbone(inputs) #2 4 8 16 32
|
|||
|
feaureslist=feaureslist[1:]
|
|||
|
out = self.decoder(feaureslist)
|
|||
|
out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True)
|
|||
|
return out
|
|||
|
|
|||
|
#
|
|||
|
|
|||
|
|
|||
|
#############################################
|
|||
|
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|||
|
indd=torch.rand(2,3,512,512).to(device)
|
|||
|
modl=resUnet_version2().to(device)
|
|||
|
out=modl(indd)
|
|||
|
print(type(out))
|
|||
|
print(out.shape)
|