Skip to content

Instantly share code, notes, and snippets.

@shreyb
Last active March 9, 2020 17:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shreyb/350addec3f7d383729e6430a65400411 to your computer and use it in GitHub Desktop.
Save shreyb/350addec3f7d383729e6430a65400411 to your computer and use it in GitHub Desktop.
jobsub_client arg parser (next version) that can support unique flags.
#!/usr/bin/python
import argparse
from collections import defaultdict
import pytest
num_changes_limit = 1
class JobsubClientNamespace(argparse.Namespace):
"""
Namespace to be used with argparsers that want to use custom actions
:param default_values: dict of default values to use
"""
def __init__(self, default_values=None):
self.__count_parse = defaultdict(int)
super(JobsubClientNamespace, self).__init__(default_values=default_values)
def __setattr__(self, name, value):
super(JobsubClientNamespace, self).__setattr__(name, value)
self.__count_parse[name] += 1
def count_parse(self, name):
return self.__count_parse[name]
class UniqueStore(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
print "Trying to set {0} to {1}".format(self.dest, values)
if self.__is_changed_multiple_from_default(namespace, self.dest, values):
error = "{0} was given multiple times, with different values. "+ \
"Please check your command, json file config, and environment."
raise Exception(error.format(self.dest))
else:
setattr(namespace, self.dest, values)
print "{0} is set to {1}".format(self.dest, getattr(namespace, self.dest))
@staticmethod
def __is_changed_multiple_from_default(namespace, dest, values):
try:
old_value = getattr(namespace, dest)
print old_value
except AttributeError:
old_value = None
# If the namespace has a default value already stored, we want to
# allow it to be overridden num_changes_limit times. So if we have
# a default key/val foo:bar, and we get --foo=baz and
# num_changes_limit=1, we want to allow that to go through return
# False). However, If we then get --foo=blah, we should not allow it
# (return True)
try:
if dest in namespace.default_values:
__num_changes_limit = num_changes_limit + 1
else:
__num_changes_limit = num_changes_limit
except AttributeError:
# Namesapce doesn't have default_values attr, so it's not a JobsubClientNamespace
# Stop here and allow caller to proceed
return False
if old_value is not None and (
old_value != values and
namespace.count_parse(dest) >= __num_changes_limit):
return True
return False
class TestClass(object):
__pass_tests__ = [
['--foo', 'BAR'],
['--foo', 'BAR', '--foo', 'BAR'],
['--foo', 'BAR', '--spam', 'BAR', '--foo', 'BAR', '--foo', 'BAR'],
['--foo', 'BAZ', '--spam', 'BAZ', '--foo', 'BAZ', '--foo', 'BAZ'],
['--foo', 'BAZ', '--spam', 'BAZ', '--foo', 'BAZ', '--foo', 'BAZ', '--spam', 'BAR'],
]
__fail_tests__ = {
1: {
"test_args": ['--foo', 'BAR', '--foo', 'BAZ'],
"expected_exception": Exception,
},
2: {
"test_args": ['--foo', 'BAZ', '--foo', 'BAZ', '--foo', 'BAR', '--foo', 'BAZ'],
"expected_exception": Exception,
},
3: {
"test_args": ['--foo', 'BAR', '--foo', 'BAR', '--foo', 'BAR', '--foo', 'BAZ'],
"expected_exception": Exception,
},
}
def setup(self):
self.parser = argparse.ArgumentParser()
self.parser.add_argument('--foo', action=UniqueStore, default='BAR')
self.parser.add_argument('--spam', action='store')
self.test_defaults = vars(self.parser.parse_args())
def pass_checker(self, args):
print args
j = JobsubClientNamespace(default_values=self.test_defaults)
print j
self.parser.parse_args(args, namespace=j)
return j.__dict__
def fail_checker(self, args, exc):
print args
j = JobsubClientNamespace(default_values=self.test_defaults)
print j
self.parser.parse_args(namespace=j)
with pytest.raises(exc):
try:
self.parser.parse_args(args, namespace=j)
except exc as e:
print e
raise
def test_pass(self):
print "These should pass"
for test_args in self.__pass_tests__:
d = self.pass_checker(test_args)
def test_failures(self):
print "These should fail"
for test_dict in self.__fail_tests__.itervalues():
self.fail_checker(test_dict["test_args"],
test_dict["expected_exception"])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment