Skip to content

Instantly share code, notes, and snippets.

@textbook
Last active August 29, 2015 14:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save textbook/f0560a4555ba3c6dfeaa to your computer and use it in GitHub Desktop.
Save textbook/f0560a4555ba3c6dfeaa to your computer and use it in GitHub Desktop.
Class-based user input validation (comment at http://codereview.stackexchange.com/q/80525/32391)
"""Functionality for validating user inputs."""
# pylint: disable=too-few-public-methods
from __future__ import print_function
import getpass
import re
import string
import sys
__all__ = ['InputValidator', 'PasswordValidator']
def create_choice_validator(choices):
"""Create a validator function based on defined choices.
Notes:
Attempts to create a set from choices to speed up membership tests
with hashable choices.
Arguments:
choices (collection): The valid choices.
Returns:
callable: A validation function to apply to user input.
Raises:
ValueError: If no choices are supplied.
"""
if not choices:
raise ValueError('No valid choices defined.')
try:
choices = set(choices)
except TypeError:
pass
def validator(ui_):
"""Validate user input based on choices."""
if ui_ not in choices:
msg = 'Input must be one of: {!r}'
raise ValueError(msg.format(choices))
return validator
def create_class_validator(classes, valid=None):
"""Validate user input based on character classes.
Arguments:
classes (seq): The required classes.
Returns:
callable: A validation function to apply to user input.
Raises:
ValueError: If one of the classes is invalid.
"""
if valid is None:
valid = InputValidator.VALID_CLASSES
classes = list(classes)
classnames = []
for index, cls in enumerate(classes):
if cls not in valid:
raise ValueError('Invalid character class: {!r}'.format(cls))
classnames.append(cls)
classes[index] = valid[cls]
def validator(ui_):
"""Validate user input based on character classes."""
for cls, name in zip(classes, classnames):
if not any(char in cls for char in ui_):
raise ValueError('No characters in {} class.'.format(name))
return validator
def create_len_validator(len_):
"""Create a validation function based on input length.
Arguments:
len_ (int or tuple): Either the acceptable length, or a tuple
(min_len, max_len).
Returns:
callable: A validation function to apply to user input.
"""
try:
min_, max_ = len_
except TypeError:
def validator(ui_):
"""Validate user input based on length."""
if len(ui_) != len_:
msg = 'Input must contain {} elements.'
raise ValueError(msg.format(len_))
else:
def validator(ui_):
"""Validate user input based on length."""
if len(ui_) < min_:
msg = 'Input must contain at least {} elements.'
raise ValueError(msg.format(min_))
elif len(ui_) > max_:
msg = 'Input must contain at most {} elements.'
raise ValueError(msg.format(max_))
return validator
def create_max_validator(max_):
"""Create a validation function based on input size.
Arguments:
max_: The maximum permitted value.
Returns:
callable: A validation function to apply to user input.
"""
def validator(ui_):
"""Validate user input based on size."""
if ui_ > max_:
msg = 'Input must be at most {}.'
raise ValueError(msg.format(max_))
return validator
def create_min_validator(min_):
"""Create a validation function based on input size.
Arguments:
min_: The minimum permitted value.
Returns:
callable: A validation function to apply to user input.
"""
def validator(ui_):
"""Validate user input based on size."""
if ui_ < min_:
msg = 'Input must be at least {}.'
raise ValueError(msg.format(min_))
return validator
def create_nonempty_validator(reject):
"""Validate user input based on presence.
Arguments:
reject (bool): Whether to reject non-empty input.
Returns:
callable: A validation function to apply to user input.
"""
if not reject:
return lambda ui_: None
def validator(ui_):
"""Reject False-y input."""
if not ui_:
raise ValueError('Input must be present.')
return validator
def create_regex_validator(pattern):
"""Validate user input based on a regular expression.
Notes:
Sets re.VERBOSE flag, allowing additional whitespace/comments.
Arguments:
pattern (str): The regular expression pattern.
Returns:
callable: A validation function to apply to user input.
"""
expr = re.compile(pattern, re.VERBOSE)
def validator(ui_):
"""Validate user input based on a regular expression."""
if expr.match(ui_) is None:
raise ValueError('Input must match {!r}.'.format(pattern))
return validator
class InputValidator(object):
"""Create validators for user input.
Notes:
Regular expressions are validated first, on the raw input string,
then the type - the argument to all other validation functions is
the type-converted input.
The following **config options are supported:
- choices (collection): The valid choices for the input.
- min_: The minimum value permitted.
- max_: The maximum value permitted.
- nonempty: (bool): Whether to reject '' (defaults to False).
- prompt (str): The default prompt to use if not supplied to
get_input (defaults to InputValidator.DEFAULT_PROMPT).
- regex (str): The regular expression to match.
- require_class (seq): The valid character classes.
- source (callable): The function to use to take user input
(defaults to [raw_]input).
- type_ (callable): The type to attempt to convert the input to
(defaults to str).
To add a new configuration option, define a function that takes a
single argument (the item from the config dictionary) and returns
a validation function. The validation function should either raise
a ValueError (invalid input) or return None (valid input). Then
add a new two-tuple ('name', function) to the VALIDATORS list and
update this docstring (and the tests!) accordingly.
Arguments:
name (str, optional): The name to store the validator under.
Defaults to None (i.e. not stored).
**config (dict): The configuration options for the validator.
Attributes:
CACHE (dict): Storage for cached instances.
INPUT (callable): The default function to use for input.
INVALID_CONFIG (seq): Invalid config keys.
PROMPT (str): The default prompt to use if not supplied in config
or the call to get_input.
VALID_CLASSES (dict): The valid character classes.
VALIDATORS (list): The validation functions.
"""
CACHE = {}
INPUT = (input if sys.version_info.major > 2
else raw_input) # pylint: disable=undefined-variable
INVALID_CONFIG = []
PROMPT = '> '
VALID_CLASSES = {
'digits': set(string.digits),
'lowercase': set(string.ascii_lowercase),
'punctuation': set(string.punctuation),
'uppercase': set(string.ascii_uppercase),
}
VALIDATORS = [
('choices', create_choice_validator),
('len_', create_len_validator),
('min_', create_min_validator),
('max_', create_max_validator),
('nonempty', create_nonempty_validator),
('require_class', create_class_validator),
]
def __new__(cls, name=None, **config):
if name is None or name not in cls.CACHE:
self = super(InputValidator, cls).__new__(cls)
#
for key in config:
if key in cls.INVALID_CONFIG:
raise TypeError('Invalid config option: {!r}'.format(key))
# Basic arguments
self.config = config
self.name = name
# Select appropriate source for user input
self.source = config.get('source', cls.INPUT)
# Default configuration
self.prompt = self.config.get('prompt', cls.PROMPT)
self.type_ = self.config.get('type_', str)
# Handle regular expressions
if 'regex' in self.config:
self.regex = create_regex_validator(self.config['regex'])
# Validation functions
self.validators = []
for get_arg, creator in self.VALIDATORS:
item = self.config.get(get_arg)
if item is not None:
self.validators.append(creator(item))
if name is not None:
cls.CACHE[name] = self
else:
self = cls.CACHE[name]
if config and config != self.config:
raise TypeError('Configuration conflict')
return self
def get_input(self, prompt=None):
"""Get validated input.
Arguments:
prompt (str, optional): The prompt to use. Defaults to the
instance's prompt attribute.
"""
if prompt is None:
prompt = self.prompt
while True:
ui_ = self.source(prompt)
# Regular expression validation
try:
self.regex(ui_)
except AttributeError:
pass
except ValueError as err:
print(err)
continue
# Basic type validation
try:
ui_ = self.type_(ui_)
except ValueError as err:
msg = 'Input must match: {!r}.'
print(msg.format(self.type_))
continue
# Any other validation required
for validate in self.validators:
try:
validate(ui_)
except ValueError as err:
print(err)
break
else:
return ui_
def __call__(self, *args, **kwargs):
return self.get_input(*args, **kwargs)
def __repr__(self):
args = []
if self.name is not None:
args.append(repr(self.name))
if self.config:
args.append('**{!r}'.format(self.config))
return '{}({})'.format(self.__class__.__name__, ', '.join(args))
class PasswordValidator(InputValidator):
"""Create validators for passwords.
Notes:
Automatically sets the 'nonempty' configuration option. By
default, either:
- requires a regex; or
- uses all defined character classes and the default LEN.
Does not allow a 'choices' option or setting a 'min_' or 'max_'
password.
Attributes:
LEN (tuple): The default minimum and maximum length.
"""
CACHE = {}
INPUT = getpass.getpass
INVALID_CONFIG = ('choices', 'min_', 'max_')
LEN = (8, 32)
PROMPT = 'Password: '
def __new__(cls, name=None, **config):
config['nonempty'] = True
if 'regex' not in config:
config['len_'] = config.get('len_', cls.LEN)
config['require_class'] = config.get(
'require_class',
list(cls.VALID_CLASSES)
)
return super(PasswordValidator, cls).__new__(cls, name, **config)
if __name__ == '__main__':
# Built-in testing
# pylint: disable=invalid-name
from ast import literal_eval
class SuppressStdOut(object):
"""Suppress the standard output for testing."""
def flush(self, *_, **__):
"""Don't flush anything."""
pass
def write(self, *_, **__):
"""Don't write anything."""
pass
out = sys.stdout
sys.stdout = SuppressStdOut()
def input_test(_):
"""Return whatever is first in args."""
return input_test.args.pop(0)
# Caching
# Ensure caching isn't activated without name argument
assert InputValidator() is not InputValidator()
# Ensure caching is activated with positional name...
assert InputValidator('name') is InputValidator('name')
# ...and keyword name...
assert InputValidator('name') is InputValidator(name='name')
# ...and config...
validator_ = InputValidator('cached', opt=True)
assert InputValidator('cached') is validator_
# ...and handles configuration conflicts
try:
_ = InputValidator('cached', opt=False)
except TypeError:
pass
else:
assert False, 'TypeError not thrown for configuration conflict'
# Calling
input_test.args = ['test', 'test']
# Test both call forms return correct value
validator_ = InputValidator(source=input_test)
assert validator_.get_input() == 'test'
assert validator_() == 'test'
# Numerical validation
input_test.args = ['-1', '11', 'foo', '5']
validator_ = InputValidator(source=input_test, type_=int, min_=0, max_=10)
assert validator_() == 5
# Empty string validation
# Test empty not allowed...
input_test.args = ['', 'test', '']
validator_ = InputValidator(source=input_test, nonempty=True)
assert validator_() == 'test'
# ...and allowed
input_test.args = ['']
validator_ = InputValidator(source=input_test)
assert validator_() == ''
# Choice validation
input_test.args = ['foo', 'bar']
validator_ = InputValidator(source=input_test, choices=['bar'])
assert validator_() == 'bar'
# Length validation
# Test exact length...
CORRECT_LEN = 10
input_test.args = [
'a' * (CORRECT_LEN + 1),
'a' * (CORRECT_LEN - 1),
'a' * CORRECT_LEN
]
validator_ = InputValidator(source=input_test, len_=CORRECT_LEN)
assert validator_() == 'a' * CORRECT_LEN
# ...and length range...
MIN_LEN = 5
MAX_LEN = 10
input_test.args = [
'a' * (MIN_LEN - 1),
'a' * (MAX_LEN + 1),
'a' * MAX_LEN
]
validator_ = InputValidator(source=input_test, len_=(MIN_LEN, MAX_LEN))
assert validator_() == 'a' * MAX_LEN
# ...and errors
LEN = 'foo'
try:
_ = InputValidator(len_=LEN)
except ValueError:
pass
else:
assert False, 'ValueError not thrown for {!r}.'.format(LEN)
# Class validation
# Test successful validation...
OUTPUT = 'Test'
input_test.args = ['', OUTPUT.lower(), OUTPUT.upper(), OUTPUT]
validator_ = InputValidator(source=input_test,
require_class=('uppercase', 'lowercase'))
assert validator_() == OUTPUT
# ...and errors
try:
_ = InputValidator(require_class=('foo',))
except ValueError:
pass
else:
assert False, 'ValueError not thrown for invalid class.'
# Regex validation
regex = r'''
[A-Z] # Start with capital letter
[\w.]{5,9} # then between five and nine other characters
'''
MATCH = 'Will_match'
input_test.args = ["won't match", MATCH]
validator_ = InputValidator(source=input_test, regex=regex)
assert validator_() == MATCH
# Something completely different
OUTPUT = ['foo', 'bar', 'baz']
input_test.args = ['[]', "('foo',)", repr(OUTPUT)]
validator_ = InputValidator(source=input_test, len_=3, type_=literal_eval)
assert validator_() == OUTPUT
# Test __repr__
validator_ = InputValidator('cached')
assert repr(validator_) == "InputValidator('cached', **{'opt': True})"
assert eval(repr(validator_)) is validator_ # pylint: disable=eval-used
assert repr(InputValidator('name')) == "InputValidator('name')"
validator_ = InputValidator(named=False)
assert repr(validator_) == "InputValidator(**{'named': False})"
validator_ = InputValidator()
assert repr(validator_) == 'InputValidator()'
# Password validation
# Check separate cache
assert PasswordValidator('name') is not InputValidator('name')
assert InputValidator('name') is InputValidator('name')
# Test defaults
VALID = '$Abc123!'
input_test.args = ['short', 'toolong' * 5, 'nospecial', VALID]
assert PasswordValidator(source=input_test)() == VALID
sys.stdout = out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment