Skip to content

Instantly share code, notes, and snippets.

@Natooz
Created March 13, 2023 15:28
Show Gist options
  • Save Natooz/8984128b55b2144b9af948698b2e8904 to your computer and use it in GitHub Desktop.
Save Natooz/8984128b55b2144b9af948698b2e8904 to your computer and use it in GitHub Desktop.
Test chord detection methods
#!/usr/bin/python3 python
"""Test chord detection methods
"""
from dataclasses import dataclass
from typing import List, Tuple, Dict, cast
from copy import copy
import numpy as np
from miditoolkit import MidiFile, Instrument, Note
from miditoolkit.pianoroll.parser import notes2pianoroll
from miditoolkit.pianoroll.utils import tochroma
import miditok
from miditok.constants import CHORD_MAPS
from miditok.utils import detect_chords
_PITCH_CLASSES = [
"C",
"C#",
"D",
"D#",
"E",
"F",
"F#",
"G",
"G#",
"A",
"A#",
"B",
]
# define chord maps (required)
_CHORD_MAPS = {
"maj": [0, 4],
"min": [0, 3],
"dim": [0, 3, 6],
"aug": [0, 4, 8],
"dom": [0, 4, 7, 10],
}
# define chord insiders (+1)
_CHORD_INSIDERS = {"maj": [7], "min": [7], "dim": [9], "aug": [], "dom": []}
# define chord outsiders (-1)
_CHORD_OUTSIDERS_1 = {
"maj": [2, 5, 9],
"min": [2, 5, 8],
"dim": [2, 5, 10],
"aug": [2, 5, 9],
"dom": [2, 5, 9],
}
# define chord outsiders (-2)
_CHORD_OUTSIDERS_2 = {
"maj": [1, 3, 6, 8, 10],
"min": [1, 4, 6, 9, 11],
"dim": [1, 4, 7, 8, 11],
"aug": [1, 3, 6, 7, 10],
"dom": [1, 3, 6, 8, 11],
}
class REMIPlusChord:
"""
Originally implemented in the REMI original repository
<https://github.com/YatingMusic/remi/blob/master/chord_recognition.py>
"""
@classmethod
def __get_candidates(cls, chroma: np.ndarray) -> Dict[int, List[int]]:
candidates: Dict[int, List[int]] = {}
for index in range(len(chroma)):
if chroma[index]:
root_note = index
_chroma = np.roll(chroma, -root_note)
sequence = np.where(_chroma == 1)[0]
candidates[root_note] = list(sequence)
return candidates
@classmethod
def __get_score(
cls, candidates: Dict[int, List[int]]
) -> Tuple[Dict[int, int], Dict[int, str]]:
scores: Dict[int, int] = {}
qualities: Dict[int, str] = {}
for root_note, sequence in candidates.items():
if 3 not in sequence and 4 not in sequence:
scores[root_note] = -100
qualities[root_note] = "None"
elif 3 in sequence and 4 in sequence:
scores[root_note] = -100
qualities[root_note] = "None"
else:
# decide quality
if 3 in sequence:
if 6 in sequence:
quality = "dim"
else:
quality = "min"
elif 4 in sequence:
if 8 in sequence:
quality = "aug"
else:
if 7 in sequence and 10 in sequence:
quality = "dom"
else:
quality = "maj"
else:
quality = ""
# decide score rules
maps = _CHORD_MAPS.get(quality, [])
score = 0
_notes = [n for n in sequence if n not in maps]
for n in _notes:
if n in _CHORD_OUTSIDERS_1.get(quality, []):
score -= 1
elif n in _CHORD_OUTSIDERS_2.get(quality, []):
score -= 2
elif n in _CHORD_INSIDERS.get(quality, []):
score += 1
scores[root_note] = score
qualities[root_note] = quality
return scores, qualities
@classmethod
def __find_chord(cls, pianoroll: np.ndarray) -> Tuple[str, str, str, int]:
chroma: np.ndarray = tochroma(pianoroll=pianoroll)
chroma = np.sum(chroma, axis=0)
chroma = np.array([1 if c else 0 for c in chroma])
if np.sum(chroma) == 0:
return "None", "None", "None", 0
else:
candidates = cls.__get_candidates(chroma=chroma)
scores, qualities = cls.__get_score(candidates=candidates)
# bass note
sorted_notes = []
for i, v in enumerate(np.sum(pianoroll, axis=0)):
if v > 0:
sorted_notes.append(int(i % 12))
bass_note = sorted_notes[0]
# root note
__root_note = []
_max = max(scores.values())
for _root_note, score in scores.items():
if score == _max:
__root_note.append(_root_note)
root_note = None
if len(__root_note) == 1:
root_note = __root_note[0]
else:
for n in sorted_notes:
if n in __root_note:
root_note = n
break
if root_note is None:
return "None", "None", "None", 0 # no root found
# quality
quality = qualities.get(root_note, "None")
sequence = candidates.get(root_note, [])
# score
score = scores.get(root_note, 0)
return (
_PITCH_CLASSES[root_note],
quality,
_PITCH_CLASSES[bass_note],
score,
)
@classmethod
def __solve(
cls,
candidates: Dict[int, Dict[int, Tuple[int, float, int, float]]],
max_tick: int,
) -> List[Tuple[int, int, str]]:
chords: List[Tuple[int, int, str]] = []
start_tick = 0
while start_tick < max_tick:
_candidates = candidates.get(start_tick, {})
_candidates = sorted(_candidates.items(), key=lambda x: (x[1][-1], x[0]))
# choose
end_tick, (root_note, quality, bass_note, _) = _candidates[-1]
if root_note == bass_note:
chord = "{}:{}".format(root_note, quality)
else:
chord = "{}:{}/{}".format(root_note, quality, bass_note)
chords.append((start_tick, end_tick, chord))
start_tick = end_tick
# remove :None
__temp = copy(chords)
while ":None" in str(__temp[0][-1]):
try:
_new_head = (__temp[0][0], __temp[1][1], __temp[1][2])
del __temp[0] # delete None
__temp = [_new_head] + __temp[1:]
except:
return []
__temp2 = []
for chord in __temp:
if ":None" not in str(chord[-1]):
__temp2.append(chord)
else:
# __temp2[-1][1] = chord[1]
__temp2 = __temp2[:-1] + [(__temp2[-1][0], chord[1], __temp2[-1][2])]
return __temp2
@classmethod
def extract(cls, notes: List[Note], ticks_per_beat: int) -> List[Tuple[int, int, str]]:
# read
max_tick = int(max([n.end for n in notes]))
pianoroll = notes2pianoroll(
note_stream_ori=notes, max_tick=max_tick, ticks_per_beat=ticks_per_beat
)
pianoroll = cast(np.ndarray, pianoroll)
# get lots of candidates
candidates = {}
# the shortest: 2 beat (1/2 bar in 4/4), longest: 4 beat (1bar in 4/4)
for interval in [4, 2]:
for start_tick in range(0, max_tick, ticks_per_beat):
end_tick = int(ticks_per_beat * interval + start_tick)
if end_tick > max_tick:
end_tick = max_tick
part_pianoroll = pianoroll[start_tick:end_tick, :]
# find chord
root_note, quality, bass_note, score = cls.__find_chord(
pianoroll=part_pianoroll
)
# save
if start_tick not in candidates:
candidates[start_tick] = {}
candidates[start_tick][end_tick] = (
root_note,
quality,
bass_note,
score,
)
else:
if end_tick not in candidates[start_tick]:
candidates[start_tick][end_tick] = (
root_note,
quality,
bass_note,
score,
)
chords = cls.__solve(candidates=candidates, max_tick=max_tick)
return chords
def test_chords():
@dataclass
class Chord:
notes: List[Note]
quality: str
def __post_init__(self):
self.start = min([n.start for n in self.notes])
def __str__(self):
return f"Chord {self.quality} ar tick {self.start}"
time_division = 480
tokenizer = miditok.REMI()
chords: List[Chord] = []
all_notes: List[Note] = []
# Adds chords
chord_args = [
(0, 384, 50, "min"),
(256, 512, 70, "maj"),
(384, 512, 60, "dim"),
(600, 700, 70, "aug"),
(700, 1000, 70, "sus2"),
]
for start, end, pitch, chord_quality in chord_args:
notes = [Note(120, pitch + offset, start, end) for offset in CHORD_MAPS[chord_quality]]
chords.append(Chord(notes, chord_quality))
all_notes += notes
# Creates MIDI and preprocess it
midi = MidiFile(ticks_per_beat=time_division)
midi.instruments = [Instrument(0)]
midi.instruments[-1].notes = all_notes.copy()
tokenizer.preprocess_midi(midi)
# Detect_chords
detected_chords = detect_chords(midi.instruments[0].notes, time_division, tokenizer._first_beat_res)
detected_chords2 = REMIPlusChord.extract(midi.instruments[0].notes, time_division)
print(f"\nChords detected with MidiTok method:")
for chord in detected_chords:
print(f"Chord {chord.value} - starting at tick: {chord.time} - {chord.desc}")
print(f"\nChords detected with REMI repo method:")
for chord in detected_chords2:
print(chord)
if __name__ == "__main__":
test_chords()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment