Skip to content

Instantly share code, notes, and snippets.

@theideasmith
Last active August 4, 2017 15:55
Show Gist options
  • Save theideasmith/f107ec97ba450fe5d508a04ce83509d0 to your computer and use it in GitHub Desktop.
Save theideasmith/f107ec97ba450fe5d508a04ce83509d0 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 137,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from glarmm import *\n",
"from scipy.io import loadmat\n",
"import h5py\n",
"from numpy import *"
]
},
{
"cell_type": "code",
"execution_count": 160,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"\"\"\"\n",
"Features: 5\n",
" .features(i)\n",
" for each interval in nfeatures(i)\n",
" \n",
"\n",
"\"\"\"\n",
"f = loadmat('/Volumes/murthy/akiva/summer-2017/songanalysis/intervalFeatures.mat');"
]
},
{
"cell_type": "code",
"execution_count": 161,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# indexes\n",
"# songIntervalIdx = 0;\n",
"pulseTimIntervalIdx = 0;\n",
"tDisIntervalIdx = 1;\n",
"pHeightIdx = 2;\n",
"pDisLaggedIdx = 3;\n",
"pAmpLaggedIdx = 4;\n",
"pTimSinceIdx = 5;\n",
"\n",
"ints = ((np.array([100, 250, 500, 750, 1000]).astype('float32')/1000)*144).round()\n",
"\n",
"\n",
"features = {\n",
" str(ivlLen): []\n",
" for ivlLen in ints\n",
"}\n",
"\n",
"tDists = [];\n",
"pAmps_autoregress = [];\n",
"pTimSince = [];\n",
"pAmps_predict = [];\n",
"for i in range(ints.shape[0]):\n",
" # for i in number of features\n",
" # songIntervalIdx = 0;\n",
" # pulseTimIntervalIdx = 1;\n",
" # tDisIntervalIdx = 2;\n",
" # pHeightIntervalIdx = 3;\n",
" \n",
" features_i = f['features'][0][i][0][0];\n",
" Nintervals = features_i.shape[0];\n",
"# print Nintervals\n",
" for j in range(Nintervals):\n",
"# print f['features'][0][i][0][0][0][pHeightIdx].shape\n",
" if f['features'][0][i][0][0][j][0].shape[0] != 0:\n",
" tDists.append(features_i[j][pDisLaggedIdx]);\n",
" pTimSince.append(features_i[j][pTimSinceIdx]);\n",
" pAmps_autoregress.append(features_i[j][pAmpLaggedIdx]);\n",
" pAmps_predict.append(features_i[j][pHeightIdx]);\n"
]
},
{
"cell_type": "code",
"execution_count": 162,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# My implementation was getting NaNs\n",
"tDists = np.vstack(tDists);\n",
"pTimSince = np.vstack(pTimSince)*(1e-4);\n",
"pTimSince[np.isnan(pTimSince)]=0;\n",
"pAmps_predict = np.vstack(pAmps_predict);\n",
"pAmps_autoregress = np.vstack(pAmps_autoregress);\n",
"\n",
"import glarmm \n",
"reload(glarmm)\n",
"model = glarmm.GLARMM(tDists, pAmps_autoregress, pTimSince, pAmps_predict, mixtures=2);\n",
"for i in range(10):\n",
" model.update()\n",
" print model.sigma\n",
" print model.beta\n",
" print model.alpha\n",
" print model.phi"
]
},
{
"cell_type": "code",
"execution_count": 217,
"metadata": {
"collapsed": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0.87713826 1.34756267]]\n",
"[[ nan nan nan]\n",
" [ nan nan nan]]\n",
"[[ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]\n",
" [ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]]\n",
"[[ nan nan]]\n",
"[[ 0.87713826 1.34756267]]\n",
"[[ nan nan nan]\n",
" [ nan nan nan]]\n",
"[[ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]\n",
" [ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]]\n",
"[[ nan nan]]\n",
"[[ 0.87713826 1.34756267]]\n",
"[[ nan nan nan]\n",
" [ nan nan nan]]\n",
"[[ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]\n",
" [ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]]\n",
"[[ nan nan]]\n",
"[[ 0.87713826 1.34756267]]\n",
"[[ nan nan nan]\n",
" [ nan nan nan]]\n",
"[[ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]\n",
" [ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]]\n",
"[[ nan nan]]\n",
"[[ 0.87713826 1.34756267]]\n",
"[[ nan nan nan]\n",
" [ nan nan nan]]\n",
"[[ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]\n",
" [ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]]\n",
"[[ nan nan]]\n",
"[[ 0.87713826 1.34756267]]\n",
"[[ nan nan nan]\n",
" [ nan nan nan]]\n",
"[[ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]\n",
" [ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]]\n",
"[[ nan nan]]\n",
"[[ 0.87713826 1.34756267]]\n",
"[[ nan nan nan]\n",
" [ nan nan nan]]\n",
"[[ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]\n",
" [ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]]\n",
"[[ nan nan]]\n",
"[[ 0.87713826 1.34756267]]\n",
"[[ nan nan nan]\n",
" [ nan nan nan]]\n",
"[[ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]\n",
" [ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]]\n",
"[[ nan nan]]\n",
"[[ 0.87713826 1.34756267]]\n",
"[[ nan nan nan]\n",
" [ nan nan nan]]\n",
"[[ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]\n",
" [ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]]\n",
"[[ nan nan]]\n",
"[[ 0.87713826 1.34756267]]\n",
"[[ nan nan nan]\n",
" [ nan nan nan]]\n",
"[[ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]\n",
" [ nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan]]\n",
"[[ nan nan]]\n"
]
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": 352,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Y: [[ 153.]\n",
" [ 187.]\n",
" [ 148.]\n",
" ..., \n",
" [ -70.]\n",
" [ -64.]\n",
" [ -96.]]\n",
"Sigma: [[ 106.44272614 106.44272614]]\n",
"Mus: [[ 54.15526962 116.13597107]\n",
" [ 151.90632629 159.61132812]\n",
" [ 240.17140198 214.29992676]\n",
" ..., \n",
" [ 123.46837616 38.27497864]\n",
" [ -1.74603879 -35.9292717 ]\n",
" [-103.4500351 -83.69264221]]\n",
"Likelihood: [[ 0.00243523 0.00352979]\n",
" [ 0.00354969 0.00362591]\n",
" [ 0.00257615 0.00308708]\n",
" ..., \n",
" [ 0.0007185 0.00223412]\n",
" [ 0.00315876 0.00361986]\n",
" [ 0.00373878 0.00372298]]\n",
"Y: [[ 153.]\n",
" [ 187.]\n",
" [ 148.]\n",
" ..., \n",
" [ -70.]\n",
" [ -64.]\n",
" [ -96.]]\n",
"Sigma: [[ nan nan]]\n",
"Mus: [[ nan nan]\n",
" [ nan nan]\n",
" [ nan nan]\n",
" ..., \n",
" [ nan nan]\n",
" [ nan nan]\n",
" [ nan nan]]\n",
"Likelihood: [[ nan nan]\n",
" [ nan nan]\n",
" [ nan nan]\n",
" ..., \n",
" [ nan nan]\n",
" [ nan nan]\n",
" [ nan nan]]\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"\n",
"mixtures = 2;\n",
"z1 = tf.constant(tDists, name=\"tDists\", dtype='float32')\n",
"z2 = tf.constant(pAmps_autoregress, name=\"pAmps_autoregress\", dtype='float32')\n",
"z3 = tf.constant(pTimSince, name=\"pTimSince\", dtype='float32')\n",
"y = tf.constant(pAmps_predict, name=\"pAmps_predict\", dtype='float32')\n",
"\n",
"alpha = tf.Variable(tf.random_normal([2, z1.get_shape()[1].value]), dtype='float32');\n",
"beta = tf.Variable(tf.random_normal([2, z2.get_shape()[1].value]), dtype='float32');\n",
"gamma = tf.Variable(tf.random_normal([2, z3.get_shape()[1].value]), dtype='float32');\n",
"phi = tf.Variable(tf.random_normal([1, mixtures]), dtype='float32');\n",
"sigma = np.std(pAmps_predict)*tf.Variable(tf.ones([1, mixtures]), dtype='float32');\n",
"\n",
"alphaDotdistance = tf.matmul(z1,tf.transpose(alpha)); \n",
"betaDotpAmp = tf.matmul(z2, tf.transpose(beta)); \n",
"gammaDottimeSince = tf.matmul(z3, tf.transpose(gamma)); \n",
"mus = alphaDotdistance + betaDotpAmp + gammaDottimeSince; \n",
"\n",
"def N(x, mu, sigma):\n",
" dist = tf.contrib.distributions.Normal(mu,sigma)\n",
" return dist.prob(x);\n",
"\n",
"log_likelihood = -1*tf.reduce_sum(tf.log(tf.reduce_sum(phi*N(y, mus, sigma), axis=1)), axis=0); \n",
"train_step = tf.train.AdadeltaOptimizer(1.0).minimize(log_likelihood)\n",
"sess = tf.InteractiveSession()\n",
"tf.global_variables_initializer().run()\n",
"\n",
"for _ in range(2):\n",
" print \"Y: \", sess.run(y);\n",
" print \"Sigma: \", sess.run(sigma)\n",
" print \"Mus: \", sess.run(mus);\n",
" print \"Likelihood: \", sess.run(N(y, mus, sigma));\n",
" sess.run(train_step);"
]
},
{
"cell_type": "code",
"execution_count": 268,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"nan"
]
},
"execution_count": 268,
"metadata": {},
"output_type": "execute_result"
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment