Skip to content

Instantly share code, notes, and snippets.

@Astroneko404
Last active January 12, 2024 06:53
Show Gist options
  • Save Astroneko404/1dcde11576e510e964882bbbafaeb050 to your computer and use it in GitHub Desktop.
Save Astroneko404/1dcde11576e510e964882bbbafaeb050 to your computer and use it in GitHub Desktop.
Music key determination algorithm in Python using Krumhansl-Kessler weight, blog: https://astroneko404.github.io//Key-Determination/
from math import sqrt
from mido import MidiFile
MODE_TABLE = [
"C Major", "C Minor", "C# Major", "C# Minor", "D Major", "D Minor", "D# Major", "D# Minor", "E Major",
"E Minor", "F Major", "F Minor", "F# Major", "F# Minor", "G Major", "G Minor", "G# Major", "G# Minor",
"A Major", "A Minor", "A# Major", "A# Minor", "B Major", "B Minor"
]
# Weight arrays retrieved from music21
KRUMHANSL_SCHMUCKLER_MAJOR = [6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88]
KRUMHANSL_SCHMUCKLER_MINOR = [6.33, 2.68, 3.52, 5.38, 2.6, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17]
BELLMAN_BUDGE_MAJOR = [16.8, 0.86, 12.95, 1.41, 13.49, 11.93, 1.25, 20.28, 1.8, 8.04, 0.62, 10.57]
BELLMAN_BUDGE_MINOR = [18.16, 0.69, 12.99, 13.34, 1.07, 11.15, 1.38, 21.07, 7.49, 1.53, 0.92, 10.21]
def argmax(lst: list):
"""
Returns the index of the maximum value
:param lst: Input array
:return: The index
"""
def f(i):
return lst[i]
return max(range(len(lst)), key=f)
def getAllMidiNotes(mid: MidiFile):
"""
Retrieve midi note information
:param mid: mido MidiFile
:return: A list containing all motes with their midi notes number and time
"""
midi_notes = []
for track in mid.tracks:
for msg in track:
msg_dict = msg.dict()
if msg_dict["type"] == "note_on" and msg_dict["channel"] != 9: # Remove the percussion channel
midi_notes.append([msg_dict["note"], msg_dict["time"]])
return midi_notes
def getPitchDuration(note_list):
"""
Calculate the duration of each pitch class
:param note_list: The list returned by getAllMidiNotes
:return: Pitch duration array for Pearson correlation calculation
"""
pitch_duration = [0 for _ in range(12)]
for note, length in note_list:
pitch_duration[note % 12] += length
return pitch_duration
def DeterminateKey(midi_file):
"""
Use weights to perform key determination
:param midi_file: mido MidiFile
:return: Pearson correlation result (in list)
"""
note_length_vector = getPitchDuration(getAllMidiNotes(midi_file))
corr_res = []
for idx in range(24):
shift_idx = 12 - idx // 2
weight = KRUMHANSL_SCHMUCKLER_MAJOR[shift_idx:] + \
KRUMHANSL_SCHMUCKLER_MAJOR[:shift_idx] if idx % 2 == 0 else \
KRUMHANSL_SCHMUCKLER_MINOR[shift_idx:] + KRUMHANSL_SCHMUCKLER_MINOR[:shift_idx]
corr_res.append(pearsonCorrelation(note_length_vector, weight))
return corr_res
def pearsonCorrelation(x, y):
"""
Calculate the Pearson correlation of two arrays
:param x: List of numbers
:param y: List of numbers
:return: Pearson correlation
"""
assert len(x) == len(y)
n = len(x)
sum_x = sum(x)
sum_y = sum(y)
sum_x2 = sum(item**2 for item in x)
sum_y2 = sum(item**2 for item in y)
sum_xy = sum(x[i] * y[i] for i in range(len(x)))
return (n * sum_xy - sum_x * sum_y) / \
(sqrt((n * sum_x2 - sum_x ** 2) * (n * sum_y2 - sum_y ** 2)))
if __name__ == "__main__":
midi = MidiFile("[midi]/ff8_blue_fields.mid")
res = DeterminateKey(midi)
print(MODE_TABLE[argmax(res)])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment