Skip to content

Instantly share code, notes, and snippets.

@ypeleg
Created Apr 30, 2021
Embed
What would you like to do?
Minimal keras implementation: "Perceiver: General Perception with Iterative Attention. Jaegle et al"
# Cleaned and minimal perceiver transformer, originally from code https://github.com/Rishit-dagli/Perceiver
# Original paper: Perceiver: General Perception with Iterative Attention. Jaegle et al. https://arxiv.org/pdf/2103.03206.pdf.
import math
import tensorflow as tf
from typing import Callable
from einops import rearrange, repeat
def fourier_encode(x, max_freq, num_bands = 4, base = 2):
x = tf.expand_dims(x, -1)
x = tf.cast(x, dtype = tf.float32)
orig_x = x
scales = tf.experimental.numpy.logspace(1.0, math.log(max_freq / 2) / math.log(base), num = num_bands, base = base, dtype = tf.float32, )
scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
x = x * scales * math.pi
x = tf.concat([tf.math.sin(x), tf.math.cos(x)], axis = -1)
x = tf.concat((x, orig_x), axis = -1)
return x
class PreNorm(tf.keras.layers.Layer):
def __init__(self, dim, fn, context_dim = None):
super(PreNorm, self).__init__()
self.fn = fn
self.norm = tf.keras.layers.LayerNormalization(axis = -1)
if context_dim is None: self.norm_context = None
else: self.norm_context = tf.keras.layers.LayerNormalization(axis = -1)
def call(self, x, **kwargs):
x = self.norm(x)
return self.fn(x)
class Perceiver(tf.keras.Model):
def __init__(self, num_freq_bands, depth, max_freq, freq_base = 2, input_channels = 3, input_axis = 2, num_latents = 512, latent_dim = 512, cross_heads = 1, latent_heads = 8, cross_dim_head = 64, latent_dim_head = 64, num_classes = 1000, attn_dropout = 0.0, ff_dropout = 0.0, ):
super(Perceiver, self).__init__()
self.input_axis = input_axis
self.max_freq = max_freq
self.num_freq_bands = num_freq_bands
self.freq_base = freq_base
input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels
self.latents = tf.Variable(tf.random.normal([num_latents, latent_dim]))
get_cross_attn: Callable[[], PreNorm] = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout, ), context_dim = input_dim, )
get_cross_ff: Callable[[], PreNorm] = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))
get_latent_attn: Callable[[], PreNorm] = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head, dropout = attn_dropout, ), )
get_latent_ff: Callable[[], PreNorm] = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))
self.existing_layers = list()
for i in range(depth):
self.existing_layers.append(get_cross_attn())
self.existing_layers.append(get_cross_ff())
self.existing_layers.append(get_latent_attn())
self.existing_layers.append(get_latent_ff())
self.existing_layers = tf.keras.Sequential(self.existing_layers)
self.to_logits = tf.keras.Sequential([tf.keras.layers.LayerNormalization(axis = -1), tf.keras.layers.Dense(num_classes, input_dim = latent_dim), ])
def call(self, data, mask = None):
b, *axis, _ = data.shape
axis_pos = list(map(lambda size: tf.linspace(-1.0, 1.0, num = size), axis))
pos = tf.stack(tf.meshgrid(*axis_pos, indexing = "ij"), axis = -1)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands, base = self.freq_base)
enc_pos = rearrange(enc_pos, "... n d -> ... (n d)")
enc_pos = repeat(enc_pos, "... -> b ...", b = b)
data = tf.concat((data, enc_pos), axis = -1)
data = rearrange(data, "b ... d -> b (...) d")
x = repeat(self.latents, "n d -> b n d", b = b)
x = self.existing_layers(x)
x = tf.math.reduce_mean(x, axis = -2)
return self.to_logits(x)
class GEGLU(tf.keras.layers.Layer):
def call(self, x):
x, gates = tf.split(x, 2, axis = -1)
return x * tf.nn.gelu(gates)
class Attention(tf.keras.layers.Layer):
def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, dropout = 0.0):
super(Attention, self).__init__()
inner_dim = dim_head * heads
if context_dim is None: context_dim = query_dim
self.scale = dim_head ** -0.5
self.heads = heads
self.to_queries = tf.keras.layers.Dense(inner_dim, input_dim = query_dim, use_bias = False)
self.to_keys_values = tf.keras.layers.Dense(inner_dim * 2, input_dim = query_dim, use_bias = False)
self.to_out = tf.keras.Sequential([tf.keras.layers.Dense(inner_dim, input_dim = query_dim), tf.keras.layers.Dropout(dropout), ])
def call(self, x, context = None, mask = None):
h = self.heads
queries = self.to_queries(x)
if context is None: context = x
kv = self.to_keys_values(context)
keys, values = tf.split(kv, num_or_size_splits = 2, axis = -1)
queries, keys, values = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h = h), (queries, keys, values), )
sim = tf.einsum("b i d, b j d -> b i j", queries, keys) * self.scale
if mask is not None:
mask = rearrange(mask, "b ... -> b (...)")
max_neg_value = -tf.experimental.numpy.finfo(sim.dtype).max
mask = repeat(mask, "b j -> (b h) () j", h = h)
sim = tf.where(tf.bitwise.invert(mask), max_neg_value, sim)
attn = tf.nn.softmax(sim, axis = -1)
out = tf.einsum("b i j, b j d -> b i d", attn, values)
out = rearrange(out, "(b h) n d -> b n (h d)", h = h)
out = self.to_out(out)
return out
class FeedForward(tf.keras.layers.Layer):
def __init__(self, dim, mult = 4, dropout = 0.0):
super(FeedForward, self).__init__()
self.net = tf.keras.Sequential([tf.keras.layers.Dense(dim * mult * 2, input_dim = dim), GEGLU(), tf.keras.layers.Dropout(dropout), tf.keras.layers.Dense(dim, input_dim = dim * mult), ])
def call(self, inputs): return self.net(inputs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment