-
-
Save mtsokol/cc10c0d57ac0050d4cf8ea6774acbde5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-2019 Uber Technologies, Inc. | |
# SPDX-License-Identifier: Apache-2.0 | |
import argparse | |
import logging | |
import pandas as pd | |
import torch | |
import pyro | |
from pyro.distributions import Binomial, HalfCauchy, Pareto, Uniform | |
from pyro.distributions.util import scalar_like | |
from pyro.infer import MCMC, NUTS | |
from pyro.infer.mcmc import StreamingMCMC | |
from pyro.infer.mcmc.util import initialize_model | |
logging.basicConfig(format='%(message)s', level=logging.INFO) | |
DATA_URL = "https://d2hg8soec8ck9v.cloudfront.net/datasets/EfronMorrisBB.txt" | |
# =================================== | |
# MODELS | |
# =================================== | |
def fully_pooled(at_bats, hits): | |
r""" | |
Number of hits in $K$ at bats for each player has a Binomial | |
distribution with a common probability of success, $\phi$. | |
:param (torch.Tensor) at_bats: Number of at bats for each player. | |
:param (torch.Tensor) hits: Number of hits for the given at bats. | |
:return: Number of hits predicted by the model. | |
""" | |
phi_prior = Uniform(scalar_like(at_bats, 0), scalar_like(at_bats, 1)) | |
phi = pyro.sample("phi", phi_prior) | |
num_players = at_bats.shape[0] | |
with pyro.plate("num_players", num_players): | |
return pyro.sample("obs", Binomial(at_bats, phi), obs=hits) | |
def train_test_split(pd_dataframe): | |
""" | |
Training data - 45 initial at-bats and hits for each player. | |
Validation data - Full season at-bats and hits for each player. | |
""" | |
device = torch.Tensor().device | |
train_data = torch.tensor(pd_dataframe[["At-Bats", "Hits"]].values, dtype=torch.float, device=device) | |
test_data = torch.tensor(pd_dataframe[["SeasonAt-Bats", "SeasonHits"]].values, dtype=torch.float, device=device) | |
first_name = pd_dataframe["FirstName"].values | |
last_name = pd_dataframe["LastName"].values | |
player_names = [" ".join([first, last]) for first, last in zip(first_name, last_name)] | |
return train_data, test_data, player_names | |
def main(args): | |
setattr(torch.multiprocessing, 'cpu_count', lambda: 1) | |
baseball_dataset = pd.read_csv(DATA_URL, "\t") | |
train, _, player_names = train_test_split(baseball_dataset) | |
at_bats, hits = train[:, 0], train[:, 1] | |
logging.info("Original Dataset:") | |
logging.info(baseball_dataset) | |
# (1) Full Pooling Model | |
# In this model, we illustrate how to use MCMC with general potential_fn. | |
init_params, potential_fn, transforms, _ = initialize_model( | |
fully_pooled, model_args=(at_bats, hits), num_chains=args.num_chains, | |
jit_compile=args.jit, skip_jit_warnings=True) | |
nuts_kernel = NUTS(potential_fn=potential_fn) | |
if args.stream: | |
mcmc_cls = StreamingMCMC | |
else: | |
mcmc_cls = MCMC | |
mcmc = mcmc_cls(nuts_kernel, | |
num_samples=args.num_samples, | |
warmup_steps=args.warmup_steps, | |
num_chains=args.num_chains, | |
initial_params=init_params, | |
transforms=transforms) | |
mcmc.run(at_bats, hits) | |
if args.stream: | |
stats = mcmc.get_statistics(group_by_chain=False) | |
print(stats) | |
else: | |
samples = mcmc.get_samples(group_by_chain=False) | |
mean = samples['phi'].mean() | |
var = samples['phi'].var() | |
print(f'mean: {mean}, var: {var}') | |
if __name__ == "__main__": | |
assert pyro.__version__.startswith('1.6.0') | |
parser = argparse.ArgumentParser(description="Baseball batting average using HMC") | |
parser.add_argument("-n", "--num-samples", nargs="?", default=20000, type=int) | |
parser.add_argument("--num-chains", nargs='?', default=4, type=int) | |
parser.add_argument("--warmup-steps", nargs='?', default=100, type=int) | |
parser.add_argument("--rng_seed", nargs='?', default=0, type=int) | |
parser.add_argument("--jit", action="store_true", default=False, | |
help="use PyTorch jit") | |
parser.add_argument("--cuda", action="store_true", default=False, | |
help="run this example in GPU") | |
parser.add_argument("--stream", action="store_true", default=False) | |
args = parser.parse_args() | |
# work around the error "CUDA error: initialization error" | |
# see https://github.com/pytorch/pytorch/issues/2517 | |
torch.multiprocessing.set_start_method("spawn") | |
pyro.set_rng_seed(args.rng_seed) | |
# Enable validation checks | |
# work around with the error "RuntimeError: received 0 items of ancdata" | |
# see https://discuss.pytorch.org/t/received-0-items-of-ancdata-pytorch-0-4-0/19823 | |
torch.multiprocessing.set_sharing_strategy("file_system") | |
if args.cuda: | |
torch.set_default_tensor_type(torch.cuda.FloatTensor) | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Run with:
and: