Skip to content

Instantly share code, notes, and snippets.

@pmdarrow
Created June 17, 2020 17:39
Show Gist options
  • Save pmdarrow/97e36ae996296a84906fcacb3d44740c to your computer and use it in GitHub Desktop.
Save pmdarrow/97e36ae996296a84906fcacb3d44740c to your computer and use it in GitHub Desktop.
Simple marshmallow enum field with support for apispec
# Loosely based on https://github.com/h/marshmallow_enum
from marshmallow.fields import Field
class EnumField(Field):
default_error_messages = {
'invalid': 'Invalid enum value {input}',
}
def __init__(self, enum_type, *args, **kwargs):
self.enum = enum_type
super(EnumField, self).__init__(*args, **kwargs)
# Detect type of enum and make it available to apispec
values = [e.value for e in self.enum if e.value is not None]
if all(isinstance(v, int) for v in values):
self.metadata['type'] = 'integer'
elif all(isinstance(v, (float, int)) for v in values):
self.metadata['type'] = 'number'
elif all(isinstance(v, bool) for v in values):
self.metadata['type'] = 'boolean'
elif all(isinstance(v, str) for v in values):
self.metadata['type'] = 'string'
# Ensure all enum values are made available to apispec
self.metadata['enum'] = sorted([e.value for e in self.enum])
def _serialize(self, value, attr, obj):
if value is None:
return None
return value.value
def _deserialize(self, value, attr, data, **kwargs):
if value is None:
return None
try:
return self.enum(value)
except ValueError:
self.fail('invalid', input=value, value=value)
def fail(self, key, **kwargs):
kwargs['values'] = ', '.join([str(mem.value) for mem in self.enum])
super(EnumField, self).fail(key, **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment