Skip to content

Instantly share code, notes, and snippets.

@nerodono
Last active April 20, 2024 19:09
Show Gist options
  • Save nerodono/60b999913605ec4b805bbdf0a936f075 to your computer and use it in GitHub Desktop.
Save nerodono/60b999913605ec4b805bbdf0a936f075 to your computer and use it in GitHub Desktop.
I wrote that on my phone because I was bored
from dataclasses import dataclass
from typing import (
Self,
Callable,
Sequence,
TypeAlias,
TypeVar,
Protocol,
Any,
)
from pprint import pprint
from enum import IntEnum, auto
import itertools
import functools
import string
T = TypeVar("T")
K = TypeVar("K")
R = TypeVar("R")
E = TypeVar("E")
class ParseFail(Exception):
...
class Skip:
__slots__ = ()
def __bool__(self) -> bool:
return True
SKIP = Skip()
Alterer: TypeAlias = Callable[[T], T]
def try_alter(f: Alterer[T] | None, value: T) -> T:
return f(value) if f is not None else value
class IndentsStack:
stack: tuple[int, ...]
total: int
def __init__(
self,
stack: tuple[int, ...] = (),
total: int | None = None,
) -> None:
self.stack = stack
self.total = total if total is not None else sum(stack)
def alter(
self,
stack: Alterer[tuple[int, ...]] | None = None,
total: Alterer[int] | None = None,
) -> 'IndentsStack':
return IndentsStack(
stack=try_alter(stack, self.stack),
total=try_alter(total, self.total),
)
def peek(self) -> int | None:
if not self.stack:
return None
return self.stack[-1]
def push(self, level: int) -> 'IndentsStack':
return self.alter(
stack=lambda prev: (*prev, level),
total=lambda prev: prev + level,
)
def pop(self) -> tuple[int, 'IndentsStack']:
last = self.stack[-1]
return (
last,
self.alter(
total=lambda prev: prev - last,
stack=lambda prev: prev[:-1]
)
)
@dataclass
class State:
start: bool
text: str
stack: IndentsStack
def alter(
self,
start: Alterer[bool] | None = None,
text: Alterer[str] | None = None,
stack: Alterer[IndentsStack] | None = None,
) -> 'State':
return State(
start=try_alter(start, self.start),
text=try_alter(text, self.text),
stack=try_alter(stack, self.stack)
)
@classmethod
def new(cls, text: str) -> Self:
return cls(True, text, IndentsStack())
@dataclass
class Number:
value: int
@dataclass
class Newline:
...
@dataclass
class Operator:
op: str
@dataclass
class Name:
value: str
@dataclass
class Indentation:
increase: bool
class Keyword(IntEnum):
IF = auto()
ELSE = auto()
ELIF = auto()
DEF = auto()
CLASS = auto()
Token: TypeAlias = (
Number
| Operator
| Keyword
| Indentation
| Newline
)
T_contra = TypeVar("T_contra", contravariant=True)
K_co = TypeVar("K_co", covariant=True)
class Parser(Protocol[T_contra, K_co]):
def __call__(self, src: T_contra, /) -> tuple[K_co, T_contra]:
...
Predicate: TypeAlias = Callable[[T], bool]
Stateful: TypeAlias = Callable[[State], tuple[T, State]]
def singleton(x: T) -> tuple[T]:
return (x,)
def singleton_of(parser: Parser[K, T]) -> Parser[tuple[K, ...], T]:
@functools.wraps(parser)
def inner(arg: T) -> tuple[tuple[K, ...], T]:
result = parser(arg)
return (singleton(result[0]), result[1])
return inner
def stateless(parser: Parser[K, str]) -> Parser[K, State]:
def stateless_to_stateful(state: State) -> tuple[K, State]:
result, left = parser(state.text)
return (result, state.alter(text=const(left)))
return stateless_to_stateful
def any_of(x: Sequence[str]) -> Predicate[str]:
return lambda src: src in x
is_operator = any_of("+-*&%$@!:<>=|")
is_name = any_of(string.ascii_letters + string.digits)
def const(val: T) -> Callable[[Any], T]:
return lambda _: val
def take_while1(pred: Predicate[str]) -> Parser[str, str]:
def inner(x: str) -> tuple[str, str]:
satis = ''.join(itertools.takewhile(pred, x))
return satis, x[len(satis):]
return inner
def success_if(pred: Predicate[K], parser: Parser[K, T]) -> Parser[K, T]:
def inner(x: T) -> tuple[K, T]:
result = parser(x)
if pred(result[0]):
return result
raise ParseFail(f"{parser} failed to parse {x} ({result})")
return inner
def curry2(src: Callable[[T, E], R]) -> Callable[[T], Callable[[E], R]]:
return functools.wraps(src)(lambda x: lambda y: src(x, y))
def first(src_p: Parser[K, T], f: Callable[[K], R]) -> Parser[R, T]:
def inner(src: T) -> tuple[R, T]:
result = src_p(src)
return (f(result[0]), result[1])
return inner
def compose(f: Callable[[R], K], g: Callable[[T], R]) -> Callable[[T], K]:
return lambda x: f(g(x))
def alt(parsers: tuple[Parser[K, T], ...]) -> Parser[K, T]:
def alt_impl(x: T) -> tuple[K, T]:
for parser in parsers:
try:
return parser(x)
except ParseFail:
continue
raise ParseFail(f"Failed to parse using {parsers} ({x})")
return alt_impl
def indentation(state: State) -> tuple[tuple[Token, ...], State]:
if not state.start:
_, state = stateless(newline)(state)
chars, state = stateless(take_while1(any_of(" ")))(state)
indents_size = len(chars)
if state.stack.total == indents_size:
if state.start:
return (singleton(SKIP), state)
return (singleton(Newline()), state)
elif state.stack.total > indents_size:
dedents: list[Indentation] = []
stack = state.stack
diff = stack.total - indents_size
while diff != 0:
level, stack = stack.pop()
if diff < level:
raise ValueError(f"Unexpected dedent, expected {level}, found {diff}")
diff -= level
dedents.append(Indentation(False))
state = state.alter(stack=const(stack))
return (tuple(dedents), state)
else:
state = state.alter(stack=lambda x: x.push(indents_size - x.total))
return (singleton(Indentation(True)), state)
take_while = compose(curry2(success_if)(bool), take_while1)
newline = first(take_while(any_of("\n")), const(Newline()))
indents = take_while(any_of(" "))
parse_name = first(take_while(is_name), Name)
parse_number = first(take_while(any_of(string.digits)), compose(Number, int))
parse_operator = first(take_while(is_operator), Operator)
parse_skip = first(take_while(any_of(" \t")), const(SKIP))
step = alt((
indentation,
stateless(singleton_of(parse_number)),
stateless(singleton_of(parse_name)),
stateless(singleton_of(parse_operator)),
stateless(singleton_of(parse_skip))
))
def tokenize(state: State) -> list[Token]:
tokens: list[Token] = []
while state.text:
cur_tokens, state = step(state)
state.start = False
tokens.extend(n for n in cur_tokens if n is not SKIP)
return tokens
src = """
x = 1
x = 1
"""
pprint(tokenize(State.new(src)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment