#!/usr/bin/env python # -*- coding: utf-8 -*- """ @project: @File : setr @Author : qiqq @create_time : 2023/5/31 20:33 """ import torch from torch import nn import torch.nn.functional as F import numpy as np from taihuyuan_pv.compared_experiment.mySETR.model.vit import VisionTransformer from taihuyuan_pv.compared_experiment.mySETR.model.vit_up_head import VisionTransformerUpHead norm_cfg = dict(type='BN', requires_grad=True) #setr naive 768*76880k city的配置 backbone = dict( type='VisionTransformer', model_name='vit_large_patch16_384', img_size=768, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16, num_classes=19, drop_rate=0.1, norm_cfg=norm_cfg, pos_embed_interp=True, align_corners=False, ) decode_head=dict( type='VisionTransformerUpHead', in_channels=1024, channels=512, in_index=23, img_size=768, embed_dim=1024, num_classes=19, norm_cfg=norm_cfg, num_conv=2, upsampling_method='bilinear', align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) class MSETR_naive(nn.Module): def __init__(self,pretrained="/home/qiqq/q3dl/code/pretrain_weight/pretrained/vit_checkpoint/setrvit/vit_base_p16_224-4e355ebd.pth"): super(MSETR_naive, self).__init__() self.backbone=VisionTransformer( model_name='vit_base_patch16_224', img_size=512, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, num_classes=2, drop_rate=0.1, qkv_bias=False, norm_cfg=norm_cfg, pos_embed_interp=True, align_corners=False) self.backbone.init_weights(pretrained=pretrained) print("setr权重{}加载".format(pretrained)) self.decoder=VisionTransformerUpHead( in_channels=768, channels=512, in_index=11, img_size=512, embed_dim=768, num_classes=2, norm_cfg=norm_cfg, num_conv=2, upsampling_method='bilinear', align_corners=False) def forward(self,x): out=self.backbone(x) out =self.decoder(out) return out # inpudd=torch.rand(2,3,512,512) # model=MSETR_naive() # our=model(inpudd) # print(our.shape)