Skip to content

Instantly share code, notes, and snippets.

@hwayne
Created March 11, 2024 20:51
Show Gist options
  • Save hwayne/dd5c33e41a94d7f3d242f9728bf7e47d to your computer and use it in GitHub Desktop.
Save hwayne/dd5c33e41a94d7f3d242f9728bf7e47d to your computer and use it in GitHub Desktop.
expand_version.py
import xml.etree.ElementTree as ET
from xml.etree.ElementTree import Element
from copy import deepcopy
from argparse import ArgumentParser
from dataclasses import dataclass
from string import Template
from pathlib import Path
import typing as t
#Common issue is that I need to have multiple slightly different versions of the same spec, this is a helper to do that.
# TODO rename spec to 'file'
# TODO let 'files' define strings for specific numbers for easy control
def parse_args():
parser = ArgumentParser()
parser.add_argument("file", help="xml file to convert")
parser.add_argument("--spec", required=False, help="spec in the file to convert. Default is all.")
parser.add_argument("--version", required=False, help="which version of the spec. Should only be used if --spec is also used. Default is all. TODO find how argparse works for better documentation of flag limitations")
parser.add_argument("-d", "--dryrun", action="store_true", help="print the expansion to STDOUT instead of writing files.")
# Arguments to control if we're updating just the file or also the state spaces
return parser.parse_args()
@dataclass
class VersionRange:
start: int
finish: t.Optional[int]
def __init__(self, start: str | int, finish):
if start == '': # These are actually strings rn, not Nones, of form ''
start = 1
if finish:
finish = int(finish)
self.start = int(start)
self.finish = finish
if self.start and self.finish:
assert self.start <= self.finish
def max_version(self) -> int:
if not self.finish:
return self.start # guaranteed only maximal for n- switches
return self.finish
def contains(self, i: int) -> bool:
if self.start == 0: # Only happens intentionally
return False
if not self.finish: # Missing!
return i >= self.start
return self.start <= i <= self.finish
def expand_on_attrib(on_str: str) -> VersionRange:
if "-" in on_str:
a,b = on_str.split("-")
return VersionRange(a, b)
else:
a = int(on_str)
return VersionRange(start=a, finish=a)
def get_on(s: Element) -> VersionRange:
return expand_on_attrib(s.attrib["on"]) # could also be _
def tree_to_text(tree) -> str:
return "".join(tree.itertext())
@dataclass
class SpecVersion:
name: str
version: int
text: str
ext: str
def filename(self):
return f"{self.name}__{self.version:0=2}"
def __str__(self):
return Template(self.text).substitute({"name": self.filename()}) # For TLA stuff
class Metafile(Element): # make this totally its own thing
...
def create_spec_version(spec_root: Element, version: int) -> SpecVersion:
new_version = deepcopy(spec_root)
for switch in new_version.findall('.//s'):
if not get_on(switch).contains(version):
for child in switch.iter():
# includes switch ^
child.text = ""
# A tag's *text* is the text between start and first child
# A tag's *tail* is the text between close and the next tag
# So all the text INSIDE switch is switch.text + switch.child.(text + tail)
if child != switch:
child.tail = ""
return SpecVersion(
name=new_version.attrib["name"],
version=version+int(new_version.get("start-from", 1))-1,
text=tree_to_text(new_version),
ext=new_version.attrib.get("ext", "tla")
)
def create_all_spec_versions(spec_root: Element) -> list[SpecVersion]:
num_versions = 0
switches = map(get_on, spec_root.findall('.//s')) # Is here where we do the name-replace?
for v in switches:
num_versions = max(num_versions, v.max_version())
out = []
for i in range(1, num_versions+1):
out.append(create_spec_version(spec_root, i))
return out
def expand_version(args):
tree = ET.parse(args.file)
folder = tree.getroot().attrib["folder"]
ext = tree.getroot().attrib.get("ext") #backwards compatibility
out: list[SpecVersion] = []
if args.spec:
spec_root = tree.find(f"spec[@name='{args.spec}']")
assert spec_root is not None # did we get the name wrong
if args.version:
out = [create_spec_version(spec_root, int(args.version))]
else:
out = create_all_spec_versions(spec_root)
else:
specs = tree.findall(f"spec")
for spec_root in specs:
out += create_all_spec_versions(spec_root)
# TODO split this out into "expand version" and "Write to files"
if args.dryrun:
return [str(spec) for spec in out]
else:
for spec in out:
to_write = str(spec)
Path(folder).mkdir(exist_ok=True, parents=True)
if spec.ext:
out_path = Path(folder) / f"{spec.filename()}.{spec.ext}"
else:
out_path = Path(folder) / f"{spec.filename()}.{ext}"
if out_path.exists():
# Preserve metadata at top of file
parts = out_path.read_text().split("!!!")
parts[-1] = to_write
to_write = "!!!".join(parts)
out_path.write_text(to_write)
return []
def main():
args = parse_args()
out = expand_version(args)
if args.dryrun:
for spec in out:
print(spec)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment