#!/usr/bin/env python # -*- coding: utf-8 -*- """ @project: @File : manet @Author : qiqq @create_time : 2023/9/17 16:46 """ import torch from torch import nn import torch.nn.functional as F import numpy as np from taihuyuan_roof.manet.model.mit_backbone import mit_b1 from taihuyuan_roof.manet.model.decoder import * class MitUnet(nn.Module): '''可以滚了...pv iou最高不到88.。怎么也得到90才可以啊''' def __init__(self,pretrained="/home/qiqq/q3dl/code/pretrain_weight/pretrained/segformer_mit/mit_b1.pth"): super(MitUnet, self).__init__() self.bacbone=mit_b1() self.bacbone.init_weights(pretrained) self.decoder=unetDecoder() def forward(self,x): feaureslist=self.bacbone(x) out =self.decoder(feaureslist) out= F.interpolate(out, size=x.size()[2:], mode='bilinear', align_corners=True) return out class MitCAB0Unet(nn.Module): def __init__(self,pretrained="/home/qiqq/q3dl/code/pretrain_weight/pretrained/segformer_mit/mit_b1.pth"): super(MitCAB0Unet, self).__init__() self.bacbone=mit_b1() self.bacbone.init_weights(pretrained) self.decoder=unetCAB0Decoder() def forward(self,x): feaureslist=self.bacbone(x) out =self.decoder(feaureslist) out= F.interpolate(out, size=x.size()[2:], mode='bilinear', align_corners=True) return out if __name__ == '__main__': inpp=torch.randn(2,3,512,512) model=MitCAB0Unet() out=model(inpp)