35 lines
1.0 KiB
Python
35 lines
1.0 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
@project:
|
|
@File : segfomer
|
|
@Author : qiqq
|
|
@create_time : 2023/5/22 9:17
|
|
"""
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
from compared_experiment.mySegformer.model.mix_transformer import mit_b1
|
|
from compared_experiment.mySegformer.model.segformer_head import SegFormerHead
|
|
|
|
class MSegformer(nn.Module):
|
|
def __init__(self,pretrained=""):
|
|
super(MSegformer, self).__init__()
|
|
self.bacbone=mit_b1()
|
|
self.bacbone.init_weights(pretrained)
|
|
self.decoder=SegFormerHead()
|
|
|
|
|
|
def forward(self,x):
|
|
feaures=self.bacbone(x)
|
|
out =self.decoder(feaures)
|
|
out= F.interpolate(out, size=x.size()[2:], mode='bilinear', align_corners=True)
|
|
return out
|
|
|
|
# inpii=torch.rand(2,3,512,512)
|
|
# pdp="/home/qiqq/q3dl/code/pretrain_weight/pretrained/segformer_mit/mit_b1.pth"
|
|
# model=MSegformer(pretrained=pdp)
|
|
# out=model(inpii)
|
|
# print(type(out))
|
|
# print(out.shape) |