Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save dylanashley/1387a99deb85bfc0bce11286810cd98b to your computer and use it in GitHub Desktop.
Save dylanashley/1387a99deb85bfc0bce11286810cd98b to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# -*- coding: ascii -*-
# MIT License
#
# Copyright (c) 2021 Dylan Robert Ashley and Vincent Herrmann
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from spotipy.oauth2 import SpotifyClientCredentials
from typing import Any, Dict, List
import argparse
import argparse
import collections
import google_drive_downloader
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle
import scipy.sparse
import seaborn as sns
import spotipy
import sys
import tempfile
import torch
import torch.nn.functional as F
import zipfile
class AlbumsDataset(torch.utils.data.Dataset):
"""Wrapper for the album dataset for learning the predictive scalar
representation of songs.
"""
def __init__(self, path, pad_both_ends=False, mode="train", feature_mask=None):
super().__init__()
with open(path, "rb") as handle:
self.data = pickle.load(handle)
self.mode = mode
if self.mode == "valid":
self.data = self.data[-2000:-1000]
elif self.mode == "test":
self.data = self.data[-1000:]
else:
self.data = self.data[:-2000]
self.feature_mask = feature_mask
self.pad_both_ends = pad_both_ends
def __getitem__(self, item):
album_dict = self.data[item]
feature_vectors = [
torch.tensor(t["feature_vector"]) for t in album_dict["track dicts"]
]
if self.pad_both_ends:
feature_vectors = (
[torch.zeros_like(feature_vectors[0])]
+ feature_vectors
+ [torch.zeros_like(feature_vectors[0])]
)
feature_vectors = torch.stack(feature_vectors, dim=0)
if self.feature_mask is not None:
feature_vectors = feature_vectors * self.feature_mask.unsqueeze(0)
return feature_vectors
def __len__(self):
return len(self.data)
class BidirectionalLSTMEncoder(torch.nn.Module):
"""Model for learning the predictive scalar representation of
songs.
"""
def __init__(self, input_size, hidden_size, mlp_hidden_size, encoder_mode="learn"):
super().__init__()
if encoder_mode == "train":
self.input_selection_func = lambda x: x
self.encoder = torch.nn.Sequential(
torch.nn.Linear(input_size, 1), torch.nn.Sigmoid()
)
else:
self.encoder = torch.nn.Sequential(
torch.nn.Linear(1, 1), torch.nn.Sigmoid()
)
if encoder_mode == "mean":
self.input_selection_func = lambda x: x.mean(dim=-1, keepdim=True)
elif type(encoder_mode) is int:
self.input_selection_func = lambda x: x[..., encoder_mode].unsqueeze(-1)
else:
print(encoder_mode)
assert False
self.forward_lstm = torch.nn.LSTM(
input_size=1, hidden_size=hidden_size, num_layers=1, batch_first=True
)
self.backward_lstm = torch.nn.LSTM(
input_size=1, hidden_size=hidden_size, num_layers=1, batch_first=True
)
self.mlp = torch.nn.Sequential(
torch.nn.Linear(hidden_size, mlp_hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(mlp_hidden_size, 1),
torch.nn.Sigmoid(),
)
def forward(self, x, lengths):
x = self.input_selection_func(x)
encodings = self.encoder(x) # batch x max_length x 1
# add item to start and end
encodings = F.pad(encodings, pad=(0, 0, 1, 1))
encodings_forward = encodings[:, :-2]
h_forward, _ = self.forward_lstm(encodings_forward)
encodings_backward = encodings.flip(dims=[1])
encodings_backward_list = [
encodings_backward[i, -l - 2 :] for i, l in enumerate(lengths)
]
encodings_backward = torch.nn.utils.rnn.pad_sequence(
encodings_backward_list, batch_first=True
)
encodings_backward = encodings_backward[:, :-2]
h_backward, _ = self.backward_lstm(encodings_backward)
h_backward = h_backward.flip(dims=[1])
h_backward_list = [h_backward[i, -l:] for i, l in enumerate(lengths)]
h_backward = torch.nn.utils.rnn.pad_sequence(h_backward_list, batch_first=True)
bidirectional_features = h_forward + h_backward
predictions = self.mlp(bidirectional_features)
return predictions, encodings[:, 1:-1], lengths
def album_collate(data_list):
"""Helper function for learning the predictive scalar
representation of songs.
"""
lengths = torch.LongTensor([len(l) for l in data_list])
padded_seq = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True)
return padded_seq, lengths
def contrastive_loss_function(predictions, targets):
"""Loss function for learning the predictive scalar representation
of songs.
"""
mse_target = (predictions - targets) ** 2
mse_noise = (predictions.unsqueeze(0) - targets.detach().unsqueeze(1)) ** 2
return mse_target - mse_noise.mean()
def fit_scores(scores: Dict[str, float]) -> List[str]:
"""Fits a set of scores to an ideal narrative curve."""
scores = list(scores.items())
np.random.shuffle(scores) # just in case
scores = collections.OrderedDict(scores)
rv = [-1 for _ in range(len(scores))]
distances = np.zeros((len(scores), len(scores)), dtype=float)
for i, y in enumerate(scores.values()):
for j in range(len(scores)):
x = scale(j, 0, len(scores), 0, 1)
distances[i, j] = abs(y - narrative_curve(x))
# binary search to find smallest deviation matching
candidates = np.sort(distances.flatten())
min_idx = 0
max_idx = len(candidates) - 1
while min_idx != max_idx:
pivot = min_idx + (max_idx - min_idx) // 2
graph = scipy.sparse.csr_matrix(distances <= candidates[pivot])
matching = scipy.sparse.csgraph.maximum_bipartite_matching(graph)
if -1 in matching:
min_idx = pivot + 1
else:
max_idx = pivot
# clean up smallest deviation matching
filenames = list(scores.keys())
graph = scipy.sparse.csgraph.csgraph_from_masked(
np.ma.masked_greater(distances, candidates[max_idx]) + 1
)
matching = scipy.sparse.csgraph.min_weight_full_bipartite_matching(graph)[1]
return [filenames[int(np.where(matching == i)[0])] for i in range(len(matching))]
def get_track_features_of_album(album_title, album_artist, spotify):
"""Retrieves album information from spotify."""
query = album_title.replace(" ", "+") + "+" + album_artist.replace(" ", "+")
search_result = spotify.search(q=query, type="album")
# get the album ID
album_id = None
for album_description in search_result["albums"]["items"]:
if album_title.lower() == album_description["name"].lower():
album_id = album_description["id"]
break
if album_id is None:
print("exact album title", album_title, "not found")
album_description = search_result["albums"]["items"][0]
album_id = album_description["id"]
print("using", album_description["name"], "instead")
# get audio features for each track of the album
track_ids = []
track_names = []
for track in spotify.album_tracks(album_id)["items"]:
track_names.append(track["name"])
track_ids.append(track["id"])
track_features = spotify.audio_features(track_ids)
return track_names, track_features
def main(args):
# setup spotify
client_credentials_manager = SpotifyClientCredentials(
client_id=args["cid"], client_secret=args["secret"]
)
spotify = spotipy.Spotify(client_credentials_manager=client_credentials_manager)
# generate the three fitting plots
plot_playlist_tempo(spotify)
plot_playlist_valence(spotify)
zeta = {
"Wanna Be Startin' Somethin'": 0.08055852353572845,
"Baby Be Mine": 0.14294381439685822,
"The Girl Is Mine (with Paul McCartney)": 0.31416481733322144,
"Thriller": 0.10454542934894562,
"Beat It": 0.07758291065692902,
"Billie Jean": 0.18676850199699402,
"Human Nature": 0.7985778450965881,
"P.Y.T. (Pretty Young Thing)": 0.14572905004024506,
"The Lady in My Life": 0.43962377309799194,
} # previously obtained zeta values
plot_playlist_zeta(spotify, zeta)
# load dataset
dataset_path = os.path.join(
os.getcwd(), "albums_with_track_features_normalized.pickle"
)
if not os.path.exists(dataset_path):
google_drive_downloader.GoogleDriveDownloader.download_file_from_google_drive(
file_id="1EHMNv1DlKL4YIeVokkkhlZvM-WYzEiFZ",
dest_path=dataset_path,
unzip=True,
)
# train model and plot weights
model = train_model(dataset_path, encoder_mode="train", num_epochs=50)
plot_feature_weights(model)
# get loss of single features
song_features = [
"Danceability",
"Energy",
"Key",
"Loudness",
"Mode",
"Speechiness",
"Acousticness",
"Instrumentalness",
"Liveness",
"Valence",
"Tempo",
"Duration",
"Time Signature",
]
for only_feature in [1, 9, 10, 11]:
print("train model with only feature", song_features[only_feature])
model = train_model(dataset_path, encoder_mode=only_feature, num_epochs=50)
def narrative_curve(x: float, order: int = 2) -> float:
"""Returns the ideal narrative score for a time value in the range
[0, 1].
"""
assert 0 <= x <= 1
assert order in [1, 2]
if order == 1:
if x <= 0.2:
return 1 / 2 + 5 / 4 * x
elif x <= 0.5:
return 5 / 4 - 5 / 2 * x
elif x <= 0.8:
return -5 / 3 + 10 / 3 * x
else:
return 2 - 5 / 4 * x
else:
if x <= 0.2:
return 1 / 2 + 5 / 2 * x - 25 / 4 * x ** 2
elif x <= 0.3:
return -1 / 4 + 10 * x - 25 * x ** 2
elif x <= 0.5:
return 25 / 8 - 25 / 2 * x + 25 / 2 * x ** 2
elif x <= 0.65:
return 50 / 9 - 200 / 9 * x + 200 / 9 * x ** 2
elif x <= 0.8:
return -119 / 9 + 320 / 9 * x - 200 / 9 * x ** 2
else:
return -3 + 10 * x - 25 / 4 * x ** 2
def parse_args():
"""Reads command line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument(
"cid",
help="client id for spotipy client credentials manager",
)
parser.add_argument(
"secret",
help="client secret for spotipy client credentials manager",
)
return vars(parser.parse_args())
def plot_feature_weights(model):
"""Sketches Figure 3."""
weighting = model.encoder[0].weight.squeeze().detach()
sns.set_style("ticks")
sns.set_context("paper", font_scale=0.8)
colors = sns.color_palette("colorblind", 4)
fig, ax = plt.subplots(1, 1, figsize=(3, 2), dpi=300)
ax.bar(torch.arange(13), weighting.detach().numpy())
ax.set_yticks([-3, 0, 3])
ax.set_xticks(torch.arange(13))
ax.set_xticklabels(
[
"Danceability",
"Energy",
"Key",
"Loudness",
"Mode",
"Speechiness",
"Acousticness",
"Instrumentalness",
"Liveness",
"Valence",
"Tempo",
"Duration",
"Time Signature",
]
)
ax.tick_params(axis="x", labelrotation=90)
ax.set_ylabel("Learned Weights")
fig.subplots_adjust(hspace=0)
fig.savefig("weights.pdf", bbox_inches="tight")
def plot_playlist_tempo(spotify):
"""Sketches Figure 1."""
names, features = get_track_features_of_album(
"Thriller", "Michael Jackson", spotify
)
scores = dict()
for k, v in zip(names, features):
scores[k] = v["tempo"]
min_score = min(scores.values())
max_score = max(scores.values())
for k, v in scores.items():
scores[k] = scale(v, min_score, max_score, 0, 1)
playlist = fit_scores(scores)
sns.set_style("ticks")
sns.set_context("paper", font_scale=0.8)
colors = sns.color_palette("colorblind", 5)
fig, axarr = plt.subplots(2, 1, figsize=(1.5, 3), dpi=300)
x = np.linspace(0, 1, num=1000)
axarr[0].plot(
[scale(scores[song], 0, 1, min_score, max_score) for song in names],
np.arange(len(playlist)) + 1,
"o-",
color=colors[0],
label="Original Album",
)
axarr[1].plot(
[scale(narrative_curve(i), 0, 1, min_score, max_score) for i in x],
[scale(i, 0, 1, 1, len(playlist)) for i in x],
color=colors[1],
label="Narrative Template Curve",
)
axarr[1].plot(
[scale(scores[song], 0, 1, min_score, max_score) for song in playlist],
np.arange(len(playlist)) + 1,
"o-",
color=colors[2],
label="Fitted Playlist",
)
axarr[0].set_xticks([])
axarr[0].set_ylim(0, len(playlist) + 1)
axarr[0].set_yticks(np.arange(len(playlist)) + 1)
axarr[0].set_yticklabels(names)
axarr[0].yaxis.set_label_position("right")
axarr[0].yaxis.tick_right()
axarr[0].invert_yaxis()
axarr[1].set_xticks([90, 115, 140])
axarr[1].set_xlabel("Tempo (bpm)", labelpad=5)
axarr[1].set_ylim(0, len(playlist) + 1)
axarr[1].set_yticks(np.arange(len(playlist)) + 1)
axarr[1].set_yticklabels(playlist)
axarr[1].yaxis.set_label_position("right")
axarr[1].yaxis.tick_right()
axarr[1].invert_yaxis()
handles = sum([ax.get_legend_handles_labels()[0] for ax in axarr], [])
labels = sum([ax.get_legend_handles_labels()[1] for ax in axarr], [])
fig.legend(handles, labels, loc=(0.05, 0.86), frameon=False)
fig.subplots_adjust(hspace=0)
fig.savefig("playlist_tempo.pdf", bbox_inches="tight")
def plot_playlist_valence(spotify):
"""Sketches Figure 5."""
names, features = get_track_features_of_album(
"Thriller", "Michael Jackson", spotify
)
scores = dict()
for k, v in zip(names, features):
scores[k] = v["valence"]
min_score = min(scores.values())
max_score = max(scores.values())
for k, v in scores.items():
scores[k] = scale(v, min_score, max_score, 0, 1)
playlist = fit_scores(scores)
sns.set_style("ticks")
sns.set_context("paper", font_scale=0.8)
colors = sns.color_palette("colorblind", 5)
fig, axarr = plt.subplots(2, 1, figsize=(1.5, 3), dpi=300)
x = np.linspace(0, 1, num=1000)
axarr[0].plot(
[scale(scores[song], 0, 1, min_score, max_score) for song in names],
np.arange(len(playlist)) + 1,
"o-",
color=colors[0],
label="Original Album",
)
axarr[1].plot(
[scale(narrative_curve(i), 0, 1, min_score, max_score) for i in x],
[scale(i, 0, 1, 1, len(playlist)) for i in x],
color=colors[1],
label="Narrative Template Curve",
)
axarr[1].plot(
[scale(scores[song], 0, 1, min_score, max_score) for song in playlist],
np.arange(len(playlist)) + 1,
"o-",
color=colors[3],
label="Fitted Playlist",
)
axarr[0].set_xticks([])
axarr[0].set_ylim(0, len(playlist) + 1)
axarr[0].set_yticks(np.arange(len(playlist)) + 1)
axarr[0].set_yticklabels(names)
axarr[0].yaxis.set_label_position("right")
axarr[0].yaxis.tick_right()
axarr[0].invert_yaxis()
axarr[1].set_xticks([0.5, 0.7, 0.9])
axarr[1].set_xlabel("Valence", labelpad=5)
axarr[1].set_ylim(0, len(playlist) + 1)
axarr[1].set_yticks(np.arange(len(playlist)) + 1)
axarr[1].set_yticklabels(playlist)
axarr[1].yaxis.set_label_position("right")
axarr[1].yaxis.tick_right()
axarr[1].invert_yaxis()
handles = sum([ax.get_legend_handles_labels()[0] for ax in axarr], [])
labels = sum([ax.get_legend_handles_labels()[1] for ax in axarr], [])
fig.legend(handles, labels, loc=(0.05, 0.86), frameon=False)
fig.subplots_adjust(hspace=0)
fig.savefig("playlist_valence.pdf", bbox_inches="tight")
def plot_playlist_zeta(spotify, zeta):
"""Sketches Figure 4."""
names, features = get_track_features_of_album(
"Thriller", "Michael Jackson", spotify
)
scores = dict()
for k, v in zip(names, features):
scores[k] = zeta[k]
min_score = min(scores.values())
max_score = max(scores.values())
for k, v in scores.items():
scores[k] = scale(v, min_score, max_score, 0, 1)
playlist = fit_scores(scores)
sns.set_style("ticks")
sns.set_context("paper", font_scale=0.8)
colors = sns.color_palette("colorblind", 5)
fig, axarr = plt.subplots(2, 1, figsize=(1.5, 3), dpi=300)
x = np.linspace(0, 1, num=1000)
axarr[0].plot(
[scale(scores[song], 0, 1, min_score, max_score) for song in names],
np.arange(len(playlist)) + 1,
"o-",
color=colors[0],
label="Original Album",
)
axarr[1].plot(
[scale(narrative_curve(i), 0, 1, min_score, max_score) for i in x],
[scale(i, 0, 1, 1, len(playlist)) for i in x],
color=colors[1],
label="Narrative Template Curve",
)
axarr[1].plot(
[scale(scores[song], 0, 1, min_score, max_score) for song in playlist],
np.arange(len(playlist)) + 1,
"o-",
color=colors[4],
label="Fitted Playlist",
)
axarr[0].set_xticks([])
axarr[0].set_ylim(0, len(playlist) + 1)
axarr[0].set_yticks(np.arange(len(playlist)) + 1)
axarr[0].set_yticklabels(names)
axarr[0].yaxis.set_label_position("right")
axarr[0].yaxis.tick_right()
axarr[0].invert_yaxis()
axarr[1].set_xticks([0.15, 0.4, 0.65])
axarr[1].set_xlabel("\u03B6", labelpad=5)
axarr[1].set_ylim(0, len(playlist) + 1)
axarr[1].set_yticks(np.arange(len(playlist)) + 1)
axarr[1].set_yticklabels(playlist)
axarr[1].yaxis.set_label_position("right")
axarr[1].yaxis.tick_right()
axarr[1].invert_yaxis()
handles = sum([ax.get_legend_handles_labels()[0] for ax in axarr], [])
labels = sum([ax.get_legend_handles_labels()[1] for ax in axarr], [])
fig.legend(handles, labels, loc=(0.05, 0.86), frameon=False)
fig.subplots_adjust(hspace=0)
fig.savefig("playlist_zeta.pdf", bbox_inches="tight")
def scale(
value: float, start_min: float, start_max: float, end_min: float, end_max: float
) -> float:
"""Returns the result of scaling value from the range
[start_min, start_max] to [end_min, end_max].
"""
return end_min + (end_max - end_min) * (value - start_min) / (start_max - start_min)
def train_model(dataset_path, encoder_mode="train", num_epochs=50):
"""Trains the model for learning the predictive scalar
representation of songs.
"""
model = BidirectionalLSTMEncoder(
input_size=13, hidden_size=16, mlp_hidden_size=16, encoder_mode=encoder_mode
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_dataset = AlbumsDataset(
dataset_path, mode="train", pad_both_ends=False, feature_mask=None
)
eval_dataset = AlbumsDataset(
dataset_path, mode="valid", pad_both_ends=False, feature_mask=None
)
def evaluate():
model.eval()
dataloader = torch.utils.data.DataLoader(
dataset=eval_dataset, batch_size=32, collate_fn=album_collate
)
all_losses = []
for batch in iter(dataloader):
predictions, encodings, sequence_lengths = model(batch[0], batch[1])
# print("encodings shape", encodings.shape)
encodings_list = [encodings[i, :l] for i, l in enumerate(sequence_lengths)]
predictions_list = [
predictions[i, :l] for i, l in enumerate(sequence_lengths)
]
encodings_flat = torch.cat(encodings_list, dim=0).view(-1)
predictions_flat = torch.cat(predictions_list, dim=0).view(-1)
loss = contrastive_loss_function(predictions_flat, encodings_flat)
all_losses.append(loss)
all_losses = torch.cat(all_losses, dim=0)
model.train()
return all_losses.mean().item()
validation_loss = evaluate()
print("validation loss:", validation_loss)
for epoch in range(num_epochs):
print("epoch:", epoch)
dataloader = torch.utils.data.DataLoader(
dataset=train_dataset, batch_size=32, shuffle=True, collate_fn=album_collate
)
for i, batch in enumerate(iter(dataloader)):
predictions, encodings, sequence_lengths = model(batch[0], batch[1])
encodings_list = [encodings[i, :l] for i, l in enumerate(sequence_lengths)]
predictions_list = [
predictions[i, :l] for i, l in enumerate(sequence_lengths)
]
encodings_flat = torch.cat(encodings_list, dim=0).view(-1)
predictions_flat = torch.cat(predictions_list, dim=0).view(-1)
loss = contrastive_loss_function(predictions_flat, encodings_flat).mean()
loss.backward()
optimizer.step()
optimizer.zero_grad()
validation_loss = evaluate()
print("validation loss:", validation_loss)
return model
if __name__ == "__main__":
main(parse_args())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment