Created
May 26, 2022 07:16
-
-
Save jackdreilly/1552a5e9a2d94e436712af9195cea369 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import functools | |
import json | |
from devtools import debug | |
from bidict import FrozenOrderedBidict, frozenbidict | |
from collections import defaultdict | |
from dataclasses import dataclass, field | |
from enum import Enum | |
from typing import Any, Iterable, List, Union | |
import pytest | |
json_patterns = { | |
" ": r"\s+", | |
"string": r'"[^"]*"', | |
"number": r"\d+(\.\d+)?", | |
"true": "true", | |
"false": "false", | |
"null": "null", | |
"[": r"\[", | |
"]": r"\]", | |
**{k: k for k in "{,}:"}, | |
} | |
@dataclass(frozen=True) | |
class Asterisk: | |
child: Any | |
@dataclass(frozen=True) | |
class Or: | |
options: List[Any] | |
class EscapeChar(Enum): | |
dot = "." | |
@dataclass(frozen=True) | |
class Not: | |
options: List[Any] | |
def matches(self, value: Any) -> bool: | |
return not any(_matches(option, value) for option in self.options) | |
@dataclass(frozen=True) | |
class Escape: | |
p: Union[EscapeChar, str] | |
def matches(self, x: str) -> bool: | |
if self.p == "w": | |
return x.isalnum() | |
if self.p == "c": | |
return x.isalpha() | |
if self.p == "d": | |
return x.isnumeric() | |
if self.p == "s": | |
return x.isspace() | |
if self.p == EscapeChar.dot: | |
return True | |
return self.p == x | |
EdgeValue = Union[str, Escape, Not] | |
Parsed = Union[Asterisk, Or, EdgeValue, List["Parsed"]] | |
def gen(s: str) -> Parsed: | |
def _gen(s: str, previous: List[Parsed] = []) -> List[Parsed]: | |
if not s: | |
return previous | |
c, s = s[0], s[1:] | |
if c == "(": | |
counter = 1 | |
for i, ss in enumerate(s): | |
counter += {"(": 1, ")": -1}.get(ss, 0) | |
if not counter: | |
break | |
return _gen(s[i + 1 :], [*previous, _gen(s[:i])]) | |
if c == "[": | |
i = s.index("]") | |
inside = s[:i] | |
negate = False | |
if inside and inside[0] == "^": | |
inside = inside[1:] | |
negate = True | |
parsed = _gen(inside) | |
return _gen(s[i + 1 :], [*previous, (Not if negate else Or)(parsed)]) | |
if c == "*": | |
previous, a = previous[:-1], previous[-1] | |
return _gen(s, [*previous, Asterisk(a)]) | |
if c == "+": | |
return _gen("*" + s, [*previous, previous[-1]]) | |
if c == "?": | |
return _gen("|()" + s, previous) | |
if c == "|": | |
previous, c = previous[:-1], previous[-1] | |
a, *b = _gen(s) | |
return [*previous, Or([c, *getattr(a, "options", [a])]), *b] | |
if c == "\\": | |
return _gen(s[1:], [*previous, Escape(s[0])]) | |
if c == ".": | |
return _gen(s, [*previous, Escape(EscapeChar.dot)]) | |
return _gen(s, [*previous, c]) | |
def clean(x: Parsed) -> Parsed: | |
if not x: | |
return x | |
if isinstance(x, list): | |
cleaned = list(map(clean, x)) | |
return cleaned[1:] and cleaned or cleaned[0] | |
if isinstance(x, Asterisk): | |
return Asterisk(clean(x.child)) | |
if isinstance(x, Or): | |
return Or(list(map(clean, x.options))) | |
if isinstance(x, Not): | |
return Not(tuple(map(clean, x.options))) | |
return x | |
return clean(_gen(s)) | |
def node_list(*args) -> List[Node]: | |
return list({id(a): a for a in args}.values()) | |
@dataclass | |
class Edge: | |
value: EdgeValue | |
next: Node | |
@property | |
def is_epsilon(self) -> bool: | |
return isinstance(self.value, eps) | |
@dataclass | |
class Node: | |
edges: List[Edge] = field(default_factory=list) | |
is_terminal: bool = False | |
def __iter__(self) -> Iterable[Node]: | |
return self._iter() | |
def _iter(self, history: set[int] = None) -> Iterable[Node]: | |
history = history or set() | |
if self.id in history: | |
return | |
yield self | |
history.add(self.id) | |
for edge in self.edges: | |
yield from edge.next._iter(history) | |
@property | |
def terminals(self) -> Iterable[Node]: | |
return (n for n in self if n.is_terminal or not n.edges) | |
@property | |
def size(self) -> int: | |
return sum(1 for _ in self) | |
@property | |
def epsilon_closure(self) -> List[Node]: | |
return node_list(*self._epsilon_closure()) | |
@property | |
def id(self) -> int: | |
return id(self) | |
def _epsilon_closure(self, visited: set[int] = None) -> List[Node]: | |
visited = visited or set() | |
if self.id in visited: | |
return [] | |
visited.add(self.id) | |
yield self | |
for edge in self.edges: | |
if not edge.is_epsilon: | |
continue | |
yield from edge.next._epsilon_closure(visited) | |
@property | |
def terminal(self) -> Node: | |
return next(self.terminals) | |
@classmethod | |
def make_terminal(cls) -> Node: | |
return cls(is_terminal=True) | |
def __rshift__(self, edge: Edge) -> Node: | |
self.edges.append(edge) | |
return self | |
class eps: | |
@classmethod | |
def edge(cls, node: Node) -> Edge: | |
return Edge(eps(), node) | |
def __repr__(self) -> str: | |
return str(self) | |
def __str__(self) -> str: | |
return "ε" | |
def graph(s: str) -> Node: | |
parsed = gen(s) | |
def helper(x): | |
source, terminal = Node(), Node.make_terminal() | |
if isinstance(x, Asterisk): | |
child = helper(x.child) | |
child_terminal = child.terminal | |
child_terminal.is_terminal = False | |
source >> eps.edge(child) | |
child_terminal >> eps.edge(terminal) | |
source >> eps.edge(terminal) | |
terminal >> eps.edge(source) | |
return source | |
if isinstance(x, Or): | |
for o in x.options: | |
child = helper(o) | |
source >> eps.edge(child) | |
child_terminal = child.terminal | |
child_terminal >> eps.edge(terminal) | |
child_terminal.is_terminal = False | |
return source | |
if isinstance(x, EdgeValue.__args__): | |
source >> Edge(x, terminal) | |
return source | |
if isinstance(x, list): | |
for xx in x: | |
terminal = source.terminal | |
terminal >> eps.edge(helper(xx)) | |
terminal.is_terminal = False | |
if not x: | |
source >> eps.edge(terminal) | |
return source | |
def clean(node: Node): | |
nodes = set() | |
def clean_node(x: Node) -> Node: | |
if id(node) in nodes: | |
return x | |
nodes.add(id(x)) | |
for edge in x.edges: | |
clean_edge(edge) | |
return x | |
def clean_edge(edge: Edge) -> Edge: | |
clean_node(edge.next) | |
if len(edge.next.edges) == 1 and edge.value == edge.next.edges[0].value: | |
edge.next = edge.next.edges[0].next | |
return clean_node(node) | |
return clean(helper(parsed)) | |
@functools.lru_cache(maxsize=None) | |
def jet(regex: str, a: str) -> bool: | |
def helper(x, nodes: List[Node]) -> bool: | |
if not nodes: | |
return False | |
if not x: | |
return any(node.is_terminal for node in nodes) | |
return helper( | |
x[1:], | |
node_list( | |
*( | |
y | |
for node in nodes | |
for edge in node.edges | |
if _matches(edge.value, x[0]) | |
for y in edge.next.epsilon_closure | |
) | |
), | |
) | |
return helper( | |
a, | |
compile(regex).epsilon_closure, | |
) | |
def to_dfs(nfs: Node): | |
def nodeset(a): | |
return tuple(sorted(aa.id for aa in a)) | |
closure = nfs.epsilon_closure | |
source = Node(is_terminal=any(n.is_terminal for n in closure)) | |
to_process = [[source, closure]] | |
processed = {nodeset(closure): source} | |
source = to_process[0][0] | |
while to_process: | |
d, nodes = to_process.pop(0) | |
lk = defaultdict(list) | |
for node in nodes: | |
for edge in node.edges: | |
if edge.is_epsilon: | |
continue | |
lk[edge.value].append(edge.next) | |
for k, v in lk.items(): | |
v = node_list(*(vvv for vv in v for vvv in vv.epsilon_closure)) | |
ns = nodeset(v) | |
if ns in processed: | |
d >> Edge(k, processed[ns]) | |
continue | |
n = Node(is_terminal=any(node.is_terminal for node in v)) | |
processed[ns] = n | |
e = Edge(k, n) | |
d >> e | |
to_process.append((n, v)) | |
return source | |
def optimize(dfs: Node) -> Node: | |
def helper(states): | |
idx = {nn.id: i for i, n in enumerate(states) for nn in n} | |
for state in list(states): | |
for c in {edge.value for node in state for edge in node.edges}: | |
ts = defaultdict(list) | |
for node in state: | |
for edge in node.edges: | |
if edge.value == c: | |
ts[idx[edge.next.id]].append(node) | |
break | |
else: | |
ts[None].append(node) | |
if len(ts) > 1: | |
states.remove(state) | |
return helper([*states, *ts.values()]) | |
return states | |
states = helper( | |
[ | |
list(dfs.terminals), | |
list(n for n in dfs if n.id not in set(n.id for n in dfs.terminals)), | |
] | |
) | |
nodes = [Node(is_terminal=any(n.is_terminal for n in nodes)) for nodes in states] | |
idx = {nn.id: i for i, n in enumerate(states) for nn in n} | |
edges = defaultdict(dict) | |
for node in dfs: | |
i = idx[node.id] | |
for edge in node.edges: | |
edges[i][edge.value] = idx[edge.next.id] | |
for i, edge in edges.items(): | |
for v, j in edge.items(): | |
nodes[i] >> Edge(v, nodes[j]) | |
for state, node in zip(states, nodes): | |
if dfs in state: | |
return node | |
@pytest.mark.parametrize( | |
("pattern", "parsed"), | |
( | |
(('["]'), Or(['"'])), | |
(('[^"]'), Not(tuple(['"']))), | |
((""), []), | |
(("a"), "a"), | |
(("ab"), ["a", "b"]), | |
(("abc"), ["a", "b", "c"]), | |
(("abc*"), ["a", "b", Asterisk("c")]), | |
(("(abc)"), ["a", "b", "c"]), | |
( | |
("(abc)*asdf*"), | |
[ | |
Asterisk(["a", "b", "c"]), | |
"a", | |
"s", | |
"d", | |
Asterisk("f"), | |
], | |
), | |
(("a|b"), Or(["a", "b"])), | |
(("a|(b|c)"), Or(["a", Or(options=["b", "c"])])), | |
(("(a|(b|c)|(c|d|e))"), Or(["a", Or(["b", "c"]), Or(["c", "d", "e"])])), | |
(("a+"), ["a", Asterisk("a")]), | |
(("(a|b)+"), [Or(["a", "b"]), Asterisk(Or(["a", "b"]))]), | |
(("a?"), Or(["a", []])), | |
(("(ab)?"), Or([["a", "b"], []])), | |
((r"\w+"), [Escape("w"), Asterisk(Escape("w"))]), | |
((r"\."), Escape(".")), | |
), | |
) | |
def test_gen(pattern, parsed): | |
assert gen(pattern) == parsed | |
@pytest.mark.parametrize( | |
("patterns", "input_string", "parsed"), | |
( | |
({1: "a", 2: "b"}, "ba", [(2, "b"), (1, "a")]), | |
({1: "a*", 2: "b*"}, "bbaaabb", [(2, "bb"), (1, "aaa"), (2, "bb")]), | |
), | |
) | |
def test_parse(patterns, input_string, parsed): | |
assert parse(patterns, input_string) == parsed | |
def test_json(): | |
assert ('number', '3.5') in parse( | |
json_patterns, | |
json.dumps( | |
dict(how=dict(did=["yo asdf asdf f d432 324!!1 [][]u", dict(do="that", even=3.5, or_not=False)])) | |
), | |
) | |
@pytest.mark.parametrize( | |
("pattern", "value", "result"), | |
[ | |
(r"\.", ".", True), | |
(r"\.", "a", False), | |
(r"\.", "..", False), | |
("", "", True), | |
("", "a", False), | |
("a", "a", True), | |
("a", "b", False), | |
("a*", "", True), | |
("a*", "a", True), | |
("a*", "aaa", True), | |
("a*", "ba", False), | |
("a*", "b", False), | |
("a|b", "a", True), | |
("a|b", "b", True), | |
("a|b", "c", False), | |
("(a|b)", "c", False), | |
("(a|b)", "a", True), | |
("(a|b)", "aa", False), | |
("(a|b)*", "aa", True), | |
("(a|b)*", "aababaaba", True), | |
("(a|b)*", "", True), | |
("(a|b|c)*", "abcccba", True), | |
("(a|b|c)*asdf", "abcccba", False), | |
("(a|b|c)*asdf", "abcccbasdf", True), | |
("(a|b|c)*asdf", "abcccbaasdf", True), | |
("(((a|b|c)*asdf)*)|(d|c)*", "abcccbaasdf", True), | |
("(((a|b|c)*asdf)*)|(r|t)*", "abcccbaasdfg", False), | |
("(((a|b|c)*asdf)*)|(r|t)*", "casdfasdf", True), | |
("(((a|b|c)*asdf)*)|(r|t)*", "casdfasdfbbaasdfrtttr", False), | |
("(((a|b|c)*asdf)*)|(r|t)*", "", True), | |
("c*|(r|t)", "", True), | |
("(((a|b|c)*asdf)*)|(r|t)", "", True), | |
( | |
"(((a|b|c)*asdf)*)*|(r*|t)*lias*df|asd*f(asdf|asdf*(snjakg*h|gsfda(asdf*as)*)*)*asd*f", | |
"", | |
False, | |
), | |
("(a|b|c)*" * 3, "a" * 3, True), | |
("(a|b)+", "a", True), | |
("(a|b)+", "b", True), | |
("(a|b)+", "bbaa", True), | |
("(a|b)+", "", False), | |
("a?", "", True), | |
("a?", "a", True), | |
("a?", "aa", False), | |
(r"\c\d", "a3", True), | |
(r"\c\d", "b5", True), | |
(r"\c\d", "55", False), | |
(".*", "", True), | |
(".*", "89023hironfasjfnd", True), | |
(".+", "", False), | |
*[ | |
[pattern, value, not i] | |
for pattern, yeses, nos in ( | |
( | |
"((1|(11)|(111)),)*(()|1|(11)|(111))", | |
("1", "11", "111", "", "1,111"), | |
("a", "1111", "1a", ",111"), | |
), | |
( | |
r'"\w+(\s\w+)*" <\w(\.|\w)*@\w(\.|\w)*\.\w+>', | |
( | |
'"John Doe" <john.doe@gmail.com>', | |
'"j" <j@w.w>', | |
), | |
( | |
'"" <john.doe@gmail.com>', | |
'"John Doe"<john.doe@gmail.com>', | |
'"John Doe" john.doe@gmail.com>', | |
'"John Doe" <.doe@gmail.com>', | |
'"John Doe" <john.doegmail.com>', | |
'"John Doe" <john.doe@.com>', | |
'"John Doe" <john.doe@gmail.>', | |
'"John Doe" <john.doe@gmail>', | |
'"John Doe" <john.doe@gmail', | |
'"John Doe" <john.doe@gmail.com>a', | |
), | |
), | |
( | |
"(X|(<(X|({X*}))*>)|({(X|(<X*>))*}))*", | |
( | |
"XXX<XX{X}XXX>X", | |
"X{X}X<X>X{X}X<X>X", | |
), | |
( | |
"XXX<X<XX>>XX", | |
"XX<XX{XX>XX}XX", | |
), | |
), | |
( | |
r"as[^\c]+", | |
( | |
"as3", | |
"as321", | |
), | |
( | |
"321", | |
"asd", | |
"as", | |
"as3a3", | |
), | |
), | |
) | |
for i, x in enumerate([yeses, nos]) | |
for value in x | |
], | |
], | |
) | |
def test_regexes(pattern: str, value: str, result: bool): | |
assert jet(pattern, value) == result | |
@functools.lru_cache(maxsize=None) | |
def compile(r: str) -> Node: | |
return optimize(to_dfs(graph(r))) | |
def parse( | |
patterns: FrozenOrderedBidict[Any, str], input_string: str | |
) -> List[Any | str]: | |
return _parse(frozenbidict(patterns), input_string) | |
@functools.lru_cache(maxsize=None) | |
def _parse( | |
patterns: FrozenOrderedBidict[Any, str], input_string: str | |
) -> List[Any | str]: | |
if not input_string: | |
return [] | |
regex = "(" + "|".join(f"({v})" for v in patterns.values()) + ")*" | |
for i in range(len(input_string), 0, -1): | |
if not jet(regex, input_string[i:]): | |
continue | |
for k, v in patterns.items(): | |
if jet(v, input_string[:i]): | |
return [(k, input_string[:i])] + _parse(patterns, input_string[i:]) | |
raise RuntimeError(input_string) | |
def _matches(edge_value: EdgeValue, string: str) -> bool: | |
if isinstance(edge_value, str): | |
return edge_value == string | |
elif hasattr(edge_value, "matches"): | |
return edge_value.matches(string) | |
if isinstance(edge_value, eps): | |
return False | |
raise RuntimeError("Unknown edge value: {}".format(edge_value)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment