Last active
October 10, 2023 12:56
-
-
Save omots/84f789abd055ab36bce47a9409c2e980 to your computer and use it in GitHub Desktop.
SNES unheadering tool
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# unheader.py - Portable commandline SNES ROM unheading tool | |
# By oldmanoftheSEA (omotss@twitter, omots@github) | |
# THIS UTILITY COMES WITH NO WARRANTY EXPRESSED OR IMPLIED | |
# Backups are ALWAYS a good idea. | |
# Principles & support: | |
# This tool is generally conservative when it comes to making modifications. | |
# Input ROMs must pass sanity checks before any action is taken. Questionable | |
# dumps and poorly-made hacks will almost certainly fail, possibly along with | |
# a handful of legitimate specimen. Satellaview and BIOS files will notably be | |
# skipped. | |
# This tool also needs to be nudged to overwrite files: -y when unheadering | |
# files in-place, and -x when also using -d. | |
# Basic usage: | |
# * Unheader a single file: | |
# python3 unheader.py <file> -y # in-place | |
# python3 unheader.py <file> -o <out-file> # specify output file | |
# * Unheader multiple files | |
# python3 unheader.py <file1> <file2> <fileN..> -y # in-place | |
# python3 unheader.py <file1> <file2> <fileN..> -o <out-dir> # specify output directory | |
# * Unheader everything in the specified folder(s) | |
# python3 unheader.py <dir1> <dirN..> -d -y # in-place | |
# python3 unheader.py <dir1> <dirN..> -d -o <out-dir> # specify output directory | |
# python3 unheader.py <dir1> <dirN..> -d -o <out-dir> -x # overwrite files in output directory | |
# Other useful flags: | |
# -n No-op: Do not write any files. | |
# -f Force: Turn off sanity checks (not recommended for multiple inputs) | |
# -l Debug: Enable debugging output | |
import argparse | |
import codecs | |
import io | |
import logging | |
import os | |
import sys | |
from collections import namedtuple | |
record = namedtuple("record", ["file", "ok", "oper", "reason"]) | |
extensions = ('.sfc', '.smc') | |
buf = bytearray(4096) | |
bufv = memoryview(buf) | |
# Operation constants | |
COPY = "COPY" | |
UNHEADER = "UNHEADER" | |
SKIP = "SKIP" | |
def main(): | |
parser = argparse.ArgumentParser(description="SNES unheader script") | |
parser.add_argument("inputs", nargs='+', help="ROM or folder of ROMs (if -d/--dir is supplied) to unheader") | |
parser.add_argument("-o", "--out", help="Output file or folder. Must be a folder if multiple inputs are specified. If unspecified, the input file will be replaced iff -y is specified.") | |
parser.add_argument("-d", "--dir", action="store_const", const=True, help="Input is a directory") | |
parser.add_argument("-f", "--force", action="store_const", const=True, help="Disable safety checks") | |
parser.add_argument("-l", "--debug", action="store_const", const=True, help="Enable debug logging") | |
parser.add_argument("-x", "--overwrite", action="store_const", const=True, help="Overwrite existing files when using -d") | |
parser.add_argument("-g", "--log", default=None, help="Write logger output to this file in addition to stderr") | |
parser.add_argument("-n", "--nop", "--noop", "--no-op", action="store_const", const=True, help="Trial run, don't write any output") | |
parser.add_argument("-y", "--yes-i-am-sure", action="store_const", const=True, help="Must be supplied when unheadering in-place (no -o specified)") | |
args = parser.parse_args() | |
if not args.out: | |
if not args.yes_i_am_sure: | |
print("You must either supply -o or --yes-i-am-sure", file=sys.stderr) | |
return 2 | |
loglevel = logging.DEBUG if args.debug else logging.INFO | |
logging.basicConfig(level=loglevel) | |
if args.log is not None: | |
handler = logging.FileHandler(args.log) | |
handler.setLevel(loglevel) | |
logging.getLogger().addHandler(handler) | |
tracker = Tracker() | |
if len(args.inputs) > 1 and not args.dir and args.out is not None: | |
# Out must be a directory. | |
if os.path.exists(args.out) and not os.path.isdir(args.out): | |
logging.error("If multiple inputs are specified, the output must be a directory") | |
return 2 | |
if not os.path.exists(args.out): | |
os.mkdir(args.out) | |
if args.dir: | |
for inp in args.inputs: | |
process_dir(tracker, inp, args.out, sanity=not args.force, nop=args.nop, overwrite=args.overwrite) | |
elif len(args.inputs) == 1: | |
out = args.out if args.out is not None else args.inputs[0] | |
process(tracker, args.inputs[0], out, sanity=not args.force, nop=args.nop) | |
else: | |
for inp in args.inputs: | |
out = os.path.join(args.out, os.path.basename(inp)) if args.out is not None else inp | |
process(tracker, inp, out, sanity=not args.force, nop=args.nop) | |
print("Task completed.", file=sys.stderr) | |
print(" * Processed {} file(s)".format(len(tracker.results)), file=sys.stderr) | |
print(" * Skipped {} file(s)".format(len(tracker.skipped())), file=sys.stderr) | |
print(" * Copied {} file(s)".format(len(tracker.copied())), file=sys.stderr) | |
print(" * Unheadered {} file(s)".format(len(tracker.unheadered())), file=sys.stderr) | |
print(file=sys.stderr) | |
failures = tracker.failures() | |
if len(failures) > 0: | |
print("FAILED: {} file(s)".format(len(failures)), file=sys.stderr) | |
for rec in failures: | |
print(" * {} [{}]: {}".format(rec.file, rec.oper, rec.reason), file=sys.stderr) | |
print(file=sys.stderr) | |
return 1 | |
return 0 | |
def process_dir(tracker, inpath, outpath, sanity=True, nop=False, overwrite=False): | |
for root, dirs, files in os.walk(inpath): | |
# Ensure outpath exists. | |
if outpath is not None: | |
outdir = os.path.join(outpath, os.path.relpath(root, inpath)) | |
if not os.path.exists(outdir): | |
if not nop: | |
logging.info("Creating %s", outdir) | |
os.mkdir(outdir) | |
elif not os.path.isdir(outdir): | |
# Uh-oh! | |
logging.warning("Skipping directory %s because %s exists and is not a directory", root, outdir) | |
continue | |
else: | |
# In-place. | |
outdir = root | |
for file in files: | |
infile = os.path.join(root, file) | |
if any(file.lower().endswith(ext) for ext in extensions): | |
outfile = os.path.join(outdir, file) | |
if outpath is not None and not overwrite and os.path.exists(outfile): | |
tracker.post(infile, True, SKIP, "Skipping because output file exists and -x was not specified") | |
else: | |
process(tracker, infile, outfile, sanity=sanity, nop=nop) | |
else: | |
tracker.post(infile, True, SKIP, "Invalid extension") | |
def process(tracker, infile, outfile, sanity=True, nop=False): | |
if not os.path.exists(infile) or not os.path.isfile(infile): | |
tracker.post(infile, False, SKIP, "File doesn't exist or isn't a regular file") | |
return | |
st = os.stat(infile) | |
if st.st_size & 0x3FFF == 0x200: | |
# Header detected, sanity check first. | |
if sanity: | |
if not sanity_check(infile, True): | |
# Try no header? | |
if sanity_check(infile, False): | |
# Copy.. | |
try: | |
copy(infile, outfile, nop) | |
except OSError as ex: | |
tracker.post(infile, False, COPY, ex.strerror) | |
return | |
tracker.post(infile, True, COPY, "No header detected, copied as-is") | |
return | |
else: | |
# Skip | |
tracker.post(infile, False, UNHEADER, "Header detected, but failed sanity checks") | |
return | |
# OK, unheader. | |
try: | |
unheader(infile, outfile, nop) | |
except OSError as ex: | |
tracker.post(infile, False, UNHEADER, ex.strerror) | |
return | |
tracker.post(infile, True, UNHEADER, "Header removed") | |
return | |
elif st.st_size & 0x3FFF == 0: | |
if sanity_check(infile, False): | |
# Copy... | |
try: | |
copy(infile, outfile, nop) | |
except OSError as ex: | |
tracker.post(infile, False, COPY, ex.strerror) | |
return | |
tracker.post(infile, True, COPY, "No header detected, copied as-is") | |
return | |
else: | |
# Skip | |
tracker.post(infile, False, COPY, "No header detected, but failed sanity checks") | |
else: | |
tracker.post(infile, False, SKIP, "Unexpected file size") | |
def unheader(infile, outfile, nop): | |
if nop: | |
return | |
wfile = outfile + ".partial" if os.path.exists(outfile) else outfile | |
try: | |
done = False | |
with open(infile, 'rb') as fin, open(wfile, 'wb') as fout: | |
fin.seek(0x200, io.SEEK_SET) | |
copystream(fin, fout) | |
done = True | |
finally: | |
if not done and os.path.exists(wfile): | |
os.unlink(wfile) | |
if wfile != outfile: | |
replace(wfile, outfile) | |
def copy(infile, outfile, nop): | |
if nop: | |
return | |
wfile = outfile + ".partial" if os.path.exists(outfile) else outfile | |
try: | |
done = False | |
with open(infile, 'rb') as fin, open(wfile, 'wb') as fout: | |
copystream(fin, fout) | |
done = True | |
finally: | |
if not done and os.path.exists(wfile): | |
os.unlink(wfile) | |
if wfile != outfile: | |
replace(wfile, outfile) | |
def copystream(fin, fout): | |
while True: | |
rres = fin.readinto(buf) | |
if rres < 0: | |
raise OSError(-1, "Read error") | |
elif rres == 0: | |
return | |
wres = fout.write(bufv[:rres]) | |
if wres != rres: | |
raise OSError(-1, "Write error") | |
def replace(tmpfile, outfile): | |
if os.path.exists(outfile): | |
os.unlink(outfile) | |
os.rename(tmpfile, outfile) | |
def sanity_check(filename, header): | |
hoffset = 0x200 if header else 0 | |
logging.debug("Testing sanity for %s ...", filename) | |
def readaddr(addr, count=1): | |
f.seek(mapper.tofile(addr, sz) + hoffset, io.SEEK_SET) | |
return f.readinto(bufv[:count]) == count | |
# Determine size first | |
sz = os.stat(filename).st_size - hoffset | |
with open(filename, 'rb') as f: | |
mapper_matches = 0 | |
for mapper in mappers: | |
# Use mapper to access the internal header and validate the data we find there. | |
logging.debug(" - Trying mapper %s", type(mapper).__name__) | |
if readaddr(0xFFD5, 2) and mapper.hdrid(buf[0], buf[1]): | |
logging.debug(" - Mapper OK") | |
# Potential candidate. Do the other tests.. | |
checks = 0 | |
# Fixed value pt 1 | |
if readaddr(0xFFB6, 7) and all(buf[i] == 0 for i in range(7)): | |
checks += 1 | |
else: | |
logging.debug(" - Failed Fixed value 1 test") | |
# Fixed value pt 2 | |
if readaddr(0xFFDA) and buf[0] == 0x33: | |
checks += 1 | |
else: | |
logging.debug(" - Failed Fixed value 2 test %d", buf[0]) | |
# Destination code | |
if readaddr(0xFFD9) and buf[0] in (0, 1, 2): | |
checks += 1 | |
else: | |
logging.debug(" - Failed destination code test %d", buf[0]) | |
# Cart type | |
if readaddr(0xFFD6) and buf[0] in (0, 1, 2, 0x13, 0x14, 0x15, 0x1A, 0x33, 0x34, 0x35): | |
checks += 1 | |
else: | |
logging.debug(" - Failed cart type test %d", buf[0]) | |
# Checksum + inverse check | |
if readaddr(0xFFDC, 4) and buf[0] == (buf[2] ^ 0xff) and buf[1] == (buf[3] ^ 0xff): | |
checks += 2 | |
else: | |
checks -= 2 | |
logging.debug(" - Failed cart checksum test") | |
if readaddr(0xFFC0, 21) and check_name(bufv[:21]): | |
checks += 2 | |
else: | |
logging.debug(" - Failed cart internal name test") | |
if checks >= 4: | |
logging.debug(" > Mapper %s succeeded with score = %d", type(mapper).__name__, checks) | |
return True | |
else: | |
logging.debug(" > Mapper %s failed with score = %d", type(mapper).__name__, checks) | |
return False | |
c_A = ord('A') | |
c_Z = ord('Z') | |
c_a = ord('a') | |
c_z = ord('z') | |
c_sp = ord(' ') | |
c_0 = ord('0') | |
c_9 = ord('9') | |
c_dash = ord('-') | |
def check_name(name): | |
# Stolen from asar's setmapper() in main.cpp | |
foundnil = False | |
score = highbits = 0 | |
for c in name: | |
if foundnil and c != 0: | |
score -= 4 | |
if 0xA1 <= c <= 0xDF: # Katakana | |
highbits += 1 | |
score += 3 | |
elif c >= 128: # Invalid high bit code point | |
score -= 6 | |
elif c < 0x20 or c == 0x7f: # Unprintable | |
score -= 6 | |
elif c_A <= c <= c_Z: | |
score += 3 | |
elif c == c_sp: | |
score += 2 | |
elif (c_0 <= c <= c_9) or (c_a <= c <= c_z) or c == c_dash: | |
score += 1 | |
elif c == 0: | |
foundnil = True | |
else: | |
score -= 3 | |
logging.debug(" - Internal name \"%s\", score = %d", decode_name(name), score) | |
return score >= 20 | |
def load_jis_x_0201(): | |
tab = ['\ufffe'] * 256 | |
# Load relevant parts of Shift-JIS | |
for i in range(128): | |
tab[i] = bytes([i]).decode('shift_jis') | |
for i in range(0xA1, 0xE0): | |
tab[i] = bytes([i]).decode('shift_jis') | |
# Fixes | |
tab[0x5c] = "\u00a5" # Yen | |
tab[0x7e] = "\u203e" # Overbar | |
decoding_table = ''.join(tab) | |
encoding_table = codecs.charmap_build(decoding_table) | |
class Codec(codecs.Codec): | |
def encode(self, input, errors='strict'): | |
return codecs.charmap_encode(input, errors, encoding_table) | |
def decode(self, input, errors='strict'): | |
return codecs.charmap_decode(input, errors, decoding_table) | |
class IncrementalEncoder(codecs.IncrementalEncoder): | |
def encode(self, input, final=False): | |
return codecs.charmap_encode(input, self.errors, encoding_table)[0] | |
class IncrementalDecoder(codecs.IncrementalDecoder): | |
def decode(self, input, final=False): | |
return codecs.charmap_decode(input, self.errors, decoding_table)[0] | |
class StreamWriter(Codec, codecs.StreamWriter): | |
pass | |
class StreamReader(Codec, codecs.StreamReader): | |
pass | |
def search(name): | |
if name in ('jis_x0201', 'jis_x_0201', 'x0201', 'x_0201'): | |
return codecs.CodecInfo( | |
name=name, | |
encode=Codec().encode, | |
decode=Codec().decode, | |
incrementalencoder=IncrementalEncoder, | |
incrementaldecoder=IncrementalDecoder, | |
streamreader=StreamReader, | |
streamwriter=StreamWriter) | |
codecs.register(search) | |
load_jis_x_0201() | |
def decode_name(name): | |
return bytes(name).decode('jis_x_0201', 'replace') | |
class LoromMapper: | |
def __init__(self, default=False): | |
self.default = default # Always try LoRom if nothing else worked. | |
def hdrid(self, mapbyte, cartbyte): | |
#return mapbyte == 0x20 or mapbyte == 0x30 | |
return self.default or (mapbyte & 0xEF) == 0x20 | |
def tofile(self, addr, sz): | |
if addr & 0x8000 == 0: | |
# Out of bounds | |
return None | |
if ((addr >> 16) & 0xff) in (0x7e, 0x7f): | |
# Work ram | |
return None | |
return ((addr & 0x7FFF) + ((addr & 0x7F0000) >> 1)) % sz | |
class HiromMapper: | |
def hdrid(self, mapbyte, cartbyte): | |
#return mapbyte == 0x21 or mapbyte == 0x31 | |
return (mapbyte & 0xEF) == 0x21 | |
def tofile(self, addr, sz): | |
bank = (addr >> 16) & 0xff | |
bkoff = addr & 0xffff | |
if (bank & 0x7f) < 0x40 and bkoff < 0x8000: | |
# bank address < $8000 and bank < $40 | |
return None | |
elif bank == 0x7e or bank == 0x7f: | |
# Work ram | |
return None | |
else: | |
return (addr & 0x3fffff) % sz | |
class ExhiromMapper: | |
def hdrid(self, mapbyte, cartbyte): | |
#return mapbyte == 0x25 or mapbyte == 0x35 | |
return (mapbyte & 0xEF) == 0x25 | |
def tofile(self, addr, sz): | |
bank = (addr >> 16) & 0xff | |
bkoff = addr & 0xffff | |
if (bank & 0x7f) < 0x40 and bkoff < 0x8000: | |
# bank address < $8000 and bank < $40 | |
return None | |
elif bank == 0x7e or bank == 0x7f: | |
# Work ram | |
return None | |
else: | |
#return ((addr & 0x7fffff) ^ 0x400000) % rsz | |
if addr >= 0x800000: | |
return (addr & 0x3fffff) % sz | |
else: | |
return ((addr & 0x3fffff) | 0x400000) % sz | |
class ExloromMapper: | |
def hdrid(self, mapbyte, cartbyte): | |
# NOTE: Made-up by pirates. | |
#return mapbyte == 0x24 or mapbyte == 0x34 | |
return (mapbyte & 0xEF) == 0x24 | |
def tofile(self, addr, sz): | |
return ((((addr ^ 0x800000) & 0xff0000) >> 1) + (addr & 0x7fff)) % sz | |
class SfxMapper: | |
def hdrid(self, mapbyte, cartbyte): | |
return mapbyte == 0x20 and cartbyte in (0x13, 0x14, 0x15, 0x1A) | |
def tofile(self, addr, sz): | |
if (addr & 0x600000) == 0x600000 or (addr & 0x408000) == 0 or (addr & 0x800000) == 0x800000: | |
# WRAM, SRAM, open bus, hardware registers, ram mirrors, rom mirrors, fastrom | |
return None | |
if addr & 0x400000: | |
return (addr & 0x3FFFFF) % sz | |
return (((addr & 0x7F0000) >> 1) | (addr & 0x7FFF)) % sz | |
class SA1Mapper: | |
def __init__(self): | |
# These are variable -- but probably fine for our purposes (at least 0!) | |
self.sa1banks = [0, 1<<20, 0, 0, 2<<20, 3<<20] | |
def hdrid(self, mapbyte, cartbyte): | |
return mapbyte == 0x23 and cartbyte in (0x32, 0x34, 0x35) | |
def tofile(self, addr, sz): | |
if (addr & 0x408000) == 0x008000: | |
return self.sa1banks[(addr & 0xE00000) >> 21] | ((addr & 0x1F0000) >> 1) | (addr & 0x7FFF) | |
if (addr & 0xC00000) == 0xC00000: | |
return self.sa1banks[((addr & 0x100000) >> 20) | ((addr & 0x200000) >> 19)] | (addr & 0x0FFFFF) | |
return None | |
mappers = (LoromMapper(), HiromMapper(), ExhiromMapper(), ExloromMapper(), SfxMapper(), SA1Mapper(), LoromMapper(True)) | |
class Tracker: | |
def __init__(self): | |
self.results = [] | |
def post(self, file, ok, oper, reason): | |
self.results.append(record(file=file, ok=ok, oper=oper, reason=reason)) | |
logging.info("Processed %s :: %s :: %s", file, oper, reason) | |
def failures(self): | |
return [rec for rec in self.results if not rec.ok] | |
def copied(self): | |
return [rec for rec in self.results if rec.oper == COPY and rec.ok] | |
def unheadered(self): | |
return [rec for rec in self.results if rec.oper == UNHEADER and rec.ok] | |
def skipped(self): | |
return [rec for rec in self.results if rec.oper == SKIP] | |
if __name__ == '__main__': | |
res = main() | |
if type(res) is int: | |
sys.exit(res) | |
elif res is None: | |
sys.exit(0) | |
else: | |
sys.exit(0 if res else 1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment