Created
July 22, 2019 01:03
-
-
Save slott56/851c9706a3f83ffdb04fb0623ab2a20f to your computer and use it in GitHub Desktop.
Generate strings based on an RE. The resulting parse tree will be nodes. `node.generate()` should create strings that match the RE.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
RE-based generator. | |
Overview | |
======== | |
See https://json-schema.org/understanding-json-schema/reference/regular_expressions.html | |
Also, this: http://www.ecma-international.org/ecma-262/9.0/index.html#sec-regular-expressions | |
See https://bitbucket.org/mrabarnett/mrab-regex for a reference implementation | |
The RE language we're going to handle is this: | |
:: | |
A single unicode character (other than the special characters below) matches itself. | |
``.``: Matches any character except line break characters. (Be aware that what constitutes a line break character is somewhat dependent on your platform and language environment, but in practice this rarely matters). | |
``^``: Matches only at the beginning of the string. | |
``$``: Matches only at the end of the string. | |
``(...)``: Group a series of regular expressions into a single regular expression. | |
``|``: Matches either the regular expression preceding or following the | symbol. | |
``[abc]``: Matches any of the characters inside the square brackets. | |
``[a-z]``: Matches the range of characters. | |
``[^abc]``: Matches any character not listed. | |
``[^a-z]``: Matches any character outside of the range. | |
``+``: Matches one or more repetitions of the preceding regular expression. | |
``*``: Matches zero or more repetitions of the preceding regular expression. | |
``?``: Matches zero or one repetitions of the preceding regular expression. | |
``+?``, ``*?``, ``??``: The *, +, and ? qualifiers are all greedy; they match as much text as possible. Sometimes this behavior isn’t desired and you want to match as few characters as possible. | |
``(?!...}``, ``(?=...}``: Negative and positive lookahead. | |
``{x}``: Match exactly x occurrences of the preceding regular expression. | |
``{x,y}``: Match at least x and at most y occurrences of the preceding regular expression. | |
``{x,}``: Match x occurrences or more of the preceding regular expression. | |
``{x}?``, ``{x,y}?``, ``{x,}?``: Lazy versions of the above expressions. | |
Also. | |
These defined sets:: | |
``\\d``: digit | |
``\\D``: non-digit | |
``\\s``: whitespace | |
``\\S``: non-whitespace | |
``\\w``: word (a-z A-Z 0-9 _) | |
``\\W``: Non-word | |
Structure | |
========= | |
This module contains | |
- AST model. All subclasses of an abstract superclass, :class:`Node`. | |
- parser functions. The entry-point is :func:`compile`. | |
""" | |
from functools import lru_cache | |
import random | |
import string | |
from enum import Enum | |
from typing import Iterator, Optional, Set, cast, List, Union | |
# AST | |
# ========== | |
class Node: | |
"""A node in the RE AST.""" | |
def __init__(self, item: 'Node') -> None: | |
self.item = item | |
def __eq__(self, other: 'Node') -> bool: | |
return self.__class__ == other.__class__ and self.item == other.item | |
def __repr__(self) -> str: | |
return f"{self.__class__.__name__}({self.item!r})" | |
def generate(self) -> Iterator[str]: | |
raise NotImplementedError | |
def domain(self) -> Set[str]: | |
return set() | |
class Literal(Node): | |
"""Item is a single Character instance.""" | |
def generate(self) -> Iterator[str]: | |
yield from self.item.generate() | |
class Sequence(Node): | |
"""Used for ordinary sequence of RE's: list of children""" | |
def __init__(self, *nodes: 'Node') -> None: | |
self.children = list(nodes) | |
def __eq__(self, other: 'Sequence') -> bool: | |
return self.__class__ == other.__class__ and self.children == other.children | |
def __repr__(self) -> str: | |
return f"{self.__class__.__name__}(*{repr(self.children)})" | |
def generate(self) -> Iterator[str]: | |
for c in self.children: | |
yield from c.generate() | |
class Repeated(Node): | |
"""Used for repetition suffix. Two children: node with pattern and repetition suffix""" | |
def __init__(self, node: 'Node', quantifier: 'Quantifier') -> None: | |
self.node = node | |
self.quantifier = quantifier | |
def __eq__(self, other: 'Repeated') -> bool: | |
return (self.__class__ == other.__class__ | |
and self.node == other.node | |
and self.quantifier == other.quantifier | |
) | |
def __repr__(self) -> str: | |
return f"{self.__class__.__name__}({self.node!r}, {self.quantifier!r})" | |
def generate(self) -> Iterator[str]: | |
size = cast(Quantifier, self.quantifier).repeat() | |
for _ in range(size): | |
yield from self.node.generate() | |
class Quantifier(Node): | |
"""Used for {a, b}, {a}, {a,}, *, ?, + quantifiers""" | |
def __init__(self, | |
simple: str, | |
start: Optional[str]=None, | |
end: Optional[str]=None, | |
greedy: bool=True | |
): | |
self.text = simple | |
self.start = start | |
self.end = end | |
self.greedy = greedy | |
self.low = self.high = 1 | |
if self.text == '?': | |
self.low, self.high = 0, 1 | |
elif self.text == '+': | |
self.low, self.high = 1, 8 | |
elif self.text == '*': | |
self.low, self.high = 0, 7 | |
elif self.text == '{': | |
if self.start: | |
self.low = int(self.start) | |
else: | |
raise Exception(f"Bad {{{self.start}-{self.end}}}") | |
if self.end is None: | |
self.high = self.low | |
elif self.end == '': | |
self.high = 8 | |
else: | |
self.high = int(self.end) | |
else: | |
raise Exception(f"Bad {{{self.start}-{self.end}}}") | |
def __eq__(self, other: 'Quantifier') -> bool: | |
return (self.__class__ == other.__class__ | |
and self.text == other.text | |
and self.start == other.start | |
and self.end == other.end | |
and self.greedy == other.greedy | |
) | |
def __repr__(self) -> str: | |
return ( | |
f"{self.__class__.__name__}({self.text!r}, {self.start!r}, {self.end!r}, " | |
f"greedy={self.greedy})" | |
) | |
def repeat(self) -> int: | |
""" | |
This is based on the overall minLength and maxLength in the schema and | |
the number of bytes already generated. Ideally, we can partition the '*' and '+' | |
operators to land | |
""" | |
return random.randint(self.low, self.high) | |
class Alternation(Node): | |
"""The | operators with a sequence of children""" | |
def __init__(self, *nodes: 'Node') -> None: | |
self.children = list(nodes) | |
def __eq__(self, other: 'Alternation') -> bool: | |
return self.__class__ == other.__class__ and self.children == other.children | |
def __repr__(self) -> str: | |
return f"{self.__class__.__name__}(*{repr(self.children)})" | |
def generate(self) -> Iterator[str]: | |
choice = random.choice(self.children) | |
yield from choice.generate() | |
class Group(Node): | |
"""The () with a child.""" | |
def __init__(self, node: 'Node') -> None: | |
self.node = node | |
def __eq__(self, other: 'Group') -> bool: | |
return self.__class__ == other.__class__ and self.node == other.node | |
def __repr__(self) -> str: | |
return f"{self.__class__.__name__}({self.node})" | |
def generate(self) -> Iterator[str]: | |
yield from self.node.generate() | |
class Sense(str, Enum): | |
NEGATIVE = "!" | |
POSITIVE = "=" | |
class PatternSet(Node): | |
"""One of the sets, either []... or one of the defined sets, including the special . set.""" | |
def __init__(self, *items: 'Node', sense: Sense=Sense.POSITIVE) -> None: | |
self.children = list(items) | |
self.sense: Sense = sense | |
self.concrete_domain: Optional[List[str]] = None # ideally Set[str], but unit tests are | |
# complex | |
def __eq__(self, other: 'PatternSet') -> bool: | |
return ( | |
self.__class__ == other.__class__ | |
and self.children == other.children | |
and self.sense == other.sense | |
) | |
def __repr__(self) -> str: | |
return f"{self.__class__.__name__}(*{self.children!r}, sense={self.sense!r})" | |
def domain(self) -> Set[str]: | |
base_set: Set[str] | |
if self.sense == Sense.POSITIVE: | |
base_set = set() | |
for item in self.children: | |
base_set |= item.domain() | |
return base_set | |
else: | |
base_set = set(string.printable) | |
for item in self.children: | |
base_set -= item.domain() | |
return base_set | |
def generate(self) -> Iterator[str]: | |
# Sorted makes unit tests predictable. | |
if self.concrete_domain is None: | |
self.concrete_domain = list(sorted(self.domain())) | |
char = random.choice(self.concrete_domain) | |
yield char | |
class SetAll(PatternSet): | |
"""The "." set""" | |
def __init__(self) -> None: | |
super().__init__() | |
def __repr__(self) -> str: | |
return f"{self.__class__.__name__}()" | |
def domain(self) -> Set[str]: | |
return set(string.printable) | |
class SetDigit(PatternSet): | |
"""The \\d set""" | |
def __init__(self, sense=Sense.POSITIVE) -> None: | |
super().__init__(sense=sense) | |
def __repr__(self) -> str: | |
return f"{self.__class__.__name__}()" | |
def domain(self) -> Set[str]: | |
return set(string.digits) | |
class SetNonDigit(SetDigit): | |
"""The \\D set""" | |
def __init__(self) -> None: | |
super().__init__(sense=Sense.NEGATIVE) | |
def domain(self) -> Set[str]: | |
return set(string.printable) - set(string.digits) | |
class SetWhitespace(PatternSet): | |
"""The \\s set""" | |
def __init__(self, sense=Sense.POSITIVE) -> None: | |
super().__init__(sense=sense) | |
def __repr__(self) -> str: | |
return f"{self.__class__.__name__}()" | |
def domain(self) -> Set[str]: | |
return set(string.whitespace) | |
class SetNonWhitespace(SetWhitespace): | |
"""The \\S set""" | |
def __init__(self) -> None: | |
super().__init__(sense=Sense.NEGATIVE) | |
def domain(self) -> Set[str]: | |
return set(string.printable) - set(string.whitespace) | |
class SetWord(PatternSet): | |
"""The \\w set""" | |
def __init__(self, sense=Sense.POSITIVE) -> None: | |
super().__init__(sense=sense) | |
def __repr__(self) -> str: | |
return f"{self.__class__.__name__}()" | |
def domain(self) -> Set[str]: | |
return ( | |
set(string.ascii_uppercase) | |
| set(string.ascii_lowercase) | |
| set(string.digits) | |
| {'_'} | |
) | |
class SetNonWord(SetWord): | |
"""The \\W set""" | |
def __init__(self) -> None: | |
super().__init__(sense=Sense.NEGATIVE) | |
def domain(self) -> Set[str]: | |
return set(string.printable) - SetWord().domain() | |
class Character(Node): | |
"""Item within a PatternSet, or a Literal""" | |
def __init__(self, char: str) -> None: | |
self.char = char | |
def __eq__(self, other: 'Character') -> bool: | |
return self.__class__ == other.__class__ and self.char == other.char | |
def __repr__(self) -> str: | |
return f"{self.__class__.__name__}({self.char!r})" | |
def generate(self) -> Iterator[str]: | |
yield self.char | |
def domain(self ) -> Set[str]: | |
return {self.char} | |
class Range(Character): | |
"""Item with in a PatternSet -- a number of characters""" | |
def __init__(self, start: 'Character', end: 'Character') -> None: | |
self.start = start | |
self.end = end | |
def __eq__(self, other: 'Range') -> bool: | |
return ( | |
self.__class__ == other.__class__ | |
and self.start == other.start | |
and self.end == other.end | |
) | |
def __repr__(self) -> str: | |
return f"{self.__class__.__name__}({self.start!r}, {self.end!r})" | |
def domain(self) -> Set[str]: | |
return set(chr(x) for x in range(ord(self.start.char), ord(self.end.char)+1)) | |
class Start(Node): | |
"""The "^" Start of string node""" | |
def generate(self) -> Iterator[str]: | |
yield '' | |
class End(Node): | |
"""The "$" End of string node""" | |
def generate(self) -> Iterator[str]: | |
yield '' | |
class PositiveLookahead(Node): | |
""" | |
The (?=...) lookahead pattern | |
TODO: This constrains the pattern to include the following text. | |
Effectively, this is part of a sequence. | |
""" | |
def __init__(self, node: 'Node') -> None: | |
self.text = '' | |
self.node = node | |
def __eq__(self, other: 'PositiveLookahead') -> bool: | |
return self.__class__ == other.__class__ and self.node == other.node | |
def __repr__(self) -> str: | |
return f"{self.__class__.__name__}({self.node!r})" | |
def generate(self) -> Iterator[str]: | |
yield '' | |
class NegativeLookahead(Node): | |
""" | |
The (?!...) lookahead pattern | |
# TODO: This constrains the pattern to include the following text. | |
Effectively, this is part of a sequence. But. The following part must NOT | |
match the pattern. Hard to do. | |
""" | |
def __init__(self, node: 'Node') -> None: | |
self.text = '' | |
self.node = node | |
def __eq__(self, other: 'NegativeLookahead') -> bool: | |
return self.__class__ == other.__class__ and self.node == other.node | |
def __repr__(self) -> str: | |
return f"{self.__class__.__name__}({self.node!r})" | |
def generate(self) -> Iterator[str]: | |
yield '' | |
# Parser | |
# ========== | |
class Source: | |
""" | |
Source of characters in the Regular Expression. | |
This offers an unget capability to allow 1-char lookahead inside ()'s. | |
""" | |
def __init__(self, text: str) -> None: | |
self.text = text | |
self.pos = 0 | |
def get(self) -> str: | |
try: | |
character = self.text[self.pos] | |
self.pos += 1 | |
return character | |
except IndexError: | |
return '' | |
def unget(self, character: str) -> None: | |
self.pos -= len(character) | |
assert self.text.startswith(character, self.pos) | |
def match(self, match_text: str) -> bool: | |
if self.text.startswith(match_text, self.pos): | |
self.pos += len(match_text) | |
return True | |
return False | |
def match_any(self, match_text: str) -> str: | |
if self.text[self.pos] in match_text: | |
character = self.text[self.pos] | |
self.pos += 1 | |
return character | |
return '' | |
def expect(self, match_text: str) -> None: | |
if not self.match(match_text): | |
raise SyntaxError(f"Missing {match_text}") | |
def parse_pattern(source: Source) -> Node: | |
"""Entry-point for pattern parsing.""" | |
pattern = [parse_sequence(source)] | |
while source.match("|"): | |
pattern.append(parse_sequence(source)) | |
if len(pattern) == 1: | |
return pattern[0] | |
node = Alternation(*pattern) | |
return node | |
def parse_sequence(source: Source) -> Node: | |
pattern: List[Node] = [] | |
while True: | |
c = source.get() | |
if c == '': | |
break | |
elif c in (')', '|'): | |
# End of ()'s in sequence or | between alternatives | |
break | |
elif c == '\\': | |
# Escaped character | |
pe = parse_escape(source) | |
if isinstance(pe, Character): | |
pattern.append(Literal(pe)) | |
else: | |
pattern.append(pe) | |
elif c == '(': | |
# ()'d grouping | |
pattern.append(parse_paren(source)) | |
elif c == '.': | |
# the '.' set | |
pattern.append(SetAll()) | |
elif c == '[': | |
# A general set | |
pattern.append(parse_set(source)) | |
elif c == '^': | |
# The start-of-line marker | |
pattern.append(Start(Character(c))) | |
elif c == '$': | |
# The end-of-line marker | |
pattern.append(End(Character(c))) | |
elif c in ('?', '*', '+', '{'): | |
source.unget(c) | |
suffix = parse_quantifier(source) | |
pattern[-1] = Repeated(pattern[-1], suffix) | |
else: | |
pattern.append(Literal(Character(c))) | |
return Sequence(*pattern) | |
def parse_escape(source: Source) -> Node: | |
c = source.get() | |
if c == 'd': | |
return SetDigit() | |
elif c == 'D': | |
return SetNonDigit() | |
elif c == 's': | |
return SetWhitespace() | |
elif c == 'S': | |
return SetNonWhitespace() | |
elif c == 'w': | |
return SetWord() | |
elif c == 'W': | |
return SetNonWord() | |
elif c in ('a', 'b', 'f', 'n', 'r', 't', 'v'): | |
escapes = { | |
"a": "\a", | |
"b": "\b", | |
"f": "\f", | |
"n": "\n", | |
"r": "\r", | |
"t": "\t", | |
"v": "\v", | |
} | |
return Character(escapes[c]) | |
else: | |
return Character(c) | |
def parse_paren(source: Source) -> Node: | |
"""Started with (...""" | |
c = source.get() | |
if c == '?': | |
c2 = source.get() | |
if c2 in ('!', '='): | |
# Lookahead | |
return parse_lookahead(source, c2) | |
else: | |
# Unsupported (?...) syntax | |
raise SyntaxError("Unsupported (?...)") | |
else: | |
source.unget(c) | |
sub_pattern = parse_pattern(source) | |
# source.expect(")") | |
return Group(sub_pattern) | |
def parse_lookahead(source: Source, positive: str) -> Node: | |
sub_pattern = parse_pattern(source) | |
source.expect(")") | |
if positive == '=': | |
return PositiveLookahead(sub_pattern) | |
elif positive == '!': | |
return NegativeLookahead(sub_pattern) | |
else: | |
raise SyntaxError("Unsupported (?...)") | |
def parse_set(source: Source) -> Node: | |
"""[...] and [^...] with special cases for - ranges and \\]""" | |
sense = Sense.NEGATIVE if source.match("^") else Sense.POSITIVE | |
items: List[Node] = [] | |
members = parse_set_member(source) | |
if members is None: | |
raise Exception("Invalid set") | |
items.extend(members) | |
while True: | |
if source.match("]"): | |
break | |
members = parse_set_member(source) | |
if members is None: | |
break | |
items.extend(members) | |
return PatternSet(*items, sense=sense) | |
def parse_set_member( | |
source: Source | |
) -> Union[List[Character], List[PatternSet], List[Range], List[Node], None]: | |
"""letter or letter-letter or special case of letter- at end of set.""" | |
start = parse_set_item(source) | |
if isinstance(start, Character) and start.char == '': | |
return None | |
if source.match("-"): | |
if source.match("]"): | |
# Special trailing "-" case. | |
source.unget("]") | |
if isinstance(start, Character): | |
return [start, Character("-")] | |
else: | |
raise Exception(f"Bad range start {start!r}") | |
end = parse_set_item(source) | |
if isinstance(start, Character) and isinstance(end, Character): | |
return [Range(start, end)] | |
raise Exception(f"Bad ranges {start!r}-{end!r}") | |
else: | |
# TODO: Check range validity | |
return [start] | |
def parse_set_item(source: Source) -> Node: | |
"""Handle escape inside []'s""" | |
if source.match('\\'): | |
return parse_escape(source) | |
c = source.get() | |
return Character(c) | |
def parse_quantifier(source: Source) -> Quantifier: | |
"""Handle *, ?, +, {a}, {a,b} {,b} quantifiers""" | |
c = source.get() | |
if c == "{": | |
# parse count, comma, count | |
start = parse_count(source) | |
if source.match(","): | |
end = parse_count(source) | |
q = Quantifier("{", start, end) | |
else: | |
q = Quantifier("{", start) | |
source.expect("}") | |
else: | |
q = Quantifier(c) | |
if source.match("?"): | |
q.greedy = False | |
return q | |
def parse_count(source: Source) -> str: | |
text = '' | |
c = source.match_any('0123456789') | |
while c: | |
text += c | |
c = source.match_any('0123456789') | |
return text | |
@lru_cache(None) | |
def compile(text: str) -> Node: | |
return parse_pattern(Source(text)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment