94 lines
3.1 KiB
Python
94 lines
3.1 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
@project:
|
|
@File : setr
|
|
@Author : qiqq
|
|
@create_time : 2023/5/31 20:33
|
|
"""
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
from taihuyuan_roof.compared_experiment.mySETR.model.vit import VisionTransformer
|
|
from taihuyuan_roof.compared_experiment.mySETR.model.vit_up_head import VisionTransformerUpHead
|
|
norm_cfg = dict(type='BN', requires_grad=True)
|
|
#setr naive 768*76880k city的配置
|
|
backbone = dict(
|
|
type='VisionTransformer',
|
|
model_name='vit_large_patch16_384',
|
|
img_size=768,
|
|
patch_size=16,
|
|
in_chans=3,
|
|
embed_dim=1024,
|
|
depth=24,
|
|
num_heads=16,
|
|
num_classes=19,
|
|
drop_rate=0.1,
|
|
norm_cfg=norm_cfg,
|
|
pos_embed_interp=True,
|
|
align_corners=False,
|
|
)
|
|
|
|
|
|
decode_head=dict(
|
|
type='VisionTransformerUpHead',
|
|
in_channels=1024,
|
|
channels=512,
|
|
in_index=23,
|
|
img_size=768,
|
|
embed_dim=1024,
|
|
num_classes=19,
|
|
norm_cfg=norm_cfg,
|
|
num_conv=2,
|
|
upsampling_method='bilinear',
|
|
align_corners=False,
|
|
loss_decode=dict(
|
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
|
|
|
|
class MSETR_naive(nn.Module):
|
|
def __init__(self,pretrained="/home/qiqq/q3dl/code/pretrain_weight/pretrained/vit_checkpoint/setrvit/vit_base_p16_224-4e355ebd.pth"):
|
|
super(MSETR_naive, self).__init__()
|
|
|
|
self.backbone=VisionTransformer( model_name='vit_base_patch16_224',
|
|
img_size=512,
|
|
patch_size=16,
|
|
in_chans=3,
|
|
embed_dim=768,
|
|
depth=12,
|
|
num_heads=12,
|
|
num_classes=2,
|
|
drop_rate=0.1,
|
|
qkv_bias=False,
|
|
norm_cfg=norm_cfg,
|
|
pos_embed_interp=True,
|
|
align_corners=False)
|
|
self.backbone.init_weights(pretrained=pretrained)
|
|
print("setr权重{}加载".format(pretrained))
|
|
self.decoder=VisionTransformerUpHead( in_channels=768,
|
|
channels=512,
|
|
in_index=11,
|
|
img_size=512,
|
|
embed_dim=768,
|
|
num_classes=2,
|
|
norm_cfg=norm_cfg,
|
|
num_conv=2,
|
|
upsampling_method='bilinear',
|
|
align_corners=False)
|
|
|
|
|
|
|
|
def forward(self,x):
|
|
out=self.backbone(x)
|
|
out =self.decoder(out)
|
|
return out
|
|
|
|
|
|
|
|
|
|
# inpudd=torch.rand(2,3,512,512)
|
|
# model=MSETR_naive()
|
|
# our=model(inpudd)
|
|
# print(our.shape)
|