Skip to content

Instantly share code, notes, and snippets.

@omaskery
Last active July 7, 2019 21:46
Show Gist options
  • Save omaskery/298b5364aa90db6ababe3371635ffc73 to your computer and use it in GitHub Desktop.
Save omaskery/298b5364aa90db6ababe3371635ffc73 to your computer and use it in GitHub Desktop.
Playing with writing a parser combinator in Python
package uk.co.maskery.parsercombinator;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
class ParserException extends RuntimeException {
ParserException(String message) {
super(message);
}
}
class TokenStream<T> {
private final List<T> tokens;
private int index;
TokenStream(Stream<T> tokens) {
this.tokens = tokens.collect(Collectors.toList());
this.index = 0;
}
boolean isEnd() {
return index >= tokens.size();
}
T get() {
if (isEnd()) {
throw new ParserException("unexpected end of stream");
}
var result = tokens.get(index);
index++;
return result;
}
Optional<T> peek() {
return isEnd() ? Optional.empty() : Optional.of(tokens.get(index));
}
<U> Optional<U> withBacktrack(Function<TokenStream<T>, U> action) {
var startingIndex = index;
try {
return Optional.of(action.apply(this));
} catch (ParserException exc) {
index = startingIndex;
return Optional.empty();
}
}
}
public class ParserCombinator {
private static <T> Function<TokenStream<T>, T> literal(T value) {
return stream -> {
if (stream.isEnd()) {
throw new ParserException(String.format("unexpected end of stream, expected literal '%s'", value));
}
var next = stream.get();
if (!next.equals(value)) {
throw new ParserException(String.format("expected literal '%s', got '%s'", value, next));
}
return next;
};
}
@SafeVarargs
private static <T> Function<TokenStream<T>, List<?>> sequence(Function<TokenStream<T>, ?>... parsers) {
return stream -> {
var result = new ArrayList<>();
for (var parser : parsers) {
result.add(parser.apply(stream));
}
return result;
};
}
@SafeVarargs
private static <T> Function<TokenStream<T>, ?> one_of(Function<TokenStream<T>, ?>... alternatives) {
return stream -> {
for (var parser : alternatives) {
var result = stream.withBacktrack(parser);
if (result.isPresent()) {
return result.get();
}
}
throw new ParserException("none of the possible alternatives matched");
};
}
private static <T> Function<TokenStream<T>, T> any() {
return TokenStream::get;
}
private static <T, U> Function<TokenStream<T>, Optional<U>> optional(Function<TokenStream<T>, U> parser) {
return stream -> stream.withBacktrack(parser);
}
private static <T, U> Function<TokenStream<T>, List<U>> at_least_one(Function<TokenStream<T>, U> parser) {
return stream -> {
var result = new ArrayList<U>();
result.add(parser.apply(stream));
while (true) {
var aditional = stream.withBacktrack(parser);
aditional.ifPresent(result::add);
if (aditional.isEmpty()) {
break;
}
}
return result;
};
}
private static <T> Function<TokenStream<T>, Void> end_of_input() {
return stream -> {
if (!stream.isEnd()) {
throw new ParserException(String.format("expected end of input, but got '%s'", stream.peek()));
}
return null;
};
}
public static void main(String[] args) {
var option_parser = one_of(
literal("informs"),
literal("traps"),
sequence(
literal("version"),
one_of(
sequence(
literal("3"),
optional(one_of(
literal("auth"),
literal("noauth"),
literal("priv")
))
),
any()
)
)
);
var notification_type_parser = one_of(
sequence(literal("udp-port"), any()),
any()
);
var snmp_server_host_parser = sequence(
literal("snmp-server"),
literal("host"),
any(),
optional(at_least_one(option_parser)),
any(),
optional(at_least_one(notification_type_parser)),
end_of_input()
);
var input = "snmp-server host some.host.name informs version 2c community-string udp-port 20 snmp";
var tokens = new TokenStream<>(Stream.of(input.split(" ")));
var parsed = snmp_server_host_parser.apply(tokens);
System.out.println("input: " + input);
System.out.println("output: " + parsed);
}
}
from dataclasses import dataclass, field
import contextlib
import typing
class ParserError(Exception):
pass
class EndOfStream(ParserError):
pass
class Stream:
def __init__(self, tokens):
self.tokens = tokens
self.index = 0
def is_end(self):
return self.index >= len(self.tokens)
def peek(self):
if self.index < len(self.tokens):
return self.tokens[self.index]
return None
def get(self):
if self.index < len(self.tokens):
result = self.tokens[self.index]
self.index += 1
return result
else:
raise EndOfStream()
@contextlib.contextmanager
def backtrack(stream):
index = stream.index
try:
yield
except ParserError:
stream.index = index
raise
def literal(value):
def _parser(stream):
try:
token = stream.get()
if token != value:
raise ParserError(f"unexpected token '{token}', expected '{value}'")
return token
except EndOfStream:
raise ParserError(f"unexpected end of stream, expected '{value}'")
return _parser
def one_of(*parsers):
def _parser(stream):
for parser in parsers:
try:
with backtrack(stream):
return parser(stream)
except ParserError:
pass
raise ParserError(f"no alternatives matched")
return _parser
def sequence(*parsers):
def _parser(stream):
return [
parser(stream)
for parser in parsers
]
return _parser
def optional(parser):
def _parser(stream):
try:
with backtrack(stream):
return parser(stream)
except ParserError:
return None
return _parser
def any_token():
def _parser(stream):
return stream.get()
return _parser
def reaches_end_of_input(parser):
def _parser(stream):
result = parser(stream)
if not stream.is_end():
raise ParserError(f"unexpected token '{stream.peek()}', expected end of stream")
return result
return _parser
def at_least_one(parser):
def _parser(stream):
first = parser(stream)
results = [first]
while not stream.is_end():
try:
with backtrack(stream):
additional = parser(stream)
results.append(additional)
except ParserError:
break
return results
return _parser
@dataclass(frozen=True)
class SnmpVersion:
version: str
auth: typing.Optional[str] = None
SnmpHostOption = typing.Union[str, SnmpVersion]
@dataclass(frozen=True)
class SnmpServerHostCommand:
hostname: str
community_string: str
options: typing.List[SnmpHostOption] = field(default_factory=list)
udp_port: int = None
notification_types: typing.List[str] = field(default_factory=list)
def parse_snmp_server_host(config_line):
stream = Stream(config_line.strip().split())
option_parser = one_of(
literal("informs"),
literal("traps"),
sequence(
literal("version"),
one_of(
sequence(
literal("3"),
optional(one_of(
literal("auth"),
literal("noauth"),
literal("priv")
))
),
any_token()
)
)
)
notification_type_parser = one_of(
sequence(literal("udp-port"), any_token()),
any_token()
)
snmp_server_host_parser = reaches_end_of_input(sequence(
literal("snmp-server"),
literal("host"),
any_token(),
optional(at_least_one(option_parser)),
any_token(),
optional(at_least_one(notification_type_parser))
))
parsed = snmp_server_host_parser(stream)
def _map_option(option):
if isinstance(option, str):
return option
_, version = option
if isinstance(version, str):
return SnmpVersion(version, None)
return SnmpVersion(*version)
options = list(map(_map_option, parsed[3])) if parsed[3] else []
udp_port = [
int(entry[1])
for entry in parsed[5]
if len(entry) > 1 and entry[0] == 'udp-port'
] if parsed[5] else []
udp_port = udp_port[0] if udp_port else None
notification_types = [
entry for entry in parsed[5]
if isinstance(entry, str)
] if parsed[5] else []
return SnmpServerHostCommand(
hostname=parsed[2],
options=options,
community_string=parsed[4],
udp_port=udp_port,
notification_types=notification_types
)
def main():
test_cases = [
(
"snmp-server host some.host.name wow",
SnmpServerHostCommand(
hostname='some.host.name',
community_string='wow'
)
),
(
"snmp-server host some.host.name wow udp-port 40",
SnmpServerHostCommand(
hostname='some.host.name',
community_string='wow',
udp_port=40
)
),
(
"snmp-server host some.host.name wow udp-port 40 snmp",
SnmpServerHostCommand(
hostname='some.host.name',
community_string='wow',
udp_port=40,
notification_types=['snmp']
)
),
(
"snmp-server host some.host.name wow snmp udp-port 40",
SnmpServerHostCommand(
hostname='some.host.name',
community_string='wow',
udp_port=40,
notification_types=['snmp']
)
),
(
"snmp-server host some.host.name informs wow",
SnmpServerHostCommand(
hostname='some.host.name',
community_string='wow',
options=['informs']
)
),
(
"snmp-server host some.host.name traps wow",
SnmpServerHostCommand(
hostname='some.host.name',
community_string='wow',
options=['traps']
)
),
(
"snmp-server host some.host.name informs traps wow",
SnmpServerHostCommand(
hostname='some.host.name',
community_string='wow',
options=['informs', 'traps']
)
),
(
"snmp-server host some.host.name version 1 wow",
SnmpServerHostCommand(
hostname='some.host.name',
community_string='wow',
options=[SnmpVersion('1')]
)
),
(
"snmp-server host some.host.name version 2 wow",
SnmpServerHostCommand(
hostname='some.host.name',
community_string='wow',
options=[SnmpVersion('2')]
)
),
(
"snmp-server host some.host.name version 2c wow",
SnmpServerHostCommand(
hostname='some.host.name',
community_string='wow',
options=[SnmpVersion('2c')]
)
),
(
"snmp-server host some.host.name version 3 wow",
SnmpServerHostCommand(
hostname='some.host.name',
community_string='wow',
options=[SnmpVersion('3')]
)
),
(
"snmp-server host some.host.name version 3 auth wow",
SnmpServerHostCommand(
hostname='some.host.name',
community_string='wow',
options=[SnmpVersion('3', 'auth')]
)
),
(
"snmp-server host some.host.name version 3 noauth wow",
SnmpServerHostCommand(
hostname='some.host.name',
community_string='wow',
options=[SnmpVersion('3', 'noauth')]
)
),
(
"snmp-server host some.host.name version 3 priv wow",
SnmpServerHostCommand(
hostname='some.host.name',
community_string='wow',
options=[SnmpVersion('3', 'priv')]
)
),
(
"snmp-server host some.host.name informs version 2c community-string udp-port 20 snmp",
SnmpServerHostCommand(
hostname="some.host.name",
options=['informs', SnmpVersion('2c')],
community_string='community-string',
udp_port=20,
notification_types=['snmp']
)
)
]
failures = 0
for config_line, expected_result in test_cases:
actual_result = parse_snmp_server_host(config_line)
success = actual_result == expected_result
print(f"parsing '{config_line}'")
print(f" result: {actual_result}")
print(f" {'success' if success else 'failure'}")
if not success:
print(f" expected: {expected_result}")
failures += 1
print(f"{len(test_cases)} tests, {failures} failures")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment