Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Created April 2, 2022 09:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sayakpaul/8145133dead2d7e7f5f34e901eefc5d1 to your computer and use it in GitHub Desktop.
Save sayakpaul/8145133dead2d7e7f5f34e901eefc5d1 to your computer and use it in GitHub Desktop.
"""
Reference:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/deit.py
"""
import tensorflow as tf
from tensorflow.keras import layers
from .vit_models import ViTClassifier
class ViTDistilled(ViTClassifier):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_tokens = 2
# CLS and distillation tokens, positional embedding.
init_value = tf.zeros((1, 1, self.config.projection_dim))
self.dist_token = tf.Variable(init_value, name="dist_token")
self.positional_embedding = tf.Variable(
tf.zeros(
(
1,
self.config.num_patches + self.num_tokens,
self.config.projection_dim,
)
),
name="pos_embedding",
)
# Head layers.
if not self.config.pre_logits:
self.head = (
layers.Dense(
self.config.num_classes, name="classification_head"
)
if self.config.num_classes > 0
else tf.nn.identity
)
self.head_dist = (
layers.Dense(self.config.num_classes, name="distillation_head")
if self.config.num_classes > 0
else tf.nn.identity
)
def call(self, inputs, training=True):
n = tf.shape(inputs)[0]
# Create patches and project the patches.
projected_patches = self.projection(inputs)
# Append the tokens.
cls_token = tf.tile(self.cls_token, (n, 1, 1))
dist_token = tf.tile(self.dist_token, (n, 1, 1))
if cls_token.dtype != projected_patches.dtype != dist_token.dtype:
cls_token = tf.cast(cls_token, projected_patches.dtype)
dist_token = tf.cast(dist_token, projected_patches.dtype)
projected_patches = tf.concat(
[cls_token, dist_token, projected_patches], axis=1
)
# Add positional embeddings to the projected patches.
encoded_patches = (
self.positional_embedding + projected_patches
) # (B, number_patches, projection_dim)
encoded_patches = self.dropout(encoded_patches)
# Initialize a dictionary to store attention scores from each transformer
# block.
attention_scores = dict()
# Iterate over the number of layers and stack up blocks of
# Transformer.
for transformer_module in self.transformer_blocks:
# Add a Transformer block.
encoded_patches, attention_score = transformer_module(
encoded_patches
)
attention_scores[f"{transformer_module.name}_att"] = attention_score
# Final layer normalization.
representation = self.layer_norm(encoded_patches)
# Pool representation.
if self.config.pre_logits:
return (
representation[:, 0] + representation[:, 1]
) / 2, attention_scores
# Classification heads.
else:
x, x_dist = self.head(representation[:, 0]), self.head_dist(
representation[:, 1]
)
if "distilled" in self.config.name and training:
# Only return separate classification predictions when training in distilled mode.
return x, x_dist, attention_scores
else:
# During standard train / finetune, inference average the classifier predictions.
# Additionally, return the attention scores too.
return (x + x_dist) / 2, attention_scores
# Copied and modified from:
# https://github.com/huggingface/transformers/blob/master/src/transformers/models/vit/modeling_tf_vit.py
import math
from typing import Tuple
import tensorflow as tf
from ml_collections import ConfigDict
from tensorflow import keras
class TFViTSelfAttention(keras.layers.Layer):
def __init__(self, config: ConfigDict, **kwargs):
super().__init__(**kwargs)
if config.projection_dim % config.num_heads != 0:
raise ValueError(
f"The hidden size ({config.projection_dim}) is not a multiple of the number "
f"of attention heads ({config.num_heads})"
)
self.num_attention_heads = config.num_heads
self.attention_head_size = int(config.projection_dim / config.num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
self.query = keras.layers.Dense(
units=self.all_head_size,
kernel_initializer=keras.initializers.TruncatedNormal(
stddev=config.initializer_range
),
name="query",
)
self.key = keras.layers.Dense(
units=self.all_head_size,
kernel_initializer=keras.initializers.TruncatedNormal(
stddev=config.initializer_range
),
name="key",
)
self.value = keras.layers.Dense(
units=self.all_head_size,
kernel_initializer=keras.initializers.TruncatedNormal(
stddev=config.initializer_range
),
name="value",
)
self.dropout = keras.layers.Dropout(rate=config.dropout_rate)
def transpose_for_scores(
self, tensor: tf.Tensor, batch_size: int
) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
tensor = tf.reshape(
tensor=tensor,
shape=(
batch_size,
-1,
self.num_attention_heads,
self.attention_head_size,
),
)
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
return tf.transpose(tensor, perm=[0, 2, 1, 3])
def call(
self,
hidden_states: tf.Tensor,
head_mask: tf.Tensor = None,
output_attentions: bool = False,
training: bool = False,
) -> Tuple[tf.Tensor]:
batch_size = tf.shape(hidden_states)[0]
mixed_query_layer = self.query(inputs=hidden_states)
mixed_key_layer = self.key(inputs=hidden_states)
mixed_value_layer = self.value(inputs=hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
# Take the dot product between "query" and "key" to get the raw attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
attention_scores = tf.divide(attention_scores, dk)
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(
inputs=attention_probs, training=training
)
# Mask heads if we want to
if head_mask is not None:
attention_probs = tf.multiply(attention_probs, head_mask)
attention_output = tf.matmul(attention_probs, value_layer)
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
# (batch_size, seq_len_q, all_head_size)
attention_output = tf.reshape(
tensor=attention_output, shape=(batch_size, -1, self.all_head_size)
)
outputs = (
(attention_output, attention_probs)
if output_attentions
else (attention_output,)
)
return outputs
class TFViTSelfOutput(keras.layers.Layer):
"""
The residual connection is defined in TFViTLayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: ConfigDict, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.projection_dim,
kernel_initializer=keras.initializers.TruncatedNormal(
stddev=config.initializer_range
),
name="dense",
)
self.dropout = keras.layers.Dropout(rate=config.dropout_rate)
def call(
self,
hidden_states: tf.Tensor,
training: bool = False,
) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
return hidden_states
class TFViTAttention(keras.layers.Layer):
def __init__(self, config: ConfigDict, **kwargs):
super().__init__(**kwargs)
self.self_attention = TFViTSelfAttention(config, name="attention")
self.dense_output = TFViTSelfOutput(config, name="output")
def call(
self,
input_tensor: tf.Tensor,
head_mask: tf.Tensor = None,
output_attentions: bool = False,
training: bool = False,
) -> Tuple[tf.Tensor]:
self_outputs = self.self_attention(
hidden_states=input_tensor,
head_mask=head_mask,
output_attentions=output_attentions,
training=training,
)
attention_output = self.dense_output(
hidden_states=self_outputs[0]
if output_attentions
else self_outputs,
training=training,
)
if output_attentions:
outputs = (attention_output,) + self_outputs[
1:
] # add attentions if we output them
return outputs
"""
Vision Transformer model class put together by
Aritra (ariG23498) and Sayak (sayakpaul)
Reference:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from typing import List
import ml_collections
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from mha import TFViTAttention
def mlp(x: int, dropout_rate: float, hidden_units: List[int]):
"""FFN for a Transformer block."""
# Iterate over the hidden units and
# add Dense => Dropout.
for (idx, units) in enumerate(hidden_units):
x = layers.Dense(
units,
activation=tf.nn.gelu if idx == 0 else None,
kernel_initializer="glorot_uniform",
bias_initializer=keras.initializers.RandomNormal(stddev=1e-6),
)(x)
x = layers.Dropout(dropout_rate)(x)
return x
# Referred from: github.com:rwightman/pytorch-image-models.
class StochasticDepth(layers.Layer):
def __init__(self, drop_prop, **kwargs):
super(StochasticDepth, self).__init__(**kwargs)
self.drop_prob = drop_prop
def call(self, x, training=None):
if training:
keep_prob = 1 - self.drop_prob
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
random_tensor = tf.floor(random_tensor)
return (x / keep_prob) * random_tensor
return x
# Referred from: github.com:rwightman/pytorch-image-models.
class LayerScale(layers.Layer):
def __init__(self, config: ml_collections.ConfigDict, **kwargs):
super().__init__(**kwargs)
self.gamma = tf.Variable(
config.init_values * tf.ones((config.projection_dim,)),
name="layer_scale",
)
def call(self, x):
return x * self.gamma
def transformer(
config: ml_collections.ConfigDict, name: str, drop_prob=0.0
) -> keras.Model:
"""Transformer block with pre-norm."""
num_patches = (
config.num_patches + 2
if "distilled" in config.name
else config.num_patches + 1
)
if "distilled" in config.name:
num_patches = config.num_patches + 2
elif "distilled" not in config.name and config.classifier == "token":
num_patches = config.num_patches + 1
elif (
config.classifer == "gap"
): # This setting should not be used during weight porting.
assert (
"distilled" not in config.name
), "Distillation token is not suitable for GAP."
num_patches = config.num_patches + 0
encoded_patches = layers.Input((num_patches, config.projection_dim))
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=config.layer_norm_eps)(
encoded_patches
)
# Multi Head Self Attention layer 1.
attention_output, attention_score = TFViTAttention(config)(
x1, output_attentions=True
)
attention_output = (
LayerScale(config)(attention_output)
if config.init_values
else attention_output
)
attention_output = (
StochasticDepth(drop_prob)(attention_output)
if drop_prob
else attention_output
)
# Skip connection 1.
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=config.layer_norm_eps)(x2)
# MLP layer 1.
x4 = mlp(
x3, hidden_units=config.mlp_units, dropout_rate=config.dropout_rate
)
x4 = LayerScale(config)(x4) if config.init_values else x4
x4 = StochasticDepth(drop_prob)(attention_output) if drop_prob else x4
# Skip connection 2.
outputs = layers.Add()([x2, x4])
return keras.Model(encoded_patches, [outputs, attention_score], name=name)
class ViTClassifier(keras.Model):
"""Vision Transformer base class."""
def __init__(self, config: ml_collections.ConfigDict, **kwargs):
super().__init__(**kwargs)
self.config = config
# Patchify + embedding.
self.projection = keras.Sequential(
[
layers.Conv2D(
filters=config.projection_dim,
kernel_size=(config.patch_size, config.patch_size),
strides=(config.patch_size, config.patch_size),
padding="VALID",
name="conv_projection",
),
layers.Reshape(
target_shape=(config.num_patches, config.projection_dim),
name="flatten_projection",
),
],
name="projection",
)
# Positional embedding.
init_value = tf.ones(
(
1,
config.num_patches + 1
if self.config.classifier == "token"
else config.num_patches,
config.projection_dim,
)
)
self.positional_embedding = tf.Variable(
init_value, name="pos_embedding"
) # This will be loaded with the pre-trained positional embeddings later.
# Transformer blocks.
dpr = [
x
for x in tf.linspace(
0.0, self.config.drop_path_rate, self.config.num_layers
)
]
self.transformer_blocks = [
transformer(config, name=f"transformer_block_{i}", drop_prob=dpr[i])
for i in range(config.num_layers)
]
# CLS token or GAP.
if config.classifier == "token":
initial_value = tf.zeros((1, 1, config.projection_dim))
self.cls_token = tf.Variable(
initial_value=initial_value, trainable=True, name="cls"
)
if config.classifier == "gap":
self.gap_layer = layers.GlobalAvgPool1D()
# Other layers.
self.dropout = layers.Dropout(config.dropout_rate)
self.layer_norm = layers.LayerNormalization(
epsilon=config.layer_norm_eps
)
if not self.config.pre_logits:
self.head = layers.Dense(
config.num_classes,
kernel_initializer="zeros",
dtype="float32",
name="classification_head",
)
def call(self, inputs, training=None):
n = tf.shape(inputs)[0]
# Create patches and project the patches.
projected_patches = self.projection(inputs)
# Append class token if needed.
if self.config.classifier == "token":
cls_token = tf.tile(self.cls_token, (n, 1, 1))
if cls_token.dtype != projected_patches.dtype:
cls_token = tf.cast(cls_token, projected_patches.dtype)
projected_patches = tf.concat(
[cls_token, projected_patches], axis=1
)
# Add positional embeddings to the projected patches.
encoded_patches = (
self.positional_embedding + projected_patches
) # (B, number_patches, projection_dim)
encoded_patches = self.dropout(encoded_patches)
# Initialize a dictionary to store attention scores from each transformer
# block.
attention_scores = dict()
# Iterate over the number of layers and stack up blocks of
# Transformer.
for transformer_module in self.transformer_blocks:
# Add a Transformer block.
encoded_patches, attention_score = transformer_module(
encoded_patches
)
attention_scores[f"{transformer_module.name}_att"] = attention_score
# Final layer normalization.
representation = self.layer_norm(encoded_patches)
# Pool representation.
if self.config.classifier == "token":
encoded_patches = representation[:, 0]
elif self.config.classifier == "gap":
encoded_patches = self.gap_layer(representation)
if self.config.pre_logits:
return encoded_patches, attention_scores
# Classification head.
else:
output = self.head(encoded_patches)
return output, attention_scores
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment