ai-station-code/wudingpv/taihuyuan_roof/compared_experiment/mySETR/model/setr.py

94 lines
3.1 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@project:
@File : setr
@Author : qiqq
@create_time : 2023/5/31 20:33
"""
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from taihuyuan_roof.compared_experiment.mySETR.model.vit import VisionTransformer
from taihuyuan_roof.compared_experiment.mySETR.model.vit_up_head import VisionTransformerUpHead
norm_cfg = dict(type='BN', requires_grad=True)
#setr naive 768*76880k city的配置
backbone = dict(
type='VisionTransformer',
model_name='vit_large_patch16_384',
img_size=768,
patch_size=16,
in_chans=3,
embed_dim=1024,
depth=24,
num_heads=16,
num_classes=19,
drop_rate=0.1,
norm_cfg=norm_cfg,
pos_embed_interp=True,
align_corners=False,
)
decode_head=dict(
type='VisionTransformerUpHead',
in_channels=1024,
channels=512,
in_index=23,
img_size=768,
embed_dim=1024,
num_classes=19,
norm_cfg=norm_cfg,
num_conv=2,
upsampling_method='bilinear',
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
class MSETR_naive(nn.Module):
def __init__(self,pretrained="/home/qiqq/q3dl/code/pretrain_weight/pretrained/vit_checkpoint/setrvit/vit_base_p16_224-4e355ebd.pth"):
super(MSETR_naive, self).__init__()
self.backbone=VisionTransformer( model_name='vit_base_patch16_224',
img_size=512,
patch_size=16,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
num_classes=2,
drop_rate=0.1,
qkv_bias=False,
norm_cfg=norm_cfg,
pos_embed_interp=True,
align_corners=False)
self.backbone.init_weights(pretrained=pretrained)
print("setr权重{}加载".format(pretrained))
self.decoder=VisionTransformerUpHead( in_channels=768,
channels=512,
in_index=11,
img_size=512,
embed_dim=768,
num_classes=2,
norm_cfg=norm_cfg,
num_conv=2,
upsampling_method='bilinear',
align_corners=False)
def forward(self,x):
out=self.backbone(x)
out =self.decoder(out)
return out
# inpudd=torch.rand(2,3,512,512)
# model=MSETR_naive()
# our=model(inpudd)
# print(our.shape)