Skip to content

Instantly share code, notes, and snippets.

@bricksdont
Last active September 3, 2020 08:46
Show Gist options
  • Save bricksdont/e0304b9cc6ecbcc868368353d6e6262c to your computer and use it in GitHub Desktop.
Save bricksdont/e0304b9cc6ecbcc868368353d6e6262c to your computer and use it in GitHub Desktop.
#! /bin/env python3
import numpy as np
import mxnet as mx
import argparse
import logging
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--context", type=str, required=True, choices=["cpu", "gpu"])
parser.add_argument("--num-trials", type=int, required=True)
parser.add_argument("--seed", type=int, required=True)
parser.add_argument("--dist-size", type=int, required=True)
parser.add_argument("--simulate-intermediate-neglogprobs", action="store_true", required=False, default=False)
parser.add_argument("--malformed", action="store_true", required=False, default=False)
parser.add_argument("--verbose", action="store_true", required=False, default=False)
args = parser.parse_args()
return args
def main():
args = parse_args()
logging.basicConfig(level=logging.DEBUG)
logging.debug(args)
if args.context == "cpu":
context = mx.cpu(0)
else:
context = mx.gpu(0)
mx.random.seed(args.seed, ctx=context)
dist = mx.nd.random.randn(args.dist_size, ctx=context)
# normalize
if not args.malformed:
if args.simulate_intermediate_neglogprobs:
neglog_dist = - mx.nd.log_softmax(dist)
dist = mx.nd.exp(- neglog_dist)
else:
dist = mx.nd.softmax(dist)
impossible_index = args.dist_size
impossible_events = mx.nd.array([0.], ctx=context, dtype=np.int32)
if args.verbose:
print("Shape:")
print(dist.shape)
print("Context:")
print(dist.context)
samples = []
for _ in range(args.num_trials):
sample = mx.random.multinomial(dist, get_prob=False)
samples.append(sample)
impossible_events = impossible_events + mx.nd.sum(sample == impossible_index)
if args.verbose:
print("Samples:")
print(samples)
print("Impossible events: %d out of %d" % (impossible_events.asnumpy(), args.num_trials))
if __name__ == '__main__':
main()
python sample_mxnet.py --seed 1 --context cpu --num-trials 10 --dist-size 20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment