ai-station-code/wudingpv/taihuyuan_pv/mitunet/model/segformer.py

35 lines
1.0 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : segfomer
@Author : qiqq
@create_time : 2023/5/22 9:17
"""
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from compared_experiment.mySegformer.model.mix_transformer import mit_b1
from compared_experiment.mySegformer.model.segformer_head import SegFormerHead
class MSegformer(nn.Module):
def __init__(self,pretrained=""):
super(MSegformer, self).__init__()
self.bacbone=mit_b1()
self.bacbone.init_weights(pretrained)
self.decoder=SegFormerHead()
def forward(self,x):
feaures=self.bacbone(x)
out =self.decoder(feaures)
out= F.interpolate(out, size=x.size()[2:], mode='bilinear', align_corners=True)
return out
# inpii=torch.rand(2,3,512,512)
# pdp="/home/qiqq/q3dl/code/pretrain_weight/pretrained/segformer_mit/mit_b1.pth"
# model=MSegformer(pretrained=pdp)
# out=model(inpii)
# print(type(out))
# print(out.shape)