Skip to content

Instantly share code, notes, and snippets.

@aswild
Last active September 26, 2019 15:01
Show Gist options
  • Save aswild/c006956299552298b70a7c964a8bfda1 to your computer and use it in GitHub Desktop.
Save aswild/c006956299552298b70a7c964a8bfda1 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import argparse
import inspect
from io import StringIO
from itertools import product
import random
import re
import sys
class DiceSet:
@staticmethod
def parse_dice_str(dice_str):
if not dice_str:
raise ValueError('dice is empty')
dice_strs = (x.strip() for x in dice_str.split('+'))
dice = {}
constant = 0
for s in dice_strs:
# looking for a positive constant is easy, just see if we can parse as an int
try:
constant += int(s)
continue
except ValueError:
pass
# negaitve constants are trickier since we only split on '+', look for minus
# something at the end of this, and if found add it to the constant and
# then remove it from the string
m = re.match(r'^.*(?P<minus>-)\s*(?P<constant>\d+)$', s)
if m:
constant -= int(m.group('constant'))
s = s[:m.start('minus')].strip()
m = re.match(r'^(?P<count>\d*)(?:d(?P<die>\d+))?$', s)
if not m:
raise ValueError('Invalid roll "%s"'%s)
count = m.group('count')
die = m.group('die')
if not die:
# no die mentioned, add to the constant
if not count:
raise ValueError('Invalid roll "%s"'%s)
constant += int(count)
continue
die = int(die)
if count:
count = int(count)
else:
count = 1
dice[die] = dice.get(die, 0) + count
return dice, constant
@staticmethod
def get_roll_total(roll):
return sum(map(lambda x: x[1], roll))
def __init__(self, dice, constant=0):
self._constant = constant
if isinstance(dice, str):
self._dice, _constant = self.parse_dice_str(dice)
self._constant += _constant
elif isinstance(dice, dict):
self._dice = {}
for k, v in dice.items():
self._dice[int(k)] = int(v)
# check if we can use /dev/urandom
self._urandom_fp = None
try:
self._urandom_fp = open('/dev/urandom', 'rb')
except Exception:
pass
def __del__(self):
if self._urandom_fp is not None:
try:
self._urandom_fp.close()
except Exception:
pass
finally:
self._urandom_fp = None
def __str__(self):
buf = StringIO()
first = True
for die in reversed(sorted(self._dice)):
if first:
first = False
else:
buf.write('+')
buf.write('%dd%d'%(self._dice[die], die))
if self._constant:
buf.write('%+d'%self._constant)
return buf.getvalue()
def __repr__(self):
return '<DiceSet: %s>'%self.__str__()
@property
def dice(self):
return self._dice
@property
def constant(self):
return self._constant
def _roll_one(self, n):
""" Roll one die of value n. Use urandom if available, fallback to random.randint """
try:
# a d4294967295 should be a reasonable upper limit, right?
b = self._urandom_fp.read(4)
i = int.from_bytes(b, byteorder=sys.byteorder, signed=False)
return (i % n) + 1
except Exception:
return random.randint(1, n)
def roll(self):
""" Roll the dice and return a list of roll results, where each result
is a 2-tuple of the form (die, result). If a constant value is added,
it's returned as the 2-tuple (0, constant). For example, a 2d6+1d4+1
roll may return [(6, 3), (6, 6), (4, 2), (0, 1)]. """
result = []
for die in reversed(sorted(self._dice)):
for _ in range(self._dice[die]):
result.append((die, self._roll_one(die)))
if self._constant:
result.append((0, self._constant))
return result
def roll_total(self):
""" Roll the dice and return the total value. Convenience wrapper for
roll = dice_set.roll()
get_roll_total(roll)
"""
return self.get_roll_total(self.roll())
def histogram(self):
""" Return a histogram representing the probability of rolling each
possible value. Data is returned as a dict of value:count pairs, with
value representing the roll total (including constant) and count
representing the number of rolls which generate that value out of the
total possible number of rolls. For example, the histogram of 2d6+1
returns {3: 1, 4: 2, 5: 3, 6: 4, 7: 5, 8: 6, 9: 5, 10: 4, 11: 3, 12: 2, 13: 1}.
To get the probability of a particular value, divide its count by the total
number of values (i.e. sum(hist.values()))
On Python 3.7+, where dict insertion order is guaranteed to be preserved, the
keys of the histogram dict will be in ascending order. On older versions, the
order may be arbitrary.
"""
ranges = []
for die, count in self._dice.items():
for _ in range(count):
ranges.append(range(1, die+1))
hist = {}
for value in sorted(map(sum, product(*ranges))):
value = value + self._constant
hist[value] = hist.get(value, 0) + 1
return hist
def roll_prob(self, low, high, normalize=False):
""" Return the number of rolls in this DiceSet's histogram with a value
between [low, high], inclusive. If low is 0, then consider all rolls
<=high. If high is 0, consider all rolls >=low. If normalize is True,
return the floating point probability in range [0.0, 1.0] rather that
the number of rolls which would produce that value. """
if not (low or high):
raise ValueError('At least low or high must be specified')
if low and high and low > high:
raise ValueError('low cannot be greater than high')
if low and high:
check = lambda x: x >= low and x <= high
elif low:
check = lambda x: x >= low
else:
check = lambda x: x <= high
hist = self.histogram()
total = 0
for roll, count in hist.items():
if check(roll):
total += count
return (total / sum(hist.values())) if normalize else total
def roll_prob_str(self, spec, normalize=False):
""" Parse spec in one of the forms 'a-b', 'a+', or 'a-' and pass it
to self.roll_prob. """
try:
# if spec is just one number, then it's easy, and short-circuit the regex stuff
x = int(spec)
except ValueError:
pass
else:
return self.roll_prob(x, x, normalize)
m = re.match(r'(\d+)([+-])(\d+)?$', spec)
if m is None:
raise ValueError(f'Invalid spec: "{spec}"')
a = int(m.group(1))
plus = m.group(2) == '+'
b = int(m.group(3)) if m.group(3) else None
if plus and b is not None:
raise ValueError(f'Invalid spec: "{spec}"')
if plus:
return self.roll_prob(a, 0, normalize)
if b is None:
return self.roll_prob(0, a, normalize)
return self.roll_prob(a, b, normalize)
@staticmethod
def _ndigits(n):
""" Helper function for print_histogram.
Return the number of digits needed to display an integer. For positive
values, this is mathematically equivalent to floor(log10(n))+1, but
this definition is unreliable due to floating-point inexactness, and
special handling is needed for zero and negative numbers anyway. """
assert isinstance(n, int)
digits = 0
negative = False
if n < 0:
negative = True
n = abs(n)
while ((10 ** digits) - 1) < n:
digits += 1
return digits + 1 if negative else digits
def print_histogram(self):
hist = self.histogram()
# check width of max and min roll values, in case of negatives
value_width = max(self._ndigits(max(hist.keys())), self._ndigits(min(hist.keys())))
value_width = max(value_width, len('Roll'))
# value counts are always positive
count_width = self._ndigits(max(hist.values()))
count_width = max(count_width, len('Count'))
print('%*s %*s Probability'%(value_width, 'Roll', count_width, 'Count'))
total_count = sum(hist.values())
for value in sorted(hist):
count = hist[value]
prob = (count / total_count) * 100.0
print('%*d %*d %4.1f%% %s'%(value_width, value, count_width, count, prob, '#'*count))
class DiceRollCLI:
class CmdBase:
# subclasses should set this to the command name and help text description
NAME = ''
HELP = ''
@classmethod
def populate_parser(cls, parser):
""" Take the given argparse.ArgumentParser object and add arguments it """
raise NotImplementedError('unimplemented abstract method in %s'%cls)
@classmethod
def run(cls, args):
""" Run the command with the given args Namespace from argparse """
raise NotImplementedError('unimplemented abstract method in %s'%cls)
class CmdRollBase(CmdBase):
# common code for roll and qroll
NAME = ''
HELP = ''
@classmethod
def populate_parser(cls, parser):
parser.add_argument('-c', '--count', type=int, default=1,
help='Number of times to roll this dice set')
if not cls._always_quiet:
parser.add_argument('-q', '--quiet', action='store_true',
help='Quiet mode, print only the total, not the result of each die.')
parser.add_argument('dice', help='Set of dice to roll, e.g. "d20", "2d6", or "2d4+1"')
@classmethod
def run(cls, args):
if args.count < 1:
sys.exit('Error: roll count must be positive')
dice = DiceSet(args.dice)
for i in range(args.count):
roll = dice.roll()
total = dice.get_roll_total(roll)
if cls._always_quiet or args.quiet:
print(total)
else:
print('%s: %s'%(total, roll))
class CmdRoll(CmdRollBase):
NAME = 'roll'
HELP = 'Roll some dice (the default)'
_always_quiet = False
class CmdRollQuiet(CmdRollBase):
NAME = 'qroll'
HELP = 'Roll some dice, only print the result (shortcut for "roll -q")'
_always_quiet = True
class CmdHistogram(CmdBase):
NAME = 'histogram'
HELP = 'Display a roll probability histogram'
@classmethod
def populate_parser(cls, parser):
parser.add_argument('dice', help='Set of dice to roll, e.g. "d20", "2d6", or "2d4+1"')
@classmethod
def run(cls, args):
dice = DiceSet(args.dice)
dice.print_histogram()
class CmdProb(CmdBase):
NAME = 'prob'
HELP = 'Get the probability of a give roll.'
@classmethod
def populate_parser(cls, parser):
parser.add_argument('dice', help='Set of dice to roll, e.g. "d20", "2d6", or "2d4+1"')
parser.add_argument('value', help='Target roll value, can be of the form "2" (exactly 2), '+
'"6-" (6 or less), "7-9" (inclusive range), or "10+" (at least 10)')
@classmethod
def run(cls, args):
dice = DiceSet(args.dice)
prob = dice.roll_prob_str(args.value, normalize=True)
print('%.1f%%'%(prob * 100.0))
@classmethod
def get_cmds(cls):
# cache the command list
try:
return cls._cmds
except AttributeError:
pass
cmds = {}
for _, cmdclass in inspect.getmembers(cls, inspect.isclass):
# filter on subclasses of CmdBase and a non-empty name (to distinguish from the base class)
if issubclass(cmdclass, cls.CmdBase) and cmdclass.NAME:
cmds[cmdclass.NAME] = cmdclass
cls._cmds = cmds
return cls._cmds
@staticmethod
def cmd_search(cmd_names, cmd):
""" Search for cmd in cmd_names, allowing for partial matches.
If an exact match is found, or only one partial match, return it.
If multiple partial matches found, return a list of them.
If no matches found, return None. """
if cmd in cmd_names:
return cmd
partial_matches = []
for c in cmd_names:
if c.startswith(cmd):
partial_matches.append(c)
if len(partial_matches) == 1:
return partial_matches[0]
if len(partial_matches) > 1:
return partial_matches
return None
@classmethod
def main(cls, argv=None):
if argv is None:
argv = sys.argv[1:]
# Hack around argparse so that roll can be the default command. Look
# at the first argument and see if it can be parsed as a DiceSet
if argv:
try:
DiceSet.parse_dice_str(argv[0])
return cls.main_roll(argv)
except ValueError:
pass
cmds = cls.get_cmds()
parser = argparse.ArgumentParser(description='Roll some dice!',
epilog='Valid commands are: %s'%', '.join(cmds.keys()))
parser.add_argument('command', help='What to do.')
parser.add_argument('command_args', nargs=argparse.REMAINDER,
help='Arguments for the command. Use "%s COMMAND -h" for command help'%parser.prog)
args = parser.parse_args()
cmd_name = cls.cmd_search(cmds.keys(), args.command)
if cmd_name is None:
return 'Error: unknown command: ' + args.command
if isinstance(cmd_name, list):
return 'Error: multiple possible commands, be more specific: ' + ', '.join(cmd_name)
cmdclass = cmds[cmd_name]
cmdparser = argparse.ArgumentParser(prog='%s %s'%(parser.prog, cmdclass.NAME), description=cmdclass.HELP)
cmdclass.populate_parser(cmdparser)
cmdargs = cmdparser.parse_args(args.command_args)
try:
cmds[cmd_name].run(cmdargs)
except Exception as e:
return f'Error: {e}'
@classmethod
def main_roll(cls, argv):
cmd = cls.get_cmds()['roll']
parser = argparse.ArgumentParser(prog=sys.argv[0]+' roll', description='Roll some dice!')
cmd.populate_parser(parser)
args = parser.parse_args(argv)
cmd.run(args)
if __name__ == '__main__':
sys.exit(DiceRollCLI.main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment