Skip to content

Instantly share code, notes, and snippets.

View michaelchughes's full-sized avatar

Mike Hughes michaelchughes

View GitHub Profile
@michaelchughes
michaelchughes / HowToOptimizeBPR.ipynb
Created March 15, 2023 17:28
Reflection on how to pick tracts to optimize BPR
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@michaelchughes
michaelchughes / jax_demo_stacked_list.py
Created February 24, 2023 19:05
Demo of JAX applied to a stack of a list to avoid item assignment
import numpy as np
import jax
import jax.numpy as jnp
import jax.nn
def calc_trans_mat(x_TD, r_KD, p_KK):
''' Compute transition matrix for each timestep
Args
@michaelchughes
michaelchughes / test_gpu_jax.py
Created November 29, 2022 00:03
Simple script that verifies JAX has access to GPU and can do basic ops (matrix multiply)
import jax
import jax.numpy as jnp
if __name__ == '__main__':
print("jax.devices()")
print(jax.devices())
a = jnp.asarray([[1.0, 2.0, 3.0], [4., 5., 6.]])
b = jnp.asarray([[1.0, 2.0], [3.0, 4.0], [5., 6.]])

Install notes from BDL2022f env install on 2022-11-28

  • For CUDA TOOLKIT 11.3, which can be used on older devices but may not be optimal
  1. set up basic conda env without any torch or jax packages, via
conda env create -f bdl_2022f.yml 
@michaelchughes
michaelchughes / estimating_expectations_under_correlated_distributions.ipynb
Created August 31, 2022 17:37
Formal analysis and implementation sanity check of noisy minibatch estimates for expectations under correlated distributions
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@michaelchughes
michaelchughes / transform_images_to_match_source.py
Created August 5, 2022 02:49
Create a function that will monotonically transform the intensity values of images from a "target" distribution to match a desired "source" distribution
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
from statsmodels.distributions.empirical_distribution import ECDF
def create_transform_func_to_match_source(target_x_ND, src_x_MD, n_quantiles=1000):
'''
@michaelchughes
michaelchughes / PowerLawFitToViewClassifierPerformance.ipynb
Created February 4, 2022 16:34
Power Law Fit to the trend in Error vs Dev Set size in View Classifier Performance
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@michaelchughes
michaelchughes / vi_for_poisson_normal
Created November 1, 2021 04:04
Demonstration of ELBO computation using Monte Carlo method
''' VI for Poisson Normal
Model
-----
Latent variable z is drawn from a Normal prior: z ~ Normal( 40, 10)
Data y is drawn iid from a Poisson likelihood: y_n ~ Poisson(z)
Approx Posterior
----------------
Posterior on z is assumed to be Normal with unknown mean and stddev
@michaelchughes
michaelchughes / accuracy_by_window_size.csv
Last active August 12, 2021 02:55
Demo of Half-Violin Plots for Showing distributions
window_size sample_id accuracy
5.0 0 0.6236094882645041
5.0 1 0.593111865845944
5.0 2 0.6060493252962028
5.0 3 0.6342719738873018
5.0 4 0.6259239448289695
5.0 5 0.5623114809268821
5.0 6 0.6054087015122116
5.0 7 0.5807796285836049
5.0 8 0.5818560349582886
vals_float32 = np.logspace(0, 5, dtype=np.float32)
vals_float64 = np.logspace(0, 5, dtype=np.float64)
## Pretty-print output of array so each float takes same num chars
def pprint_arr(arr, n_per_line=6):
for s in range(0, arr.size, n_per_line):
chunk = arr[s:s+n_per_line]
print(" ".join(["%10s" % np.format_float_scientific(x, precision=2, unique=False, exp_digits=3) for x in chunk]))
print()