53 lines
1.5 KiB
Python
53 lines
1.5 KiB
Python
|
#!/usr/bin/env python
|
||
|
# -*- coding: utf-8 -*-
|
||
|
"""
|
||
|
@project:
|
||
|
@File : manet
|
||
|
@Author : qiqq
|
||
|
@create_time : 2023/9/17 16:46
|
||
|
"""
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
import torch.nn.functional as F
|
||
|
import numpy as np
|
||
|
from taihuyuan_roof.manet.model.mit_backbone import mit_b1
|
||
|
from taihuyuan_roof.manet.model.decoder import *
|
||
|
|
||
|
|
||
|
class MitUnet(nn.Module):
|
||
|
|
||
|
'''可以滚了...pv iou最高不到88.。怎么也得到90才可以啊'''
|
||
|
def __init__(self,pretrained="/home/qiqq/q3dl/code/pretrain_weight/pretrained/segformer_mit/mit_b1.pth"):
|
||
|
super(MitUnet, self).__init__()
|
||
|
self.bacbone=mit_b1()
|
||
|
self.bacbone.init_weights(pretrained)
|
||
|
self.decoder=unetDecoder()
|
||
|
|
||
|
def forward(self,x):
|
||
|
feaureslist=self.bacbone(x)
|
||
|
out =self.decoder(feaureslist)
|
||
|
out= F.interpolate(out, size=x.size()[2:], mode='bilinear', align_corners=True)
|
||
|
return out
|
||
|
|
||
|
class MitCAB0Unet(nn.Module):
|
||
|
def __init__(self,pretrained="/home/qiqq/q3dl/code/pretrain_weight/pretrained/segformer_mit/mit_b1.pth"):
|
||
|
super(MitCAB0Unet, self).__init__()
|
||
|
self.bacbone=mit_b1()
|
||
|
self.bacbone.init_weights(pretrained)
|
||
|
self.decoder=unetCAB0Decoder()
|
||
|
|
||
|
|
||
|
def forward(self,x):
|
||
|
feaureslist=self.bacbone(x)
|
||
|
out =self.decoder(feaureslist)
|
||
|
out= F.interpolate(out, size=x.size()[2:], mode='bilinear', align_corners=True)
|
||
|
return out
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
inpp=torch.randn(2,3,512,512)
|
||
|
model=MitCAB0Unet()
|
||
|
out=model(inpp)
|
||
|
|
||
|
|