#!/usr/bin/env python # -*- coding: utf-8 -*- """ @project: @File : segfomer @Author : qiqq @create_time : 2023/5/22 9:17 """ import torch from torch import nn import torch.nn.functional as F import numpy as np from compared_experiment.mySegformer.model.mix_transformer import mit_b1 from compared_experiment.mySegformer.model.segformer_head import SegFormerHead class MSegformer(nn.Module): def __init__(self,pretrained=""): super(MSegformer, self).__init__() self.bacbone=mit_b1() self.bacbone.init_weights(pretrained) self.decoder=SegFormerHead() def forward(self,x): feaures=self.bacbone(x) out =self.decoder(feaures) out= F.interpolate(out, size=x.size()[2:], mode='bilinear', align_corners=True) return out # inpii=torch.rand(2,3,512,512) # pdp="/home/qiqq/q3dl/code/pretrain_weight/pretrained/segformer_mit/mit_b1.pth" # model=MSegformer(pretrained=pdp) # out=model(inpii) # print(type(out)) # print(out.shape)