Skip to content

Instantly share code, notes, and snippets.

@bkyle
Forked from omots/unheader.py
Created October 10, 2023 12:56
Show Gist options
  • Save bkyle/683b91850ede159cbaedb00e3826d63a to your computer and use it in GitHub Desktop.
Save bkyle/683b91850ede159cbaedb00e3826d63a to your computer and use it in GitHub Desktop.
SNES unheadering tool
#!/usr/bin/env python3
# unheader.py - Portable commandline SNES ROM unheading tool
# By oldmanoftheSEA (omotss@twitter, omots@github)
# THIS UTILITY COMES WITH NO WARRANTY EXPRESSED OR IMPLIED
# Backups are ALWAYS a good idea.
# Principles & support:
# This tool is generally conservative when it comes to making modifications.
# Input ROMs must pass sanity checks before any action is taken. Questionable
# dumps and poorly-made hacks will almost certainly fail, possibly along with
# a handful of legitimate specimen. Satellaview and BIOS files will notably be
# skipped.
# This tool also needs to be nudged to overwrite files: -y when unheadering
# files in-place, and -x when also using -d.
# Basic usage:
# * Unheader a single file:
# python3 unheader.py <file> -y # in-place
# python3 unheader.py <file> -o <out-file> # specify output file
# * Unheader multiple files
# python3 unheader.py <file1> <file2> <fileN..> -y # in-place
# python3 unheader.py <file1> <file2> <fileN..> -o <out-dir> # specify output directory
# * Unheader everything in the specified folder(s)
# python3 unheader.py <dir1> <dirN..> -d -y # in-place
# python3 unheader.py <dir1> <dirN..> -d -o <out-dir> # specify output directory
# python3 unheader.py <dir1> <dirN..> -d -o <out-dir> -x # overwrite files in output directory
# Other useful flags:
# -n No-op: Do not write any files.
# -f Force: Turn off sanity checks (not recommended for multiple inputs)
# -l Debug: Enable debugging output
import argparse
import codecs
import io
import logging
import os
import sys
from collections import namedtuple
record = namedtuple("record", ["file", "ok", "oper", "reason"])
extensions = ('.sfc', '.smc')
buf = bytearray(4096)
bufv = memoryview(buf)
# Operation constants
COPY = "COPY"
UNHEADER = "UNHEADER"
SKIP = "SKIP"
def main():
parser = argparse.ArgumentParser(description="SNES unheader script")
parser.add_argument("inputs", nargs='+', help="ROM or folder of ROMs (if -d/--dir is supplied) to unheader")
parser.add_argument("-o", "--out", help="Output file or folder. Must be a folder if multiple inputs are specified. If unspecified, the input file will be replaced iff -y is specified.")
parser.add_argument("-d", "--dir", action="store_const", const=True, help="Input is a directory")
parser.add_argument("-f", "--force", action="store_const", const=True, help="Disable safety checks")
parser.add_argument("-l", "--debug", action="store_const", const=True, help="Enable debug logging")
parser.add_argument("-x", "--overwrite", action="store_const", const=True, help="Overwrite existing files when using -d")
parser.add_argument("-g", "--log", default=None, help="Write logger output to this file in addition to stderr")
parser.add_argument("-n", "--nop", "--noop", "--no-op", action="store_const", const=True, help="Trial run, don't write any output")
parser.add_argument("-y", "--yes-i-am-sure", action="store_const", const=True, help="Must be supplied when unheadering in-place (no -o specified)")
args = parser.parse_args()
if not args.out:
if not args.yes_i_am_sure:
print("You must either supply -o or --yes-i-am-sure", file=sys.stderr)
return 2
loglevel = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(level=loglevel)
if args.log is not None:
handler = logging.FileHandler(args.log)
handler.setLevel(loglevel)
logging.getLogger().addHandler(handler)
tracker = Tracker()
if len(args.inputs) > 1 and not args.dir and args.out is not None:
# Out must be a directory.
if os.path.exists(args.out) and not os.path.isdir(args.out):
logging.error("If multiple inputs are specified, the output must be a directory")
return 2
if not os.path.exists(args.out):
os.mkdir(args.out)
if args.dir:
for inp in args.inputs:
process_dir(tracker, inp, args.out, sanity=not args.force, nop=args.nop, overwrite=args.overwrite)
elif len(args.inputs) == 1:
out = args.out if args.out is not None else args.inputs[0]
process(tracker, args.inputs[0], out, sanity=not args.force, nop=args.nop)
else:
for inp in args.inputs:
out = os.path.join(args.out, os.path.basename(inp)) if args.out is not None else inp
process(tracker, inp, out, sanity=not args.force, nop=args.nop)
print("Task completed.", file=sys.stderr)
print(" * Processed {} file(s)".format(len(tracker.results)), file=sys.stderr)
print(" * Skipped {} file(s)".format(len(tracker.skipped())), file=sys.stderr)
print(" * Copied {} file(s)".format(len(tracker.copied())), file=sys.stderr)
print(" * Unheadered {} file(s)".format(len(tracker.unheadered())), file=sys.stderr)
print(file=sys.stderr)
failures = tracker.failures()
if len(failures) > 0:
print("FAILED: {} file(s)".format(len(failures)), file=sys.stderr)
for rec in failures:
print(" * {} [{}]: {}".format(rec.file, rec.oper, rec.reason), file=sys.stderr)
print(file=sys.stderr)
return 1
return 0
def process_dir(tracker, inpath, outpath, sanity=True, nop=False, overwrite=False):
for root, dirs, files in os.walk(inpath):
# Ensure outpath exists.
if outpath is not None:
outdir = os.path.join(outpath, os.path.relpath(root, inpath))
if not os.path.exists(outdir):
if not nop:
logging.info("Creating %s", outdir)
os.mkdir(outdir)
elif not os.path.isdir(outdir):
# Uh-oh!
logging.warning("Skipping directory %s because %s exists and is not a directory", root, outdir)
continue
else:
# In-place.
outdir = root
for file in files:
infile = os.path.join(root, file)
if any(file.lower().endswith(ext) for ext in extensions):
outfile = os.path.join(outdir, file)
if outpath is not None and not overwrite and os.path.exists(outfile):
tracker.post(infile, True, SKIP, "Skipping because output file exists and -x was not specified")
else:
process(tracker, infile, outfile, sanity=sanity, nop=nop)
else:
tracker.post(infile, True, SKIP, "Invalid extension")
def process(tracker, infile, outfile, sanity=True, nop=False):
if not os.path.exists(infile) or not os.path.isfile(infile):
tracker.post(infile, False, SKIP, "File doesn't exist or isn't a regular file")
return
st = os.stat(infile)
if st.st_size & 0x3FFF == 0x200:
# Header detected, sanity check first.
if sanity:
if not sanity_check(infile, True):
# Try no header?
if sanity_check(infile, False):
# Copy..
try:
copy(infile, outfile, nop)
except OSError as ex:
tracker.post(infile, False, COPY, ex.strerror)
return
tracker.post(infile, True, COPY, "No header detected, copied as-is")
return
else:
# Skip
tracker.post(infile, False, UNHEADER, "Header detected, but failed sanity checks")
return
# OK, unheader.
try:
unheader(infile, outfile, nop)
except OSError as ex:
tracker.post(infile, False, UNHEADER, ex.strerror)
return
tracker.post(infile, True, UNHEADER, "Header removed")
return
elif st.st_size & 0x3FFF == 0:
if sanity_check(infile, False):
# Copy...
try:
copy(infile, outfile, nop)
except OSError as ex:
tracker.post(infile, False, COPY, ex.strerror)
return
tracker.post(infile, True, COPY, "No header detected, copied as-is")
return
else:
# Skip
tracker.post(infile, False, COPY, "No header detected, but failed sanity checks")
else:
tracker.post(infile, False, SKIP, "Unexpected file size")
def unheader(infile, outfile, nop):
if nop:
return
wfile = outfile + ".partial" if os.path.exists(outfile) else outfile
try:
done = False
with open(infile, 'rb') as fin, open(wfile, 'wb') as fout:
fin.seek(0x200, io.SEEK_SET)
copystream(fin, fout)
done = True
finally:
if not done and os.path.exists(wfile):
os.unlink(wfile)
if wfile != outfile:
replace(wfile, outfile)
def copy(infile, outfile, nop):
if nop:
return
wfile = outfile + ".partial" if os.path.exists(outfile) else outfile
try:
done = False
with open(infile, 'rb') as fin, open(wfile, 'wb') as fout:
copystream(fin, fout)
done = True
finally:
if not done and os.path.exists(wfile):
os.unlink(wfile)
if wfile != outfile:
replace(wfile, outfile)
def copystream(fin, fout):
while True:
rres = fin.readinto(buf)
if rres < 0:
raise OSError(-1, "Read error")
elif rres == 0:
return
wres = fout.write(bufv[:rres])
if wres != rres:
raise OSError(-1, "Write error")
def replace(tmpfile, outfile):
if os.path.exists(outfile):
os.unlink(outfile)
os.rename(tmpfile, outfile)
def sanity_check(filename, header):
hoffset = 0x200 if header else 0
logging.debug("Testing sanity for %s ...", filename)
def readaddr(addr, count=1):
f.seek(mapper.tofile(addr, sz) + hoffset, io.SEEK_SET)
return f.readinto(bufv[:count]) == count
# Determine size first
sz = os.stat(filename).st_size - hoffset
with open(filename, 'rb') as f:
mapper_matches = 0
for mapper in mappers:
# Use mapper to access the internal header and validate the data we find there.
logging.debug(" - Trying mapper %s", type(mapper).__name__)
if readaddr(0xFFD5, 2) and mapper.hdrid(buf[0], buf[1]):
logging.debug(" - Mapper OK")
# Potential candidate. Do the other tests..
checks = 0
# Fixed value pt 1
if readaddr(0xFFB6, 7) and all(buf[i] == 0 for i in range(7)):
checks += 1
else:
logging.debug(" - Failed Fixed value 1 test")
# Fixed value pt 2
if readaddr(0xFFDA) and buf[0] == 0x33:
checks += 1
else:
logging.debug(" - Failed Fixed value 2 test %d", buf[0])
# Destination code
if readaddr(0xFFD9) and buf[0] in (0, 1, 2):
checks += 1
else:
logging.debug(" - Failed destination code test %d", buf[0])
# Cart type
if readaddr(0xFFD6) and buf[0] in (0, 1, 2, 0x13, 0x14, 0x15, 0x1A, 0x33, 0x34, 0x35):
checks += 1
else:
logging.debug(" - Failed cart type test %d", buf[0])
# Checksum + inverse check
if readaddr(0xFFDC, 4) and buf[0] == (buf[2] ^ 0xff) and buf[1] == (buf[3] ^ 0xff):
checks += 2
else:
checks -= 2
logging.debug(" - Failed cart checksum test")
if readaddr(0xFFC0, 21) and check_name(bufv[:21]):
checks += 2
else:
logging.debug(" - Failed cart internal name test")
if checks >= 4:
logging.debug(" > Mapper %s succeeded with score = %d", type(mapper).__name__, checks)
return True
else:
logging.debug(" > Mapper %s failed with score = %d", type(mapper).__name__, checks)
return False
c_A = ord('A')
c_Z = ord('Z')
c_a = ord('a')
c_z = ord('z')
c_sp = ord(' ')
c_0 = ord('0')
c_9 = ord('9')
c_dash = ord('-')
def check_name(name):
# Stolen from asar's setmapper() in main.cpp
foundnil = False
score = highbits = 0
for c in name:
if foundnil and c != 0:
score -= 4
if 0xA1 <= c <= 0xDF: # Katakana
highbits += 1
score += 3
elif c >= 128: # Invalid high bit code point
score -= 6
elif c < 0x20 or c == 0x7f: # Unprintable
score -= 6
elif c_A <= c <= c_Z:
score += 3
elif c == c_sp:
score += 2
elif (c_0 <= c <= c_9) or (c_a <= c <= c_z) or c == c_dash:
score += 1
elif c == 0:
foundnil = True
else:
score -= 3
logging.debug(" - Internal name \"%s\", score = %d", decode_name(name), score)
return score >= 20
def load_jis_x_0201():
tab = ['\ufffe'] * 256
# Load relevant parts of Shift-JIS
for i in range(128):
tab[i] = bytes([i]).decode('shift_jis')
for i in range(0xA1, 0xE0):
tab[i] = bytes([i]).decode('shift_jis')
# Fixes
tab[0x5c] = "\u00a5" # Yen
tab[0x7e] = "\u203e" # Overbar
decoding_table = ''.join(tab)
encoding_table = codecs.charmap_build(decoding_table)
class Codec(codecs.Codec):
def encode(self, input, errors='strict'):
return codecs.charmap_encode(input, errors, encoding_table)
def decode(self, input, errors='strict'):
return codecs.charmap_decode(input, errors, decoding_table)
class IncrementalEncoder(codecs.IncrementalEncoder):
def encode(self, input, final=False):
return codecs.charmap_encode(input, self.errors, encoding_table)[0]
class IncrementalDecoder(codecs.IncrementalDecoder):
def decode(self, input, final=False):
return codecs.charmap_decode(input, self.errors, decoding_table)[0]
class StreamWriter(Codec, codecs.StreamWriter):
pass
class StreamReader(Codec, codecs.StreamReader):
pass
def search(name):
if name in ('jis_x0201', 'jis_x_0201', 'x0201', 'x_0201'):
return codecs.CodecInfo(
name=name,
encode=Codec().encode,
decode=Codec().decode,
incrementalencoder=IncrementalEncoder,
incrementaldecoder=IncrementalDecoder,
streamreader=StreamReader,
streamwriter=StreamWriter)
codecs.register(search)
load_jis_x_0201()
def decode_name(name):
return bytes(name).decode('jis_x_0201', 'replace')
class LoromMapper:
def __init__(self, default=False):
self.default = default # Always try LoRom if nothing else worked.
def hdrid(self, mapbyte, cartbyte):
#return mapbyte == 0x20 or mapbyte == 0x30
return self.default or (mapbyte & 0xEF) == 0x20
def tofile(self, addr, sz):
if addr & 0x8000 == 0:
# Out of bounds
return None
if ((addr >> 16) & 0xff) in (0x7e, 0x7f):
# Work ram
return None
return ((addr & 0x7FFF) + ((addr & 0x7F0000) >> 1)) % sz
class HiromMapper:
def hdrid(self, mapbyte, cartbyte):
#return mapbyte == 0x21 or mapbyte == 0x31
return (mapbyte & 0xEF) == 0x21
def tofile(self, addr, sz):
bank = (addr >> 16) & 0xff
bkoff = addr & 0xffff
if (bank & 0x7f) < 0x40 and bkoff < 0x8000:
# bank address < $8000 and bank < $40
return None
elif bank == 0x7e or bank == 0x7f:
# Work ram
return None
else:
return (addr & 0x3fffff) % sz
class ExhiromMapper:
def hdrid(self, mapbyte, cartbyte):
#return mapbyte == 0x25 or mapbyte == 0x35
return (mapbyte & 0xEF) == 0x25
def tofile(self, addr, sz):
bank = (addr >> 16) & 0xff
bkoff = addr & 0xffff
if (bank & 0x7f) < 0x40 and bkoff < 0x8000:
# bank address < $8000 and bank < $40
return None
elif bank == 0x7e or bank == 0x7f:
# Work ram
return None
else:
#return ((addr & 0x7fffff) ^ 0x400000) % rsz
if addr >= 0x800000:
return (addr & 0x3fffff) % sz
else:
return ((addr & 0x3fffff) | 0x400000) % sz
class ExloromMapper:
def hdrid(self, mapbyte, cartbyte):
# NOTE: Made-up by pirates.
#return mapbyte == 0x24 or mapbyte == 0x34
return (mapbyte & 0xEF) == 0x24
def tofile(self, addr, sz):
return ((((addr ^ 0x800000) & 0xff0000) >> 1) + (addr & 0x7fff)) % sz
class SfxMapper:
def hdrid(self, mapbyte, cartbyte):
return mapbyte == 0x20 and cartbyte in (0x13, 0x14, 0x15, 0x1A)
def tofile(self, addr, sz):
if (addr & 0x600000) == 0x600000 or (addr & 0x408000) == 0 or (addr & 0x800000) == 0x800000:
# WRAM, SRAM, open bus, hardware registers, ram mirrors, rom mirrors, fastrom
return None
if addr & 0x400000:
return (addr & 0x3FFFFF) % sz
return (((addr & 0x7F0000) >> 1) | (addr & 0x7FFF)) % sz
class SA1Mapper:
def __init__(self):
# These are variable -- but probably fine for our purposes (at least 0!)
self.sa1banks = [0, 1<<20, 0, 0, 2<<20, 3<<20]
def hdrid(self, mapbyte, cartbyte):
return mapbyte == 0x23 and cartbyte in (0x32, 0x34, 0x35)
def tofile(self, addr, sz):
if (addr & 0x408000) == 0x008000:
return self.sa1banks[(addr & 0xE00000) >> 21] | ((addr & 0x1F0000) >> 1) | (addr & 0x7FFF)
if (addr & 0xC00000) == 0xC00000:
return self.sa1banks[((addr & 0x100000) >> 20) | ((addr & 0x200000) >> 19)] | (addr & 0x0FFFFF)
return None
mappers = (LoromMapper(), HiromMapper(), ExhiromMapper(), ExloromMapper(), SfxMapper(), SA1Mapper(), LoromMapper(True))
class Tracker:
def __init__(self):
self.results = []
def post(self, file, ok, oper, reason):
self.results.append(record(file=file, ok=ok, oper=oper, reason=reason))
logging.info("Processed %s :: %s :: %s", file, oper, reason)
def failures(self):
return [rec for rec in self.results if not rec.ok]
def copied(self):
return [rec for rec in self.results if rec.oper == COPY and rec.ok]
def unheadered(self):
return [rec for rec in self.results if rec.oper == UNHEADER and rec.ok]
def skipped(self):
return [rec for rec in self.results if rec.oper == SKIP]
if __name__ == '__main__':
res = main()
if type(res) is int:
sys.exit(res)
elif res is None:
sys.exit(0)
else:
sys.exit(0 if res else 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment