#!/usr/bin/env python # -*- coding: utf-8 -*- """ @project: @File : resunet @Author : qiqq @create_time : 2023/7/20 18:45 """ import torch import torch.nn as nn from taihuyuan_pv.mitunet.model.resnet import resnet50 from taihuyuan_pv.mitunet.model.decoder import * class unetUp(nn.Module): def __init__(self, in_size, out_size): super(unetUp, self).__init__() self.up = nn.UpsamplingBilinear2d(scale_factor=2) self.cbr = nn.Sequential( nn.Conv2d(in_size, out_size, 3, 1, 1, bias=False), nn.BatchNorm2d(out_size), nn.ReLU(inplace=True) ) def forward(self, inputs1, inputs2): outputs = torch.cat([inputs1, self.up(inputs2)], 1) outputs = self.cbr(outputs) return outputs class resUnet(nn.Module): def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'): super(resUnet, self).__init__() self.nclas = num_classes self.finnal_channel = 512 self.backbone = resnet50(pretrained=pretrained) # #j最后一层加了个trans (2048-512) 解码器用了两个3*3深度cbr # self.decoder = unetDecoder(in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]) self.decoder = unetDecoder(in_filters=[ 512, 1024, 3072],out_filters=[128, 256, 512]) def forward(self, inputs): feaureslist = self.backbone(inputs) #2 4 8 16 32 feaureslist=feaureslist[1:] out = self.decoder(feaureslist) out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True) return out # # class resUnetPAM(nn.Module): def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'): super(resUnetPAM, self).__init__() self.nclas = num_classes self.backbone = resnet50(pretrained=pretrained) self.decoder = unetpamDecoder(in_filters=[ 512, 1024, 1536],out_filters=[128, 256, 512]) def forward(self, inputs): feaureslist = self.backbone(inputs) #2 4 8 16 32 feaureslist=feaureslist[1:] out = self.decoder(feaureslist) out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True) return out class resUnetpamcarb(nn.Module): def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'): super(resUnetpamcarb, self).__init__() self.nclas = num_classes self.finnal_channel = 512 self.backbone = resnet50(pretrained=pretrained) self.decoder = unetpamCARBDecoder() def forward(self, inputs): feaureslist = self.backbone(inputs) #2 4 8 16 32 feaureslist=feaureslist[1:] out = self.decoder(feaureslist) out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True) return out class resUnetaspp(nn.Module): def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'): super(resUnetaspp, self).__init__() self.nclas = num_classes self.backbone = resnet50(pretrained=pretrained) ##aspp出来是256 ''' 256 512 1024 2048_256 用一个1*1变成128 256 512 in_filters=[ 448, 640, 768],out_filters=[128, 320, 384] ''' self.decoder = unetasppDecoder( in_filters=[ 448, 640, 768],out_filters=[128, 320, 384] ) def forward(self, inputs): feaureslist = self.backbone(inputs) #2 4 8 16 32 feaureslist=feaureslist[1:] out = self.decoder(feaureslist) out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True) return out # class resUnetpamcam(nn.Module): # def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'): # super(resUnetpamcam, self).__init__() # # self.nclas = num_classes # self.finnal_channel = 512 # self.backbone = resnet50(pretrained=pretrained) # self.decoder = unetpamcamDecoder() # # def forward(self, inputs): # # feaureslist = self.backbone(inputs) #2 4 8 16 32 # feaureslist=feaureslist[1:] # out = self.decoder(feaureslist) # out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True) # return out # # # # # 1/4 1/8 1/16 1/32 记作 1 2 3 4 # class resUnetpamcarb_4(nn.Module): # def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'): # super(resUnetpamcarb_4, self).__init__() # # self.nclas = num_classes # self.finnal_channel = 512 # self.backbone = resnet50(pretrained=pretrained) # self.decoder = unetpamDecoderzuhe() # # def forward(self, inputs): # # feaureslist = self.backbone(inputs) #2 4 8 16 32 # feaureslist=feaureslist[1:] # out = self.decoder(feaureslist) # out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True) # return out # #################resunet大调参版本2, 再不行我就不活了##################################### class resUnet_version2(nn.Module): def __init__(self, num_classes=2, pretrained=True, backbone='resnet50'): super(resUnet_version2, self).__init__() self.nclas = num_classes self.finnal_channel = 512 self.backbone = resnet50(pretrained=pretrained) self.decoder = unetDecoder2(in_filters = [384, 768, 1536], out_filters = [128, 256, 512]) def forward(self, inputs): feaureslist = self.backbone(inputs) #2 4 8 16 32 feaureslist=feaureslist[1:] out = self.decoder(feaureslist) out = F.interpolate(out, size=inputs.size()[2:], mode='bilinear', align_corners=True) return out # ############################################# if __name__ == '__main__': device = torch.device("cuda" if torch.cuda.is_available() else "cpu") indd=torch.rand(2,3,512,512).to(device) modl=resUnet_version2().to(device) out=modl(indd) print(type(out)) print(out.shape)