Skip to content

Instantly share code, notes, and snippets.

@roblem
Created May 2, 2020 09:56
Show Gist options
  • Save roblem/be18414429eff8f2d751b3add494be95 to your computer and use it in GitHub Desktop.
Save roblem/be18414429eff8f2d751b3add494be95 to your computer and use it in GitHub Desktop.
Rocm tensorflow probaility benchmarks for NUTS
import sys
print("Running in :", sys.executable)
import tensorflow as tf
print("TF devices: ", tf.config.list_physical_devices())
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
import numpy as np
import pandas as pd
import time as time
# set tensorflow data type
dtype = tf.float32
##
## simple OLS Data Generation Process
##
# True beta
N = 50000
K = 200
b = np.random.randn(K)
b[0] = b[0] + 3
# True error std deviation
sigma_e = 1
x = np.c_[np.ones(N), np.random.randn(N,K-1)]
y = x.dot(b) + sigma_e * np.random.randn(N)
# estimate parameter vector, errors, sd of errors, and se of parameters
bols = np.linalg.inv(x.T.dot(x)).dot(x.T.dot(y))
err = y - x.dot(bols)
sigma_ols = np.sqrt(err.dot(err)/(x.shape[0] - x.shape[1]))
se = np.sqrt(err.dot(err)/(x.shape[0] - x.shape[1]) * np.diagonal(np.linalg.inv(x.T.dot(x))))
# put results together for easy viewing
ols_parms = np.r_[bols, sigma_ols]
ols_se = np.r_[se, np.nan]
print("\n")
indexn = ['b'+str(i) for i in range(K)]
indexn.extend(['sigma'])
print(pd.DataFrame(np.c_[ols_parms, ols_se],columns=['estimate', 'std err'],
index=indexn))
print("\n\n")
X = tf.constant(x, dtype=dtype)
Y = tf.constant(y, dtype=dtype)
N_ = tf.constant(N, dtype=dtype)
pi = tf.constant(np.pi, dtype=dtype)
init_step_size = .05
nuts_burnin = 100
nuts_samples = 1000
# initialize
init = [tf.constant(np.random.randn(K), dtype=dtype), tf.constant(1., dtype=dtype)]
##
## Model Log-Likelihood/Posterior
##
@tf.function(experimental_compile=True, experimental_relax_shapes=True)
def ols_loglike(beta, sigma):
# xb (mu_i for each observation)
mu = tf.linalg.matvec(X, beta)
# this is normal pdf logged and summed over all observations
ll = - (N_/2.)*tf.math.log(2.*pi*sigma**2) -\
(1./(2.*sigma**2.))*tf.math.reduce_sum((Y-mu)**2., axis=-1)
return ll
#
# Evaluate speed of function evals (no tfp required)
#
with tf.device('/CPU:0'):
ll = ols_loglike(init[0], init[1])
startt = time.time()
ll = ols_loglike(init[0], init[1])
endt = time.time()
print("\n\nLogL calculation in %2.2f MS on CPU"% ((endt - startt)*1000))
print("\n\n")
try:
with tf.device('/GPU:0'):
ll = ols_loglike(init[0], init[1])
startt = time.time()
ll = ols_loglike(init[0], init[1])
endt = time.time()
print("\n\nLogL calculation in %2.2f MS on GPU"% ((endt - startt)*1000))
print("\n\n")
except:
print("GPU not available in this python environment")
##
## NUTS (using inner step size averaging step)
##
@tf.function(experimental_compile=True, experimental_relax_shapes=True)
def nuts_sampler(init, nuts_samples=1, nuts_burnin=1):
@tf.function
def ols_loglike_(beta, sigma):
# xb (mu_i for each observation)
mu = tf.linalg.matvec(X, beta)
# this is normal pdf logged and summed over all observations
ll = - (N_/2.)*tf.math.log(2.*pi*sigma**2) -\
(1./(2.*sigma**2.))*tf.math.reduce_sum((Y-mu)**2., axis=-1)
return ll
nuts_kernel = tfp.mcmc.NoUTurnSampler(
target_log_prob_fn=ols_loglike_,
step_size=init_step_size)
adapt_nuts_kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
inner_kernel=nuts_kernel,
num_adaptation_steps=nuts_burnin,
step_size_getter_fn=lambda pkr: pkr.step_size,
log_accept_prob_getter_fn=lambda pkr: pkr.log_accept_ratio,
step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(step_size=new_step_size))
samples_nuts_, stats_nuts_ = tfp.mcmc.sample_chain(
num_results=nuts_samples,
current_state=init,
kernel=adapt_nuts_kernel,
num_burnin_steps=100, parallel_iterations=10)
return samples_nuts_, stats_nuts_
with tf.device('/CPU:0'):
samples_nuts, stats_nuts = nuts_sampler(init, 1, 1)
startt = time.time()
samples_nuts, stats_nuts = nuts_sampler(init, nuts_samples, nuts_burnin)
endt = time.time()
print("\n\nNuts sampling completed in %2.2f seconds on CPU"% (endt - startt))
print("\n\n")
try:
with tf.device('/GPU:0'):
samples_nuts, stats_nuts = nuts_sampler(init, 1, 1)
startt = time.time()
samples_nuts, stats_nuts = nuts_sampler(init, nuts_samples, nuts_burnin)
endt = time.time()
print("\n\nNuts sampling completed in %2.2f seconds on GPU"% (endt - startt))
print("\n\n")
except:
print("GPU not available in this python environment")
trace_sigman = samples_nuts[1].numpy()
trace_betan = samples_nuts[0].numpy()
est_nuts = np.r_[trace_betan.mean(axis=0), trace_sigman.mean()]
std_nuts = np.r_[trace_betan.std(axis=0), trace_sigman.std()]
# assemble and print
print(pd.DataFrame(np.c_[est_nuts, std_nuts],columns=['estimate', 'std err'],
index=indexn))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment