Skip to content

Instantly share code, notes, and snippets.

@textbook
Last active August 29, 2015 14:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save textbook/f0560a4555ba3c6dfeaa to your computer and use it in GitHub Desktop.
Save textbook/f0560a4555ba3c6dfeaa to your computer and use it in GitHub Desktop.
Class-based user input validation (comment at http://codereview.stackexchange.com/q/80525/32391)
"""Functionality for validating user inputs."""
# pylint: disable=too-few-public-methods
from __future__ import print_function
import sys
__all__ = ['InputValidator']
def create_choice_validator(choices):
"""Create a validator function based on defined choices.
Notes:
Attempts to create a set from choices to speed up membership tests
with hashable choices.
Arguments:
choices (collection): The valid choices.
Returns:
callable: A validation function to apply to user input.
"""
try:
choices = set(choices)
except TypeError:
pass
def validator(ui_):
"""Validate user input based on choices."""
if ui_ not in choices:
msg = 'Input must be one of {!r}'
raise ValueError(msg.format(choices))
return validator
def create_empty_validator(allow_empty):
"""Validate user input based on presence.
Arguments:
allow_empty (bool): Whether to allow empty input.
Returns:
callable: A validation function to apply to user input.
"""
if not allow_empty:
def validator(ui_):
"""Reject False-y input."""
if not ui_:
raise ValueError('Input must be present.')
else:
validator = lambda ui_: None
return validator
def create_len_validator(len_):
"""Create a validation function based on input length.
Arguments:
len_ (int or tuple): Either the acceptable length, or a tuple
(min_len, max_len).
Returns:
callable: A validation function to apply to user input.
"""
try:
min_, max_ = len_
except TypeError:
def validator(ui_):
"""Validate user input based on length."""
if len(ui_) != len_:
msg = 'Input must contain {} elements.'
raise ValueError(msg.format(len_))
else:
def validator(ui_):
"""Validate user input based on length."""
if len(ui_) < min_:
msg = 'Input must contain at least {} elements.'
raise ValueError(msg.format(min_))
elif len(ui_) > max_:
msg = 'Input must contain at most {} elements.'
raise ValueError(msg.format(max_))
return validator
def create_max_validator(max_):
"""Create a validation function based on input size.
Arguments:
max_: The maximum permitted value.
Returns:
callable: A validation function to apply to user input.
"""
def validator(ui_):
"""Validate user input based on size."""
if ui_ > max_:
msg = 'Input must be at most {}.'
raise ValueError(msg.format(max_))
return validator
def create_min_validator(min_):
"""Create a validation function based on input size.
Arguments:
min_: The minimum permitted value.
Returns:
callable: A validation function to apply to user input.
"""
def validator(ui_):
"""Validate user input based on size."""
if ui_ < min_:
msg = 'Input must be at least {}.'
raise ValueError(msg.format(min_))
return validator
class Cached(object):
"""Cache classes by positional arguments."""
# pylint: disable=no-member
def __new__(cls, *args, **_):
if not hasattr(cls, 'cache'):
setattr(cls, 'cache', {})
if not args:
return super(Cached, cls).__new__(cls)
if args not in cls.cache:
cls.cache[args] = super(Cached, cls).__new__(cls)
return cls.cache[args]
class InputValidator(Cached):
"""Create validators for user input.
Notes:
Type is validated first - the argument to all other validation
functions is the type-converted input.
The following **config options are supported:
- choices (collection): The valid choices for the input.
- prompt (str): The default prompt to use if not supplied to
get_input (defaults to InputValidator.DEFAULT_PROMPT).
- allow_empty' (bool): Whether to allow '' (defaults to False).
- min_: The minimum value permitted.
- max_: The maximum value permitted.
- source (callable): The function to use to take user input
(defaults to [raw_]input).
- type_ (callable): The type to attempt to convert the input to
(defaults to str).
Arguments:
name (str, optional): The name to store the validator under.
Defaults to None (i.e. not stored).
**config (dict): The configuration options for the validator.
Attributes:
DEFAULT_PROMPT (str): The default prompt to use if not supplied
in config or the call to get_input.
VALIDATORS (list): The validation functions.
"""
DEFAULT_PROMPT = '> '
VALIDATORS = [
(('choices',), create_choice_validator),
(('allow_empty', False), create_empty_validator),
(('len_',), create_len_validator),
(('min_',), create_min_validator),
(('max_',), create_max_validator),
]
def __new__(cls, name=None, **config):
if name is None:
self = super(InputValidator, cls).__new__(cls)
else:
self = super(InputValidator, cls).__new__(cls, name)
if hasattr(self, 'config') and self.config != config:
raise TypeError('Configuration conflict')
return self
def __init__(self, name=None, **config):
# Basic arguments
self.config = config
self.name = name
# Select appropriate source for user input
source = config.get('source')
if source is None:
if sys.version_info.major < 3:
source = raw_input # pylint: disable=undefined-variable
else:
source = input
self.source = source
# Default configuration
self.empty = config.get('empty', False)
self.prompt = config.get('prompt', self.DEFAULT_PROMPT)
self.type_ = config.get('type_', str)
# Validation functions
self.validators = []
for get_args, creator in self.VALIDATORS:
item = config.get(*get_args) # pylint: disable=star-args
if item is not None:
self.validators.append(creator(item))
def get_input(self, prompt=None):
"""Get validated input.
Arguments:
prompt (str, optional): The prompt to use. Defaults to the
instance's prompt attribute.
"""
if prompt is None:
prompt = self.prompt
while True:
ui_ = self.source(prompt)
# Basic type validation
try:
ui_ = self.type_(ui_)
except ValueError as err:
msg = 'Input must be {!r}.'
print(msg.format(self.type_))
continue
# Any other validation required
for validate in self.validators:
try:
validate(ui_)
except ValueError as err:
print(err)
break
else:
return ui_
def __call__(self, *args, **kwargs):
"""Allow direct call, invoking get_input."""
return self.get_input(*args, **kwargs)
if __name__ == '__main__':
# Built-in testing
from ast import literal_eval
class SuppressStdOut(object):
"""Suppress the standard output for testing."""
def flush(self, *_, **__):
"""Don't flush anything."""
pass
def write(self, *_, **__):
"""Don't write anything."""
pass
sys.stdout = SuppressStdOut()
def input_test(_):
"""Return whatever is first in args."""
return input_test.args.pop(0)
# 1. Caching
# Ensure caching isn't activated without name argument
assert InputValidator() is not InputValidator()
# Ensure caching is activated with positional name...
assert InputValidator('name') is InputValidator('name')
# ...and keyword name...
assert InputValidator('name') is InputValidator(name='name')
# ...and handles configuration conflicts
try:
_ = InputValidator('name', option='other')
except TypeError:
pass
else:
assert False, 'TypeError not thrown for configuration conflict'
# 2. Calling
input_test.args = ['test', 'test']
# Test both call forms return correct value
VALIDATOR = InputValidator(source=input_test)
assert VALIDATOR.get_input() == VALIDATOR() == 'test'
# 3. Numerical validation
input_test.args = ['-1', '11', 'foo', '5']
VALIDATOR = InputValidator(source=input_test, type_=int, min_=0, max_=10)
assert VALIDATOR() == 5
# 4. Empty string validation
# Test empty not allowed...
input_test.args = ['', 'test', '']
VALIDATOR = InputValidator(source=input_test)
assert VALIDATOR() == 'test'
# ...and allowed
input_test.args = ['']
VALIDATOR = InputValidator(source=input_test, allow_empty=True)
assert VALIDATOR() == ''
# 5. Choice validation
input_test.args = ['foo', 'bar']
VALIDATOR = InputValidator(source=input_test, choices=['bar'])
assert VALIDATOR() == 'bar'
# 6. Length validation
# Test exact length...
CORRECT_LEN = 10
input_test.args = [
'a' * (CORRECT_LEN + 1),
'a' * (CORRECT_LEN - 1),
'a' * CORRECT_LEN
]
VALIDATOR = InputValidator(source=input_test, len_=CORRECT_LEN)
assert VALIDATOR() == 'a' * CORRECT_LEN
# ...and length range...
MIN_LEN = 5
MAX_LEN = 10
input_test.args = [
'a' * (MIN_LEN - 1),
'a' * (MAX_LEN + 1),
'a' * MAX_LEN
]
VALIDATOR = InputValidator(source=input_test, len_=(MIN_LEN, MAX_LEN))
assert VALIDATOR() == 'a' * MAX_LEN
# ...and errors
LEN = 'foo'
try:
_ = InputValidator(len_=LEN)
except ValueError:
pass
else:
assert False, 'ValueError not thrown for {!r}.'.format(LEN)
# 7. Something completely different
OUTPUT = ['foo', 'bar', 'baz']
input_test.args = ['[]', '["foo"]', repr(OUTPUT)]
VALIDATOR = InputValidator(source=input_test, len_=3, type_=literal_eval)
assert VALIDATOR() == OUTPUT
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment