Skip to content

Instantly share code, notes, and snippets.

@noahtren
Created June 22, 2021 17:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save noahtren/fc4147105eaba38f19fb2ed4c2a0555b to your computer and use it in GitHub Desktop.
Save noahtren/fc4147105eaba38f19fb2ed4c2a0555b to your computer and use it in GitHub Desktop.
differentiable waveguide forward pass (flax)
@functools.partial(nn.scan,
variable_broadcast='params',
split_rngs={'params': False})
@nn.remat
@nn.compact
def __call__(self, carry, t: int):
noise, memory, latent_code = carry['noise'], carry['memory'], carry[
'latent_code']
intermediates = {}
batch_size = memory.shape[0]
use_memory = jax.lax.dynamic_slice(
memory, [0, t, 0], [batch_size, self.receptive_field, self.num_modules])
use_noise = jax.lax.dynamic_slice(noise, [t, 0], [1, self.num_modules])[0]
# exciter
exciter_coeffs = nn.Dense(features=self.num_modules,
name='exciter_coeffs')(latent_code)
intermediates['exciter_coeffs'] = exciter_coeffs
x = exciter_coeffs * use_noise[np.newaxis]
# delay
delay_logits = nn.Dense(features=self.receptive_field * self.num_modules *
self.num_modules,
name='delay_logits')(latent_code)
delay_logits = np.reshape(
delay_logits,
[batch_size, self.receptive_field, self.num_modules, self.num_modules])
# rescale delay logits (reduce power)
scale_freq = (np.arange(self.receptive_field) / self.receptive_field)
delay_logits = delay_logits - scale_freq[np.newaxis, :, np.newaxis,
np.newaxis]
intermediates['delay_logits'] = delay_logits
# delay coeffs
delay_coeffs = nn.Dense(features=self.num_modules * self.num_modules,
name='delay_coeffs')(latent_code)
delay_coeffs = np.reshape(delay_coeffs,
[batch_size, self.num_modules, self.num_modules])
delay_coeffs = nn.sigmoid(delay_coeffs)
intermediates['delay_coeffs'] = delay_coeffs
# generate signal
delay_values = nn.softmax(delay_logits, axis=1)
duration_values = use_memory[..., np.newaxis] * delay_values
delay_lines = np.sum(duration_values, axis=1)
delay_lines = delay_lines * delay_coeffs
module_delay_inputs = np.mean(delay_lines, axis=1)
x = x + module_delay_inputs
# nonlinearity
pre_activation_coeffs = nn.Dense(features=self.num_modules,
name='pre_activation_coeffs')(latent_code)
x = np.tanh(x * pre_activation_coeffs)
memory = jax.lax.dynamic_update_index_in_dim(memory, x,
t + self.receptive_field, 1)
return {
'noise': noise,
'memory': memory,
'latent_code': latent_code
}, intermediates
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment