Skip to content

Instantly share code, notes, and snippets.

@AranKomat
Created April 2, 2021 13:53
Show Gist options
  • Save AranKomat/1fd80f4b8d563918e6a700f638bc2553 to your computer and use it in GitHub Desktop.
Save AranKomat/1fd80f4b8d563918e6a700f638bc2553 to your computer and use it in GitHub Desktop.
import types
from typing import Any, Optional
from .moving_average import ExponentialMovingAverage
from flax import linen as nn
import jax
import jax.numpy as jnp
# inspired from Haiku's corresponding code to Flax
class VectorQuantizerEMA(nn.Module):
embedding_dim: int
num_embeddings: int
commitment_cost: float
decay: float
epsilon: float = 1e-5
dtype: Any = jnp.float32
cross_replica_axis: Optional[str] = None
initialized: bool = False
def setup(self):
embedding_shape = [self.embedding_dim, self.num_embeddings]
initializer = lambda: nn.initializers.lecun_uniform(embedding_shape, dtype=self.dtype)
self.embeddings = self.variable("stats", "embeddings", initializer)
self.ema_cluster_size = ExponentialMovingAverage([self.num_embeddings], self.dtype, decay=self.decay)
self.ema_dw = ExponentialMovingAverage(embedding_shape, self.dtype, decay=self.decay)
def __call__(self, inputs, is_training, initialized=True):
flat_inputs = jnp.reshape(inputs, [-1, self.embedding_dim])
print(flat_inputs, self.embeddings.value)
embeddings = self.embeddings.value
distances = (
jnp.sum(flat_inputs**2, 1, keepdims=True) -
2 * jnp.matmul(flat_inputs, self.embeddings.value) +
jnp.sum(self.embeddings.value**2, 0, keepdims=True))
encoding_indices = jnp.argmax(-distances, 1)
encodings = jax.nn.one_hot(encoding_indices,
self.num_embeddings,
dtype=distances.dtype)
encoding_indices = jnp.reshape(encoding_indices, inputs.shape[:-1])
quantized = self.quantize(encoding_indices)
e_latent_loss = jnp.mean((jax.lax.stop_gradient(quantized) - inputs)**2)
if is_training:
cluster_size = jnp.sum(encodings, axis=0)
if self.cross_replica_axis:
cluster_size = jax.lax.psum(
cluster_size, axis_name=self.cross_replica_axis)
updated_ema_cluster_size = self.ema_cluster_size(cluster_size, update_stats=initialized)
dw = jnp.matmul(flat_inputs.T, encodings)
if self.cross_replica_axis:
dw = jax.lax.psum(dw, axis_name=self.cross_replica_axis)
updated_ema_dw = self.ema_dw(dw, update_stats=initialized)
n = jnp.sum(updated_ema_cluster_size)
updated_ema_cluster_size = ((updated_ema_cluster_size + self.epsilon) /
(n + self.num_embeddings * self.epsilon) * n)
normalised_updated_ema_w = (
updated_ema_dw / jnp.reshape(updated_ema_cluster_size, [1, -1]))
if initialized:
self.embeddings.value = normalised_updated_ema_w
loss = self.commitment_cost * e_latent_loss
else:
loss = self.commitment_cost * e_latent_loss
# Straight Through Estimator
quantized = inputs + jax.lax.stop_gradient(quantized - inputs)
avg_probs = jnp.mean(encodings, 0)
if self.cross_replica_axis:
avg_probs = jax.lax.pmean(avg_probs, axis_name=self.cross_replica_axis)
perplexity = jnp.exp(-jnp.sum(avg_probs * jnp.log(avg_probs + 1e-10)))
return {
"quantize": quantized,
"loss": loss,
"perplexity": perplexity,
"encodings": encodings,
"encoding_indices": encoding_indices,
"distances": distances,
}
def quantize(self, encoding_indices):
"""Returns embedding tensor for a batch of indices."""
w = self.embeddings.value.swapaxes(1, 0)
w = jax.device_put(w) # Required when embeddings is a NumPy array.
return w[(encoding_indices,)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment