Skip to content

Instantly share code, notes, and snippets.

@charasyn
Last active December 19, 2021 14:53
Show Gist options
  • Save charasyn/d083d651121eb5dd260bb99bdc589dad to your computer and use it in GitHub Desktop.
Save charasyn/d083d651121eb5dd260bb99bdc589dad to your computer and use it in GitHub Desktop.
# CoilSnake code :) pls don't hurt me
import array
import copy
import os
from zlib import crc32
InvalidArgumentError = ValueError
OutOfBoundsError = ValueError
def check_range_validity(range, size):
begin, end = range
if end < begin:
raise InvalidArgumentError("Invalid range[(%#x,%#x)] provided" % (begin, end))
elif (begin < 0) or (end >= size):
raise OutOfBoundsError("Invalid range[(%#x,%#x)] provided" % (begin, end))
def fix_slice(key, size):
if not (key.step == 1 or key.step is None):
raise InvalidArgumentError("Slice step must be 1 or None, but is {}".format(key.step))
start, stop = key.start, key.stop
def fix_single_index(x, default):
if x is None:
x = default
elif x < 0:
x += size
return x
start = fix_single_index(start, 0)
stop = fix_single_index(stop, size)
return slice(start, stop)
class Block(object):
def __init__(self, size=0):
self.reset(size)
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
del self.data
def reset(self, size=0):
self.data = array.array('B', [0] * size)
self.size = size
def from_list(self, data_list):
self.size = len(data_list)
del self.data
self.data = array.array('B')
self.data.fromlist(data_list)
def from_array(self, data_array):
self.size = len(data_array)
del self.data
self.data = copy.copy(data_array)
def from_block(self, block, offset=0, size=None):
if size is None:
size = block.size - offset
with block[offset:offset + size] as sub_block:
self.size = sub_block.size
self.data = sub_block.data
def to_file(self, filename):
with open(filename, 'wb') as f:
self.data.tofile(f)
def to_list(self):
return self.data.tolist()
def to_array(self):
return self.data
def to_block(self, block, offset=0):
self[offset:offset + block.size] = block
def read_multi(self, key, size):
if size < 0:
raise InvalidArgumentError("Attempted to read data of negative length[%d]" % size)
elif size == 0:
return 0
elif (key < 0) or (key >= self.size) or (key + size > self.size):
raise OutOfBoundsError("Attempted to read size[%d] bytes from offset[%#x], which is out of bounds in this "
"block of size[%#x]" % (size, key, self.size))
else:
out = 0
bit_offset = 0
for byte in self.data[key:key + size]:
out |= byte << bit_offset
bit_offset += 8
return out
def write_multi(self, key, item, size):
if size < 0:
raise InvalidArgumentError("Attempted to write data of negative length[%d]" % size)
elif (key < 0) or (key >= self.size) or (key + size > self.size):
raise OutOfBoundsError("Attempted to write size[%d] bytes to offset[%#x], which is out of bounds in this "
"block of size[%#x]" % (size, key, self.size))
elif size == 0:
return
else:
for i in range(key, key+size):
self.data[i] = item & 0xff
item >>= 8
def __getitem__(self, key):
if isinstance(key, slice):
key = fix_slice(key, self.size)
if key.start > key.stop:
raise InvalidArgumentError("Second argument of slice %s must be greater than the first" % key)
elif (key.start < 0) or (key.stop - 1 >= self.size):
raise OutOfBoundsError("Attempted to read from range (%#x,%#x) which is out of bounds" % (key.start,
key.stop - 1))
else:
out = Block()
out.from_array(self.data[key])
return out
elif isinstance(key, int):
if key >= self.size:
raise OutOfBoundsError("Attempted to read at offset[%#x] which is out of bounds" % key)
else:
return self.data[key]
else:
raise TypeError("Argument \"key\" had invalid type of %s" % type(key).__name__)
def __setitem__(self, key, item):
if isinstance(key, int) and isinstance(item, (int, int)):
if item < 0 or item > 0xff:
raise InvalidArgumentError("Could not write invalid value[%d] as a single byte" % item)
if key >= self.size:
raise OutOfBoundsError("Attempted to write to offset[%#x] which is out of bounds" % key)
else:
self.data[key] = item
elif isinstance(key, slice) and \
(isinstance(item, list) or isinstance(item, array.array) or isinstance(item, Block)):
key = fix_slice(key, self.size)
if key.start > key.stop:
raise InvalidArgumentError("Second argument of slice %s must be greater than the first" % key)
elif (key.start < 0) or (key.stop - 1 >= self.size):
raise OutOfBoundsError("Attempted to write to range (%#x,%#x) which is out of bounds" % (key.start,
key.stop - 1))
elif len(item) != (key.stop - key.start):
raise InvalidArgumentError("Attempted to write data of size %d to range of length %d" % (
len(item), key.stop - key.start))
elif (key.stop - key.start) == 0:
raise InvalidArgumentError("Attempted to write data of size 0")
else:
if isinstance(item, list):
self.data[key] = array.array('B', item)
elif isinstance(item, array.array):
self.data[key] = item
elif isinstance(item, Block):
self.data[key] = item.data
else:
raise InvalidArgumentError("Can not write value of type[{}]".format(type(item)))
else:
raise TypeError("Arguments \"key\" and \"item\" had invalid types of %s and %s" % (type(key).__name__,
type(item).__name__))
def __len__(self):
return self.size
def __eq__(self, other):
return (isinstance(other, type(self))) and (self.data == other.data)
def __ne__(self, other):
return not (self == other)
def __hash__(self):
return crc32(self.data)
#!/usr/bin/env python3
# v3 - implement portamento
# v2 - ok I guess we can search ARAM for the song address table pointer
import sys
from dataclasses import dataclass
from typing import Dict, List, Union
import re
from Block import Block
ERR = 0x00
IGN = 0x01
mapAmkCommands = {
0xC6: (0,0xC8), # Tie (note)
0xC7: (0,0xC9), # Rest (note)
0xC8: (0,0xC9), # Rest (note)
0xC9: (0,0xC9), # Rest (note)
0xCA: (0,0xC9), # Rest (note)
0xCB: (0,0xC9), # Rest (note)
0xCC: (0,0xC9), # Rest (note)
0xCD: (0,0xC9), # Rest (note)
0xCE: (0,0xC9), # Rest (note)
0xCF: (0,0xC9), # Rest (note)
0xD0: (0,0xCA), # Percussion
0xD1: (0,0xCB), # Percussion
0xD2: (0,0xCC), # Percussion
0xD3: (0,0xCD), # Percussion
0xD4: (0,0xCE), # Percussion
0xD5: (0,0xCF), # Percussion
0xD6: (0,0xD0), # Percussion
0xD7: (0,0xD1), # Percussion
0xD8: (0,0xD2), # Percussion
0xD9: (0,0xD3), # Percussion
0xDA: (1,0xE0), # Inst
0xDB: (1,0xE1), # Pan
0xDC: (2,0xE2), # PanFade
0xDD: (3,0xF9), # Portamento
0xDE: (3,0xE3), # Vibrato
0xDF: (0,0xE4), # VibratoOff
0xE0: (1,0xE5), # VolGlobal
0xE1: (2,0xE6), # VolGlobalFade
0xE2: (1,0xE7), # Tempo
0xE3: (2,0xE8), # TempoFade
0xE4: (1,0xE9), # TrspGlobal
0xE5: (3,0xEB), # Tremolo
0xE6: (1,ERR), # Subloop
0xE7: (1,0xED), # Vol
0xE8: (2,0xEE), # VolFade
0xE9: (3,0xEF), # Loop
0xEA: (1,0xF0), # VibratoFade
0xEB: (3,0xF1), # BendAway
0xEC: (3,0xF2), # BendTo
0xED: (2,ERR), # Envelope
0xEE: (1,0xF4), # Detune
0xEF: (3,0xF5), # EchoVol
0xF0: (0,0xF6), # EchoOff
0xF1: (3,0xF7), # EchoParams
0xF2: (3,0xF8), # EchoFade
0xF3: (2,IGN), # SampleLoad
0xF4: (1,IGN), # ExtF4
0xF5: (8,ERR), # FIR
0xF6: (2,ERR), # DSP
0xF7: (3,ERR), # ARAM
0xF8: (1,ERR), # Noise
0xF9: (2,ERR), # DataSend
0xFA: (2,IGN), # ExtFA
0xFB: (3,ERR), # Arpeggio
0xFC: (4,ERR), # Callback
}
mapEbCommands = {cmd_tup[1]: (cmd_tup[0], cmd_amk) for cmd_amk, cmd_tup in mapAmkCommands.items() if cmd_tup[1] >= 0xE0}
# @dataclass
# class Track:
# pass
# @dataclass
# class Pattern:
# tracks: Dict[int, Track]
# @dataclass
# class RepeatCmd:
# pattern_seq_offset: int
# loop_val: int
# @dataclass
# class Song:
# pattern_seq: List[Union[Pattern,RepeatCmd]]
def cmd_len_fn(cmd):
if cmd < 0xda:
return 1
else:
cmd_len, _ = mapAmkCommands[cmd]
return 1 + cmd_len
def cmd_len_eb_fn(cmd):
if cmd < 0xe0:
return 1
else:
cmd_len, _ = mapEbCommands[cmd]
return 1 + cmd_len
def fix_track(track_ptr: int, track_data: List[int]):
data_idx = 0
def consume_byte() -> int:
nonlocal data_idx
val = track_data[data_idx]
data_idx += 1
assert isinstance(val, int), f"val is not int, is {val}"
return val
def consume_word() -> int:
nonlocal data_idx
val = track_data[data_idx] | track_data[data_idx + 1] << 8
data_idx += 2
return val
def change_last_byte(newval):
track_data[data_idx - 1] = newval
while data_idx < len(track_data):
cmd = consume_byte()
if cmd < 0xc6:
# 1-byte note duration / parameter / note command
continue
# Must be a command.
def cmd_non_subloop(cmd):
cmd_len, cmd_eb = mapAmkCommands[cmd]
nonlocal data_idx
if cmd_eb == ERR:
raise ValueError(f"Can't convert to EBM due to bad cmd ${cmd:02X}")
elif cmd_eb == IGN:
print(f"${track_ptr + data_idx:04X}: ignoring command ${cmd:02X} with no EB equivalent")
del track_data[data_idx - 1 : data_idx + cmd_len]
data_idx -= 1
else:
change_last_byte(cmd_eb)
data_idx += cmd_len
if cmd == 0xE6: # AMK Subloop
subloop_start = data_idx - 1
assert consume_byte() == 0x00, "Unexpected subloop end command"
# Get subloop data
while True:
cmd = consume_byte()
assert cmd != 0, "Unexpected track end in subloop"
if cmd == 0xE6:
subloop_count = consume_byte()
assert subloop_count != 0, "Loop count can't be 0"
subloop_end = data_idx
break
data_idx += cmd_len_fn(cmd) - 1
# Unroll
subloop_data = track_data[subloop_start+2 : subloop_end - 2]
del track_data[subloop_start:subloop_end]
track_data[subloop_start:subloop_start] = subloop_data * (subloop_count + 1)
# Pretend as if the subloop was never there.
data_idx = subloop_start
continue
else:
cmd_non_subloop(cmd)
hexlist = lambda l: ' '.join(f'{x:02X}' for x in l)
def find_nspc_addr_table(aram: Block):
def le_to_int(x: bytes) -> int:
return int.from_bytes(x, 'little')
for m in re.finditer(br'\x1C\xFD\xF6(..)\x2D\xC4\x40\xF6(..)\x2D\xC4\x41', aram[:0x2000].data.tobytes()):
a1 = le_to_int(m.group(1))
a2 = le_to_int(m.group(2))
if a2 == a1 + 1:
return a1
def extract_ebm(spc_data):
assert len(spc_data) == 0x10200, "unknown spc format"
aram_part = Block(0x10000)
aram_part.from_list(list(spc_data[0x100:0x10100]))
aram = Block(0x20000)
aram[:0x10000] = aram_part
song_num = aram[0xF6]
song_table_ptr = find_nspc_addr_table(aram)
song_ptr = aram.read_multi(song_table_ptr + song_num * 2, 2)
assert song_ptr > 0x100
# Let's parse a song :)
data_ptr = song_ptr
largest_ptr = -1
try:
# Helper functions
def consume_byte() -> int:
nonlocal data_ptr, largest_ptr
val = aram.read_multi(data_ptr, 1)
data_ptr += 1
largest_ptr = max(largest_ptr, data_ptr)
return val
def consume_word() -> int:
nonlocal data_ptr, largest_ptr
val = aram.read_multi(data_ptr, 2)
data_ptr += 2
largest_ptr = max(largest_ptr, data_ptr)
return val
def change_last_byte(newval):
nonlocal data_ptr
aram.write_multi(data_ptr - 1, newval, 1)
# Parse all phrases / "block list"
print('Parsing phrases...')
phrases_done = set()
pattern_set = set()
while True:
phrases_done.add(data_ptr)
pattern_ptr = consume_word()
if pattern_ptr & 0xff00 == 0:
# Special meaning - not a phrase address
if pattern_ptr == 0:
# End of phrases
# - Stop processing
break
elif pattern_ptr == 0x80:
# Debug: Fast forward on
# - Skip this, it's not a phrase address
continue
elif pattern_ptr == 0x81:
# Debug: Fast forward off
# - Skip this, it's not a phrase address
pass
else:
# Loop (0x01-0x7f) / Jump (0x82-0xff)
jump_dest = consume_word()
if jump_dest not in phrases_done:
# Take jump if we haven't yet
data_ptr = jump_dest
elif pattern_ptr >= 0x80:
# Don't fallthrough on unconditional jump - stop processing
break
else:
# Actually a phrase pointer - add to list
pattern_set.add(pattern_ptr)
print('Found patterns:', ' '.join(f'${x:04X}' for x in pattern_set))
assert len(pattern_set) > 0
pattern_start = min(pattern_set)
for pattern in pattern_set:
assert pattern >= pattern_start and (pattern - pattern_start) % 16 == 0
# # Move song phrase list a bit if needed, so it ends with zero and there's no space after
# phrase_start = song_ptr
# phrase_end = largest_ptr
# phrases = []
# for phrase_ptr in range(phrase_start, phrase_end, 2):
# phrases.append(aram.read_multi(phrase_ptr,2))
# if phrases[-1] != 0:
# phrases.append(0)
# phrase_ptr = pattern_start - len(phrases) * 2
# phrase_movement = phrase_ptr - phrase_start
# song_ptr = phrase_ptr
# last_was_repeat = False
# for phrase in phrases:
# if last_was_repeat:
# phrase += phrase_movement
# last_was_repeat = False
# else:
# last_was_repeat = phrase < 0x100
# aram.write_multi(phrase_ptr, phrase, 2)
# phrase_ptr += 2
# Parse all patterns
print('Parsing patterns...')
track_list = []
track_refs: Dict[int, List[int]] = {}
def add_track_ref(track_ptr, ref_ptr):
if track_ptr not in track_refs:
track_refs[track_ptr] = []
track_refs[track_ptr].append(ref_ptr)
for pattern_table_addr in pattern_set:
for pattern_idx in range(8):
data_ptr = pattern_table_addr + pattern_idx * 2
track_ptr = consume_word()
if track_ptr == 0:
continue
add_track_ref(track_ptr, data_ptr - 2)
if track_ptr not in track_list:
track_list.append(track_ptr)
print('Found tracks:', ' '.join(f'${x:04X}' for x in track_list))
# Parse all tracks
print('Extracting tracks...')
tracks: Dict[int, List[int]] = {}
track_ends_in_zero = set()
sub_ptr_to_idx = {}
sub_idx_to_ptr = {}
for track_ptr in track_list:
next_track = 0x10000
for tp in track_list:
if track_ptr < tp < next_track:
next_track = tp
track_end = track_ptr
while True:
cmd = aram[track_end]
assert cmd != 0, f"Track cannot have zero length, start=${track_ptr:04X} addr=${track_end:04X}"
if cmd == 0xE9: # Call subroutine
sub_ptr = aram.read_multi(track_end + 1, 2)
if sub_ptr not in sub_ptr_to_idx:
sub_idx = len(track_list)
track_list.append(sub_ptr)
sub_ptr_to_idx[sub_ptr] = sub_idx
sub_idx_to_ptr[sub_idx] = sub_ptr
else:
sub_idx = sub_ptr_to_idx[sub_ptr]
aram.write_multi(track_end + 1, sub_idx, 2)
track_end += cmd_len_fn(cmd)
if aram[track_end] == 0 or track_end >= next_track:
break
assert aram[track_end] == 0 or track_end == next_track, \
f"Coding error, track {track_ptr:04X} end {track_end:04X} overlaps with next {next_track:04X}"
if aram[track_end] == 0:
track_ends_in_zero.add(track_ptr)
tracks[track_ptr] = aram[track_ptr:track_end].to_list()
first_track_addr = min(track_list)
print('Tracks:')
for track_ptr, track_data in tracks.items():
print(f'${track_ptr:04X}: {hexlist(track_data)}')
print('Fixing tracks...')
for track_ptr, track_data in tracks.items():
fix_track(track_ptr, track_data)
print('Computing track addresses...')
output_ptr = first_track_addr
output_track_ptrs: Dict[int, int] = {}
for track_ptr in track_list:
output_track_ptrs[track_ptr] = output_ptr
track_data = tracks[track_ptr]
if track_ptr in track_ends_in_zero:
track_data.append(0)
output_ptr += len(track_data)
song_end = output_ptr
largest_ptr = song_end
print('Tracks:')
for track_ptr, track_data in tracks.items():
if track_ptr in sub_ptr_to_idx:
sub_idx = sub_ptr_to_idx[track_ptr]
print(f'sub{sub_idx:02X} ', end='')
print(f'${track_ptr:04X}->${output_track_ptrs[track_ptr]:04X}: {hexlist(track_data) }')
print('Inserting tracks...')
for track_ptr in track_list:
output_ptr = output_track_ptrs[track_ptr]
track_data = tracks[track_ptr]
next_track = output_ptr + len(track_data)
aram[output_ptr:next_track] = track_data
# Relocate subroutines
track_end = output_ptr
while True:
cmd = aram[track_end]
assert cmd != 0, "Track cannot have zero length"
if cmd == 0xEF: # Call subroutine
sub_idx = aram.read_multi(track_end + 1, 2)
sub_ptr = sub_idx_to_ptr[sub_idx]
output_sub_ptr = output_track_ptrs[sub_ptr]
aram.write_multi(track_end + 1, output_sub_ptr, 2)
print(f"Overwrote pointer at ${track_end+1:04X} with ${output_sub_ptr:04X}")
track_end += cmd_len_eb_fn(cmd)
if aram[track_end] == 0 or track_end >= next_track:
break
print('Correcting track references...')
for track_ptr, refs in track_refs.items():
output_track_ptr = output_track_ptrs[track_ptr]
for ref in refs:
aram.write_multi(ref, output_track_ptr, 2)
finally:
pass
return song_ptr, aram[song_ptr:largest_ptr]
def extract_ebm_wrapper(fp_spc, fp_ebm):
with open(fp_spc, 'rb') as f:
spc_data = f.read()
ebm_addr, ebm_data = extract_ebm(spc_data)
ebm_size = len(ebm_data)
ebm = Block(ebm_size + 4)
ebm.write_multi(0, ebm_size, 2)
ebm.write_multi(2, ebm_addr, 2)
ebm[4:] = ebm_data
ebm.to_file(fp_ebm)
if __name__ == "__main__":
args = sys.argv[1:]
if not 1 <= len(args) <= 2:
print('Usage: python extract.py <SPC file in> [EBM file out]')
print('Default output EBM is <SPC file>.ebm')
sys.exit(1)
fp_spcin = args[0]
if len(args) >= 2:
fp_romout = args[1]
else:
if fp_spcin.endswith('.spc'):
fp_romout = fp_spcin[:-3] + 'ebm'
else:
fp_romout = fp_spcin + '.ebm'
extract_ebm_wrapper(fp_spcin, fp_romout)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment