Last active
July 7, 2019 21:46
-
-
Save omaskery/298b5364aa90db6ababe3371635ffc73 to your computer and use it in GitHub Desktop.
Playing with writing a parser combinator in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package uk.co.maskery.parsercombinator; | |
import java.util.ArrayList; | |
import java.util.List; | |
import java.util.Optional; | |
import java.util.function.Function; | |
import java.util.stream.Collectors; | |
import java.util.stream.Stream; | |
class ParserException extends RuntimeException { | |
ParserException(String message) { | |
super(message); | |
} | |
} | |
class TokenStream<T> { | |
private final List<T> tokens; | |
private int index; | |
TokenStream(Stream<T> tokens) { | |
this.tokens = tokens.collect(Collectors.toList()); | |
this.index = 0; | |
} | |
boolean isEnd() { | |
return index >= tokens.size(); | |
} | |
T get() { | |
if (isEnd()) { | |
throw new ParserException("unexpected end of stream"); | |
} | |
var result = tokens.get(index); | |
index++; | |
return result; | |
} | |
Optional<T> peek() { | |
return isEnd() ? Optional.empty() : Optional.of(tokens.get(index)); | |
} | |
<U> Optional<U> withBacktrack(Function<TokenStream<T>, U> action) { | |
var startingIndex = index; | |
try { | |
return Optional.of(action.apply(this)); | |
} catch (ParserException exc) { | |
index = startingIndex; | |
return Optional.empty(); | |
} | |
} | |
} | |
public class ParserCombinator { | |
private static <T> Function<TokenStream<T>, T> literal(T value) { | |
return stream -> { | |
if (stream.isEnd()) { | |
throw new ParserException(String.format("unexpected end of stream, expected literal '%s'", value)); | |
} | |
var next = stream.get(); | |
if (!next.equals(value)) { | |
throw new ParserException(String.format("expected literal '%s', got '%s'", value, next)); | |
} | |
return next; | |
}; | |
} | |
@SafeVarargs | |
private static <T> Function<TokenStream<T>, List<?>> sequence(Function<TokenStream<T>, ?>... parsers) { | |
return stream -> { | |
var result = new ArrayList<>(); | |
for (var parser : parsers) { | |
result.add(parser.apply(stream)); | |
} | |
return result; | |
}; | |
} | |
@SafeVarargs | |
private static <T> Function<TokenStream<T>, ?> one_of(Function<TokenStream<T>, ?>... alternatives) { | |
return stream -> { | |
for (var parser : alternatives) { | |
var result = stream.withBacktrack(parser); | |
if (result.isPresent()) { | |
return result.get(); | |
} | |
} | |
throw new ParserException("none of the possible alternatives matched"); | |
}; | |
} | |
private static <T> Function<TokenStream<T>, T> any() { | |
return TokenStream::get; | |
} | |
private static <T, U> Function<TokenStream<T>, Optional<U>> optional(Function<TokenStream<T>, U> parser) { | |
return stream -> stream.withBacktrack(parser); | |
} | |
private static <T, U> Function<TokenStream<T>, List<U>> at_least_one(Function<TokenStream<T>, U> parser) { | |
return stream -> { | |
var result = new ArrayList<U>(); | |
result.add(parser.apply(stream)); | |
while (true) { | |
var aditional = stream.withBacktrack(parser); | |
aditional.ifPresent(result::add); | |
if (aditional.isEmpty()) { | |
break; | |
} | |
} | |
return result; | |
}; | |
} | |
private static <T> Function<TokenStream<T>, Void> end_of_input() { | |
return stream -> { | |
if (!stream.isEnd()) { | |
throw new ParserException(String.format("expected end of input, but got '%s'", stream.peek())); | |
} | |
return null; | |
}; | |
} | |
public static void main(String[] args) { | |
var option_parser = one_of( | |
literal("informs"), | |
literal("traps"), | |
sequence( | |
literal("version"), | |
one_of( | |
sequence( | |
literal("3"), | |
optional(one_of( | |
literal("auth"), | |
literal("noauth"), | |
literal("priv") | |
)) | |
), | |
any() | |
) | |
) | |
); | |
var notification_type_parser = one_of( | |
sequence(literal("udp-port"), any()), | |
any() | |
); | |
var snmp_server_host_parser = sequence( | |
literal("snmp-server"), | |
literal("host"), | |
any(), | |
optional(at_least_one(option_parser)), | |
any(), | |
optional(at_least_one(notification_type_parser)), | |
end_of_input() | |
); | |
var input = "snmp-server host some.host.name informs version 2c community-string udp-port 20 snmp"; | |
var tokens = new TokenStream<>(Stream.of(input.split(" "))); | |
var parsed = snmp_server_host_parser.apply(tokens); | |
System.out.println("input: " + input); | |
System.out.println("output: " + parsed); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass, field | |
import contextlib | |
import typing | |
class ParserError(Exception): | |
pass | |
class EndOfStream(ParserError): | |
pass | |
class Stream: | |
def __init__(self, tokens): | |
self.tokens = tokens | |
self.index = 0 | |
def is_end(self): | |
return self.index >= len(self.tokens) | |
def peek(self): | |
if self.index < len(self.tokens): | |
return self.tokens[self.index] | |
return None | |
def get(self): | |
if self.index < len(self.tokens): | |
result = self.tokens[self.index] | |
self.index += 1 | |
return result | |
else: | |
raise EndOfStream() | |
@contextlib.contextmanager | |
def backtrack(stream): | |
index = stream.index | |
try: | |
yield | |
except ParserError: | |
stream.index = index | |
raise | |
def literal(value): | |
def _parser(stream): | |
try: | |
token = stream.get() | |
if token != value: | |
raise ParserError(f"unexpected token '{token}', expected '{value}'") | |
return token | |
except EndOfStream: | |
raise ParserError(f"unexpected end of stream, expected '{value}'") | |
return _parser | |
def one_of(*parsers): | |
def _parser(stream): | |
for parser in parsers: | |
try: | |
with backtrack(stream): | |
return parser(stream) | |
except ParserError: | |
pass | |
raise ParserError(f"no alternatives matched") | |
return _parser | |
def sequence(*parsers): | |
def _parser(stream): | |
return [ | |
parser(stream) | |
for parser in parsers | |
] | |
return _parser | |
def optional(parser): | |
def _parser(stream): | |
try: | |
with backtrack(stream): | |
return parser(stream) | |
except ParserError: | |
return None | |
return _parser | |
def any_token(): | |
def _parser(stream): | |
return stream.get() | |
return _parser | |
def reaches_end_of_input(parser): | |
def _parser(stream): | |
result = parser(stream) | |
if not stream.is_end(): | |
raise ParserError(f"unexpected token '{stream.peek()}', expected end of stream") | |
return result | |
return _parser | |
def at_least_one(parser): | |
def _parser(stream): | |
first = parser(stream) | |
results = [first] | |
while not stream.is_end(): | |
try: | |
with backtrack(stream): | |
additional = parser(stream) | |
results.append(additional) | |
except ParserError: | |
break | |
return results | |
return _parser | |
@dataclass(frozen=True) | |
class SnmpVersion: | |
version: str | |
auth: typing.Optional[str] = None | |
SnmpHostOption = typing.Union[str, SnmpVersion] | |
@dataclass(frozen=True) | |
class SnmpServerHostCommand: | |
hostname: str | |
community_string: str | |
options: typing.List[SnmpHostOption] = field(default_factory=list) | |
udp_port: int = None | |
notification_types: typing.List[str] = field(default_factory=list) | |
def parse_snmp_server_host(config_line): | |
stream = Stream(config_line.strip().split()) | |
option_parser = one_of( | |
literal("informs"), | |
literal("traps"), | |
sequence( | |
literal("version"), | |
one_of( | |
sequence( | |
literal("3"), | |
optional(one_of( | |
literal("auth"), | |
literal("noauth"), | |
literal("priv") | |
)) | |
), | |
any_token() | |
) | |
) | |
) | |
notification_type_parser = one_of( | |
sequence(literal("udp-port"), any_token()), | |
any_token() | |
) | |
snmp_server_host_parser = reaches_end_of_input(sequence( | |
literal("snmp-server"), | |
literal("host"), | |
any_token(), | |
optional(at_least_one(option_parser)), | |
any_token(), | |
optional(at_least_one(notification_type_parser)) | |
)) | |
parsed = snmp_server_host_parser(stream) | |
def _map_option(option): | |
if isinstance(option, str): | |
return option | |
_, version = option | |
if isinstance(version, str): | |
return SnmpVersion(version, None) | |
return SnmpVersion(*version) | |
options = list(map(_map_option, parsed[3])) if parsed[3] else [] | |
udp_port = [ | |
int(entry[1]) | |
for entry in parsed[5] | |
if len(entry) > 1 and entry[0] == 'udp-port' | |
] if parsed[5] else [] | |
udp_port = udp_port[0] if udp_port else None | |
notification_types = [ | |
entry for entry in parsed[5] | |
if isinstance(entry, str) | |
] if parsed[5] else [] | |
return SnmpServerHostCommand( | |
hostname=parsed[2], | |
options=options, | |
community_string=parsed[4], | |
udp_port=udp_port, | |
notification_types=notification_types | |
) | |
def main(): | |
test_cases = [ | |
( | |
"snmp-server host some.host.name wow", | |
SnmpServerHostCommand( | |
hostname='some.host.name', | |
community_string='wow' | |
) | |
), | |
( | |
"snmp-server host some.host.name wow udp-port 40", | |
SnmpServerHostCommand( | |
hostname='some.host.name', | |
community_string='wow', | |
udp_port=40 | |
) | |
), | |
( | |
"snmp-server host some.host.name wow udp-port 40 snmp", | |
SnmpServerHostCommand( | |
hostname='some.host.name', | |
community_string='wow', | |
udp_port=40, | |
notification_types=['snmp'] | |
) | |
), | |
( | |
"snmp-server host some.host.name wow snmp udp-port 40", | |
SnmpServerHostCommand( | |
hostname='some.host.name', | |
community_string='wow', | |
udp_port=40, | |
notification_types=['snmp'] | |
) | |
), | |
( | |
"snmp-server host some.host.name informs wow", | |
SnmpServerHostCommand( | |
hostname='some.host.name', | |
community_string='wow', | |
options=['informs'] | |
) | |
), | |
( | |
"snmp-server host some.host.name traps wow", | |
SnmpServerHostCommand( | |
hostname='some.host.name', | |
community_string='wow', | |
options=['traps'] | |
) | |
), | |
( | |
"snmp-server host some.host.name informs traps wow", | |
SnmpServerHostCommand( | |
hostname='some.host.name', | |
community_string='wow', | |
options=['informs', 'traps'] | |
) | |
), | |
( | |
"snmp-server host some.host.name version 1 wow", | |
SnmpServerHostCommand( | |
hostname='some.host.name', | |
community_string='wow', | |
options=[SnmpVersion('1')] | |
) | |
), | |
( | |
"snmp-server host some.host.name version 2 wow", | |
SnmpServerHostCommand( | |
hostname='some.host.name', | |
community_string='wow', | |
options=[SnmpVersion('2')] | |
) | |
), | |
( | |
"snmp-server host some.host.name version 2c wow", | |
SnmpServerHostCommand( | |
hostname='some.host.name', | |
community_string='wow', | |
options=[SnmpVersion('2c')] | |
) | |
), | |
( | |
"snmp-server host some.host.name version 3 wow", | |
SnmpServerHostCommand( | |
hostname='some.host.name', | |
community_string='wow', | |
options=[SnmpVersion('3')] | |
) | |
), | |
( | |
"snmp-server host some.host.name version 3 auth wow", | |
SnmpServerHostCommand( | |
hostname='some.host.name', | |
community_string='wow', | |
options=[SnmpVersion('3', 'auth')] | |
) | |
), | |
( | |
"snmp-server host some.host.name version 3 noauth wow", | |
SnmpServerHostCommand( | |
hostname='some.host.name', | |
community_string='wow', | |
options=[SnmpVersion('3', 'noauth')] | |
) | |
), | |
( | |
"snmp-server host some.host.name version 3 priv wow", | |
SnmpServerHostCommand( | |
hostname='some.host.name', | |
community_string='wow', | |
options=[SnmpVersion('3', 'priv')] | |
) | |
), | |
( | |
"snmp-server host some.host.name informs version 2c community-string udp-port 20 snmp", | |
SnmpServerHostCommand( | |
hostname="some.host.name", | |
options=['informs', SnmpVersion('2c')], | |
community_string='community-string', | |
udp_port=20, | |
notification_types=['snmp'] | |
) | |
) | |
] | |
failures = 0 | |
for config_line, expected_result in test_cases: | |
actual_result = parse_snmp_server_host(config_line) | |
success = actual_result == expected_result | |
print(f"parsing '{config_line}'") | |
print(f" result: {actual_result}") | |
print(f" {'success' if success else 'failure'}") | |
if not success: | |
print(f" expected: {expected_result}") | |
failures += 1 | |
print(f"{len(test_cases)} tests, {failures} failures") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment