Created
April 2, 2021 13:53
-
-
Save AranKomat/1fd80f4b8d563918e6a700f638bc2553 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import types | |
from typing import Any, Optional | |
from .moving_average import ExponentialMovingAverage | |
from flax import linen as nn | |
import jax | |
import jax.numpy as jnp | |
# inspired from Haiku's corresponding code to Flax | |
class VectorQuantizerEMA(nn.Module): | |
embedding_dim: int | |
num_embeddings: int | |
commitment_cost: float | |
decay: float | |
epsilon: float = 1e-5 | |
dtype: Any = jnp.float32 | |
cross_replica_axis: Optional[str] = None | |
initialized: bool = False | |
def setup(self): | |
embedding_shape = [self.embedding_dim, self.num_embeddings] | |
initializer = lambda: nn.initializers.lecun_uniform(embedding_shape, dtype=self.dtype) | |
self.embeddings = self.variable("stats", "embeddings", initializer) | |
self.ema_cluster_size = ExponentialMovingAverage([self.num_embeddings], self.dtype, decay=self.decay) | |
self.ema_dw = ExponentialMovingAverage(embedding_shape, self.dtype, decay=self.decay) | |
def __call__(self, inputs, is_training, initialized=True): | |
flat_inputs = jnp.reshape(inputs, [-1, self.embedding_dim]) | |
print(flat_inputs, self.embeddings.value) | |
embeddings = self.embeddings.value | |
distances = ( | |
jnp.sum(flat_inputs**2, 1, keepdims=True) - | |
2 * jnp.matmul(flat_inputs, self.embeddings.value) + | |
jnp.sum(self.embeddings.value**2, 0, keepdims=True)) | |
encoding_indices = jnp.argmax(-distances, 1) | |
encodings = jax.nn.one_hot(encoding_indices, | |
self.num_embeddings, | |
dtype=distances.dtype) | |
encoding_indices = jnp.reshape(encoding_indices, inputs.shape[:-1]) | |
quantized = self.quantize(encoding_indices) | |
e_latent_loss = jnp.mean((jax.lax.stop_gradient(quantized) - inputs)**2) | |
if is_training: | |
cluster_size = jnp.sum(encodings, axis=0) | |
if self.cross_replica_axis: | |
cluster_size = jax.lax.psum( | |
cluster_size, axis_name=self.cross_replica_axis) | |
updated_ema_cluster_size = self.ema_cluster_size(cluster_size, update_stats=initialized) | |
dw = jnp.matmul(flat_inputs.T, encodings) | |
if self.cross_replica_axis: | |
dw = jax.lax.psum(dw, axis_name=self.cross_replica_axis) | |
updated_ema_dw = self.ema_dw(dw, update_stats=initialized) | |
n = jnp.sum(updated_ema_cluster_size) | |
updated_ema_cluster_size = ((updated_ema_cluster_size + self.epsilon) / | |
(n + self.num_embeddings * self.epsilon) * n) | |
normalised_updated_ema_w = ( | |
updated_ema_dw / jnp.reshape(updated_ema_cluster_size, [1, -1])) | |
if initialized: | |
self.embeddings.value = normalised_updated_ema_w | |
loss = self.commitment_cost * e_latent_loss | |
else: | |
loss = self.commitment_cost * e_latent_loss | |
# Straight Through Estimator | |
quantized = inputs + jax.lax.stop_gradient(quantized - inputs) | |
avg_probs = jnp.mean(encodings, 0) | |
if self.cross_replica_axis: | |
avg_probs = jax.lax.pmean(avg_probs, axis_name=self.cross_replica_axis) | |
perplexity = jnp.exp(-jnp.sum(avg_probs * jnp.log(avg_probs + 1e-10))) | |
return { | |
"quantize": quantized, | |
"loss": loss, | |
"perplexity": perplexity, | |
"encodings": encodings, | |
"encoding_indices": encoding_indices, | |
"distances": distances, | |
} | |
def quantize(self, encoding_indices): | |
"""Returns embedding tensor for a batch of indices.""" | |
w = self.embeddings.value.swapaxes(1, 0) | |
w = jax.device_put(w) # Required when embeddings is a NumPy array. | |
return w[(encoding_indices,)] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment