Skip to content

Instantly share code, notes, and snippets.

@nhumrich
Last active March 14, 2023 19:53
Show Gist options
  • Save nhumrich/9faffb66c368f03971abf90d6e9cb042 to your computer and use it in GitHub Desktop.
Save nhumrich/9faffb66c368f03971abf90d6e9cb042 to your computer and use it in GitHub Desktop.
Pep501 proposal with tests
class TemplateLiteral:
__slots__ = ("raw_template", "parsed_template", "field_values", "format_specifiers")
def __new__(cls, raw_template, parsed_template, field_values, format_specifiers):
self = super().__new__(cls)
self.raw_template = raw_template
if len(parsed_template) == 0:
raise ValueError("'parsed_template' must contain at least one value")
self.parsed_template = parsed_template
self.field_values = field_values
self.format_specifiers = format_specifiers
return self
def __bool__(self):
return bool(self.raw_template)
def __add__(self, other):
if isinstance(other, TemplateLiteral):
if (
self.parsed_template
and self.parsed_template[-1][1] is None
and other.parsed_template
):
# merge the last string of self with the first string of other
content = self.parsed_template[-1][0]
new_parsed_template = (
self.parsed_template[:-1]
+ (
(
content + other.parsed_template[0][0],
other.parsed_template[0][1],
),
)
+ other.parsed_template[1:]
)
else:
new_parsed_template = self.parsed_template + other.parsed_template
return TemplateLiteral(
self.raw_template + other.raw_template,
new_parsed_template,
self.field_values + other.field_values,
self.format_specifiers + other.format_specifiers,
)
if isinstance(other, str):
if self.parsed_template and self.parsed_template[-1][1] is None:
# merge string with last value
new_parsed_template = self.parsed_template[:-1] + (
(self.parsed_template[-1][0] + other, None),
)
else:
new_parsed_template = self.parsed_template + ((other, None),)
return TemplateLiteral(
self.raw_template + other,
new_parsed_template,
self.field_values,
self.format_specifiers,
)
else:
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)
def __radd__(self, other):
if isinstance(other, str):
if self.parsed_template:
new_parsed_template = (
(other + self.parsed_template[0][0], self.parsed_template[0][1]),
) + self.parsed_template[1:]
else:
new_parsed_template = ((other, None),)
return TemplateLiteral(
other + self.raw_template,
new_parsed_template,
self.field_values,
self.format_specifiers,
)
else:
raise TypeError(
f"unsupported operand type(s) for +: '{type(other)}' and '{type(self)}'"
)
def __mul__(self, other):
if isinstance(other, int):
if not self.raw_template or other == 1:
return self
if other < 1:
return TemplateLiteral("", ("", None), (), ())
parsed_template = self.parsed_template
last_node = parsed_template[-1]
trailing_field = last_node[1]
if trailing_field is not None:
# With a trailing field, everything can just be repeated the requested number of times
new_parsed_template = parsed_template * other
else:
# Without a trailing field, need to amend the parsed template repetitions to merge
# the trailing text from each repetition with the leading text of the next
first_node = parsed_template[0]
merged_node = (last_node[0] + first_node[0], first_node[1])
repeating_pattern = parsed_template[1:-1] + merged_node
new_parsed_template = (
parsed_template[:-1]
+ (repeating_pattern * (other - 1))[:-1]
+ last_node
)
return TemplateLiteral(
self.raw_template * other,
new_parsed_template,
self.field_values * other,
self.format_specifiers * other,
)
else:
raise TypeError(
f"unsupported operand type(s) for *: '{type(self)}' and '{type(other)}'"
)
def __rmul__(self, other):
if isinstance(other, int):
return self * other
else:
raise TypeError(
f"unsupported operand type(s) for *: '{type(other)}' and '{type(self)}'"
)
def __eq__(self, other):
if not isinstance(other, TemplateLiteral):
return False
return (
self.raw_template == other.raw_template
and self.parsed_template == other.parsed_template
and self.field_values == other.field_values
and self.format_specifiers == other.format_specifiers
)
def __repr__(self):
return (
f"<{type(self).__qualname__} {repr(self.raw_template)} "
f"at {id(self):#x}>"
)
def __format__(self, format_specifier):
# When formatted, render to a string, and use string formatting
return format(self.render(), format_specifier)
def render(self, *, render_template="".join, render_field=format):
... # See definition of the template rendering semantics below
"""********** TESTS ********"""
filename = "my file"
flag = "a=b"
template_a = TemplateLiteral(
"cat {filename}", (("cat", "filename"),), (filename,), (f"",)
)
template_b = TemplateLiteral("--flag {flag}", (("--flag", "flag"),), (flag,), (f"",))
assert template_a + " " + template_b == TemplateLiteral(
"cat {filename} --flag {flag}",
(("cat", "filename"), (" --flag", "flag")),
(filename, flag),
(f"", f""),
)
# if last LHS node is empty
name = "bob"
template_c = TemplateLiteral(
"My {name} is:", (("My ", "name"), (" is:", None)), (name,), (f"",)
)
template_d = TemplateLiteral("Joe", (("Joe", None),), (), ())
assert template_c + template_d == TemplateLiteral(
"My {name} is:Joe", (("My ", "name"), (" is:Joe", None)), (name,), (f"",)
)
# if first RHS node has no leading text
template_e = TemplateLiteral("file:", (("file:", None),), (), ())
template_f = TemplateLiteral("{filename}", (("", "filename"),), (filename,), (f"",))
assert template_e + template_f == TemplateLiteral(
"file:{filename}", (("file:", "filename"),), (filename,), (f"",)
)
assert template_e + "normal string" == TemplateLiteral(
"file:normal string", (("file:normal string", None),), (), ()
)
print(template_a * 5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment