ai-station-code/wudingpv/taihuyuan_pv/mitunet/model/mitunet.py

53 lines
1.5 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : mitunet
@Author : qiqq
@create_time : 2023/9/17 16:46
"""
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from taihuyuan_pv.mitunet.model.mit_backbone import mit_b1
from taihuyuan_pv.mitunet.model.decoder import *
class MitUnet(nn.Module):
'''可以滚了...pv iou最高不到88.。怎么也得到90才可以啊'''
def __init__(self,pretrained="/home/qiqq/q3dl/code/pretrain_weight/pretrained/segformer_mit/mit_b1.pth"):
super(MitUnet, self).__init__()
self.bacbone=mit_b1()
self.bacbone.init_weights(pretrained)
self.decoder=unetDecoder()
def forward(self,x):
feaureslist=self.bacbone(x)
out =self.decoder(feaureslist)
out= F.interpolate(out, size=x.size()[2:], mode='bilinear', align_corners=True)
return out
class MitCAB0Unet(nn.Module):
def __init__(self,pretrained="/home/qiqq/q3dl/code/pretrain_weight/pretrained/segformer_mit/mit_b1.pth"):
super(MitCAB0Unet, self).__init__()
self.bacbone=mit_b1()
self.bacbone.init_weights(pretrained)
self.decoder=unetCAB0Decoder()
def forward(self,x):
feaureslist=self.bacbone(x)
out =self.decoder(feaureslist)
out= F.interpolate(out, size=x.size()[2:], mode='bilinear', align_corners=True)
return out
if __name__ == '__main__':
inpp=torch.randn(2,3,512,512)
model=MitCAB0Unet()
out=model(inpp)