Last active
January 12, 2024 06:53
-
-
Save Astroneko404/1dcde11576e510e964882bbbafaeb050 to your computer and use it in GitHub Desktop.
Music key determination algorithm in Python using Krumhansl-Kessler weight, blog: https://astroneko404.github.io//Key-Determination/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from math import sqrt | |
from mido import MidiFile | |
MODE_TABLE = [ | |
"C Major", "C Minor", "C# Major", "C# Minor", "D Major", "D Minor", "D# Major", "D# Minor", "E Major", | |
"E Minor", "F Major", "F Minor", "F# Major", "F# Minor", "G Major", "G Minor", "G# Major", "G# Minor", | |
"A Major", "A Minor", "A# Major", "A# Minor", "B Major", "B Minor" | |
] | |
# Weight arrays retrieved from music21 | |
KRUMHANSL_SCHMUCKLER_MAJOR = [6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88] | |
KRUMHANSL_SCHMUCKLER_MINOR = [6.33, 2.68, 3.52, 5.38, 2.6, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17] | |
BELLMAN_BUDGE_MAJOR = [16.8, 0.86, 12.95, 1.41, 13.49, 11.93, 1.25, 20.28, 1.8, 8.04, 0.62, 10.57] | |
BELLMAN_BUDGE_MINOR = [18.16, 0.69, 12.99, 13.34, 1.07, 11.15, 1.38, 21.07, 7.49, 1.53, 0.92, 10.21] | |
def argmax(lst: list): | |
""" | |
Returns the index of the maximum value | |
:param lst: Input array | |
:return: The index | |
""" | |
def f(i): | |
return lst[i] | |
return max(range(len(lst)), key=f) | |
def getAllMidiNotes(mid: MidiFile): | |
""" | |
Retrieve midi note information | |
:param mid: mido MidiFile | |
:return: A list containing all motes with their midi notes number and time | |
""" | |
midi_notes = [] | |
for track in mid.tracks: | |
for msg in track: | |
msg_dict = msg.dict() | |
if msg_dict["type"] == "note_on" and msg_dict["channel"] != 9: # Remove the percussion channel | |
midi_notes.append([msg_dict["note"], msg_dict["time"]]) | |
return midi_notes | |
def getPitchDuration(note_list): | |
""" | |
Calculate the duration of each pitch class | |
:param note_list: The list returned by getAllMidiNotes | |
:return: Pitch duration array for Pearson correlation calculation | |
""" | |
pitch_duration = [0 for _ in range(12)] | |
for note, length in note_list: | |
pitch_duration[note % 12] += length | |
return pitch_duration | |
def DeterminateKey(midi_file): | |
""" | |
Use weights to perform key determination | |
:param midi_file: mido MidiFile | |
:return: Pearson correlation result (in list) | |
""" | |
note_length_vector = getPitchDuration(getAllMidiNotes(midi_file)) | |
corr_res = [] | |
for idx in range(24): | |
shift_idx = 12 - idx // 2 | |
weight = KRUMHANSL_SCHMUCKLER_MAJOR[shift_idx:] + \ | |
KRUMHANSL_SCHMUCKLER_MAJOR[:shift_idx] if idx % 2 == 0 else \ | |
KRUMHANSL_SCHMUCKLER_MINOR[shift_idx:] + KRUMHANSL_SCHMUCKLER_MINOR[:shift_idx] | |
corr_res.append(pearsonCorrelation(note_length_vector, weight)) | |
return corr_res | |
def pearsonCorrelation(x, y): | |
""" | |
Calculate the Pearson correlation of two arrays | |
:param x: List of numbers | |
:param y: List of numbers | |
:return: Pearson correlation | |
""" | |
assert len(x) == len(y) | |
n = len(x) | |
sum_x = sum(x) | |
sum_y = sum(y) | |
sum_x2 = sum(item**2 for item in x) | |
sum_y2 = sum(item**2 for item in y) | |
sum_xy = sum(x[i] * y[i] for i in range(len(x))) | |
return (n * sum_xy - sum_x * sum_y) / \ | |
(sqrt((n * sum_x2 - sum_x ** 2) * (n * sum_y2 - sum_y ** 2))) | |
if __name__ == "__main__": | |
midi = MidiFile("[midi]/ff8_blue_fields.mid") | |
res = DeterminateKey(midi) | |
print(MODE_TABLE[argmax(res)]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment