Skip to content

Instantly share code, notes, and snippets.

@AlexApps99
Created April 7, 2023 09:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AlexApps99/974d45d71f94a9187209cf4e31780a21 to your computer and use it in GitHub Desktop.
Save AlexApps99/974d45d71f94a9187209cf4e31780a21 to your computer and use it in GitHub Desktop.
Convert a .WAV music file into MIDI using a bunch of numbers and stuff
'''
Converts a .WAV music file into MIDI using a bunch of numbers and stuff
The gist of the program is:
- Load a mono WAV file
- Move a rolling window along the audio
- Multiply that window by a "hanning taper window" (this step is optional, it’s basically multiplying it by a bell-curve-like shape so the middle of the window has the most influence)
- Run an FFT on that window
- Set the velocity of each piano note to the average magnitude of its closest frequencies
- Copy the piano notes from each window into a MIDI file, holding down or skipping notes where needed
If you wanna run it, you'll need numpy and matplotlib (but you probably have that already), and you'll need to install mido (a Python MIDI library) from pip.
Just chuck a mono WAV file at "in.wav", and a MIDI file will be saved to "out.mid".
'''
import wave
import numpy as np
from mido import MetaMessage, Message, MidiFile, MidiTrack
def hz2mid(hz):
'''
Converts frequency to MIDI note ID (epsilon is added to prevent log2 0)
'''
return 12.0 * np.log2(hz / 440.0 + 1e-300) + 69.0
def get_audio_data(path):
'''
Loads mono 16-bit WAV from path
'''
with wave.open(path, 'rb') as f:
nchannels, sampwidth, framerate, nframes, comptype, compname = f.getparams()
assert nchannels == 1, "WAV must be mono"
assert sampwidth == 2, "WAV must be 16-bit"
frames = f.readframes(nframes)
# Load as little-endian int16
buf = np.frombuffer(frames, '<i2')
# Normalize to [-1, 1)
normalized_frames = np.float64(buf) / 32768
return normalized_frames, framerate
def fft_rolling_windows(frames, framerate, steps_per_sec, wins_per_sec):
'''
Returns a generator of FFT results over a rolling window
'''
# Split audio into 25 windows per second (IDK a better approach)
win_size = framerate // wins_per_sec
half_win_size = win_size // 2
step_size = framerate // steps_per_sec
taper = np.hanning(win_size)
frames = np.pad(frames, max(half_win_size, win_size-half_win_size))
# size of float64 in bytes
el_size = 8
rolling_windows = np.lib.stride_tricks.as_strided(frames, [frames.size // step_size, win_size], [step_size*el_size, 1*el_size]) * taper
rffts = np.fft.rfft(rolling_windows)
freq = np.fft.rfftfreq(win_size, 1/framerate)
return rffts, freq
def note_velocities(rffts, freq):
'''
Lumps the FFT frequencies into individual note velocities
'''
global midi_output_matrix
mid_notes = hz2mid(freq)
# Each column of the matrix is dot-producted with each RFFT row, leaving a velocity for each MIDI note.
# A matrix column corresponds to a MIDI note, and a row corresponds to the weighting of a given RFFT frequency.
midi_output_matrix = np.zeros((len(freq), 128), dtype=np.float64)
for i, note in enumerate(mid_notes):
if note > -1 and note < 128 and round(note) >= 0 and round(note) <= 127:
midi_output_matrix[i, round(note)] = 1
# Rather than having each frequency have a weight of 1, they should have a weight of 1/(number of frequencies contributing to note), so it's more like an average
column_sum = midi_output_matrix.sum(axis=0)
midi_output_matrix *= np.reciprocal(column_sum, where=column_sum != 0)
notes_list = np.matmul(np.abs(rffts), midi_output_matrix)
return notes_list
def make_midi(notes_list, steps_per_sec, vol_multiply=1, max_note=None):
'''
Creates a mido MidiFile object with the provided note data
'''
max_vel = max(max(notes) for notes in notes_list)
print("Max vel:", max_vel)
mid = MidiFile(type=0, ticks_per_beat=1)
track = mid.add_track()
track.append(MetaMessage('set_tempo', tempo=1000000//steps_per_sec, time=0))
# TODO hold note until velocity is too different
note_vel = [0 for n in range(128)]
for notes in notes_list:
# outliers seem to make a lot of things too quiet
notes_tweaked = [min(round((v / max_vel)*127 * vol_multiply), 127) for v in notes]
# filter out relevant notes that are similar to note_velocities
notes_indexed = [(i, v) for i, v in enumerate(notes_tweaked) if (max_note is None or i < max_note) and abs(note_vel[i]-v) >= 8]
# for each note: if note_velocities is zero, do note_on, otherwise, do note_off note_on
for i, vel in notes_indexed:
# clip quiet notes
if vel <= 4:
vel = 0
if note_vel[i] != 0:
track.append(Message('note_off', note=i, velocity=note_vel[i], time=0))
if vel != 0:
track.append(Message('note_on', note=i, velocity=vel, time=0))
note_vel[i] = vel
track.append(Message('sysex', data=[], time=1))
# Remove all held notes at end of song
for i, v in enumerate(note_vel):
track.append(Message('note_off', note=i, velocity=v, time=0))
note_vel[i] = 0
return mid
if __name__ == "__main__":
import matplotlib as mpl
import matplotlib.pyplot as plt
STEPS_PER_SEC = 10
WINS_PER_SEC = 5
VOL_MULTIPLY = 2.0
MAX_NOTE = None # was 84
frames, framerate = get_audio_data("in.wav")
print(f"Loaded WAV at {framerate} Hz sample rate ({frames.size} samples)")
rffts, freq = fft_rolling_windows(frames, framerate, steps_per_sec=STEPS_PER_SEC, wins_per_sec=WINS_PER_SEC)
print("Calculated FFT")
notes_list = note_velocities(rffts, freq)
print("Generated note velocities")
make_midi(notes_list, steps_per_sec=STEPS_PER_SEC, vol_multiply=VOL_MULTIPLY, max_note=MAX_NOTE).save('out.mid')
print("Saved MIDI")
fig, (plot_a, plot_b, plot_c) = plt.subplots(3)
plot_a.title.set_text("FFT")
m1 = plot_a.matshow(np.abs(rffts.transpose()), aspect='auto')
plot_a.set_xlabel('steps')
plot_a.set_ylabel('frequencies (TODO make the key clearer)')
fig.colorbar(m1, ax=plot_a)
plot_b.title.set_text("MIDI note velocity")
m2 = plot_b.matshow(notes_list.transpose(), aspect='auto')
plot_b.set_xlabel('steps')
plot_b.set_ylabel('MIDI notes')
fig.colorbar(m2, ax=plot_b)
plot_c.title.set_text("Translation table")
m3 = plot_c.matshow(midi_output_matrix, aspect='auto')
fig.colorbar(m3, ax=plot_c)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment