Skip to content

Instantly share code, notes, and snippets.

@jackdreilly
Created May 26, 2022 07:16
Show Gist options
  • Save jackdreilly/1552a5e9a2d94e436712af9195cea369 to your computer and use it in GitHub Desktop.
Save jackdreilly/1552a5e9a2d94e436712af9195cea369 to your computer and use it in GitHub Desktop.
from __future__ import annotations
import functools
import json
from devtools import debug
from bidict import FrozenOrderedBidict, frozenbidict
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Iterable, List, Union
import pytest
json_patterns = {
" ": r"\s+",
"string": r'"[^"]*"',
"number": r"\d+(\.\d+)?",
"true": "true",
"false": "false",
"null": "null",
"[": r"\[",
"]": r"\]",
**{k: k for k in "{,}:"},
}
@dataclass(frozen=True)
class Asterisk:
child: Any
@dataclass(frozen=True)
class Or:
options: List[Any]
class EscapeChar(Enum):
dot = "."
@dataclass(frozen=True)
class Not:
options: List[Any]
def matches(self, value: Any) -> bool:
return not any(_matches(option, value) for option in self.options)
@dataclass(frozen=True)
class Escape:
p: Union[EscapeChar, str]
def matches(self, x: str) -> bool:
if self.p == "w":
return x.isalnum()
if self.p == "c":
return x.isalpha()
if self.p == "d":
return x.isnumeric()
if self.p == "s":
return x.isspace()
if self.p == EscapeChar.dot:
return True
return self.p == x
EdgeValue = Union[str, Escape, Not]
Parsed = Union[Asterisk, Or, EdgeValue, List["Parsed"]]
def gen(s: str) -> Parsed:
def _gen(s: str, previous: List[Parsed] = []) -> List[Parsed]:
if not s:
return previous
c, s = s[0], s[1:]
if c == "(":
counter = 1
for i, ss in enumerate(s):
counter += {"(": 1, ")": -1}.get(ss, 0)
if not counter:
break
return _gen(s[i + 1 :], [*previous, _gen(s[:i])])
if c == "[":
i = s.index("]")
inside = s[:i]
negate = False
if inside and inside[0] == "^":
inside = inside[1:]
negate = True
parsed = _gen(inside)
return _gen(s[i + 1 :], [*previous, (Not if negate else Or)(parsed)])
if c == "*":
previous, a = previous[:-1], previous[-1]
return _gen(s, [*previous, Asterisk(a)])
if c == "+":
return _gen("*" + s, [*previous, previous[-1]])
if c == "?":
return _gen("|()" + s, previous)
if c == "|":
previous, c = previous[:-1], previous[-1]
a, *b = _gen(s)
return [*previous, Or([c, *getattr(a, "options", [a])]), *b]
if c == "\\":
return _gen(s[1:], [*previous, Escape(s[0])])
if c == ".":
return _gen(s, [*previous, Escape(EscapeChar.dot)])
return _gen(s, [*previous, c])
def clean(x: Parsed) -> Parsed:
if not x:
return x
if isinstance(x, list):
cleaned = list(map(clean, x))
return cleaned[1:] and cleaned or cleaned[0]
if isinstance(x, Asterisk):
return Asterisk(clean(x.child))
if isinstance(x, Or):
return Or(list(map(clean, x.options)))
if isinstance(x, Not):
return Not(tuple(map(clean, x.options)))
return x
return clean(_gen(s))
def node_list(*args) -> List[Node]:
return list({id(a): a for a in args}.values())
@dataclass
class Edge:
value: EdgeValue
next: Node
@property
def is_epsilon(self) -> bool:
return isinstance(self.value, eps)
@dataclass
class Node:
edges: List[Edge] = field(default_factory=list)
is_terminal: bool = False
def __iter__(self) -> Iterable[Node]:
return self._iter()
def _iter(self, history: set[int] = None) -> Iterable[Node]:
history = history or set()
if self.id in history:
return
yield self
history.add(self.id)
for edge in self.edges:
yield from edge.next._iter(history)
@property
def terminals(self) -> Iterable[Node]:
return (n for n in self if n.is_terminal or not n.edges)
@property
def size(self) -> int:
return sum(1 for _ in self)
@property
def epsilon_closure(self) -> List[Node]:
return node_list(*self._epsilon_closure())
@property
def id(self) -> int:
return id(self)
def _epsilon_closure(self, visited: set[int] = None) -> List[Node]:
visited = visited or set()
if self.id in visited:
return []
visited.add(self.id)
yield self
for edge in self.edges:
if not edge.is_epsilon:
continue
yield from edge.next._epsilon_closure(visited)
@property
def terminal(self) -> Node:
return next(self.terminals)
@classmethod
def make_terminal(cls) -> Node:
return cls(is_terminal=True)
def __rshift__(self, edge: Edge) -> Node:
self.edges.append(edge)
return self
class eps:
@classmethod
def edge(cls, node: Node) -> Edge:
return Edge(eps(), node)
def __repr__(self) -> str:
return str(self)
def __str__(self) -> str:
return "ε"
def graph(s: str) -> Node:
parsed = gen(s)
def helper(x):
source, terminal = Node(), Node.make_terminal()
if isinstance(x, Asterisk):
child = helper(x.child)
child_terminal = child.terminal
child_terminal.is_terminal = False
source >> eps.edge(child)
child_terminal >> eps.edge(terminal)
source >> eps.edge(terminal)
terminal >> eps.edge(source)
return source
if isinstance(x, Or):
for o in x.options:
child = helper(o)
source >> eps.edge(child)
child_terminal = child.terminal
child_terminal >> eps.edge(terminal)
child_terminal.is_terminal = False
return source
if isinstance(x, EdgeValue.__args__):
source >> Edge(x, terminal)
return source
if isinstance(x, list):
for xx in x:
terminal = source.terminal
terminal >> eps.edge(helper(xx))
terminal.is_terminal = False
if not x:
source >> eps.edge(terminal)
return source
def clean(node: Node):
nodes = set()
def clean_node(x: Node) -> Node:
if id(node) in nodes:
return x
nodes.add(id(x))
for edge in x.edges:
clean_edge(edge)
return x
def clean_edge(edge: Edge) -> Edge:
clean_node(edge.next)
if len(edge.next.edges) == 1 and edge.value == edge.next.edges[0].value:
edge.next = edge.next.edges[0].next
return clean_node(node)
return clean(helper(parsed))
@functools.lru_cache(maxsize=None)
def jet(regex: str, a: str) -> bool:
def helper(x, nodes: List[Node]) -> bool:
if not nodes:
return False
if not x:
return any(node.is_terminal for node in nodes)
return helper(
x[1:],
node_list(
*(
y
for node in nodes
for edge in node.edges
if _matches(edge.value, x[0])
for y in edge.next.epsilon_closure
)
),
)
return helper(
a,
compile(regex).epsilon_closure,
)
def to_dfs(nfs: Node):
def nodeset(a):
return tuple(sorted(aa.id for aa in a))
closure = nfs.epsilon_closure
source = Node(is_terminal=any(n.is_terminal for n in closure))
to_process = [[source, closure]]
processed = {nodeset(closure): source}
source = to_process[0][0]
while to_process:
d, nodes = to_process.pop(0)
lk = defaultdict(list)
for node in nodes:
for edge in node.edges:
if edge.is_epsilon:
continue
lk[edge.value].append(edge.next)
for k, v in lk.items():
v = node_list(*(vvv for vv in v for vvv in vv.epsilon_closure))
ns = nodeset(v)
if ns in processed:
d >> Edge(k, processed[ns])
continue
n = Node(is_terminal=any(node.is_terminal for node in v))
processed[ns] = n
e = Edge(k, n)
d >> e
to_process.append((n, v))
return source
def optimize(dfs: Node) -> Node:
def helper(states):
idx = {nn.id: i for i, n in enumerate(states) for nn in n}
for state in list(states):
for c in {edge.value for node in state for edge in node.edges}:
ts = defaultdict(list)
for node in state:
for edge in node.edges:
if edge.value == c:
ts[idx[edge.next.id]].append(node)
break
else:
ts[None].append(node)
if len(ts) > 1:
states.remove(state)
return helper([*states, *ts.values()])
return states
states = helper(
[
list(dfs.terminals),
list(n for n in dfs if n.id not in set(n.id for n in dfs.terminals)),
]
)
nodes = [Node(is_terminal=any(n.is_terminal for n in nodes)) for nodes in states]
idx = {nn.id: i for i, n in enumerate(states) for nn in n}
edges = defaultdict(dict)
for node in dfs:
i = idx[node.id]
for edge in node.edges:
edges[i][edge.value] = idx[edge.next.id]
for i, edge in edges.items():
for v, j in edge.items():
nodes[i] >> Edge(v, nodes[j])
for state, node in zip(states, nodes):
if dfs in state:
return node
@pytest.mark.parametrize(
("pattern", "parsed"),
(
(('["]'), Or(['"'])),
(('[^"]'), Not(tuple(['"']))),
((""), []),
(("a"), "a"),
(("ab"), ["a", "b"]),
(("abc"), ["a", "b", "c"]),
(("abc*"), ["a", "b", Asterisk("c")]),
(("(abc)"), ["a", "b", "c"]),
(
("(abc)*asdf*"),
[
Asterisk(["a", "b", "c"]),
"a",
"s",
"d",
Asterisk("f"),
],
),
(("a|b"), Or(["a", "b"])),
(("a|(b|c)"), Or(["a", Or(options=["b", "c"])])),
(("(a|(b|c)|(c|d|e))"), Or(["a", Or(["b", "c"]), Or(["c", "d", "e"])])),
(("a+"), ["a", Asterisk("a")]),
(("(a|b)+"), [Or(["a", "b"]), Asterisk(Or(["a", "b"]))]),
(("a?"), Or(["a", []])),
(("(ab)?"), Or([["a", "b"], []])),
((r"\w+"), [Escape("w"), Asterisk(Escape("w"))]),
((r"\."), Escape(".")),
),
)
def test_gen(pattern, parsed):
assert gen(pattern) == parsed
@pytest.mark.parametrize(
("patterns", "input_string", "parsed"),
(
({1: "a", 2: "b"}, "ba", [(2, "b"), (1, "a")]),
({1: "a*", 2: "b*"}, "bbaaabb", [(2, "bb"), (1, "aaa"), (2, "bb")]),
),
)
def test_parse(patterns, input_string, parsed):
assert parse(patterns, input_string) == parsed
def test_json():
assert ('number', '3.5') in parse(
json_patterns,
json.dumps(
dict(how=dict(did=["yo asdf asdf f d432 324!!1 [][]u", dict(do="that", even=3.5, or_not=False)]))
),
)
@pytest.mark.parametrize(
("pattern", "value", "result"),
[
(r"\.", ".", True),
(r"\.", "a", False),
(r"\.", "..", False),
("", "", True),
("", "a", False),
("a", "a", True),
("a", "b", False),
("a*", "", True),
("a*", "a", True),
("a*", "aaa", True),
("a*", "ba", False),
("a*", "b", False),
("a|b", "a", True),
("a|b", "b", True),
("a|b", "c", False),
("(a|b)", "c", False),
("(a|b)", "a", True),
("(a|b)", "aa", False),
("(a|b)*", "aa", True),
("(a|b)*", "aababaaba", True),
("(a|b)*", "", True),
("(a|b|c)*", "abcccba", True),
("(a|b|c)*asdf", "abcccba", False),
("(a|b|c)*asdf", "abcccbasdf", True),
("(a|b|c)*asdf", "abcccbaasdf", True),
("(((a|b|c)*asdf)*)|(d|c)*", "abcccbaasdf", True),
("(((a|b|c)*asdf)*)|(r|t)*", "abcccbaasdfg", False),
("(((a|b|c)*asdf)*)|(r|t)*", "casdfasdf", True),
("(((a|b|c)*asdf)*)|(r|t)*", "casdfasdfbbaasdfrtttr", False),
("(((a|b|c)*asdf)*)|(r|t)*", "", True),
("c*|(r|t)", "", True),
("(((a|b|c)*asdf)*)|(r|t)", "", True),
(
"(((a|b|c)*asdf)*)*|(r*|t)*lias*df|asd*f(asdf|asdf*(snjakg*h|gsfda(asdf*as)*)*)*asd*f",
"",
False,
),
("(a|b|c)*" * 3, "a" * 3, True),
("(a|b)+", "a", True),
("(a|b)+", "b", True),
("(a|b)+", "bbaa", True),
("(a|b)+", "", False),
("a?", "", True),
("a?", "a", True),
("a?", "aa", False),
(r"\c\d", "a3", True),
(r"\c\d", "b5", True),
(r"\c\d", "55", False),
(".*", "", True),
(".*", "89023hironfasjfnd", True),
(".+", "", False),
*[
[pattern, value, not i]
for pattern, yeses, nos in (
(
"((1|(11)|(111)),)*(()|1|(11)|(111))",
("1", "11", "111", "", "1,111"),
("a", "1111", "1a", ",111"),
),
(
r'"\w+(\s\w+)*" <\w(\.|\w)*@\w(\.|\w)*\.\w+>',
(
'"John Doe" <john.doe@gmail.com>',
'"j" <j@w.w>',
),
(
'"" <john.doe@gmail.com>',
'"John Doe"<john.doe@gmail.com>',
'"John Doe" john.doe@gmail.com>',
'"John Doe" <.doe@gmail.com>',
'"John Doe" <john.doegmail.com>',
'"John Doe" <john.doe@.com>',
'"John Doe" <john.doe@gmail.>',
'"John Doe" <john.doe@gmail>',
'"John Doe" <john.doe@gmail',
'"John Doe" <john.doe@gmail.com>a',
),
),
(
"(X|(<(X|({X*}))*>)|({(X|(<X*>))*}))*",
(
"XXX<XX{X}XXX>X",
"X{X}X<X>X{X}X<X>X",
),
(
"XXX<X<XX>>XX",
"XX<XX{XX>XX}XX",
),
),
(
r"as[^\c]+",
(
"as3",
"as321",
),
(
"321",
"asd",
"as",
"as3a3",
),
),
)
for i, x in enumerate([yeses, nos])
for value in x
],
],
)
def test_regexes(pattern: str, value: str, result: bool):
assert jet(pattern, value) == result
@functools.lru_cache(maxsize=None)
def compile(r: str) -> Node:
return optimize(to_dfs(graph(r)))
def parse(
patterns: FrozenOrderedBidict[Any, str], input_string: str
) -> List[Any | str]:
return _parse(frozenbidict(patterns), input_string)
@functools.lru_cache(maxsize=None)
def _parse(
patterns: FrozenOrderedBidict[Any, str], input_string: str
) -> List[Any | str]:
if not input_string:
return []
regex = "(" + "|".join(f"({v})" for v in patterns.values()) + ")*"
for i in range(len(input_string), 0, -1):
if not jet(regex, input_string[i:]):
continue
for k, v in patterns.items():
if jet(v, input_string[:i]):
return [(k, input_string[:i])] + _parse(patterns, input_string[i:])
raise RuntimeError(input_string)
def _matches(edge_value: EdgeValue, string: str) -> bool:
if isinstance(edge_value, str):
return edge_value == string
elif hasattr(edge_value, "matches"):
return edge_value.matches(string)
if isinstance(edge_value, eps):
return False
raise RuntimeError("Unknown edge value: {}".format(edge_value))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment