Skip to content

Instantly share code, notes, and snippets.

@mtsokol
Created June 13, 2021 10:31
Show Gist options
  • Save mtsokol/cc10c0d57ac0050d4cf8ea6774acbde5 to your computer and use it in GitHub Desktop.
Save mtsokol/cc10c0d57ac0050d4cf8ea6774acbde5 to your computer and use it in GitHub Desktop.
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import argparse
import logging
import pandas as pd
import torch
import pyro
from pyro.distributions import Binomial, HalfCauchy, Pareto, Uniform
from pyro.distributions.util import scalar_like
from pyro.infer import MCMC, NUTS
from pyro.infer.mcmc import StreamingMCMC
from pyro.infer.mcmc.util import initialize_model
logging.basicConfig(format='%(message)s', level=logging.INFO)
DATA_URL = "https://d2hg8soec8ck9v.cloudfront.net/datasets/EfronMorrisBB.txt"
# ===================================
# MODELS
# ===================================
def fully_pooled(at_bats, hits):
r"""
Number of hits in $K$ at bats for each player has a Binomial
distribution with a common probability of success, $\phi$.
:param (torch.Tensor) at_bats: Number of at bats for each player.
:param (torch.Tensor) hits: Number of hits for the given at bats.
:return: Number of hits predicted by the model.
"""
phi_prior = Uniform(scalar_like(at_bats, 0), scalar_like(at_bats, 1))
phi = pyro.sample("phi", phi_prior)
num_players = at_bats.shape[0]
with pyro.plate("num_players", num_players):
return pyro.sample("obs", Binomial(at_bats, phi), obs=hits)
def train_test_split(pd_dataframe):
"""
Training data - 45 initial at-bats and hits for each player.
Validation data - Full season at-bats and hits for each player.
"""
device = torch.Tensor().device
train_data = torch.tensor(pd_dataframe[["At-Bats", "Hits"]].values, dtype=torch.float, device=device)
test_data = torch.tensor(pd_dataframe[["SeasonAt-Bats", "SeasonHits"]].values, dtype=torch.float, device=device)
first_name = pd_dataframe["FirstName"].values
last_name = pd_dataframe["LastName"].values
player_names = [" ".join([first, last]) for first, last in zip(first_name, last_name)]
return train_data, test_data, player_names
def main(args):
setattr(torch.multiprocessing, 'cpu_count', lambda: 1)
baseball_dataset = pd.read_csv(DATA_URL, "\t")
train, _, player_names = train_test_split(baseball_dataset)
at_bats, hits = train[:, 0], train[:, 1]
logging.info("Original Dataset:")
logging.info(baseball_dataset)
# (1) Full Pooling Model
# In this model, we illustrate how to use MCMC with general potential_fn.
init_params, potential_fn, transforms, _ = initialize_model(
fully_pooled, model_args=(at_bats, hits), num_chains=args.num_chains,
jit_compile=args.jit, skip_jit_warnings=True)
nuts_kernel = NUTS(potential_fn=potential_fn)
if args.stream:
mcmc_cls = StreamingMCMC
else:
mcmc_cls = MCMC
mcmc = mcmc_cls(nuts_kernel,
num_samples=args.num_samples,
warmup_steps=args.warmup_steps,
num_chains=args.num_chains,
initial_params=init_params,
transforms=transforms)
mcmc.run(at_bats, hits)
if args.stream:
stats = mcmc.get_statistics(group_by_chain=False)
print(stats)
else:
samples = mcmc.get_samples(group_by_chain=False)
mean = samples['phi'].mean()
var = samples['phi'].var()
print(f'mean: {mean}, var: {var}')
if __name__ == "__main__":
assert pyro.__version__.startswith('1.6.0')
parser = argparse.ArgumentParser(description="Baseball batting average using HMC")
parser.add_argument("-n", "--num-samples", nargs="?", default=20000, type=int)
parser.add_argument("--num-chains", nargs='?', default=4, type=int)
parser.add_argument("--warmup-steps", nargs='?', default=100, type=int)
parser.add_argument("--rng_seed", nargs='?', default=0, type=int)
parser.add_argument("--jit", action="store_true", default=False,
help="use PyTorch jit")
parser.add_argument("--cuda", action="store_true", default=False,
help="run this example in GPU")
parser.add_argument("--stream", action="store_true", default=False)
args = parser.parse_args()
# work around the error "CUDA error: initialization error"
# see https://github.com/pytorch/pytorch/issues/2517
torch.multiprocessing.set_start_method("spawn")
pyro.set_rng_seed(args.rng_seed)
# Enable validation checks
# work around with the error "RuntimeError: received 0 items of ancdata"
# see https://discuss.pytorch.org/t/received-0-items-of-ancdata-pytorch-0-4-0/19823
torch.multiprocessing.set_sharing_strategy("file_system")
if args.cuda:
torch.set_default_tensor_type(torch.cuda.FloatTensor)
main(args)
@mtsokol
Copy link
Author

mtsokol commented Jun 13, 2021

Run with:

$ mprof run baseball_pool.py

and:

$ mprof run baseball_pool.py --stream

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment