Skip to content

Instantly share code, notes, and snippets.

@slott56
Created July 22, 2019 01:03
Show Gist options
  • Save slott56/851c9706a3f83ffdb04fb0623ab2a20f to your computer and use it in GitHub Desktop.
Save slott56/851c9706a3f83ffdb04fb0623ab2a20f to your computer and use it in GitHub Desktop.
Generate strings based on an RE. The resulting parse tree will be nodes. `node.generate()` should create strings that match the RE.
"""
RE-based generator.
Overview
========
See https://json-schema.org/understanding-json-schema/reference/regular_expressions.html
Also, this: http://www.ecma-international.org/ecma-262/9.0/index.html#sec-regular-expressions
See https://bitbucket.org/mrabarnett/mrab-regex for a reference implementation
The RE language we're going to handle is this:
::
A single unicode character (other than the special characters below) matches itself.
``.``: Matches any character except line break characters. (Be aware that what constitutes a line break character is somewhat dependent on your platform and language environment, but in practice this rarely matters).
``^``: Matches only at the beginning of the string.
``$``: Matches only at the end of the string.
``(...)``: Group a series of regular expressions into a single regular expression.
``|``: Matches either the regular expression preceding or following the | symbol.
``[abc]``: Matches any of the characters inside the square brackets.
``[a-z]``: Matches the range of characters.
``[^abc]``: Matches any character not listed.
``[^a-z]``: Matches any character outside of the range.
``+``: Matches one or more repetitions of the preceding regular expression.
``*``: Matches zero or more repetitions of the preceding regular expression.
``?``: Matches zero or one repetitions of the preceding regular expression.
``+?``, ``*?``, ``??``: The *, +, and ? qualifiers are all greedy; they match as much text as possible. Sometimes this behavior isn’t desired and you want to match as few characters as possible.
``(?!...}``, ``(?=...}``: Negative and positive lookahead.
``{x}``: Match exactly x occurrences of the preceding regular expression.
``{x,y}``: Match at least x and at most y occurrences of the preceding regular expression.
``{x,}``: Match x occurrences or more of the preceding regular expression.
``{x}?``, ``{x,y}?``, ``{x,}?``: Lazy versions of the above expressions.
Also.
These defined sets::
``\\d``: digit
``\\D``: non-digit
``\\s``: whitespace
``\\S``: non-whitespace
``\\w``: word (a-z A-Z 0-9 _)
``\\W``: Non-word
Structure
=========
This module contains
- AST model. All subclasses of an abstract superclass, :class:`Node`.
- parser functions. The entry-point is :func:`compile`.
"""
from functools import lru_cache
import random
import string
from enum import Enum
from typing import Iterator, Optional, Set, cast, List, Union
# AST
# ==========
class Node:
"""A node in the RE AST."""
def __init__(self, item: 'Node') -> None:
self.item = item
def __eq__(self, other: 'Node') -> bool:
return self.__class__ == other.__class__ and self.item == other.item
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.item!r})"
def generate(self) -> Iterator[str]:
raise NotImplementedError
def domain(self) -> Set[str]:
return set()
class Literal(Node):
"""Item is a single Character instance."""
def generate(self) -> Iterator[str]:
yield from self.item.generate()
class Sequence(Node):
"""Used for ordinary sequence of RE's: list of children"""
def __init__(self, *nodes: 'Node') -> None:
self.children = list(nodes)
def __eq__(self, other: 'Sequence') -> bool:
return self.__class__ == other.__class__ and self.children == other.children
def __repr__(self) -> str:
return f"{self.__class__.__name__}(*{repr(self.children)})"
def generate(self) -> Iterator[str]:
for c in self.children:
yield from c.generate()
class Repeated(Node):
"""Used for repetition suffix. Two children: node with pattern and repetition suffix"""
def __init__(self, node: 'Node', quantifier: 'Quantifier') -> None:
self.node = node
self.quantifier = quantifier
def __eq__(self, other: 'Repeated') -> bool:
return (self.__class__ == other.__class__
and self.node == other.node
and self.quantifier == other.quantifier
)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.node!r}, {self.quantifier!r})"
def generate(self) -> Iterator[str]:
size = cast(Quantifier, self.quantifier).repeat()
for _ in range(size):
yield from self.node.generate()
class Quantifier(Node):
"""Used for {a, b}, {a}, {a,}, *, ?, + quantifiers"""
def __init__(self,
simple: str,
start: Optional[str]=None,
end: Optional[str]=None,
greedy: bool=True
):
self.text = simple
self.start = start
self.end = end
self.greedy = greedy
self.low = self.high = 1
if self.text == '?':
self.low, self.high = 0, 1
elif self.text == '+':
self.low, self.high = 1, 8
elif self.text == '*':
self.low, self.high = 0, 7
elif self.text == '{':
if self.start:
self.low = int(self.start)
else:
raise Exception(f"Bad {{{self.start}-{self.end}}}")
if self.end is None:
self.high = self.low
elif self.end == '':
self.high = 8
else:
self.high = int(self.end)
else:
raise Exception(f"Bad {{{self.start}-{self.end}}}")
def __eq__(self, other: 'Quantifier') -> bool:
return (self.__class__ == other.__class__
and self.text == other.text
and self.start == other.start
and self.end == other.end
and self.greedy == other.greedy
)
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}({self.text!r}, {self.start!r}, {self.end!r}, "
f"greedy={self.greedy})"
)
def repeat(self) -> int:
"""
This is based on the overall minLength and maxLength in the schema and
the number of bytes already generated. Ideally, we can partition the '*' and '+'
operators to land
"""
return random.randint(self.low, self.high)
class Alternation(Node):
"""The | operators with a sequence of children"""
def __init__(self, *nodes: 'Node') -> None:
self.children = list(nodes)
def __eq__(self, other: 'Alternation') -> bool:
return self.__class__ == other.__class__ and self.children == other.children
def __repr__(self) -> str:
return f"{self.__class__.__name__}(*{repr(self.children)})"
def generate(self) -> Iterator[str]:
choice = random.choice(self.children)
yield from choice.generate()
class Group(Node):
"""The () with a child."""
def __init__(self, node: 'Node') -> None:
self.node = node
def __eq__(self, other: 'Group') -> bool:
return self.__class__ == other.__class__ and self.node == other.node
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.node})"
def generate(self) -> Iterator[str]:
yield from self.node.generate()
class Sense(str, Enum):
NEGATIVE = "!"
POSITIVE = "="
class PatternSet(Node):
"""One of the sets, either []... or one of the defined sets, including the special . set."""
def __init__(self, *items: 'Node', sense: Sense=Sense.POSITIVE) -> None:
self.children = list(items)
self.sense: Sense = sense
self.concrete_domain: Optional[List[str]] = None # ideally Set[str], but unit tests are
# complex
def __eq__(self, other: 'PatternSet') -> bool:
return (
self.__class__ == other.__class__
and self.children == other.children
and self.sense == other.sense
)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(*{self.children!r}, sense={self.sense!r})"
def domain(self) -> Set[str]:
base_set: Set[str]
if self.sense == Sense.POSITIVE:
base_set = set()
for item in self.children:
base_set |= item.domain()
return base_set
else:
base_set = set(string.printable)
for item in self.children:
base_set -= item.domain()
return base_set
def generate(self) -> Iterator[str]:
# Sorted makes unit tests predictable.
if self.concrete_domain is None:
self.concrete_domain = list(sorted(self.domain()))
char = random.choice(self.concrete_domain)
yield char
class SetAll(PatternSet):
"""The "." set"""
def __init__(self) -> None:
super().__init__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
def domain(self) -> Set[str]:
return set(string.printable)
class SetDigit(PatternSet):
"""The \\d set"""
def __init__(self, sense=Sense.POSITIVE) -> None:
super().__init__(sense=sense)
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
def domain(self) -> Set[str]:
return set(string.digits)
class SetNonDigit(SetDigit):
"""The \\D set"""
def __init__(self) -> None:
super().__init__(sense=Sense.NEGATIVE)
def domain(self) -> Set[str]:
return set(string.printable) - set(string.digits)
class SetWhitespace(PatternSet):
"""The \\s set"""
def __init__(self, sense=Sense.POSITIVE) -> None:
super().__init__(sense=sense)
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
def domain(self) -> Set[str]:
return set(string.whitespace)
class SetNonWhitespace(SetWhitespace):
"""The \\S set"""
def __init__(self) -> None:
super().__init__(sense=Sense.NEGATIVE)
def domain(self) -> Set[str]:
return set(string.printable) - set(string.whitespace)
class SetWord(PatternSet):
"""The \\w set"""
def __init__(self, sense=Sense.POSITIVE) -> None:
super().__init__(sense=sense)
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
def domain(self) -> Set[str]:
return (
set(string.ascii_uppercase)
| set(string.ascii_lowercase)
| set(string.digits)
| {'_'}
)
class SetNonWord(SetWord):
"""The \\W set"""
def __init__(self) -> None:
super().__init__(sense=Sense.NEGATIVE)
def domain(self) -> Set[str]:
return set(string.printable) - SetWord().domain()
class Character(Node):
"""Item within a PatternSet, or a Literal"""
def __init__(self, char: str) -> None:
self.char = char
def __eq__(self, other: 'Character') -> bool:
return self.__class__ == other.__class__ and self.char == other.char
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.char!r})"
def generate(self) -> Iterator[str]:
yield self.char
def domain(self ) -> Set[str]:
return {self.char}
class Range(Character):
"""Item with in a PatternSet -- a number of characters"""
def __init__(self, start: 'Character', end: 'Character') -> None:
self.start = start
self.end = end
def __eq__(self, other: 'Range') -> bool:
return (
self.__class__ == other.__class__
and self.start == other.start
and self.end == other.end
)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.start!r}, {self.end!r})"
def domain(self) -> Set[str]:
return set(chr(x) for x in range(ord(self.start.char), ord(self.end.char)+1))
class Start(Node):
"""The "^" Start of string node"""
def generate(self) -> Iterator[str]:
yield ''
class End(Node):
"""The "$" End of string node"""
def generate(self) -> Iterator[str]:
yield ''
class PositiveLookahead(Node):
"""
The (?=...) lookahead pattern
TODO: This constrains the pattern to include the following text.
Effectively, this is part of a sequence.
"""
def __init__(self, node: 'Node') -> None:
self.text = ''
self.node = node
def __eq__(self, other: 'PositiveLookahead') -> bool:
return self.__class__ == other.__class__ and self.node == other.node
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.node!r})"
def generate(self) -> Iterator[str]:
yield ''
class NegativeLookahead(Node):
"""
The (?!...) lookahead pattern
# TODO: This constrains the pattern to include the following text.
Effectively, this is part of a sequence. But. The following part must NOT
match the pattern. Hard to do.
"""
def __init__(self, node: 'Node') -> None:
self.text = ''
self.node = node
def __eq__(self, other: 'NegativeLookahead') -> bool:
return self.__class__ == other.__class__ and self.node == other.node
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.node!r})"
def generate(self) -> Iterator[str]:
yield ''
# Parser
# ==========
class Source:
"""
Source of characters in the Regular Expression.
This offers an unget capability to allow 1-char lookahead inside ()'s.
"""
def __init__(self, text: str) -> None:
self.text = text
self.pos = 0
def get(self) -> str:
try:
character = self.text[self.pos]
self.pos += 1
return character
except IndexError:
return ''
def unget(self, character: str) -> None:
self.pos -= len(character)
assert self.text.startswith(character, self.pos)
def match(self, match_text: str) -> bool:
if self.text.startswith(match_text, self.pos):
self.pos += len(match_text)
return True
return False
def match_any(self, match_text: str) -> str:
if self.text[self.pos] in match_text:
character = self.text[self.pos]
self.pos += 1
return character
return ''
def expect(self, match_text: str) -> None:
if not self.match(match_text):
raise SyntaxError(f"Missing {match_text}")
def parse_pattern(source: Source) -> Node:
"""Entry-point for pattern parsing."""
pattern = [parse_sequence(source)]
while source.match("|"):
pattern.append(parse_sequence(source))
if len(pattern) == 1:
return pattern[0]
node = Alternation(*pattern)
return node
def parse_sequence(source: Source) -> Node:
pattern: List[Node] = []
while True:
c = source.get()
if c == '':
break
elif c in (')', '|'):
# End of ()'s in sequence or | between alternatives
break
elif c == '\\':
# Escaped character
pe = parse_escape(source)
if isinstance(pe, Character):
pattern.append(Literal(pe))
else:
pattern.append(pe)
elif c == '(':
# ()'d grouping
pattern.append(parse_paren(source))
elif c == '.':
# the '.' set
pattern.append(SetAll())
elif c == '[':
# A general set
pattern.append(parse_set(source))
elif c == '^':
# The start-of-line marker
pattern.append(Start(Character(c)))
elif c == '$':
# The end-of-line marker
pattern.append(End(Character(c)))
elif c in ('?', '*', '+', '{'):
source.unget(c)
suffix = parse_quantifier(source)
pattern[-1] = Repeated(pattern[-1], suffix)
else:
pattern.append(Literal(Character(c)))
return Sequence(*pattern)
def parse_escape(source: Source) -> Node:
c = source.get()
if c == 'd':
return SetDigit()
elif c == 'D':
return SetNonDigit()
elif c == 's':
return SetWhitespace()
elif c == 'S':
return SetNonWhitespace()
elif c == 'w':
return SetWord()
elif c == 'W':
return SetNonWord()
elif c in ('a', 'b', 'f', 'n', 'r', 't', 'v'):
escapes = {
"a": "\a",
"b": "\b",
"f": "\f",
"n": "\n",
"r": "\r",
"t": "\t",
"v": "\v",
}
return Character(escapes[c])
else:
return Character(c)
def parse_paren(source: Source) -> Node:
"""Started with (..."""
c = source.get()
if c == '?':
c2 = source.get()
if c2 in ('!', '='):
# Lookahead
return parse_lookahead(source, c2)
else:
# Unsupported (?...) syntax
raise SyntaxError("Unsupported (?...)")
else:
source.unget(c)
sub_pattern = parse_pattern(source)
# source.expect(")")
return Group(sub_pattern)
def parse_lookahead(source: Source, positive: str) -> Node:
sub_pattern = parse_pattern(source)
source.expect(")")
if positive == '=':
return PositiveLookahead(sub_pattern)
elif positive == '!':
return NegativeLookahead(sub_pattern)
else:
raise SyntaxError("Unsupported (?...)")
def parse_set(source: Source) -> Node:
"""[...] and [^...] with special cases for - ranges and \\]"""
sense = Sense.NEGATIVE if source.match("^") else Sense.POSITIVE
items: List[Node] = []
members = parse_set_member(source)
if members is None:
raise Exception("Invalid set")
items.extend(members)
while True:
if source.match("]"):
break
members = parse_set_member(source)
if members is None:
break
items.extend(members)
return PatternSet(*items, sense=sense)
def parse_set_member(
source: Source
) -> Union[List[Character], List[PatternSet], List[Range], List[Node], None]:
"""letter or letter-letter or special case of letter- at end of set."""
start = parse_set_item(source)
if isinstance(start, Character) and start.char == '':
return None
if source.match("-"):
if source.match("]"):
# Special trailing "-" case.
source.unget("]")
if isinstance(start, Character):
return [start, Character("-")]
else:
raise Exception(f"Bad range start {start!r}")
end = parse_set_item(source)
if isinstance(start, Character) and isinstance(end, Character):
return [Range(start, end)]
raise Exception(f"Bad ranges {start!r}-{end!r}")
else:
# TODO: Check range validity
return [start]
def parse_set_item(source: Source) -> Node:
"""Handle escape inside []'s"""
if source.match('\\'):
return parse_escape(source)
c = source.get()
return Character(c)
def parse_quantifier(source: Source) -> Quantifier:
"""Handle *, ?, +, {a}, {a,b} {,b} quantifiers"""
c = source.get()
if c == "{":
# parse count, comma, count
start = parse_count(source)
if source.match(","):
end = parse_count(source)
q = Quantifier("{", start, end)
else:
q = Quantifier("{", start)
source.expect("}")
else:
q = Quantifier(c)
if source.match("?"):
q.greedy = False
return q
def parse_count(source: Source) -> str:
text = ''
c = source.match_any('0123456789')
while c:
text += c
c = source.match_any('0123456789')
return text
@lru_cache(None)
def compile(text: str) -> Node:
return parse_pattern(Source(text))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment