Skip to content

Instantly share code, notes, and snippets.

@ychalier
Created June 20, 2021 18:42
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save ychalier/a5951a532e82ee0c6d5764279420e839 to your computer and use it in GitHub Desktop.
Save ychalier/a5951a532e82ee0c6d5764279420e839 to your computer and use it in GitHub Desktop.
Sha-zoom, an implementation of the audio search algorithm described in the Computerphile video 'How Shazam Works (Probably!)' published on YouTube on March 15, 2021 by David Domminney Fowler. Requires FFMPEG.
"""
Sha-zoom, an implementation of the audio search algorithm described in the
Computerphile video 'How Shazam Works (Probably!)' published on YouTube on
March 15, 2021 by David Domminney Fowler. Requires FFMPEG.
"""
import os
import pickle
import logging
import argparse
import subprocess
import wave
import numpy
import matplotlib.pyplot
import simplejson
FFMPEG = "ffmpeg"
def format_timestamp(seconds):
"""
Format a timestamp in seconds into a string in format hh:mm:ss[:;.]ff
where frame numbers if truncated.
"""
return "%02d:%02d:%02d" % (
int(seconds) // 3600,
(int(seconds) % 3600) // 60,
int(seconds) % 60
)
def strip_filename(path):
return os.path.splitext(os.path.basename(path))[0]
class ShazoomOptions:
def __init__(self,
chunk_length_ms=100,
fft_size=2048,
fft_bin_start=4,
fft_bin_end=214,
fft_bin_step=1,
partition_size=6) -> None:
self.chunk_length_ms = chunk_length_ms
self.fft_size = fft_size
self.fft_bin_start = fft_bin_start
self.fft_bin_end = fft_bin_end
self.fft_bin_step = fft_bin_step
self.partition_size = partition_size
class ShazoomDatabase:
def __init__(self, options):
self.options = options
self.tracks = None
self.labels = None
def fit(self, tracks, labels, preprocess=False):
self.tracks = tracks
if preprocess:
for track in self.tracks:
track.preprocess()
self.labels = labels
def save_checkpoint(self, path):
logging.info("Saving checkpoint to '%s'", path)
with open(path, "wb") as file:
pickle.dump({
"options": self.options,
"tracks": self.tracks,
"labels": self.labels
}, file)
@classmethod
def from_checkpoint(cls, path):
with open(path, "rb") as file:
checkpoint = pickle.load(file)
database = cls(checkpoint["options"])
database.fit(checkpoint["tracks"], checkpoint["labels"])
return database
def predict(self, track):
if self.tracks is None or self.labels is None:
raise TypeError
index_max = None
scores = list()
for i, ref in enumerate(self.tracks):
score = track.match(ref)
scores.append(score)
if index_max is None or score > scores[index_max]:
index_max = i
return {
"label": self.labels[index_max],
"score": scores[index_max],
"confidence": scores[index_max] - max(scores[:index_max] + scores[index_max + 1:]),
"details": {
self.labels[i]: score
for i, score in enumerate(scores)
}
}
def predict_on_the_fly(self, path, folder, seek=0, duration=15):
track = ShazoomTrack.from_mp3(
self.options,
path,
folder,
seek=seek,
duration=duration
)
track.preprocess(clear_after=True)
prediction = self.predict(track)
return prediction
class ShazoomTrack:
def __init__(self, options, path):
self.options = options
self.path = path
self.framerate = None
self.waveform = None
self.fft = None
self.print = None
@classmethod
def from_mp3(cls, options, path, folder, seek=0, duration=15):
os.makedirs(folder, exist_ok=True)
wav_path = os.path.join(
folder,
os.path.splitext(os.path.basename(path))[0] + ".wav"
)
process = subprocess.Popen(
[
FFMPEG,
"-hide_banner",
"-loglevel",
"error",
"-ss",
format_timestamp(seek),
"-i",
path,
"-t",
format_timestamp(duration),
"-ac",
"1",
"-acodec",
"pcm_u8",
wav_path,
"-y",
],
stdout=subprocess.PIPE,
stdin=subprocess.PIPE
)
process.wait()
return cls(options, wav_path)
def _compute_waveform(self):
spf = wave.open(self.path, "r")
self.framerate = spf.getframerate()
self.waveform = list()
for _ in range(spf.getnframes()):
frame = spf.readframes(1)
self.waveform.append(
(int.from_bytes(frame, byteorder="big") - 128) / 128)
spf.close()
return self.waveform, self.framerate
def _compute_fft(self):
"""
Use Numpy to compute the Fast Fourier Transform of chunks of a waveform.
The returned array has shape [
(fft_bin_end - fft_bin_start) // fft_bin_step,
nframes / framerate / chunk_length_ms / 1000
]
"""
if self.waveform is None:
raise TypeError("Waveform is None")
chunk_length = int(self.framerate / 1000 *
self.options.chunk_length_ms)
chunks = [
self.waveform[start:start+chunk_length]
for start in range(0, len(self.waveform), chunk_length)
]
fft = numpy.absolute(numpy.fft.fft(chunks, n=self.options.fft_size)).T
self.fft = fft[
self.options.fft_bin_start:self.options.fft_bin_end:self.options.fft_bin_step,
:
]
return self.fft
def _compute_print(self):
"""
Compute the audio print from the FFT.
Returned array has shape [
partition_size,
nframes / framerate / chunk_length_ms / 1000
]
"""
if self.fft is None:
raise TypeError("FFT is None")
partition_length = self.fft.shape[0] // self.options.partition_size
self.print = numpy.zeros(
(self.options.partition_size, self.fft.shape[1]))
for j in range(self.fft.shape[1]):
for i in range(self.options.partition_size):
partition = self.fft[i *
partition_length:(i+1)*partition_length, j]
self.print[i, j] = numpy.argmax(partition) + i*partition_length
return self.print
def preprocess(self, clear_after=False):
logging.info("Preprocessing '%s'", self.path)
self._compute_waveform()
self._compute_fft()
self._compute_print()
if clear_after and os.path.isfile(self.path):
os.remove(self.path)
def plot_waveform(self) -> None:
"""
Plot a waveform.
"""
matplotlib.pyplot.rcParams["font.family"] = "Consolas"
matplotlib.pyplot.figure(figsize=(14, 7))
timeticks = [i / self.framerate for i in range(len(self.waveform))]
matplotlib.pyplot.plot(timeticks, self.waveform)
matplotlib.pyplot.xlabel("Time (s)")
matplotlib.pyplot.ylabel("Amplitude")
matplotlib.pyplot.suptitle(strip_filename(self.path))
matplotlib.pyplot.show()
def plot_print(self) -> None:
"""
Plot the spectrum and the footprint over it.
"""
partition_size = self.print.shape[0]
partition_length = self.fft.shape[0] // partition_size
print_extended = numpy.zeros(self.fft.shape)
max_val = self.fft.max()
for i in range(self.fft.shape[0]):
for j in range(self.fft.shape[1]):
if i // partition_length == 6:
continue
if self.print[i // partition_length, j] == i:
print_extended[i, j] = max_val
matplotlib.pyplot.rcParams["font.family"] = "Consolas"
matplotlib.pyplot.figure(figsize=(14, 7))
matplotlib.pyplot.imshow(
numpy.maximum(self.fft, print_extended),
interpolation="nearest",
aspect="auto",
origin="lower"
)
x_ticks_pos = list(range(
0,
self.fft.shape[1] + 1,
2000 // self.options.chunk_length_ms
))
x_ticks_labels = [
i * self.options.chunk_length_ms // 1000
for i in x_ticks_pos
]
matplotlib.pyplot.xticks(ticks=x_ticks_pos, labels=x_ticks_labels)
matplotlib.pyplot.xlabel("Time (s)")
y_ticks_pos = list(range(
0,
self.fft.shape[0] + 1,
(self.fft.shape[0] + 1) // 10
))
y_ticks_labels = [
round(self.framerate / 2 / (self.options.fft_size / 2)
* (y + self.options.fft_bin_start) * self.options.fft_bin_step)
for y in y_ticks_pos
]
matplotlib.pyplot.yticks(ticks=y_ticks_pos, labels=y_ticks_labels)
matplotlib.pyplot.ylabel("Frequency (Hz)")
matplotlib.pyplot.suptitle(strip_filename(self.path))
matplotlib.pyplot.tight_layout()
matplotlib.pyplot.show()
def match(self, other, match_length=4, look_forward=1):
if self.print is None:
raise TypeError("Self.print is None")
if other.print is None:
raise TypeError("Other.print is None")
match_score = 0
for t_j in range(self.print.shape[1] - look_forward):
for t_i in range(self.print.shape[0]):
next_point_groups = []
for offset in range(match_length):
t_ii = t_i + offset
t_jj = t_j
if t_ii >= self.print.shape[0]:
t_ii -= self.print.shape[0]
t_jj += 1
next_point_groups.append((t_ii, self.print[t_ii, t_jj]))
if len(next_point_groups) < match_length:
continue
for r_j in range(other.print.shape[1] - look_forward):
if self.print[t_i, t_j] == other.print[t_i, r_j]:
local_match_score = 0
for px in range(match_length):
for sx in range(look_forward + 1):
if other.print[next_point_groups[px][0], r_j + sx] == next_point_groups[px][1]:
local_match_score += 1
if local_match_score >= match_length:
match_score += 1
break
return match_score / (self.print.shape[0] * (self.print.shape[1] - look_forward))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("action", type=str, choices=["create_model", "analyze_songs"])
parser.add_argument("-m", "--model", type=str, required=True)
parser.add_argument("-i", "--input", type=str, nargs="+", required=True)
parser.add_argument("-p", "--plot", action="store_true")
parser.add_argument("-f", "--folder", type=str, default=".")
parser.add_argument("--chunk-length", type=int, default=100)
parser.add_argument("--fft-size", type=int, default=2048)
parser.add_argument("--fft-bin-start", type=int, default=4)
parser.add_argument("--fft-bin-end", type=int, default=214)
parser.add_argument("--fft-bin-step", type=int, default=1)
parser.add_argument("--partition-size", type=int, default=6)
parser.add_argument("--extract-seek", type=float, default=0)
parser.add_argument("--extract-duration", type=float, default=15)
parser.add_argument("-q", "--quiet", action="store_true")
args = parser.parse_args()
if not args.quiet:
logging.basicConfig(
format="%(asctime)s\t%(levelname)s\t%(message)s",
level=logging.INFO
)
if args.action == "create_model":
options = ShazoomOptions(
args.chunk_length,
args.fft_size,
args.fft_bin_start,
args.fft_bin_end,
args.fft_bin_step,
args.partition_size
)
model = ShazoomDatabase(options)
tracks = [
ShazoomTrack.from_mp3(
options,
path,
args.folder,
args.extract_seek,
args.extract_duration
)
for path in args.input
]
for track in tracks:
track.preprocess(clear_after=True)
if args.plot:
track.plot_print()
model.fit(tracks, [strip_filename(path) for path in args.input])
model.save_checkpoint(args.model)
elif args.action == "analyze_songs":
model = ShazoomDatabase.from_checkpoint(args.model)
results = []
for path in args.input:
track = ShazoomTrack.from_mp3(
model.options,
path,
args.folder,
args.extract_seek,
args.extract_duration
)
track.preprocess(clear_after=True)
if args.plot:
track.plot_print()
prediction = model.predict(track)
prediction["path"] = path
results.append(prediction)
print(simplejson.dumps({"results": results}, indent=4, sort_keys=True))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment