Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
import functools
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow.python.tpu import tpu_function
BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5
def Activation(inputs, activation='relu'):
"""Only supports ReLU and SiLU/Swish."""
assert activation in ['relu', 'silu']
if activation == 'relu':
return tf.nn.relu(inputs)
else:
return tf.nn.swish(inputs)
def BNReLU(
inputs, is_training, nonlinearity=True,
init_zero=False, activation='relu'):
"""Performs a batch normalization followed by a ReLU."""
if init_zero:
gamma_initializer = tf.zeros_initializer()
else:
gamma_initializer = tf.ones_initializer()
inputs = tf.layers.batch_normalization(
inputs=inputs,
axis=3,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
center=True,
scale=True,
training=is_training,
fused=True,
gamma_initializer=gamma_initializer)
if nonlinearity:
inputs = Activation(inputs, activation=activation)
return inputs
def fixed_padding(inputs, kernel_size):
"""Pads the input along the spatial dimensions independently of input size."""
pad_total = kernel_size - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
padded_inputs = tf.pad(
inputs, [[0, 0], [pad_beg, pad_end], pad_beg, pad_end], [0, 0]])
return padded_inputs
def Conv2D(inputs, *, filters, kernel_size, strides=1):
"""Strided 2-D convolution with explicit padding."""
if strides > 1:
inputs = fixed_padding(inputs, kernel_size)
return tf.layers.conv2d(
inputs=inputs, filters=filters, kernel_size=kernel_size, strides=strides,
padding=('SAME' if strides == 1 else 'VALID'), use_bias=False,
kernel_initializer=tf.variance_scaling_initializer(
scale=2., mode='fan_in', distribution='untruncated_normal'))
# Functions `rel_to_abs`, `relative_logits_1d`, `relative_logits`
# and `relpos_self_attention` are fully based on
# https://github.com/tensorflow/tensor2tensor/blob/21dba2c1bdcc7ab582a2bfd8c0885c217963bb4f/tensor2tensor/layers/common_attention.py#L2225.
def rel_to_abs(x):
"""
Converts relative indexing to absolute.
Input: [bs, heads, length, 2*length - 1]
Output: [bs, heads, length, length]
"""
bs, heads, length, _ = x.shape
col_pad = tf.zeros((bs, heads, length, 1), dtype=x.dtype)
x = tf.concat([x, col_pad], axis=3)
flat_x = tf.reshape(x, [bs, heads, -1])
flat_pad = tf.zeros((bs, heads, length-1), dtype=x.dtype)
flat_x_padded = tf.concat([flat_x, flat_pad], axis=2)
final_x = tf.reshape(
flat_x_padded, [bs, heads, length+1, 2*length-1])
final_x = final_x[:, :, :length, length-1:]
return final_x
def relative_logits_1d(*, q, rel_k, transpose_mask):
"""
Compute relative logits along one dimenion.
`q`: [bs, heads, height, width, dim]
`rel_k`: [2*width - 1, dim]
"""
bs, heads, h, w, dim = q.shape
rel_logits = tf.einsum('bhxyd,md->bhxym', q, rel_k)
rel_logits = tf.reshape(rel_logits, [-1, heads * h, w, 2*w-1])
rel_logits = rel_to_abs(rel_logits)
rel_logits = tf.reshape(rel_logits, [-1, heads, h, w, w])
rel_logits = tf.expand_dims(rel_logits, axis=3)
rel_logits = tf.tile(rel_logits, [1, 1, 1, h, 1, 1])
rel_logits = tf.transpose(rel_logits, transpose_mask)
return rel_logits
def relative_logits(q):
"""Compute relative position enc logits."""
with tf.variable_scope('relative', reuse=tf.AUTO_REUSE):
bs, heads, h, w, dim = q.shape
int_dim = dim.value
# Note: below, we passed stddev arg as mean for the initializer.
# Providing code as is, with this small error.
# right way: normal_initializer(stddev=int_dim**-0.5)
# Relative logits in width dimension.
rel_emb_w = tf.get_variable(
'r_width', shape=(2*w - 1, dim),
dtype=q.dtype,
initializer=tf.random_normal_initializer(int_dim**-0.5))
rel_logits_w = relative_logits_1d(
q=q, rel_k=rel_emb_w,
transpose_mask=[0, 1, 2, 4, 3, 5])
# Relative logits in height dimension.
rel_emb_h = tf.get_variable(
'r_height', shape=(2*h - 1, dim),
dtype=q.dtype,
initializer=tf.random_normal_initializer(int_dim**-0.5))
rel_logits_h = relative_logits_1d(
q=tf.transpose(q, [0, 1, 3, 2, 4]),
rel_k=rel_emb_h,
transpose_mask=[0, 1, 4, 2, 5, 3])
return rel_logits_h + rel_logits_w
def relpos_self_attention(
*, q, k, v, relative=True, fold_heads=False):
"""2D self-attention with rel-pos. Add option to fold heads."""
bs, heads, h, w, dim = q.shape
int_dim = dim.value
q = q * (dim ** -0.5) # scaled dot-product
logits = tf.einsum('bhHWd,bhPQd->bhHWPQ', q, k)
if relative:
logits += relative_logits(q)
weights = tf.reshape(logits, [-1, heads, h, w, h * w])
weights = tf.nn.softmax(weights)
weights = tf.reshape(weights, [-1, heads, h, w, h, w])
attn_out = tf.einsum('bhHWPQ,bhPQd->bHWhd', weights, v)
if fold_heads:
attn_out = tf.reshape(attn_out, [-1, h, w, heads * dim])
return attn_out
def absolute_logits(q):
"""Compute absolute position enc logits."""
with tf.variable_scope('absolute', reuse=tf.AUTO_REUSE):
emb_w = tf.get_variable(
'r_width', shape=(W, dkh),
dtype=q.dtype,
initializer=tf.random_normal_initializer(dkh**-0.5))
emb_h = tf.get_variable(
'r_height', shape=(H, dkh),
dtype=q.dtype,
initializer=tf.random_normal_initializer(dkh**-0.5))
emb_h = emb_h[:, None, :]
emb_w = emb_w[None, :, :]
emb = emb_h + emb_w
abs_logits = tf.einsum('bhxyd,pqd->bhxypq', q, emb)
return abs_logits
def abspos_self_attention(*, q, k, v, absolue=True, fold_heads=False):
"""2D self-attention with abs-pos. Add option to fold heads."""
bs, heads, h, w, dim = q.shape
int_dim = dim.value
q = q * (dim ** -0.5) # scaled dot-product
logits = tf.einsum('bhHWd,bhPQd->bhHWPQ', q, k)
abs_logits = absolute_logits(q)
if absolute:
logits += abs_logits
weights = tf.reshape(logits, [-1, heads, h, w, h * w])
weights = tf.nn.softmax(weights)
weights = tf.reshape(weights, [-1, heads, h, w, h, w])
attn_out = tf.einsum('bhHWPQ,bhPQd->bHWhd', weights, v)
if fold_heads:
attn_out = tf.reshape(attn_out, [-1, h, w, heads * dim])
return attn_out
def group_pointwise(
featuremap, proj_factor=1, name='grouppoint',
heads=4, target_dimension=None):
"""1x1 conv with heads."""
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
in_channels = featuremap.shape[-1]
if target_dimension is not None:
proj_channels = target_dimension // proj_factor
else:
proj_channels = in_channels // proj_factor
w = tf.get_variable(
'w',
[in_channels, heads, proj_channels // heads],
dtype=featuremap.dtype,
initializer=tf.random_normal_initializer(stddev=0.01))
out = tf.einsum('bHWD,Dhd->bhHWd', featuremap, w)
return out
def MHSA(featuremap, pos_enc_type='relative', use_pos=True):
"""Multi-Head Self-Attention."""
q = group_pointwise(
featuremap, proj_factor=1, name='q_proj', heads=heads,
target_dimension=bottleneck_dimension)
k = group_pointwise(
featuremap, proj_factor=1, name='k_proj', heads=heads,
target_dimension=bottleneck_dimension)
v = group_pointwise(
featuremap, proj_factor=1, name='v_proj', heads=heads,
target_dimension=bottleneck_dimension)
assert pos_enc_type in ['relative', 'absolute']
if pos_enc_type == 'relative':
o = relpos_self_attention(
q=q, k=k, v=v, relative=use_pos, fold_heads=True)
else:
o = abspos_self_attention(
q=q, k=k, v=v, absolute=use_pos, fold_heads=True)
return o
def BoT_Block(
featuremap, is_training=False,
heads=4, proj_factor=4,
activation='relu',
pos_enc_type='relative',
name='all2all', strides=1,
target_dimension=2048):
"""Bottleneck Transformer (BoT) Block."""
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
shortcut = featuremap
in_dimension = featuremap.shape[-1]
if strides != 1 or in_dimension != target_dimension:
shortcut = Conv2D(
shortcut, filters=target_dimension, kernel_size=1, strides=strides)
shortcut = BNReLU(
shortcut, is_training, activation=activation, nonlinearity=True)
bottleneck_dimension = target_dimension // proj_factor
featuremap = Conv2D(
featuremap, filters=bottleneck_dimension, kernel_size=1, strides=1)
featuremap = BNReLU(
featuremap, is_training, activation=activation, nonlinearity=True)
featuremap = MHSA(featuremap, pos_enc_type=pos_enc_type)
if strides != 1:
assert strides == 2
featuremap = tf.keras.layers.AveragePooling2D(
pool_size=(2, 2), strides=(2, 2), padding='same')(featuremap)
featuremap = BNReLU(
featuremap, is_training, activation=activation, nonlinearity=True)
featuremap= Conv2D(
featuremap, filters=target_dimension,
kernel_size=1, strides=1)
featuremap = BNReLU(
featuremap, is_training, nonlinearity=False, init_zero=True)
return Activation(shortcut + featuremap, activation=activation)
def BoT_Stack(
featuremap, *,
blocks_so_far,
total_blocks,
is_training=False,
heads=4, proj_factor=4,
activation='relu',
pos_enc_type='relative',
name='all2all_stack',
strides=2, num_layers=3,
target_dimension=2048):
"""c5 Blockgroup of BoT Blocks."""
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
for i in range(num_layers):
featuremap = BoT_Block(
featuremap,
is_training=is_training,
heads=heads,
proj_factor=proj_factor,
activation=activation,
pos_enc_type=pos_enc_type,
strides=strides if i == 0 else 1,
target_dimension=target_dimension,
name='all2all_layer_{}'.format(i))
return featuremap
@lartpang
Copy link

lartpang commented Jan 29, 2021

This is a very solid work.
But I am very confused about the way the relative position enc logits is used here, why should it be set like this?
Can you give a more detailed explanation?

@ShoufaChen
Copy link

ShoufaChen commented Jan 31, 2021

A PyTorch implementation of botnet.py.

@leondgarse
Copy link

leondgarse commented Feb 3, 2021

A Keras version on tensorflow 2.5.0 botnet.py

  • A layers.Layer class MHSAWithPositionEmbedding implemented based on keras.layers.MultiHeadAttention.
  • bot_block based on keras.applications.ResNet50.
  • BotNet50 / BotNet101 / BotNet152 based on keras.applications.ResNet50 / ResNet101 / ResNet152.

@Cyril9227
Copy link

Cyril9227 commented Feb 8, 2021

A Keras version on tensorflow 2.4.0 botnet.py

  • A layers.Layer class MHSAWithRelativePosition implemented based on keras.layers.MultiHeadAttention.
  • bot_block based on keras.applications.ResNet50.
  • BotNet50 / BotNet101 / BotNet152 based on keras.applications.ResNet50 / ResNet101 / ResNet152.
  • Only relative position.

Great work, thanks for sharing. Do you have any pretrained models available ? What training strategy do you recommend ? Should we use Swish or Relu activation for optimal accuracy in images classification ?

@leondgarse
Copy link

leondgarse commented Feb 8, 2021

@Cyril9227 Sorry, but I don't have them pretrained either. All my understanding is from the article. I think the article shows strides=1, activation="swish" works better. Other strategies like using optimizer SGD with weight_decay is also recommended from the article.

@axhiao
Copy link

axhiao commented Feb 19, 2021

A PyTorch implementation of botnet.py.

Thanks!

@BIGBALLON
Copy link

BIGBALLON commented Mar 16, 2021

A pytorch version: https://github.com/BIGBALLON/distribuuuu/blob/master/distribuuuu/models/botnet.py,

The results (the model trained by distribuuuu) :

model epoch total batch lr policy base lr Acc@1 Acc@5 model / config
resnet18 100 256 (32*8GPUs) cos 0.2 70.902 89.894 Google Drive / cfg
resnet18 100 1024 (128*8GPUs) cos 0.8 70.994 89.892
resnet18 100 8192 (128*64GPUs) cos 6.4 70.165 89.374
resnet18 100 16384 (256*64GPUs) cos 12.8 68.766 88.381
resnet50 100 256 (32*8GPUs) cos 0.2 77.252 93.430 Google Drive / cfg
botnet50 100 256 (32*8GPUs) cos 0.2 77.604 93.682 Google Drive / cfg

Training log: https://gist.github.com/BIGBALLON/3d53c81b2b11ea5dd66417c2a985cd89

@mickvdspoel
Copy link

mickvdspoel commented May 1, 2021

When will you release the pretained model?

@leimao
Copy link

leimao commented May 15, 2021

Looks like the multi-head self-attention positional encoding implementation only supports inputs of static constant shapes. However, in the paper, the authors described that they used multi-scale image inputs for training. So I wonder whether this released code is the one that is used for the authors' experiments.

@mickvdspoel
Copy link

mickvdspoel commented Jun 11, 2021

Aravind Srinivas, congrats on your PhD. Do you think you have more time now to publish the pretrained BotNet model? If so, do you have an indication as from when?

@tenpha
Copy link

tenpha commented Jul 6, 2021

This is a very solid work.
But I am very confused about the way the relative position enc logits is used here, why should it be set like this?
Can you give a more detailed explanation?

+1

@lartpang
Copy link

lartpang commented Jul 9, 2021

This is a very solid work.
But I am very confused about the way the relative position enc logits is used here, why should it be set like this?
Can you give a more detailed explanation?

+1

Hi, I found a good explanation of relative position embedding:
https://theaisummer.com/positional-embeddings/

And here is a Chinese version of the explanation I wrote:
https://www.yuque.com/lart/ugkv9f/oazsec

@leondgarse
Copy link

leondgarse commented Jul 13, 2021

Understanding the calculating process of rel_to_abs gave me an idea of simplify. I think here the zeropaddings can be removed, but maybe some scenarios I missed:

def rel_to_abs_2(rel_pos):
    _, heads, hh, ww, dim = rel_pos.shape # [bs, heads, height, width, 2 * width - 1]
    # [bs, heads, height, width * (2 * width - 1)] --> [bs, heads, height, width * (2 * width - 1) - width]
    flat_x = tf.reshape(rel_pos, [-1, heads, hh, ww * (ww * 2 - 1)])[:, :, :, ww - 1:-1]
    # [bs, heads, height, width, 2 * (width - 1)] --> [bs, heads, height, width, width]
    return tf.reshape(flat_x, [-1, heads, hh, ww, 2 * (ww - 1)])[:, :, :, :, :ww]

Test

rel_pos = tf.random.uniform([12, 6, 14, 16, 2 * 16 - 1])
orignal_rel_to_abs = tf.reshape(rel_to_abs(tf.reshape(rel_pos, [-1, 6 * 14, 16, 2 * 16 - 1])), [-1, 6, 14, 16, 16])
print(np.allclose(orignal_rel_to_abs, rel_to_abs_2(rel_pos)))
# True

rel_to_abs

  • Add here keras_cv_attention_models/botnet is my botnet with weights loaded from timm
  • I think this relative positional embedding still makes sense in some future works...

@sayakpaul
Copy link

sayakpaul commented Aug 31, 2021

@BIGBALLON the Drive link you provided for the .pth weights is not in the right format it seems:

image

Could you clarify a bit?

@bsun0802
Copy link

bsun0802 commented Dec 28, 2021

@leondgarse

I agree with you that the zeros paddings can be omitted, and your implementation seems more concise and easy-to-understand.

Would you care to push your version to Pytorch Image Models (also known as the timm package), to see if the author agree with you to replace the current version with yours (no padding)?

And also, could the Relative Positional Embedding in HaloNet also be replaced with no padding?

@leondgarse
Copy link

leondgarse commented Dec 28, 2021

@bsun0802 I have been using this implementation for a long time. Here my keras_cv_attention_models/botnet and also keras_cv_attention_models/halonet both sharing this no-padding version. Those model weights all ported from timm and kept close outputs. I may discuss this with rwightman.

@bsun0802
Copy link

bsun0802 commented Dec 28, 2021

@leondgarse
Thanks for your reply. I just verified that your idea without padding works for HaloNet as well with a slight different.

The code need to be changed to:

b = 6 # block size
h = 1 # halo size
w = b + 2 * h # window size

To visualize, the index 1 to 8 are the indices we wanted.

x = torch.tensor([[0] * (w-1-i) + list(range(1,1+w)) + [0] * i for i in range(b)])
assert x.shape == (b, 2*w-1)
x
tensor([[0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8],
        [0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0],
        [0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0],
        [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0],
        [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0],
        [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0]])

For halo attention, we need to add a single 0 at the end of flatten tensor. (If im not wrong, this implementation should work for any block size b and window size w, maybe need to be adjusted if halo size h != 1. )

x = F.pad(x.flatten(), [0, 1])[w-1:].reshape(b, -1)  # rel_to_abs
x
tensor([[1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0],
        [1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0],
        [1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0],
        [1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0],
        [1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0],
        [1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0]])

Then simply slice out the intended positions.

out = x[:, :w]
out
tensor([[1, 2, 3, 4, 5, 6, 7, 8],
        [1, 2, 3, 4, 5, 6, 7, 8],
        [1, 2, 3, 4, 5, 6, 7, 8],
        [1, 2, 3, 4, 5, 6, 7, 8],
        [1, 2, 3, 4, 5, 6, 7, 8],
        [1, 2, 3, 4, 5, 6, 7, 8]])

@leondgarse
Copy link

leondgarse commented Dec 29, 2021

You may refer to my implementation rel_to_abs, that padding is also not necessary. I'm calling it a full_rank_gap for this scenario, just need to clip them:

hh = 1
ww, dim = x.shape

pos_dim = (dim + 1) // 2
full_rank_gap = pos_dim - ww
print(f"{pos_dim = }, {full_rank_gap = }")
# pos_dim = 8, full_rank_gap = 2

flat_x = x.reshape([-1, hh, ww * dim])[:, :, ww - 1 : -1]
out = flat_x.reshape([-1, hh, ww, 2 * (pos_dim - 1)])
out
# tensor([[[[0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0],
#           [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0],
#           [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0],
#           [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0],
#           [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0],
#           [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0]]]])

out[:, :, :, full_rank_gap : pos_dim + full_rank_gap]
# tensor([[1, 2, 3, 4, 5, 6, 7, 8],
#         [1, 2, 3, 4, 5, 6, 7, 8],
#         [1, 2, 3, 4, 5, 6, 7, 8],
#         [1, 2, 3, 4, 5, 6, 7, 8],
#         [1, 2, 3, 4, 5, 6, 7, 8],
#         [1, 2, 3, 4, 5, 6, 7, 8]])

@leondgarse
Copy link

leondgarse commented Dec 29, 2021

@bsun0802 I created a PR in timm #1061, but then closed it... It's not the logic, but torch.fx currently not supporting dynamic control flow depending on inputs, and I have no idea how to fix it... We can discuss it there if you still want it. Anyway, it's just a tiny change.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment