Skip to content

Instantly share code, notes, and snippets.

@amitu
Last active April 10, 2018 21:14
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save amitu/f6241ed802afa3312713efd03ef10509 to your computer and use it in GitHub Desktop.
Save amitu/f6241ed802afa3312713efd03ef10509 to your computer and use it in GitHub Desktop.
Acko's api framework
# -*- coding: utf-8 -*-
# License: BSD
import json
import os
import re
import time
from django import forms
from django.conf import settings
from django.contrib.auth import authenticate, get_user_model, login
from django.contrib.postgres.forms import JSONField, SimpleArrayField
from django.core.exceptions import ObjectDoesNotExist
from django.core.paginator import EmptyPage, Paginator
from django.db.models import Model, Q
from django.http import Http404, JsonResponse
from django.shortcuts import render
from django.template.loader import render_to_string
import jsonschema
import structlog
from encrypted_id.models import EncryptedIDDecodeError, EncryptedIDModel
import r2d2.utils as r2d2_utils
from acko import constants, helpers, utils
from acko.models import Quote as QuoteModel
from acko.models import Asset, Policy
from acko.utils import JSONEncoder, get_point_from_latlong
from masters.models import RTO, Pincode, Variant
from pas import pas_cloud as pas
from pas.api import exceptions as pas_exceptions
from users.models import Phone, User, UserProfile
from users.utils import sanitize_phone
logger = structlog.get_logger()
class ApiJsonEncoder(JSONEncoder):
def default(self, obj):
if isinstance(obj, EncryptedIDModel):
return obj.ekey
if isinstance(obj, Model):
return obj.pk
return super().default(obj)
class IntegerField(forms.IntegerField):
"""
If this field is used, and some integer field exists in schema, but is
not passed by client in request, we will pick a default value of 0.
If you want to override the default value for some field, add the
following to your API subclass:
class MyAPI(api.API):
FIELD_MAPPING = {
'answer': (api.IntegerField, {"default": 42})
}
"""
def __init__(self, default=0, *args, **kw):
self._default = default
super().__init__(*args, **kw)
def to_python(self, value):
value = super().to_python(value=value)
if value is None:
value = self._default
return value
class CharField(forms.CharField):
def to_python(self, value):
value = super().to_python(value=value)
if value is None:
# TODO: this will screw with django's form validation
# framework to a bit, have to verify what exactly is the
# problem and whats the solution.
value = ""
return value
class EkeyField(forms.ModelChoiceField):
def __init__(self, queryset, *args, **kw):
self._model = queryset.model
super().__init__(queryset, *args, **kw)
def to_python(self, value):
if value in self.empty_values:
return None
try:
obj = self._model.objects.get_by_ekey(value)
except (self.queryset.model.DoesNotExist, EncryptedIDDecodeError):
raise ValidationError(
str(self.error_messages['invalid_choice']),
code='invalid_choice'
)
return obj
class LatLongField(forms.CharField):
def to_python(self, value):
value = super().to_python(value)
try:
return get_point_from_latlong(value)
except ValueError:
return None
settings.FIELD_MAPPING.update({
"string": (CharField, {}),
"integer": (IntegerField, {}),
"object": (JSONField, {}),
"boolean": (forms.BooleanField, {}),
"latlong": (LatLongField, {}),
"user_id": (EkeyField, {
'queryset': get_user_model().objects.all()
})
})
error_registry = {}
# noinspection PyInitNewSignature,PyMethodParameters
class ErrorMeta(type):
def __new__(cls, clsname, bases, clsdict):
new = super().__new__(cls, clsname, bases, clsdict)
new._code = clsdict["__module__"][:-4] + "." + clsname
error_registry[new._code] = new
return new
# noinspection PyMethodMayBeStatic
class APIError(forms.ValidationError, metaclass=ErrorMeta):
"""
Add template = "name of template" if you want to overwrite the
template to be used for this error. Consider overwrite .context() if
it makes sense.
"""
def __init__(self, message, code=None, params=None, field=None, **context):
self.human = message
self._context = context
self._field = field
super().__init__(message, code, params)
def error(self):
return {
"human": self.human,
"code": self._code,
"context": self._context
}
@classmethod
def context(cls, request, human, code, context):
"""
This method can be used to embed extra information in the
context. Say we want to fetch some data from database to
better help our user understand what went wrong, or what other
values they can try etc.
"""
return {
"human": human,
"code": code,
"context": context,
"request": request
}
class APIGone(object):
pass
class SchemaError(APIError):
"""
We are treating schema errors to be "code bugs", it is not customer's
fault, its developer's.
"""
class DjValidationError(APIError):
"""
Normally we will not use this, this is only if we are using some
django field that gives this error, all .clean_xxx() methods would be
raising more specialised error.
"""
class ValidationError(APIError):
"""
Generic form validation error. All our .clean_xxx() will be raising
this error
"""
def camel_to_snake(camel):
# FIXME CamelCaseURLx -> camel_case_ur_lx
snake = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', camel)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake).lower()
def snake_to_camel(snake):
return ''.join(word.title() for word in snake.split('_'))
class APIRegistry(object):
csrf_exempt = True
def __init__(self):
self._registry = {}
def register(self, mod, cls, name, v):
key = (name, v)
self._registry[key] = APISpec(mod, cls, name, v)
def handle_one(self, request, method, name, version, data):
# request MUST be only used for auth etc, not for data
spec = self._find_spec(name, version)
if spec is None:
raise Http404
cls = spec.cls
form = cls(request, spec, data)
if method == "GET" and "__doc__" in data:
# TODO
return render(request, "acko/api.html", {"form": form})
# We consider this request a API request either a. it is POST or
# b. it is GET but __doc__ is not passed.
if not form.is_valid():
d = {}
for k, v in form.errors.as_data().items():
errors = []
for e in v:
if isinstance(e, APIError):
# TODO: convert it to new style error
errors.append(e.error())
else:
for message in e.messages:
errors.append(
DjValidationError(
str(message),
key=k,
cleaned_data=form.cleaned_data,
method=method,
name=name,
version=version,
data=data,
error_class=e.__class__.__name__,
).error()
)
d[k] = errors
try:
return spec.error(**d), form.cookies
except jsonschema.ValidationError as e:
# NOTE: if validator fails here it is a BUG
print("Error ValidationError:", e)
print("Data:", d)
print("Schema:", spec.error_schema)
raise e
try:
# TODO remove the round-trip
# Don't modify the data in clean_xxx (except basic type
# casting).
# If any change is required, add it as a property.
# Directly validate the cleaned_data.
# Example: self.cleaned_pincode should be the model instance,
# and self.cleaned_data['pincode'] should be the pincode
# string.
form.serialized_data = json.loads(
json.dumps(form.cleaned_data, cls=ApiJsonEncoder)
)
spec.input(form.serialized_data)
except jsonschema.ValidationError as e:
error = e.message
if e.path:
field = e.path.popleft()
if e.path:
error = "%s: %s" % (
' -> '.join(map(str, e.path)), e.message
)
elif e.validator_value and e.validator_value[0] == 'name':
field = e.validator_value[1]
else:
field = '__all__'
errors = {
field: [
SchemaError(
error,
cleaned_data=form.cleaned_data,
schema=spec.input_schema,
method=method,
name=name,
version=version,
data=data,
).error()
]
}
return spec.error(**errors), form.cookies
if "validate_only" in data:
return {"success": True, "result": None}
try:
d = form.save()
except APIError as e:
print(e, getattr(e, '_field', '__all__'))
return spec.error(
**{getattr(e, '_field', '__all__'): [e.error()]}
), form.cookies
else:
j = form.json(d)
try:
return spec.success(j), form.cookies
except jsonschema.ValidationError as e:
# NOTE: if validator fails here it is a BUG
print("Output ValidationError:", e)
print("Data:", j)
print("Schema:", spec._success_schema)
raise
def __call__(self, request, api, v=settings.API_VERSION, internal=False):
"""
:param request:
:param api:
:param v:
:param internal: Flag if API is called internally by some other API.
:return: dict if internal else JsonResponse
"""
v = int(v)
if api == "bulk":
# in case of bulk, data looks like this:
#
# {
# "one": {"api": "api", "method": "GET/POST", "data": {}}
# "two": {"api": "api", "method": "GET/POST", "data": {}}
# }
#
# and the response must look like:
#
# {
# "one": {"success": True, "result": "foo"}
# "two": {
# "success": False, "errors": {"__all__": ["yo"]}
# }
# }
print("bulk", request.data)
result = {}
cookies = []
for api, payload in request.data.items():
result[api], api_cookies = self.handle_one(
request, payload["method"], payload["name"], v,
payload["data"])
cookies.extend(api_cookies)
else:
result, cookies = self.handle_one(request, request.method, api, v,
request.data)
if internal:
return result
response = JsonResponse(result, encoder=ApiJsonEncoder)
for args, kw in cookies:
response.set_cookie(*args, **kw)
return response
def _find_spec(self, name, v):
if v > settings.API_VERSION:
v = settings.API_VERSION
for version in range(v, 0, -1):
key = (name, version)
try:
spec = self._registry[key]
except KeyError:
continue
if issubclass(spec.cls, APIGone):
return None
return spec
return None
registry = APIRegistry()
# in urls.py we have: surl('/api/v<int:v>/<something:api>/', api.registry),
class APISpec(object):
def __init__(self, mod, cls, name, v):
self.api = name
self.cls = cls
self.module = mod
self.version = v
self._input_schema = self._read_schema("%s_input" % name, v)
self._success_schema = self._read_schema("%s_success" % name, v)
self._error_schema = self._read_schema("%s_error" % name, v)
@property
def input_schema(self):
return self._input_schema
def input(self, data):
jsonschema.validate(
data,
self._input_schema,
format_checker=jsonschema.FormatChecker())
return data
def error(self, **errors):
for k, v in errors.items():
if isinstance(v, str):
errors[k] = [v]
return self._validate({
"success": False,
"errors": errors
}, self._error_schema)
def success(self, result):
return self._validate({
"success": True,
"result": result
}, self._success_schema)
@classmethod
def _read_schema(cls, name, version):
name = os.path.join(settings.BASE_DIR,
"../schema/v%d/%s.json" % (version, name))
with open(name) as f:
text = f.read()
if not text:
return {}
schema = json.loads(text)
cls._remove_empty_required(schema)
return schema
@staticmethod
def _validate(obj, schema):
jsonschema.validate(
obj, schema, format_checker=jsonschema.FormatChecker()
)
return obj
@classmethod
def _remove_empty_required(cls, schema):
"""
If all the fields in an object are optional, the required array in
the schema will be empty; jsonschema then raises an error. This
method removes all the empty "required" arrays.
"""
schema_type = schema.get('type')
if schema_type == 'object':
if 'required' in schema and not schema['required']:
del schema['required']
for p, p_schema in schema['properties'].items():
cls._remove_empty_required(p_schema)
elif schema_type == 'array':
if 'items' in schema:
cls._remove_empty_required(schema['items'])
elif schema.get('anyOf'):
for sub_schema in schema['anyOf']:
cls._remove_empty_required(sub_schema)
elif schema.get('oneOf'):
for sub_schema in schema['oneOf']:
cls._remove_empty_required(sub_schema)
elif schema.get('allOf'):
for sub_schema in schema['allOf']:
cls._remove_empty_required(sub_schema)
def __repr__(self):
return str(self.__dict__)
# noinspection PyMethodMayBeStatic
class API(forms.Form):
FIELD_MAPPING = {}
# if perms is not empty, we check if current user has those perms
PERMS = []
def __init__(self, request, spec, *args, **kw):
self.request = request
self.spec = spec
self.cookies = []
super().__init__(*args, **kw)
for name, field in spec.input_schema['properties'].items():
self.fields[name] = self._get_field(
name, field,
(name in spec.input_schema.get("required", []))
)
def set_cookie(self, *args, **kw):
self.cookies.append((args, kw))
def d(self, name, default=None):
val = self.cleaned_data.get(name)
if val is None:
return default
return val
def i(self, name, default=0):
return int(self.cleaned_data.get(name) or default)
def _get_field(self, name, field, required):
# first lookup in cls.FIELD_MAPPING with name, then with type,
# then in settings.FIELD_MAPPING
try:
tipe = field["type"]
except KeyError:
# TODO
# Current assumptions:
# * it's a oneOf / allOf / anyOf field if "type" is not a key
# * first field of the oneOf / allOf /anyOf is the required field,
# * oneOf / anyOf: remaining are placeholders for null / blank
if 'anyOf' in field:
tipe = field["anyOf"][0]["type"]
elif 'oneOf' in field:
tipe = field["oneOf"][0]["type"]
elif 'allOf' in field:
tipe = field["allOf"][0]["type"]
else:
raise Exception("Unknown field: %s" % field)
cls, kw = self.FIELD_MAPPING.get(
name, self.FIELD_MAPPING.get(
tipe, settings.FIELD_MAPPING.get(
name,
settings.FIELD_MAPPING.get(
tipe, (forms.CharField, {})
)
)
)
)
kw = kw.copy()
kw["required"] = required
if "minLength" in field:
kw["min_length"] = field["minLength"]
if "maxLength" in field:
kw["max_length"] = field["maxLength"]
if "pattern" in field:
kw = kw # TODO: add a regex validator
return cls(**kw)
class GETyAPI(API):
def save(self):
pass
class AnonymousUser(APIError):
pass
class AGETyAPI(GETyAPI):
def clean(self):
super().clean()
if not self.request.user.is_authenticated():
raise AnonymousUser(message="No user is logged in.")
class SGETyAPI(AGETyAPI):
def clean(self):
super().clean()
if not self.request.user.is_staff:
raise AnonymousUser(message="Only staff can access this API.")
class ListAPI(GETyAPI):
def object_list(self):
return self.model.objects.all()
def json(self, _):
object_list = self.object_list()
return [self.obj2json(o) for o in object_list]
# noinspection PyMethodMayBeStatic
class PaginatedAPI(ListAPI):
def extra(self):
return None
def json(self, _):
object_list = self.object_list()
pager = Paginator(
object_list, self.i("per_page", 25),
orphans=self.i("orphans", 0),
allow_empty_first_page=True,
)
try:
page = pager.page(self.i("page", 1))
except EmptyPage:
# If page is out of range (e.g. 9999), deliver last page of
# results
page = pager.page(pager.num_pages)
return {
"object_list": [self.obj2json(o) for o in page.object_list],
"num_pages": pager.num_pages, "page": page.number,
"has_next": page.has_next(),
"has_previous": page.has_previous(),
"has_other_pages": page.has_other_pages(),
"next_page_number": (
page.next_page_number() if page.has_next() else 0
),
"previous_page_number": (
page.previous_page_number() if page.has_previous() else 0
),
"start_index": page.start_index(),
"end_index": page.end_index(),
"extra": self.extra(),
}
class FlatAPIError(APIError):
pass
class FlatAPI(GETyAPI):
api = None
tree_data = None
result = None
def clean(self):
super().clean()
data = {}
for key, v in self.cleaned_data.items():
parts = key.split('__')
last_i = len(parts) - 1
d = data
for i, k in enumerate(parts):
if i == last_i:
d[k] = v
else:
d[k] = d.get(k, {})
d = d[k]
self.tree_data = data
def save(self):
self.call_api(self.tree_data)
def call_api(self, data):
out = helpers.call_api(self.request, self.api, data)
if out['success']:
self.result = out['result']
else:
errors = {}
for f, err in out['errors'].items():
errors[f] = [e['human'] for e in err]
raise FlatAPIError(field='__all__', message=json.dumps(errors))
def json(self, _):
return self.result
def register_api(cls):
"""
If class name ends with Vn, then n is assumed to be the version
number. Else the class is assumed to implement a version 1 API.
"""
version_split = re.split(r'(V\d+)$', cls.__name__)
name = camel_to_snake(version_split[0])
if len(version_split) == 3:
vn = version_split[1]
v = int(vn[1:])
else:
v = 1
import inspect
caller_frame = inspect.stack()[1]
mod = caller_frame.filename.split(os.path.sep)[-2]
registry.register(mod, cls, name, v)
return cls
class UserDoesNotExist(APIError):
pass
@register_api
class UserInfo(GETyAPI):
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
self.user = None
def clean_ekey(self):
# self.ekey is same as self.cleaned_data["ekey"]
try:
self.user = User.objects.get_by_ekey(self.ekey)
except (User.DoesNotExist, EncryptedIDDecodeError):
raise UserDoesNotExist(message="UserDoesNotExist", ekey=self.ekey)
return self.ekey
def json(self):
return {
"id": self.user.ekey,
"informal": self.user.informal,
"formal": self.user.formal,
"phone": self.user.phone,
"is_staff": self.user.is_staff,
"is_superuser": self.user.is_superuser,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment