Skip to content

Instantly share code, notes, and snippets.

@sloonz
Created September 1, 2023 12:47
Show Gist options
  • Save sloonz/15491bf13c8df3c1bbe9cd96b04537b6 to your computer and use it in GitHub Desktop.
Save sloonz/15491bf13c8df3c1bbe9cd96b04537b6 to your computer and use it in GitHub Desktop.
mixture density network with tensorflow
import math
import numpy as np
import tensorflow as tf
data = np.zeros((1000, 500))
def get_mdn_loss(n_kernels, n_dims):
def mdn_loss(y_true, y_pred):
logit_alpha, log_scale, mu = tf.split(y_pred, axis=1, num_or_size_splits=[n_kernels, n_kernels, n_dims*n_kernels])
alpha = tf.nn.softmax(logit_alpha)
scale = tf.math.exp(log_scale)
var = tf.square(scale)
mu = tf.reshape(mu, (-1, n_kernels, n_dims)) # (batch_size, n_dims*n_kernels) -> (batch_size, n_kernels, n_dims)
y_true = tf.reshape(y_true, (-1, 1, n_dims)) # (batch_size, n_dims) -> (batch_size, 1, n_dims), now broadcast is possible
gaussian_normalization = tf.math.pow(2*math.pi*var, -0.5*n_dims)
gaussian_unnormalized = tf.exp(-0.5 * tf.reduce_sum((y_true - mu)**2, axis=2) / var)
likelihood = tf.reduce_sum(alpha * gaussian_normalization * gaussian_unnormalized, axis=1)
return tf.reduce_mean(-tf.math.log(likelihood))
return mdn_loss
def get_mixture(v):
beta = 4*(v-1/2)**2
gamma = 1/(1+np.exp(-10*(v-1/2)))
delta_v = 0.05 + 0.2*beta**2
return (1 - gamma,
v + delta_v,
v - delta_v,
delta_v / 10,
delta_v / 10)
data[:, 0] = np.random.uniform(size=data.shape[0])
for i in range(1, data.shape[1]):
v_prev = data[:, i-1]
(alpha, mu1, mu2, scale1, scale2) = get_mixture(v_prev)
v_next_1 = np.random.normal(mu1, scale1)
v_next_2 = np.random.normal(mu2, scale2)
v_next = np.where(np.random.uniform(size=v_prev.shape[0]) < alpha, v_next_1, v_next_2)
data[:, i] = np.clip(v_next, 0, 1)
X = data[:, :-1].flatten().astype(np.float32)
Y = data[:, 1:].flatten().astype(np.float32)
n_dims = 1
n_kernels = 2
model = tf.keras.models.Sequential()
model.add(tf.keras.Input(shape=(1,)))
model.add(tf.keras.layers.Dense(32, activation="tanh"))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dense(32, activation="relu"))
model.add(tf.keras.layers.Dense(n_kernels*(2+n_dims)))
model.compile(optimizer="Adam", loss=get_mdn_loss(n_kernels=n_kernels, n_dims=n_dims))
model.summary()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment