Last active
August 29, 2015 14:15
-
-
Save textbook/f0560a4555ba3c6dfeaa to your computer and use it in GitHub Desktop.
Class-based user input validation (comment at http://codereview.stackexchange.com/q/80525/32391)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Functionality for validating user inputs.""" | |
# pylint: disable=too-few-public-methods | |
from __future__ import print_function | |
import sys | |
__all__ = ['InputValidator'] | |
def create_choice_validator(choices): | |
"""Create a validator function based on defined choices. | |
Notes: | |
Attempts to create a set from choices to speed up membership tests | |
with hashable choices. | |
Arguments: | |
choices (collection): The valid choices. | |
Returns: | |
callable: A validation function to apply to user input. | |
""" | |
try: | |
choices = set(choices) | |
except TypeError: | |
pass | |
def validator(ui_): | |
"""Validate user input based on choices.""" | |
if ui_ not in choices: | |
msg = 'Input must be one of {!r}' | |
raise ValueError(msg.format(choices)) | |
return validator | |
def create_empty_validator(allow_empty): | |
"""Validate user input based on presence. | |
Arguments: | |
allow_empty (bool): Whether to allow empty input. | |
Returns: | |
callable: A validation function to apply to user input. | |
""" | |
if not allow_empty: | |
def validator(ui_): | |
"""Reject False-y input.""" | |
if not ui_: | |
raise ValueError('Input must be present.') | |
else: | |
validator = lambda ui_: None | |
return validator | |
def create_len_validator(len_): | |
"""Create a validation function based on input length. | |
Arguments: | |
len_ (int or tuple): Either the acceptable length, or a tuple | |
(min_len, max_len). | |
Returns: | |
callable: A validation function to apply to user input. | |
""" | |
try: | |
min_, max_ = len_ | |
except TypeError: | |
def validator(ui_): | |
"""Validate user input based on length.""" | |
if len(ui_) != len_: | |
msg = 'Input must contain {} elements.' | |
raise ValueError(msg.format(len_)) | |
else: | |
def validator(ui_): | |
"""Validate user input based on length.""" | |
if len(ui_) < min_: | |
msg = 'Input must contain at least {} elements.' | |
raise ValueError(msg.format(min_)) | |
elif len(ui_) > max_: | |
msg = 'Input must contain at most {} elements.' | |
raise ValueError(msg.format(max_)) | |
return validator | |
def create_max_validator(max_): | |
"""Create a validation function based on input size. | |
Arguments: | |
max_: The maximum permitted value. | |
Returns: | |
callable: A validation function to apply to user input. | |
""" | |
def validator(ui_): | |
"""Validate user input based on size.""" | |
if ui_ > max_: | |
msg = 'Input must be at most {}.' | |
raise ValueError(msg.format(max_)) | |
return validator | |
def create_min_validator(min_): | |
"""Create a validation function based on input size. | |
Arguments: | |
min_: The minimum permitted value. | |
Returns: | |
callable: A validation function to apply to user input. | |
""" | |
def validator(ui_): | |
"""Validate user input based on size.""" | |
if ui_ < min_: | |
msg = 'Input must be at least {}.' | |
raise ValueError(msg.format(min_)) | |
return validator | |
class Cached(object): | |
"""Cache classes by positional arguments.""" | |
# pylint: disable=no-member | |
def __new__(cls, *args, **_): | |
if not hasattr(cls, 'cache'): | |
setattr(cls, 'cache', {}) | |
if not args: | |
return super(Cached, cls).__new__(cls) | |
if args not in cls.cache: | |
cls.cache[args] = super(Cached, cls).__new__(cls) | |
return cls.cache[args] | |
class InputValidator(Cached): | |
"""Create validators for user input. | |
Notes: | |
Type is validated first - the argument to all other validation | |
functions is the type-converted input. | |
The following **config options are supported: | |
- choices (collection): The valid choices for the input. | |
- prompt (str): The default prompt to use if not supplied to | |
get_input (defaults to InputValidator.DEFAULT_PROMPT). | |
- allow_empty' (bool): Whether to allow '' (defaults to False). | |
- min_: The minimum value permitted. | |
- max_: The maximum value permitted. | |
- source (callable): The function to use to take user input | |
(defaults to [raw_]input). | |
- type_ (callable): The type to attempt to convert the input to | |
(defaults to str). | |
Arguments: | |
name (str, optional): The name to store the validator under. | |
Defaults to None (i.e. not stored). | |
**config (dict): The configuration options for the validator. | |
Attributes: | |
DEFAULT_PROMPT (str): The default prompt to use if not supplied | |
in config or the call to get_input. | |
VALIDATORS (list): The validation functions. | |
""" | |
DEFAULT_PROMPT = '> ' | |
VALIDATORS = [ | |
(('choices',), create_choice_validator), | |
(('allow_empty', False), create_empty_validator), | |
(('len_',), create_len_validator), | |
(('min_',), create_min_validator), | |
(('max_',), create_max_validator), | |
] | |
def __new__(cls, name=None, **config): | |
if name is None: | |
self = super(InputValidator, cls).__new__(cls) | |
else: | |
self = super(InputValidator, cls).__new__(cls, name) | |
if hasattr(self, 'config') and self.config != config: | |
raise TypeError('Configuration conflict') | |
return self | |
def __init__(self, name=None, **config): | |
# Basic arguments | |
self.config = config | |
self.name = name | |
# Select appropriate source for user input | |
source = config.get('source') | |
if source is None: | |
if sys.version_info.major < 3: | |
source = raw_input # pylint: disable=undefined-variable | |
else: | |
source = input | |
self.source = source | |
# Default configuration | |
self.empty = config.get('empty', False) | |
self.prompt = config.get('prompt', self.DEFAULT_PROMPT) | |
self.type_ = config.get('type_', str) | |
# Validation functions | |
self.validators = [] | |
for get_args, creator in self.VALIDATORS: | |
item = config.get(*get_args) # pylint: disable=star-args | |
if item is not None: | |
self.validators.append(creator(item)) | |
def get_input(self, prompt=None): | |
"""Get validated input. | |
Arguments: | |
prompt (str, optional): The prompt to use. Defaults to the | |
instance's prompt attribute. | |
""" | |
if prompt is None: | |
prompt = self.prompt | |
while True: | |
ui_ = self.source(prompt) | |
# Basic type validation | |
try: | |
ui_ = self.type_(ui_) | |
except ValueError as err: | |
msg = 'Input must be {!r}.' | |
print(msg.format(self.type_)) | |
continue | |
# Any other validation required | |
for validate in self.validators: | |
try: | |
validate(ui_) | |
except ValueError as err: | |
print(err) | |
break | |
else: | |
return ui_ | |
def __call__(self, *args, **kwargs): | |
"""Allow direct call, invoking get_input.""" | |
return self.get_input(*args, **kwargs) | |
if __name__ == '__main__': | |
# Built-in testing | |
from ast import literal_eval | |
class SuppressStdOut(object): | |
"""Suppress the standard output for testing.""" | |
def flush(self, *_, **__): | |
"""Don't flush anything.""" | |
pass | |
def write(self, *_, **__): | |
"""Don't write anything.""" | |
pass | |
sys.stdout = SuppressStdOut() | |
def input_test(_): | |
"""Return whatever is first in args.""" | |
return input_test.args.pop(0) | |
# 1. Caching | |
# Ensure caching isn't activated without name argument | |
assert InputValidator() is not InputValidator() | |
# Ensure caching is activated with positional name... | |
assert InputValidator('name') is InputValidator('name') | |
# ...and keyword name... | |
assert InputValidator('name') is InputValidator(name='name') | |
# ...and handles configuration conflicts | |
try: | |
_ = InputValidator('name', option='other') | |
except TypeError: | |
pass | |
else: | |
assert False, 'TypeError not thrown for configuration conflict' | |
# 2. Calling | |
input_test.args = ['test', 'test'] | |
# Test both call forms return correct value | |
VALIDATOR = InputValidator(source=input_test) | |
assert VALIDATOR.get_input() == VALIDATOR() == 'test' | |
# 3. Numerical validation | |
input_test.args = ['-1', '11', 'foo', '5'] | |
VALIDATOR = InputValidator(source=input_test, type_=int, min_=0, max_=10) | |
assert VALIDATOR() == 5 | |
# 4. Empty string validation | |
# Test empty not allowed... | |
input_test.args = ['', 'test', ''] | |
VALIDATOR = InputValidator(source=input_test) | |
assert VALIDATOR() == 'test' | |
# ...and allowed | |
input_test.args = [''] | |
VALIDATOR = InputValidator(source=input_test, allow_empty=True) | |
assert VALIDATOR() == '' | |
# 5. Choice validation | |
input_test.args = ['foo', 'bar'] | |
VALIDATOR = InputValidator(source=input_test, choices=['bar']) | |
assert VALIDATOR() == 'bar' | |
# 6. Length validation | |
# Test exact length... | |
CORRECT_LEN = 10 | |
input_test.args = [ | |
'a' * (CORRECT_LEN + 1), | |
'a' * (CORRECT_LEN - 1), | |
'a' * CORRECT_LEN | |
] | |
VALIDATOR = InputValidator(source=input_test, len_=CORRECT_LEN) | |
assert VALIDATOR() == 'a' * CORRECT_LEN | |
# ...and length range... | |
MIN_LEN = 5 | |
MAX_LEN = 10 | |
input_test.args = [ | |
'a' * (MIN_LEN - 1), | |
'a' * (MAX_LEN + 1), | |
'a' * MAX_LEN | |
] | |
VALIDATOR = InputValidator(source=input_test, len_=(MIN_LEN, MAX_LEN)) | |
assert VALIDATOR() == 'a' * MAX_LEN | |
# ...and errors | |
LEN = 'foo' | |
try: | |
_ = InputValidator(len_=LEN) | |
except ValueError: | |
pass | |
else: | |
assert False, 'ValueError not thrown for {!r}.'.format(LEN) | |
# 7. Something completely different | |
OUTPUT = ['foo', 'bar', 'baz'] | |
input_test.args = ['[]', '["foo"]', repr(OUTPUT)] | |
VALIDATOR = InputValidator(source=input_test, len_=3, type_=literal_eval) | |
assert VALIDATOR() == OUTPUT |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment