Skip to content

Instantly share code, notes, and snippets.

@ahopkins
Last active November 15, 2021 12:10
Show Gist options
  • Save ahopkins/9b91b950c94c1bc3c24df741b7fb02ef to your computer and use it in GitHub Desktop.
Save ahopkins/9b91b950c94c1bc3c24df741b7fb02ef to your computer and use it in GitHub Desktop.
Experiment for AST router

Experiment for AST router

import json
import re
import sys
import uuid
from collections import defaultdict, namedtuple
from datetime import datetime
from functools import lru_cache, reduce
from operator import iconcat
TAB = " "
LOOPS = 50_000
CACHE = 1024
def parse_date(d):
return datetime.strptime(d, "%Y-%m-%d").date()
REGEX_TYPES = {
"string": (str, re.compile(r"[^/]+")),
"int": (int, re.compile(r"-?\d+")),
"number": (float, re.compile(r"-?(?:\d+(?:\.\d*)?|\.\d+)")),
"alpha": (str, re.compile(r"[A-Za-z]+")),
"path": (str, re.compile(r"[^/].*?")),
"ymd": (
parse_date,
re.compile(r"([12]\d{3}-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01]))"),
),
"uuid": (
uuid.UUID,
re.compile(
r"[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}"
),
),
}
ParamInfo = namedtuple("ParamInfo", ("name", "cast", "pattern"))
def inf_dict():
return defaultdict(inf_dict)
class NotFound(Exception):
...
class NoMethod(Exception):
...
def handler1():
print("handler1")
def handler1post():
print("handler1post")
def handler2():
print("handler2")
def handler3():
print("handler3domain")
def handler4(**kwargs):
print("handler4", kwargs)
class ExtraDict(dict):
def __hash__(self):
return hash(tuple(self.items()))
class Router:
def __init__(
self, delimiter="/", exception=NotFound, method_handler_exception=NotFound
):
self.static_routes = {}
self.dynamic_routes = {}
self.delimiter = delimiter
self.exception = exception
self.method_handler_exception = method_handler_exception
self.ast = ""
def add(self, path, handler, methods="base", name=None, requirements=None):
if not isinstance(methods, (list, tuple, set)):
methods = [methods]
static = "<" not in path
routes = self.static_routes if static else self.dynamic_routes
if path in routes:
route = routes[path]
else:
route = Route(self, path, name, requirements)
routes[path] = route
if name:
routes[name] = route
if not static:
for idx, part in enumerate(path.split(self.delimiter)):
if "<" in part:
if ":" in part:
name, _type, pattern = self.parse_parameter_string(part[1:-1])
route.add_parameter(idx, name, _type, pattern)
else:
route.add_parameter(idx, part[1:-1], str)
for method in methods:
route.add_handler(handler, method)
self._compile()
@lru_cache(maxsize=CACHE)
def resolve(self, path, extra=None):
extra = extra or {}
route, basket = self.find_route(path, self, extra, {})
if route.params and ":" in route.raw_path:
for level, param_info in route.params.items():
for p in param_info:
if not p.pattern or p.pattern.search(basket[level]):
break
else:
raise self.exception
if route.requirements:
for req, value in route.requirements.items():
if req not in extra or extra[req] != value:
raise self.exception
return route, basket
def _compile(self):
src = [
"def find_route(path, router, extra, basket):",
]
if self.static_routes:
src += [
TAB + "try:",
TAB * 2 + "return router.static_routes[path], None",
TAB + "except KeyError:",
TAB * 2 + "pass",
]
if self.dynamic_routes:
src += [
TAB + "parts = path.split(router.delimiter)",
TAB + "num_parts = len(parts)",
] + self._compile_routes()
if not self.dynamic_routes:
src += [TAB + "raise NotFound"]
self.find_src = "\n".join(src)
compiled_find = compile(self.find_src, "", "exec",)
ctx = {}
exec(compiled_find, None, ctx)
self.find_route = ctx["find_route"]
def _compile_routes(self):
tree = inf_dict()
for route in self.dynamic_routes.values():
current = tree
for level, part in enumerate(route.path):
current = current[part]
current["__depth__"] = len(route.path)
if level == len(route.path) - 1:
current[
"__exit__"
] = f"router.dynamic_routes['{route.name or route.raw_path}']"
return self._parse_tree(tree, 0)
def _parse_tree(self, tree, level=0):
branches = [branch for branch in tree.items() if branch[0] != "__depth__"]
# Flatten nested list
return reduce(
iconcat,
[
self._make_exp(
index == 0, (index + 1) == len(branches), part, branch, level
)
for index, (part, branch) in enumerate(branches)
],
[],
)
def _make_exp(self, is_first, is_last, part, branch, level):
indent = (level + 1) * 2 - 2
conditional = "if" if is_first else "elif"
retval = [(indent + 1) * TAB + f"if num_parts > {level}:",] if is_first else []
if "<" in part:
eq = "==" if "__exit__" in branch else ">"
retval += [
(indent + 2) * TAB + f"basket[{level}] = parts[{level}]",
(indent + 2) * TAB + f"{conditional} num_parts {eq} {level + 1}:",
]
else:
retval += [
(indent + 2) * TAB + f"{conditional} parts[{level}] == '{part}':",
]
if "__exit__" in branch:
retval.append((indent + 3) * TAB + f"return {branch['__exit__']}, basket")
else:
retval += self._parse_tree(branch, level + 1)
if is_last:
retval += [(indent + 2) * TAB + f"raise {self.exception.__name__}"]
if branch["__depth__"] == level + 1:
retval += [(indent + 1) * TAB + f"raise {self.exception.__name__}"]
return retval
@staticmethod
def parse_parameter_string(parameter_string):
"""Pulled directly from Sanic repo"""
name = parameter_string
pattern = "string"
if ":" in parameter_string:
name, pattern = parameter_string.split(":", 1)
if not name:
raise ValueError(f"Invalid parameter syntax: {parameter_string}")
default = (str, pattern)
# Pull from pre-configured types
_type, pattern = REGEX_TYPES.get(pattern, default)
return name, _type, pattern
class Route:
def __init__(self, router, raw_path, name, requirements):
self.router = router
self.name = name
self.raw_path = raw_path
self.requirements = requirements
self.handlers = {}
self.path = []
self._explode()
self.params = defaultdict(list)
def __repr__(self):
display = (
f"{self.name}|{self.raw_path}"
if self.name and self.name != self.raw_path
else self.raw_path
)
return f"<Route: {display}>"
def __call__(self, method="base", **kwargs):
try:
return self.handlers[method](**kwargs)
except KeyError:
raise self.router.method_handler_exception(
f"Method '{method}' not found on {self}"
)
def _explode(self):
self.path = self.raw_path.split(self.router.delimiter)
def add_handler(self, handler, method):
# TODO:
# - if already has a handler for this method, raise Exception
self.handlers[method] = handler
def add_parameter(self, idx, name, cast, pattern=None):
# TODO:
# - if already has a param for this idx, raise Exception
self.params[idx].append(ParamInfo(name, cast, pattern))
def parse_parameter_basket(route, basket):
params = {}
if basket:
for idx, value in basket.items():
for p in route.params[idx]:
if not p.pattern:
params[p.name] = str(value)
elif p.pattern.search(value):
params[p.name] = p.cast(value)
return params
router = Router(method_handler_exception=NoMethod)
router.add("/foo/1", handler1, "get")
router.add("/foo/1", handler1post, ["post", "patch"])
router.add("/foo/2", handler2)
router.add("/foo/3", handler3, requirements={"domain": "foobar"})
router.add("/bar", handler2)
router.add("/foobar/1", handler1)
router.add("/foobar/2", handler2, name="foobar2")
router.add("/<foo:string>/<id:uuid>/<num:int>", handler4, "get")
paths = (
("/foo/1", "get", None),
("/foo/1", "post", None),
("/foo/1", "patch", None),
("/foo/2", None, None),
("/foo/3", None, ExtraDict({"domain": "foobar"})),
("/bar", None, None),
("/foobar/1", None, None),
("/foobar/2", None, None),
("/fizzbuzz/bc872f23-2e21-4b1c-b029-ca5dd9c727a0/123", "get", None),
)
print("\nSource:")
print(router.find_src)
for path, method, extra in paths:
print(f"\n\n~~~ Matching {path} ~~~")
route, param_basket = router.resolve(path, extra)
params = parse_parameter_basket(route, param_basket)
print(f"{route=}")
print(f"{param_basket=}")
print(f"{params=}")
print("===\nExecuting")
args = []
if method:
args.append(method)
route(*args, **params)
import re
import typing as t
import uuid
from abc import ABC, abstractmethod
from collections import OrderedDict, defaultdict, namedtuple
from datetime import datetime
from functools import lru_cache
def parse_date(d):
return datetime.strptime(d, "%Y-%m-%d").date()
def parts_to_path(parts):
path = []
for part in parts:
if part.startswith("<"):
match = REGEX_PARAM_NAME.match(part)
path.append(f"<{match.group(1)}>")
else:
path.append(part)
return "/".join(path)
REGEX_PARAM_NAME = re.compile(r"^<([a-z_]+).*>$")
REGEX_TYPES = {
"string": (str, re.compile(r"[^/]+")),
"int": (int, re.compile(r"-?\d+")),
"number": (float, re.compile(r"-?(?:\d+(?:\.\d*)?|\.\d+)")),
"alpha": (str, re.compile(r"[A-Za-z]+")),
"path": (str, re.compile(r"[^/].*?")),
"ymd": (
parse_date,
re.compile(r"([12]\d{3}-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01]))"),
),
"uuid": (
uuid.UUID,
re.compile(
r"[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}"
),
),
}
class NotFound(Exception):
...
class NoMethod(Exception):
...
class RouteExists(Exception):
...
ParamInfo = namedtuple("ParamInfo", ("name", "label", "cast", "pattern"))
class Children(OrderedDict):
...
class Node:
def __init__(self, part: str = "", root: bool = False, parent=None) -> None:
self.root = root
self.part = part
self.parent = parent
self._children = {}
self.level = None
self.route = None
self.dynamic = False
self.first = False
self.last = False
def __repr__(self) -> str:
internals = ", ".join(
f"{prop}={getattr(self, prop)}"
for prop in ["part", "level", "route", "dynamic"]
if getattr(self, prop) or prop in ["level"]
)
return f"Node({internals})"
@property
def children(self):
# TODO:
# - Optimize by making a callable and set this as static
children = {k: v for k, v in sorted(self._children.items(), key=self.__sorting)}
if children:
keys = list(children.keys())
children[keys[0]].first = True
children[keys[-1]].last = True
return children
def display(self, tab: str) -> None:
print(" " * tab * self.level, self)
for child in self.children.values():
child.display(tab)
def render(self) -> t.List[Line]:
output, delayed = self.to_src()
for child in self.children.values():
output += child.render()
output += delayed
return output
def to_src(self) -> t.List[Line]:
addition = ""
if self.route and self.level in self.route.router.do_regex:
addition = f' and REGEX_TYPES["{self.route.params[self.level].label}"][1].match(parts[{self.level}])'
indent = (self.level + 1) * 2 - 1
delayed = []
src = []
# src = [Line(f"{if_stmt} num > {self.level}{addition}:", indent)]
if self.first or self.root:
src = [Line(f"if num > {self.level}{addition}:", indent)]
if self.dynamic:
src.append(Line(f"basket[{self.level}] = parts[{self.level}]", indent + 1))
src.append(Line("...", 0, offset=-1, render=False))
else:
if_stmt = (
# "if" if (self.parent and self.parent.first) or self.root else "elif"
"if"
if self.first or self.root
else "elif"
)
src.append(
Line(f'{if_stmt} parts[{self.level}] == "{self.part}":', indent + 1)
)
if self.route:
location = delayed if self.children else src
location.append(
Line(
# f"return router.dynamic_routes[{self.route.parts}], basket",
f'return router.dynamic_routes["{self.route.raw_path}"], basket',
indent + 1 + bool(not self.children),
)
)
return src, delayed
@staticmethod
def __sorting(item) -> t.Tuple[bool, int, str]:
key, child = item
return child.dynamic, len(child.children) * -1, key
class Tree:
def __init__(self) -> None:
self.root = Node(root=True)
self.root.level = 0
def generate(self, routes: t.Dict[str, Route]) -> None:
for route in routes.values():
current = self.root
for level, part in enumerate(route.parts[1:]):
if part not in current._children:
current._children[part] = Node(part=part, parent=current)
current = current._children[part]
current.level = level + 1
# TODO:
# - full evaluation to make sure that the part if it is dynamic
# is compliant and can be parsed by one of the known types
current.dynamic = part.startswith("<")
current.route = route
def display(self, tab: int = 4) -> None:
self.root.display(tab=tab)
def render(self) -> t.List[Line]:
return self.root.render()
class Line:
TAB = " "
def __init__(
self, src: str, indent: int, offset: int = 0, render: bool = True,
) -> None:
self.src = src
self.indent = indent
self.offset = offset
self.render = render
def __str__(self):
return (self.TAB * self.indent) + self.src
class BaseRouter(ABC):
def __init__(
self,
delimiter: str = "/",
exception: Exception = NotFound,
method_handler_exception: Exception = NotFound,
) -> None:
self.static_routes = {}
self.dynamic_routes = {}
self.dynamic_paths = set()
self.do_regex = set()
self.delimiter = delimiter
self.exception = exception
self.method_handler_exception = method_handler_exception
self.tree = Tree()
@abstractmethod
def get(self):
...
@lru_cache
def resolve(self, path: str, *args) -> None:
return self.find_route(path, self, {})
def add(
self,
path: str,
handler: t.Callable,
methods: t.Union[t.List[str], str] = "base",
name: t.Optional[str] = None,
requirements: t.Optional[t.Dict[str, t.Any]] = None,
) -> None:
if not isinstance(methods, (list, tuple, set)):
methods = [methods]
static = "<" not in path
routes = self.static_routes if static else self.dynamic_routes
# key = path if static else tuple(path.split(self.delimiter))
parts = path.split(self.delimiter)
if path in routes:
route = routes[path]
else:
route = Route(self, path, name, requirements)
routes[path] = route
if name:
routes[name] = route
if not static:
for idx, part in enumerate(parts):
if "<" in part:
if ":" in part:
name, label, _type, pattern = self.parse_parameter_string(
part[1:-1]
)
route.add_parameter(idx, name, label, _type, pattern)
else:
route.add_parameter(idx, part[1:-1], "string", str, None)
clean_path = parts_to_path(parts)
if clean_path in self.dynamic_paths:
for idx, part in enumerate(parts):
if "<" in part:
self.do_regex.add(idx)
else:
self.dynamic_paths.add(clean_path)
for method in methods:
route.add_handler(handler, method)
# self._compile()
def _generate_tree(self) -> None:
self.tree.generate(self.dynamic_routes)
def _render(self, do_compile: bool = True) -> None:
src = [
Line("def find_route(path, router, basket):", 0),
]
if self.static_routes:
src += [
Line("try:", 1),
Line("return router.static_routes[path], None", 2),
Line("except KeyError:", 1),
Line("pass", 2),
]
if self.dynamic_routes:
src += [Line("parts = path.split(router.delimiter)", 1)]
src += [Line("num = len(parts)", 1)]
src += self.tree.render()
src += [Line("raise NotFound", 1)]
self.optimize(src)
self.find_route_src = "\n".join(map(str, filter(lambda x: x.render, src)))
if do_compile:
compiled_src = compile(self.find_route_src, "", "exec",)
ctx = {} # "REGEX_TYPES": {k: v[1] for k, v in REGEX_TYPES.items()}}
exec(compiled_src, None, ctx)
self.find_route = ctx["find_route"]
@staticmethod
def parse_parameter_string(parameter_string: str):
"""Parse a parameter string into its constituent name, type, and
pattern
For example::
parse_parameter_string('<param_one:[A-z]>')` ->
('param_one', str, '[A-z]')
:param parameter_string: String to parse
:return: tuple containing
(parameter_name, parameter_type, parameter_pattern)
"""
# We could receive NAME or NAME:PATTERN
name = parameter_string
label = "string"
if ":" in parameter_string:
name, label = parameter_string.split(":", 1)
if not name:
raise ValueError(f"Invalid parameter syntax: {parameter_string}")
default = (str, label)
# Pull from pre-configured types
_type, pattern = REGEX_TYPES.get(label, default)
return name, label, _type, pattern
@staticmethod
def optimize(src: t.List[Line]) -> None:
offset = 0
current = 0
insert_at = set()
for num, line in enumerate(src):
if line.indent < current:
if not line.src.startswith("."):
offset = 0
# Need to fix this
# if (
# line.src.startswith("if")
# or line.src.startswith("elif")
# or line.src.startswith("return")
# ):
# idnt = line.indent + 1
# prev_line = src[num - 1]
# while idnt < prev_line.indent:
# insert_at.add((num, idnt))
# idnt += 1
offset += line.offset
line.indent += offset
current = line.indent
for num, indent in sorted(insert_at, key=lambda x: (x[0] * -1, x[1])):
src.insert(num, Line("raise NotFound", indent))
class Route:
def __init__(self, router, raw_path, name, requirements):
self.router = router
self.name = name
self.raw_path = raw_path
self.requirements = requirements
self.handlers = {}
self.parts = tuple()
# self.params = defaultdict(list)
self.params = {}
self._explode()
def __repr__(self):
display = (
f"{self.name}|{self.raw_path}"
if self.name and self.name != self.raw_path
else self.raw_path
)
return f"<Route: {display}>"
def __call__(self, method="base", **kwargs):
try:
return self.handlers[method](**kwargs)
except KeyError:
raise self.router.method_handler_exception(
f"Method '{method}' not found on {self}"
)
def _explode(self):
self.parts = tuple(self.raw_path.split(self.router.delimiter))
self.path = parts_to_path(self.parts)
def add_handler(self, handler, method):
if method in self.handlers:
raise RouteExists(f"Route already registered: {self.raw_path} [{method}]")
self.handlers[method] = handler
def add_parameter(
self, idx: int, name: str, label: str, cast: t.Type, pattern=None
):
# TODO:
# - if already has a param for this idx, raise Exception
# self.params[idx].append(ParamInfo(name, cast, pattern))
self.params[idx] = ParamInfo(name, label, cast, pattern)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment