Skip to content

Instantly share code, notes, and snippets.

@Ryu1845
Created September 28, 2023 08:01
Show Gist options
  • Save Ryu1845/8da008ac7e0520aee1313df4fe906145 to your computer and use it in GitHub Desktop.
Save Ryu1845/8da008ac7e0520aee1313df4fe906145 to your computer and use it in GitHub Desktop.
FSQ Implementation from the paper
def round_ste(z):
"""Round with straight through gradients."""
zhat = jnp.round(z)
return z + jax.lax.stop_gradient(zhat - z)
class FSQ:
def __init__(self, levels: list[int]):
self._levels = levels
self._levels_np = np.asarray(levels)
self._basis = np.concatenate(
([1], np.cumprod(self._levels_np[:-1]))
).astype(np.uint32)
codebook_size = np.prod(levels)
self.implicit_codebook = self.indexes_to_codes(
np.arange(codebook_size))
def bound(self, z):
"""Bound `z`, an array of shape (..., d)."""
eps = 1e-3
half_l = (self._levels_np - 1) * (1 - eps) / 2
offset = jnp.where(self._levels_np %
shift = jnp.tan(offset / half_l)
return jnp.tanh(z + shift) * half_l - offset
def quantize(self, z):
"""Quanitzes z, returns quantized zhat, same shape as z."""
quantized = round_ste(self.bound(z))
half_width = self._levels_np // 2 # Renormalize to [-1, 1].
return quantized / half_width
def _scale_and_shift(self, zhat_normalized):
half_width = self._levels_np // 2
return (zhat_normalized * half_width) + half_width
def _scale_and_shift_inverse(self, zhat):
half_width = self._levels_np // 2
return (zhat - half_width) / half_width
def codes_to_indexes(self, zhat):
"""Converts a `code` to an index in the codebook."""
assert zhat.shape[-1] == len(self._levels)
zhat = self._scale_and_shift(zhat)
return (zhat * self._basis).sum(axis=-1).astype(jnp.uint32)
def indexes_to_codes(self, indices):
"""Inverse of `indexes_to_codes`."""
indices = indices[..., jnp.newaxis]
codes_non_centered = np.mod(
np.floor_divide(indices, self._basis), self._levels_np
)
return self._scale_and_shift_inverse(codes_non_centered
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment