Last active
March 9, 2020 17:06
-
-
Save shreyb/350addec3f7d383729e6430a65400411 to your computer and use it in GitHub Desktop.
jobsub_client arg parser (next version) that can support unique flags.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
import argparse | |
from collections import defaultdict | |
import pytest | |
num_changes_limit = 1 | |
class JobsubClientNamespace(argparse.Namespace): | |
""" | |
Namespace to be used with argparsers that want to use custom actions | |
:param default_values: dict of default values to use | |
""" | |
def __init__(self, default_values=None): | |
self.__count_parse = defaultdict(int) | |
super(JobsubClientNamespace, self).__init__(default_values=default_values) | |
def __setattr__(self, name, value): | |
super(JobsubClientNamespace, self).__setattr__(name, value) | |
self.__count_parse[name] += 1 | |
def count_parse(self, name): | |
return self.__count_parse[name] | |
class UniqueStore(argparse.Action): | |
def __call__(self, parser, namespace, values, option_string=None): | |
print "Trying to set {0} to {1}".format(self.dest, values) | |
if self.__is_changed_multiple_from_default(namespace, self.dest, values): | |
error = "{0} was given multiple times, with different values. "+ \ | |
"Please check your command, json file config, and environment." | |
raise Exception(error.format(self.dest)) | |
else: | |
setattr(namespace, self.dest, values) | |
print "{0} is set to {1}".format(self.dest, getattr(namespace, self.dest)) | |
@staticmethod | |
def __is_changed_multiple_from_default(namespace, dest, values): | |
try: | |
old_value = getattr(namespace, dest) | |
print old_value | |
except AttributeError: | |
old_value = None | |
# If the namespace has a default value already stored, we want to | |
# allow it to be overridden num_changes_limit times. So if we have | |
# a default key/val foo:bar, and we get --foo=baz and | |
# num_changes_limit=1, we want to allow that to go through return | |
# False). However, If we then get --foo=blah, we should not allow it | |
# (return True) | |
try: | |
if dest in namespace.default_values: | |
__num_changes_limit = num_changes_limit + 1 | |
else: | |
__num_changes_limit = num_changes_limit | |
except AttributeError: | |
# Namesapce doesn't have default_values attr, so it's not a JobsubClientNamespace | |
# Stop here and allow caller to proceed | |
return False | |
if old_value is not None and ( | |
old_value != values and | |
namespace.count_parse(dest) >= __num_changes_limit): | |
return True | |
return False | |
class TestClass(object): | |
__pass_tests__ = [ | |
['--foo', 'BAR'], | |
['--foo', 'BAR', '--foo', 'BAR'], | |
['--foo', 'BAR', '--spam', 'BAR', '--foo', 'BAR', '--foo', 'BAR'], | |
['--foo', 'BAZ', '--spam', 'BAZ', '--foo', 'BAZ', '--foo', 'BAZ'], | |
['--foo', 'BAZ', '--spam', 'BAZ', '--foo', 'BAZ', '--foo', 'BAZ', '--spam', 'BAR'], | |
] | |
__fail_tests__ = { | |
1: { | |
"test_args": ['--foo', 'BAR', '--foo', 'BAZ'], | |
"expected_exception": Exception, | |
}, | |
2: { | |
"test_args": ['--foo', 'BAZ', '--foo', 'BAZ', '--foo', 'BAR', '--foo', 'BAZ'], | |
"expected_exception": Exception, | |
}, | |
3: { | |
"test_args": ['--foo', 'BAR', '--foo', 'BAR', '--foo', 'BAR', '--foo', 'BAZ'], | |
"expected_exception": Exception, | |
}, | |
} | |
def setup(self): | |
self.parser = argparse.ArgumentParser() | |
self.parser.add_argument('--foo', action=UniqueStore, default='BAR') | |
self.parser.add_argument('--spam', action='store') | |
self.test_defaults = vars(self.parser.parse_args()) | |
def pass_checker(self, args): | |
print args | |
j = JobsubClientNamespace(default_values=self.test_defaults) | |
print j | |
self.parser.parse_args(args, namespace=j) | |
return j.__dict__ | |
def fail_checker(self, args, exc): | |
print args | |
j = JobsubClientNamespace(default_values=self.test_defaults) | |
print j | |
self.parser.parse_args(namespace=j) | |
with pytest.raises(exc): | |
try: | |
self.parser.parse_args(args, namespace=j) | |
except exc as e: | |
print e | |
raise | |
def test_pass(self): | |
print "These should pass" | |
for test_args in self.__pass_tests__: | |
d = self.pass_checker(test_args) | |
def test_failures(self): | |
print "These should fail" | |
for test_dict in self.__fail_tests__.itervalues(): | |
self.fail_checker(test_dict["test_args"], | |
test_dict["expected_exception"]) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment