Created
April 2, 2022 09:44
-
-
Save sayakpaul/8145133dead2d7e7f5f34e901eefc5d1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Reference: | |
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/deit.py | |
""" | |
import tensorflow as tf | |
from tensorflow.keras import layers | |
from .vit_models import ViTClassifier | |
class ViTDistilled(ViTClassifier): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.num_tokens = 2 | |
# CLS and distillation tokens, positional embedding. | |
init_value = tf.zeros((1, 1, self.config.projection_dim)) | |
self.dist_token = tf.Variable(init_value, name="dist_token") | |
self.positional_embedding = tf.Variable( | |
tf.zeros( | |
( | |
1, | |
self.config.num_patches + self.num_tokens, | |
self.config.projection_dim, | |
) | |
), | |
name="pos_embedding", | |
) | |
# Head layers. | |
if not self.config.pre_logits: | |
self.head = ( | |
layers.Dense( | |
self.config.num_classes, name="classification_head" | |
) | |
if self.config.num_classes > 0 | |
else tf.nn.identity | |
) | |
self.head_dist = ( | |
layers.Dense(self.config.num_classes, name="distillation_head") | |
if self.config.num_classes > 0 | |
else tf.nn.identity | |
) | |
def call(self, inputs, training=True): | |
n = tf.shape(inputs)[0] | |
# Create patches and project the patches. | |
projected_patches = self.projection(inputs) | |
# Append the tokens. | |
cls_token = tf.tile(self.cls_token, (n, 1, 1)) | |
dist_token = tf.tile(self.dist_token, (n, 1, 1)) | |
if cls_token.dtype != projected_patches.dtype != dist_token.dtype: | |
cls_token = tf.cast(cls_token, projected_patches.dtype) | |
dist_token = tf.cast(dist_token, projected_patches.dtype) | |
projected_patches = tf.concat( | |
[cls_token, dist_token, projected_patches], axis=1 | |
) | |
# Add positional embeddings to the projected patches. | |
encoded_patches = ( | |
self.positional_embedding + projected_patches | |
) # (B, number_patches, projection_dim) | |
encoded_patches = self.dropout(encoded_patches) | |
# Initialize a dictionary to store attention scores from each transformer | |
# block. | |
attention_scores = dict() | |
# Iterate over the number of layers and stack up blocks of | |
# Transformer. | |
for transformer_module in self.transformer_blocks: | |
# Add a Transformer block. | |
encoded_patches, attention_score = transformer_module( | |
encoded_patches | |
) | |
attention_scores[f"{transformer_module.name}_att"] = attention_score | |
# Final layer normalization. | |
representation = self.layer_norm(encoded_patches) | |
# Pool representation. | |
if self.config.pre_logits: | |
return ( | |
representation[:, 0] + representation[:, 1] | |
) / 2, attention_scores | |
# Classification heads. | |
else: | |
x, x_dist = self.head(representation[:, 0]), self.head_dist( | |
representation[:, 1] | |
) | |
if "distilled" in self.config.name and training: | |
# Only return separate classification predictions when training in distilled mode. | |
return x, x_dist, attention_scores | |
else: | |
# During standard train / finetune, inference average the classifier predictions. | |
# Additionally, return the attention scores too. | |
return (x + x_dist) / 2, attention_scores |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copied and modified from: | |
# https://github.com/huggingface/transformers/blob/master/src/transformers/models/vit/modeling_tf_vit.py | |
import math | |
from typing import Tuple | |
import tensorflow as tf | |
from ml_collections import ConfigDict | |
from tensorflow import keras | |
class TFViTSelfAttention(keras.layers.Layer): | |
def __init__(self, config: ConfigDict, **kwargs): | |
super().__init__(**kwargs) | |
if config.projection_dim % config.num_heads != 0: | |
raise ValueError( | |
f"The hidden size ({config.projection_dim}) is not a multiple of the number " | |
f"of attention heads ({config.num_heads})" | |
) | |
self.num_attention_heads = config.num_heads | |
self.attention_head_size = int(config.projection_dim / config.num_heads) | |
self.all_head_size = self.num_attention_heads * self.attention_head_size | |
self.sqrt_att_head_size = math.sqrt(self.attention_head_size) | |
self.query = keras.layers.Dense( | |
units=self.all_head_size, | |
kernel_initializer=keras.initializers.TruncatedNormal( | |
stddev=config.initializer_range | |
), | |
name="query", | |
) | |
self.key = keras.layers.Dense( | |
units=self.all_head_size, | |
kernel_initializer=keras.initializers.TruncatedNormal( | |
stddev=config.initializer_range | |
), | |
name="key", | |
) | |
self.value = keras.layers.Dense( | |
units=self.all_head_size, | |
kernel_initializer=keras.initializers.TruncatedNormal( | |
stddev=config.initializer_range | |
), | |
name="value", | |
) | |
self.dropout = keras.layers.Dropout(rate=config.dropout_rate) | |
def transpose_for_scores( | |
self, tensor: tf.Tensor, batch_size: int | |
) -> tf.Tensor: | |
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] | |
tensor = tf.reshape( | |
tensor=tensor, | |
shape=( | |
batch_size, | |
-1, | |
self.num_attention_heads, | |
self.attention_head_size, | |
), | |
) | |
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] | |
return tf.transpose(tensor, perm=[0, 2, 1, 3]) | |
def call( | |
self, | |
hidden_states: tf.Tensor, | |
head_mask: tf.Tensor = None, | |
output_attentions: bool = False, | |
training: bool = False, | |
) -> Tuple[tf.Tensor]: | |
batch_size = tf.shape(hidden_states)[0] | |
mixed_query_layer = self.query(inputs=hidden_states) | |
mixed_key_layer = self.key(inputs=hidden_states) | |
mixed_value_layer = self.value(inputs=hidden_states) | |
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) | |
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) | |
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) | |
# Take the dot product between "query" and "key" to get the raw attention scores. | |
# (batch size, num_heads, seq_len_q, seq_len_k) | |
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) | |
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) | |
attention_scores = tf.divide(attention_scores, dk) | |
# Normalize the attention scores to probabilities. | |
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1) | |
# This is actually dropping out entire tokens to attend to, which might | |
# seem a bit unusual, but is taken from the original Transformer paper. | |
attention_probs = self.dropout( | |
inputs=attention_probs, training=training | |
) | |
# Mask heads if we want to | |
if head_mask is not None: | |
attention_probs = tf.multiply(attention_probs, head_mask) | |
attention_output = tf.matmul(attention_probs, value_layer) | |
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) | |
# (batch_size, seq_len_q, all_head_size) | |
attention_output = tf.reshape( | |
tensor=attention_output, shape=(batch_size, -1, self.all_head_size) | |
) | |
outputs = ( | |
(attention_output, attention_probs) | |
if output_attentions | |
else (attention_output,) | |
) | |
return outputs | |
class TFViTSelfOutput(keras.layers.Layer): | |
""" | |
The residual connection is defined in TFViTLayer instead of here (as is the case with other models), due to the | |
layernorm applied before each block. | |
""" | |
def __init__(self, config: ConfigDict, **kwargs): | |
super().__init__(**kwargs) | |
self.dense = keras.layers.Dense( | |
units=config.projection_dim, | |
kernel_initializer=keras.initializers.TruncatedNormal( | |
stddev=config.initializer_range | |
), | |
name="dense", | |
) | |
self.dropout = keras.layers.Dropout(rate=config.dropout_rate) | |
def call( | |
self, | |
hidden_states: tf.Tensor, | |
training: bool = False, | |
) -> tf.Tensor: | |
hidden_states = self.dense(inputs=hidden_states) | |
hidden_states = self.dropout(inputs=hidden_states, training=training) | |
return hidden_states | |
class TFViTAttention(keras.layers.Layer): | |
def __init__(self, config: ConfigDict, **kwargs): | |
super().__init__(**kwargs) | |
self.self_attention = TFViTSelfAttention(config, name="attention") | |
self.dense_output = TFViTSelfOutput(config, name="output") | |
def call( | |
self, | |
input_tensor: tf.Tensor, | |
head_mask: tf.Tensor = None, | |
output_attentions: bool = False, | |
training: bool = False, | |
) -> Tuple[tf.Tensor]: | |
self_outputs = self.self_attention( | |
hidden_states=input_tensor, | |
head_mask=head_mask, | |
output_attentions=output_attentions, | |
training=training, | |
) | |
attention_output = self.dense_output( | |
hidden_states=self_outputs[0] | |
if output_attentions | |
else self_outputs, | |
training=training, | |
) | |
if output_attentions: | |
outputs = (attention_output,) + self_outputs[ | |
1: | |
] # add attentions if we output them | |
return outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Vision Transformer model class put together by | |
Aritra (ariG23498) and Sayak (sayakpaul) | |
Reference: | |
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py | |
""" | |
from typing import List | |
import ml_collections | |
import tensorflow as tf | |
from tensorflow import keras | |
from tensorflow.keras import layers | |
from mha import TFViTAttention | |
def mlp(x: int, dropout_rate: float, hidden_units: List[int]): | |
"""FFN for a Transformer block.""" | |
# Iterate over the hidden units and | |
# add Dense => Dropout. | |
for (idx, units) in enumerate(hidden_units): | |
x = layers.Dense( | |
units, | |
activation=tf.nn.gelu if idx == 0 else None, | |
kernel_initializer="glorot_uniform", | |
bias_initializer=keras.initializers.RandomNormal(stddev=1e-6), | |
)(x) | |
x = layers.Dropout(dropout_rate)(x) | |
return x | |
# Referred from: github.com:rwightman/pytorch-image-models. | |
class StochasticDepth(layers.Layer): | |
def __init__(self, drop_prop, **kwargs): | |
super(StochasticDepth, self).__init__(**kwargs) | |
self.drop_prob = drop_prop | |
def call(self, x, training=None): | |
if training: | |
keep_prob = 1 - self.drop_prob | |
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) | |
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) | |
random_tensor = tf.floor(random_tensor) | |
return (x / keep_prob) * random_tensor | |
return x | |
# Referred from: github.com:rwightman/pytorch-image-models. | |
class LayerScale(layers.Layer): | |
def __init__(self, config: ml_collections.ConfigDict, **kwargs): | |
super().__init__(**kwargs) | |
self.gamma = tf.Variable( | |
config.init_values * tf.ones((config.projection_dim,)), | |
name="layer_scale", | |
) | |
def call(self, x): | |
return x * self.gamma | |
def transformer( | |
config: ml_collections.ConfigDict, name: str, drop_prob=0.0 | |
) -> keras.Model: | |
"""Transformer block with pre-norm.""" | |
num_patches = ( | |
config.num_patches + 2 | |
if "distilled" in config.name | |
else config.num_patches + 1 | |
) | |
if "distilled" in config.name: | |
num_patches = config.num_patches + 2 | |
elif "distilled" not in config.name and config.classifier == "token": | |
num_patches = config.num_patches + 1 | |
elif ( | |
config.classifer == "gap" | |
): # This setting should not be used during weight porting. | |
assert ( | |
"distilled" not in config.name | |
), "Distillation token is not suitable for GAP." | |
num_patches = config.num_patches + 0 | |
encoded_patches = layers.Input((num_patches, config.projection_dim)) | |
# Layer normalization 1. | |
x1 = layers.LayerNormalization(epsilon=config.layer_norm_eps)( | |
encoded_patches | |
) | |
# Multi Head Self Attention layer 1. | |
attention_output, attention_score = TFViTAttention(config)( | |
x1, output_attentions=True | |
) | |
attention_output = ( | |
LayerScale(config)(attention_output) | |
if config.init_values | |
else attention_output | |
) | |
attention_output = ( | |
StochasticDepth(drop_prob)(attention_output) | |
if drop_prob | |
else attention_output | |
) | |
# Skip connection 1. | |
x2 = layers.Add()([attention_output, encoded_patches]) | |
# Layer normalization 2. | |
x3 = layers.LayerNormalization(epsilon=config.layer_norm_eps)(x2) | |
# MLP layer 1. | |
x4 = mlp( | |
x3, hidden_units=config.mlp_units, dropout_rate=config.dropout_rate | |
) | |
x4 = LayerScale(config)(x4) if config.init_values else x4 | |
x4 = StochasticDepth(drop_prob)(attention_output) if drop_prob else x4 | |
# Skip connection 2. | |
outputs = layers.Add()([x2, x4]) | |
return keras.Model(encoded_patches, [outputs, attention_score], name=name) | |
class ViTClassifier(keras.Model): | |
"""Vision Transformer base class.""" | |
def __init__(self, config: ml_collections.ConfigDict, **kwargs): | |
super().__init__(**kwargs) | |
self.config = config | |
# Patchify + embedding. | |
self.projection = keras.Sequential( | |
[ | |
layers.Conv2D( | |
filters=config.projection_dim, | |
kernel_size=(config.patch_size, config.patch_size), | |
strides=(config.patch_size, config.patch_size), | |
padding="VALID", | |
name="conv_projection", | |
), | |
layers.Reshape( | |
target_shape=(config.num_patches, config.projection_dim), | |
name="flatten_projection", | |
), | |
], | |
name="projection", | |
) | |
# Positional embedding. | |
init_value = tf.ones( | |
( | |
1, | |
config.num_patches + 1 | |
if self.config.classifier == "token" | |
else config.num_patches, | |
config.projection_dim, | |
) | |
) | |
self.positional_embedding = tf.Variable( | |
init_value, name="pos_embedding" | |
) # This will be loaded with the pre-trained positional embeddings later. | |
# Transformer blocks. | |
dpr = [ | |
x | |
for x in tf.linspace( | |
0.0, self.config.drop_path_rate, self.config.num_layers | |
) | |
] | |
self.transformer_blocks = [ | |
transformer(config, name=f"transformer_block_{i}", drop_prob=dpr[i]) | |
for i in range(config.num_layers) | |
] | |
# CLS token or GAP. | |
if config.classifier == "token": | |
initial_value = tf.zeros((1, 1, config.projection_dim)) | |
self.cls_token = tf.Variable( | |
initial_value=initial_value, trainable=True, name="cls" | |
) | |
if config.classifier == "gap": | |
self.gap_layer = layers.GlobalAvgPool1D() | |
# Other layers. | |
self.dropout = layers.Dropout(config.dropout_rate) | |
self.layer_norm = layers.LayerNormalization( | |
epsilon=config.layer_norm_eps | |
) | |
if not self.config.pre_logits: | |
self.head = layers.Dense( | |
config.num_classes, | |
kernel_initializer="zeros", | |
dtype="float32", | |
name="classification_head", | |
) | |
def call(self, inputs, training=None): | |
n = tf.shape(inputs)[0] | |
# Create patches and project the patches. | |
projected_patches = self.projection(inputs) | |
# Append class token if needed. | |
if self.config.classifier == "token": | |
cls_token = tf.tile(self.cls_token, (n, 1, 1)) | |
if cls_token.dtype != projected_patches.dtype: | |
cls_token = tf.cast(cls_token, projected_patches.dtype) | |
projected_patches = tf.concat( | |
[cls_token, projected_patches], axis=1 | |
) | |
# Add positional embeddings to the projected patches. | |
encoded_patches = ( | |
self.positional_embedding + projected_patches | |
) # (B, number_patches, projection_dim) | |
encoded_patches = self.dropout(encoded_patches) | |
# Initialize a dictionary to store attention scores from each transformer | |
# block. | |
attention_scores = dict() | |
# Iterate over the number of layers and stack up blocks of | |
# Transformer. | |
for transformer_module in self.transformer_blocks: | |
# Add a Transformer block. | |
encoded_patches, attention_score = transformer_module( | |
encoded_patches | |
) | |
attention_scores[f"{transformer_module.name}_att"] = attention_score | |
# Final layer normalization. | |
representation = self.layer_norm(encoded_patches) | |
# Pool representation. | |
if self.config.classifier == "token": | |
encoded_patches = representation[:, 0] | |
elif self.config.classifier == "gap": | |
encoded_patches = self.gap_layer(representation) | |
if self.config.pre_logits: | |
return encoded_patches, attention_scores | |
# Classification head. | |
else: | |
output = self.head(encoded_patches) | |
return output, attention_scores |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment