Created
March 13, 2023 15:28
-
-
Save Natooz/8984128b55b2144b9af948698b2e8904 to your computer and use it in GitHub Desktop.
Test chord detection methods
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 python | |
"""Test chord detection methods | |
""" | |
from dataclasses import dataclass | |
from typing import List, Tuple, Dict, cast | |
from copy import copy | |
import numpy as np | |
from miditoolkit import MidiFile, Instrument, Note | |
from miditoolkit.pianoroll.parser import notes2pianoroll | |
from miditoolkit.pianoroll.utils import tochroma | |
import miditok | |
from miditok.constants import CHORD_MAPS | |
from miditok.utils import detect_chords | |
_PITCH_CLASSES = [ | |
"C", | |
"C#", | |
"D", | |
"D#", | |
"E", | |
"F", | |
"F#", | |
"G", | |
"G#", | |
"A", | |
"A#", | |
"B", | |
] | |
# define chord maps (required) | |
_CHORD_MAPS = { | |
"maj": [0, 4], | |
"min": [0, 3], | |
"dim": [0, 3, 6], | |
"aug": [0, 4, 8], | |
"dom": [0, 4, 7, 10], | |
} | |
# define chord insiders (+1) | |
_CHORD_INSIDERS = {"maj": [7], "min": [7], "dim": [9], "aug": [], "dom": []} | |
# define chord outsiders (-1) | |
_CHORD_OUTSIDERS_1 = { | |
"maj": [2, 5, 9], | |
"min": [2, 5, 8], | |
"dim": [2, 5, 10], | |
"aug": [2, 5, 9], | |
"dom": [2, 5, 9], | |
} | |
# define chord outsiders (-2) | |
_CHORD_OUTSIDERS_2 = { | |
"maj": [1, 3, 6, 8, 10], | |
"min": [1, 4, 6, 9, 11], | |
"dim": [1, 4, 7, 8, 11], | |
"aug": [1, 3, 6, 7, 10], | |
"dom": [1, 3, 6, 8, 11], | |
} | |
class REMIPlusChord: | |
""" | |
Originally implemented in the REMI original repository | |
<https://github.com/YatingMusic/remi/blob/master/chord_recognition.py> | |
""" | |
@classmethod | |
def __get_candidates(cls, chroma: np.ndarray) -> Dict[int, List[int]]: | |
candidates: Dict[int, List[int]] = {} | |
for index in range(len(chroma)): | |
if chroma[index]: | |
root_note = index | |
_chroma = np.roll(chroma, -root_note) | |
sequence = np.where(_chroma == 1)[0] | |
candidates[root_note] = list(sequence) | |
return candidates | |
@classmethod | |
def __get_score( | |
cls, candidates: Dict[int, List[int]] | |
) -> Tuple[Dict[int, int], Dict[int, str]]: | |
scores: Dict[int, int] = {} | |
qualities: Dict[int, str] = {} | |
for root_note, sequence in candidates.items(): | |
if 3 not in sequence and 4 not in sequence: | |
scores[root_note] = -100 | |
qualities[root_note] = "None" | |
elif 3 in sequence and 4 in sequence: | |
scores[root_note] = -100 | |
qualities[root_note] = "None" | |
else: | |
# decide quality | |
if 3 in sequence: | |
if 6 in sequence: | |
quality = "dim" | |
else: | |
quality = "min" | |
elif 4 in sequence: | |
if 8 in sequence: | |
quality = "aug" | |
else: | |
if 7 in sequence and 10 in sequence: | |
quality = "dom" | |
else: | |
quality = "maj" | |
else: | |
quality = "" | |
# decide score rules | |
maps = _CHORD_MAPS.get(quality, []) | |
score = 0 | |
_notes = [n for n in sequence if n not in maps] | |
for n in _notes: | |
if n in _CHORD_OUTSIDERS_1.get(quality, []): | |
score -= 1 | |
elif n in _CHORD_OUTSIDERS_2.get(quality, []): | |
score -= 2 | |
elif n in _CHORD_INSIDERS.get(quality, []): | |
score += 1 | |
scores[root_note] = score | |
qualities[root_note] = quality | |
return scores, qualities | |
@classmethod | |
def __find_chord(cls, pianoroll: np.ndarray) -> Tuple[str, str, str, int]: | |
chroma: np.ndarray = tochroma(pianoroll=pianoroll) | |
chroma = np.sum(chroma, axis=0) | |
chroma = np.array([1 if c else 0 for c in chroma]) | |
if np.sum(chroma) == 0: | |
return "None", "None", "None", 0 | |
else: | |
candidates = cls.__get_candidates(chroma=chroma) | |
scores, qualities = cls.__get_score(candidates=candidates) | |
# bass note | |
sorted_notes = [] | |
for i, v in enumerate(np.sum(pianoroll, axis=0)): | |
if v > 0: | |
sorted_notes.append(int(i % 12)) | |
bass_note = sorted_notes[0] | |
# root note | |
__root_note = [] | |
_max = max(scores.values()) | |
for _root_note, score in scores.items(): | |
if score == _max: | |
__root_note.append(_root_note) | |
root_note = None | |
if len(__root_note) == 1: | |
root_note = __root_note[0] | |
else: | |
for n in sorted_notes: | |
if n in __root_note: | |
root_note = n | |
break | |
if root_note is None: | |
return "None", "None", "None", 0 # no root found | |
# quality | |
quality = qualities.get(root_note, "None") | |
sequence = candidates.get(root_note, []) | |
# score | |
score = scores.get(root_note, 0) | |
return ( | |
_PITCH_CLASSES[root_note], | |
quality, | |
_PITCH_CLASSES[bass_note], | |
score, | |
) | |
@classmethod | |
def __solve( | |
cls, | |
candidates: Dict[int, Dict[int, Tuple[int, float, int, float]]], | |
max_tick: int, | |
) -> List[Tuple[int, int, str]]: | |
chords: List[Tuple[int, int, str]] = [] | |
start_tick = 0 | |
while start_tick < max_tick: | |
_candidates = candidates.get(start_tick, {}) | |
_candidates = sorted(_candidates.items(), key=lambda x: (x[1][-1], x[0])) | |
# choose | |
end_tick, (root_note, quality, bass_note, _) = _candidates[-1] | |
if root_note == bass_note: | |
chord = "{}:{}".format(root_note, quality) | |
else: | |
chord = "{}:{}/{}".format(root_note, quality, bass_note) | |
chords.append((start_tick, end_tick, chord)) | |
start_tick = end_tick | |
# remove :None | |
__temp = copy(chords) | |
while ":None" in str(__temp[0][-1]): | |
try: | |
_new_head = (__temp[0][0], __temp[1][1], __temp[1][2]) | |
del __temp[0] # delete None | |
__temp = [_new_head] + __temp[1:] | |
except: | |
return [] | |
__temp2 = [] | |
for chord in __temp: | |
if ":None" not in str(chord[-1]): | |
__temp2.append(chord) | |
else: | |
# __temp2[-1][1] = chord[1] | |
__temp2 = __temp2[:-1] + [(__temp2[-1][0], chord[1], __temp2[-1][2])] | |
return __temp2 | |
@classmethod | |
def extract(cls, notes: List[Note], ticks_per_beat: int) -> List[Tuple[int, int, str]]: | |
# read | |
max_tick = int(max([n.end for n in notes])) | |
pianoroll = notes2pianoroll( | |
note_stream_ori=notes, max_tick=max_tick, ticks_per_beat=ticks_per_beat | |
) | |
pianoroll = cast(np.ndarray, pianoroll) | |
# get lots of candidates | |
candidates = {} | |
# the shortest: 2 beat (1/2 bar in 4/4), longest: 4 beat (1bar in 4/4) | |
for interval in [4, 2]: | |
for start_tick in range(0, max_tick, ticks_per_beat): | |
end_tick = int(ticks_per_beat * interval + start_tick) | |
if end_tick > max_tick: | |
end_tick = max_tick | |
part_pianoroll = pianoroll[start_tick:end_tick, :] | |
# find chord | |
root_note, quality, bass_note, score = cls.__find_chord( | |
pianoroll=part_pianoroll | |
) | |
# save | |
if start_tick not in candidates: | |
candidates[start_tick] = {} | |
candidates[start_tick][end_tick] = ( | |
root_note, | |
quality, | |
bass_note, | |
score, | |
) | |
else: | |
if end_tick not in candidates[start_tick]: | |
candidates[start_tick][end_tick] = ( | |
root_note, | |
quality, | |
bass_note, | |
score, | |
) | |
chords = cls.__solve(candidates=candidates, max_tick=max_tick) | |
return chords | |
def test_chords(): | |
@dataclass | |
class Chord: | |
notes: List[Note] | |
quality: str | |
def __post_init__(self): | |
self.start = min([n.start for n in self.notes]) | |
def __str__(self): | |
return f"Chord {self.quality} ar tick {self.start}" | |
time_division = 480 | |
tokenizer = miditok.REMI() | |
chords: List[Chord] = [] | |
all_notes: List[Note] = [] | |
# Adds chords | |
chord_args = [ | |
(0, 384, 50, "min"), | |
(256, 512, 70, "maj"), | |
(384, 512, 60, "dim"), | |
(600, 700, 70, "aug"), | |
(700, 1000, 70, "sus2"), | |
] | |
for start, end, pitch, chord_quality in chord_args: | |
notes = [Note(120, pitch + offset, start, end) for offset in CHORD_MAPS[chord_quality]] | |
chords.append(Chord(notes, chord_quality)) | |
all_notes += notes | |
# Creates MIDI and preprocess it | |
midi = MidiFile(ticks_per_beat=time_division) | |
midi.instruments = [Instrument(0)] | |
midi.instruments[-1].notes = all_notes.copy() | |
tokenizer.preprocess_midi(midi) | |
# Detect_chords | |
detected_chords = detect_chords(midi.instruments[0].notes, time_division, tokenizer._first_beat_res) | |
detected_chords2 = REMIPlusChord.extract(midi.instruments[0].notes, time_division) | |
print(f"\nChords detected with MidiTok method:") | |
for chord in detected_chords: | |
print(f"Chord {chord.value} - starting at tick: {chord.time} - {chord.desc}") | |
print(f"\nChords detected with REMI repo method:") | |
for chord in detected_chords2: | |
print(chord) | |
if __name__ == "__main__": | |
test_chords() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment