-
-
Save bruce-willis/977d965d6214f9e54546a5aeccf0d057 to your computer and use it in GitHub Desktop.
Backbone for the `dinov2` model for use with `mmseg` library
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Edit dinov2/eval/segmentation/models/backbones/vision_transformer.py file | |
""" | |
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Apache License, Version 2.0 | |
# found in the LICENSE file in the root directory of this source tree. | |
from mmcv.runner import BaseModule | |
from mmseg.models.builder import BACKBONES | |
from dinov2.models.vision_transformer import \ | |
DinoVisionTransformer as _DinoVisionTransformer | |
from dinov2.utils.utils import load_pretrained_weights | |
@BACKBONES.register_module() | |
class DinoVisionTransformer(_DinoVisionTransformer, BaseModule): | |
"""Vision Transformer.""" | |
def __init__( | |
self, | |
*args, | |
out_indices, | |
pretrained=None, | |
**kwargs, | |
): | |
self.pretrained = pretrained | |
self.out_indices = out_indices | |
# put the following in the config for the small model | |
# img_size=518, patch_size=14, block_chunks=0, init_values=1, embed_dim=384, num_heads=6) | |
super().__init__(**kwargs) | |
self.requires_grad_(False) | |
def init_weights(self): | |
if self.pretrained is None: | |
super().init_weights() | |
return | |
load_pretrained_weights(self, self.pretrained, None) | |
print("Loaded pretrained weights from", self.pretrained) | |
def forward(self, x): | |
return self.get_intermediate_layers(x, n=self.out_indices, reshape=True, norm=False) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment