Skip to content

Instantly share code, notes, and snippets.

@reynoldscem
Created December 5, 2018 19:10
Show Gist options
  • Save reynoldscem/f03af8fe44bbcea5030ecb98bc0d9b0f to your computer and use it in GitHub Desktop.
Save reynoldscem/f03af8fe44bbcea5030ecb98bc0d9b0f to your computer and use it in GitHub Desktop.
Specify and check for mutually exclusive arguments in argparse
from argparse import ArgumentParser, Action
from enum import Enum, auto
class ArgumentsEnum(Enum):
opt1 = auto()
opt2 = auto()
def __str__(self):
return "'{}'".format(self.name)
def RecordUserSetAction_(arg_enum):
class RecordUserSetAction(Action):
def __call__(self, parser, namespace, values, option_string=None):
parser.args_user_set.add(arg_enum)
setattr(namespace, self.dest, values)
return RecordUserSetAction
class ArgumentParserWithConflicts(ArgumentParser):
def __init__(self, *args, **kwargs):
self.conflicts = []
self.args_user_set = set()
super(ArgumentParserWithConflicts, self).__init__(*args, **kwargs)
def set_conflict(self, *args):
self.conflicts.append(set(args))
def check_conflicts(self):
for conflict in self.conflicts:
if conflict.issubset(self.args_user_set):
raise ValueError(
'These arguments are mutually exclusive: [{}]'
''.format(', '.join(sorted(map(str, conflict))))
)
def parse_args(self, args=None, namespace=None):
args_ = super(ArgumentParserWithConflicts, self).parse_args(
args=args, namespace=namespace
)
self.check_conflicts()
return args_
def build_parser():
parser = ArgumentParserWithConflicts()
parser.add_argument(
'--optional1', type=str, default='blah', nargs=1,
action=RecordUserSetAction_(ArgumentsEnum.opt1)
)
parser.add_argument(
'--optional2', default=(3, 3), nargs=2,
action=RecordUserSetAction_(ArgumentsEnum.opt2)
)
# All arguments here should be mutually exclusive.
parser.set_conflict(ArgumentsEnum.opt1, ArgumentsEnum.opt2)
return parser
def main():
parser = build_parser()
args = parser.parse_args()
print('Survived, args: {}'.format(args))
print('None-defaults: {}'.format(parser.args_user_set))
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment