Created
September 24, 2021 21:33
-
-
Save dylanashley/1387a99deb85bfc0bce11286810cd98b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: ascii -*- | |
# MIT License | |
# | |
# Copyright (c) 2021 Dylan Robert Ashley and Vincent Herrmann | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
from spotipy.oauth2 import SpotifyClientCredentials | |
from typing import Any, Dict, List | |
import argparse | |
import argparse | |
import collections | |
import google_drive_downloader | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import os | |
import pickle | |
import scipy.sparse | |
import seaborn as sns | |
import spotipy | |
import sys | |
import tempfile | |
import torch | |
import torch.nn.functional as F | |
import zipfile | |
class AlbumsDataset(torch.utils.data.Dataset): | |
"""Wrapper for the album dataset for learning the predictive scalar | |
representation of songs. | |
""" | |
def __init__(self, path, pad_both_ends=False, mode="train", feature_mask=None): | |
super().__init__() | |
with open(path, "rb") as handle: | |
self.data = pickle.load(handle) | |
self.mode = mode | |
if self.mode == "valid": | |
self.data = self.data[-2000:-1000] | |
elif self.mode == "test": | |
self.data = self.data[-1000:] | |
else: | |
self.data = self.data[:-2000] | |
self.feature_mask = feature_mask | |
self.pad_both_ends = pad_both_ends | |
def __getitem__(self, item): | |
album_dict = self.data[item] | |
feature_vectors = [ | |
torch.tensor(t["feature_vector"]) for t in album_dict["track dicts"] | |
] | |
if self.pad_both_ends: | |
feature_vectors = ( | |
[torch.zeros_like(feature_vectors[0])] | |
+ feature_vectors | |
+ [torch.zeros_like(feature_vectors[0])] | |
) | |
feature_vectors = torch.stack(feature_vectors, dim=0) | |
if self.feature_mask is not None: | |
feature_vectors = feature_vectors * self.feature_mask.unsqueeze(0) | |
return feature_vectors | |
def __len__(self): | |
return len(self.data) | |
class BidirectionalLSTMEncoder(torch.nn.Module): | |
"""Model for learning the predictive scalar representation of | |
songs. | |
""" | |
def __init__(self, input_size, hidden_size, mlp_hidden_size, encoder_mode="learn"): | |
super().__init__() | |
if encoder_mode == "train": | |
self.input_selection_func = lambda x: x | |
self.encoder = torch.nn.Sequential( | |
torch.nn.Linear(input_size, 1), torch.nn.Sigmoid() | |
) | |
else: | |
self.encoder = torch.nn.Sequential( | |
torch.nn.Linear(1, 1), torch.nn.Sigmoid() | |
) | |
if encoder_mode == "mean": | |
self.input_selection_func = lambda x: x.mean(dim=-1, keepdim=True) | |
elif type(encoder_mode) is int: | |
self.input_selection_func = lambda x: x[..., encoder_mode].unsqueeze(-1) | |
else: | |
print(encoder_mode) | |
assert False | |
self.forward_lstm = torch.nn.LSTM( | |
input_size=1, hidden_size=hidden_size, num_layers=1, batch_first=True | |
) | |
self.backward_lstm = torch.nn.LSTM( | |
input_size=1, hidden_size=hidden_size, num_layers=1, batch_first=True | |
) | |
self.mlp = torch.nn.Sequential( | |
torch.nn.Linear(hidden_size, mlp_hidden_size), | |
torch.nn.ReLU(), | |
torch.nn.Linear(mlp_hidden_size, 1), | |
torch.nn.Sigmoid(), | |
) | |
def forward(self, x, lengths): | |
x = self.input_selection_func(x) | |
encodings = self.encoder(x) # batch x max_length x 1 | |
# add item to start and end | |
encodings = F.pad(encodings, pad=(0, 0, 1, 1)) | |
encodings_forward = encodings[:, :-2] | |
h_forward, _ = self.forward_lstm(encodings_forward) | |
encodings_backward = encodings.flip(dims=[1]) | |
encodings_backward_list = [ | |
encodings_backward[i, -l - 2 :] for i, l in enumerate(lengths) | |
] | |
encodings_backward = torch.nn.utils.rnn.pad_sequence( | |
encodings_backward_list, batch_first=True | |
) | |
encodings_backward = encodings_backward[:, :-2] | |
h_backward, _ = self.backward_lstm(encodings_backward) | |
h_backward = h_backward.flip(dims=[1]) | |
h_backward_list = [h_backward[i, -l:] for i, l in enumerate(lengths)] | |
h_backward = torch.nn.utils.rnn.pad_sequence(h_backward_list, batch_first=True) | |
bidirectional_features = h_forward + h_backward | |
predictions = self.mlp(bidirectional_features) | |
return predictions, encodings[:, 1:-1], lengths | |
def album_collate(data_list): | |
"""Helper function for learning the predictive scalar | |
representation of songs. | |
""" | |
lengths = torch.LongTensor([len(l) for l in data_list]) | |
padded_seq = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True) | |
return padded_seq, lengths | |
def contrastive_loss_function(predictions, targets): | |
"""Loss function for learning the predictive scalar representation | |
of songs. | |
""" | |
mse_target = (predictions - targets) ** 2 | |
mse_noise = (predictions.unsqueeze(0) - targets.detach().unsqueeze(1)) ** 2 | |
return mse_target - mse_noise.mean() | |
def fit_scores(scores: Dict[str, float]) -> List[str]: | |
"""Fits a set of scores to an ideal narrative curve.""" | |
scores = list(scores.items()) | |
np.random.shuffle(scores) # just in case | |
scores = collections.OrderedDict(scores) | |
rv = [-1 for _ in range(len(scores))] | |
distances = np.zeros((len(scores), len(scores)), dtype=float) | |
for i, y in enumerate(scores.values()): | |
for j in range(len(scores)): | |
x = scale(j, 0, len(scores), 0, 1) | |
distances[i, j] = abs(y - narrative_curve(x)) | |
# binary search to find smallest deviation matching | |
candidates = np.sort(distances.flatten()) | |
min_idx = 0 | |
max_idx = len(candidates) - 1 | |
while min_idx != max_idx: | |
pivot = min_idx + (max_idx - min_idx) // 2 | |
graph = scipy.sparse.csr_matrix(distances <= candidates[pivot]) | |
matching = scipy.sparse.csgraph.maximum_bipartite_matching(graph) | |
if -1 in matching: | |
min_idx = pivot + 1 | |
else: | |
max_idx = pivot | |
# clean up smallest deviation matching | |
filenames = list(scores.keys()) | |
graph = scipy.sparse.csgraph.csgraph_from_masked( | |
np.ma.masked_greater(distances, candidates[max_idx]) + 1 | |
) | |
matching = scipy.sparse.csgraph.min_weight_full_bipartite_matching(graph)[1] | |
return [filenames[int(np.where(matching == i)[0])] for i in range(len(matching))] | |
def get_track_features_of_album(album_title, album_artist, spotify): | |
"""Retrieves album information from spotify.""" | |
query = album_title.replace(" ", "+") + "+" + album_artist.replace(" ", "+") | |
search_result = spotify.search(q=query, type="album") | |
# get the album ID | |
album_id = None | |
for album_description in search_result["albums"]["items"]: | |
if album_title.lower() == album_description["name"].lower(): | |
album_id = album_description["id"] | |
break | |
if album_id is None: | |
print("exact album title", album_title, "not found") | |
album_description = search_result["albums"]["items"][0] | |
album_id = album_description["id"] | |
print("using", album_description["name"], "instead") | |
# get audio features for each track of the album | |
track_ids = [] | |
track_names = [] | |
for track in spotify.album_tracks(album_id)["items"]: | |
track_names.append(track["name"]) | |
track_ids.append(track["id"]) | |
track_features = spotify.audio_features(track_ids) | |
return track_names, track_features | |
def main(args): | |
# setup spotify | |
client_credentials_manager = SpotifyClientCredentials( | |
client_id=args["cid"], client_secret=args["secret"] | |
) | |
spotify = spotipy.Spotify(client_credentials_manager=client_credentials_manager) | |
# generate the three fitting plots | |
plot_playlist_tempo(spotify) | |
plot_playlist_valence(spotify) | |
zeta = { | |
"Wanna Be Startin' Somethin'": 0.08055852353572845, | |
"Baby Be Mine": 0.14294381439685822, | |
"The Girl Is Mine (with Paul McCartney)": 0.31416481733322144, | |
"Thriller": 0.10454542934894562, | |
"Beat It": 0.07758291065692902, | |
"Billie Jean": 0.18676850199699402, | |
"Human Nature": 0.7985778450965881, | |
"P.Y.T. (Pretty Young Thing)": 0.14572905004024506, | |
"The Lady in My Life": 0.43962377309799194, | |
} # previously obtained zeta values | |
plot_playlist_zeta(spotify, zeta) | |
# load dataset | |
dataset_path = os.path.join( | |
os.getcwd(), "albums_with_track_features_normalized.pickle" | |
) | |
if not os.path.exists(dataset_path): | |
google_drive_downloader.GoogleDriveDownloader.download_file_from_google_drive( | |
file_id="1EHMNv1DlKL4YIeVokkkhlZvM-WYzEiFZ", | |
dest_path=dataset_path, | |
unzip=True, | |
) | |
# train model and plot weights | |
model = train_model(dataset_path, encoder_mode="train", num_epochs=50) | |
plot_feature_weights(model) | |
# get loss of single features | |
song_features = [ | |
"Danceability", | |
"Energy", | |
"Key", | |
"Loudness", | |
"Mode", | |
"Speechiness", | |
"Acousticness", | |
"Instrumentalness", | |
"Liveness", | |
"Valence", | |
"Tempo", | |
"Duration", | |
"Time Signature", | |
] | |
for only_feature in [1, 9, 10, 11]: | |
print("train model with only feature", song_features[only_feature]) | |
model = train_model(dataset_path, encoder_mode=only_feature, num_epochs=50) | |
def narrative_curve(x: float, order: int = 2) -> float: | |
"""Returns the ideal narrative score for a time value in the range | |
[0, 1]. | |
""" | |
assert 0 <= x <= 1 | |
assert order in [1, 2] | |
if order == 1: | |
if x <= 0.2: | |
return 1 / 2 + 5 / 4 * x | |
elif x <= 0.5: | |
return 5 / 4 - 5 / 2 * x | |
elif x <= 0.8: | |
return -5 / 3 + 10 / 3 * x | |
else: | |
return 2 - 5 / 4 * x | |
else: | |
if x <= 0.2: | |
return 1 / 2 + 5 / 2 * x - 25 / 4 * x ** 2 | |
elif x <= 0.3: | |
return -1 / 4 + 10 * x - 25 * x ** 2 | |
elif x <= 0.5: | |
return 25 / 8 - 25 / 2 * x + 25 / 2 * x ** 2 | |
elif x <= 0.65: | |
return 50 / 9 - 200 / 9 * x + 200 / 9 * x ** 2 | |
elif x <= 0.8: | |
return -119 / 9 + 320 / 9 * x - 200 / 9 * x ** 2 | |
else: | |
return -3 + 10 * x - 25 / 4 * x ** 2 | |
def parse_args(): | |
"""Reads command line arguments.""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"cid", | |
help="client id for spotipy client credentials manager", | |
) | |
parser.add_argument( | |
"secret", | |
help="client secret for spotipy client credentials manager", | |
) | |
return vars(parser.parse_args()) | |
def plot_feature_weights(model): | |
"""Sketches Figure 3.""" | |
weighting = model.encoder[0].weight.squeeze().detach() | |
sns.set_style("ticks") | |
sns.set_context("paper", font_scale=0.8) | |
colors = sns.color_palette("colorblind", 4) | |
fig, ax = plt.subplots(1, 1, figsize=(3, 2), dpi=300) | |
ax.bar(torch.arange(13), weighting.detach().numpy()) | |
ax.set_yticks([-3, 0, 3]) | |
ax.set_xticks(torch.arange(13)) | |
ax.set_xticklabels( | |
[ | |
"Danceability", | |
"Energy", | |
"Key", | |
"Loudness", | |
"Mode", | |
"Speechiness", | |
"Acousticness", | |
"Instrumentalness", | |
"Liveness", | |
"Valence", | |
"Tempo", | |
"Duration", | |
"Time Signature", | |
] | |
) | |
ax.tick_params(axis="x", labelrotation=90) | |
ax.set_ylabel("Learned Weights") | |
fig.subplots_adjust(hspace=0) | |
fig.savefig("weights.pdf", bbox_inches="tight") | |
def plot_playlist_tempo(spotify): | |
"""Sketches Figure 1.""" | |
names, features = get_track_features_of_album( | |
"Thriller", "Michael Jackson", spotify | |
) | |
scores = dict() | |
for k, v in zip(names, features): | |
scores[k] = v["tempo"] | |
min_score = min(scores.values()) | |
max_score = max(scores.values()) | |
for k, v in scores.items(): | |
scores[k] = scale(v, min_score, max_score, 0, 1) | |
playlist = fit_scores(scores) | |
sns.set_style("ticks") | |
sns.set_context("paper", font_scale=0.8) | |
colors = sns.color_palette("colorblind", 5) | |
fig, axarr = plt.subplots(2, 1, figsize=(1.5, 3), dpi=300) | |
x = np.linspace(0, 1, num=1000) | |
axarr[0].plot( | |
[scale(scores[song], 0, 1, min_score, max_score) for song in names], | |
np.arange(len(playlist)) + 1, | |
"o-", | |
color=colors[0], | |
label="Original Album", | |
) | |
axarr[1].plot( | |
[scale(narrative_curve(i), 0, 1, min_score, max_score) for i in x], | |
[scale(i, 0, 1, 1, len(playlist)) for i in x], | |
color=colors[1], | |
label="Narrative Template Curve", | |
) | |
axarr[1].plot( | |
[scale(scores[song], 0, 1, min_score, max_score) for song in playlist], | |
np.arange(len(playlist)) + 1, | |
"o-", | |
color=colors[2], | |
label="Fitted Playlist", | |
) | |
axarr[0].set_xticks([]) | |
axarr[0].set_ylim(0, len(playlist) + 1) | |
axarr[0].set_yticks(np.arange(len(playlist)) + 1) | |
axarr[0].set_yticklabels(names) | |
axarr[0].yaxis.set_label_position("right") | |
axarr[0].yaxis.tick_right() | |
axarr[0].invert_yaxis() | |
axarr[1].set_xticks([90, 115, 140]) | |
axarr[1].set_xlabel("Tempo (bpm)", labelpad=5) | |
axarr[1].set_ylim(0, len(playlist) + 1) | |
axarr[1].set_yticks(np.arange(len(playlist)) + 1) | |
axarr[1].set_yticklabels(playlist) | |
axarr[1].yaxis.set_label_position("right") | |
axarr[1].yaxis.tick_right() | |
axarr[1].invert_yaxis() | |
handles = sum([ax.get_legend_handles_labels()[0] for ax in axarr], []) | |
labels = sum([ax.get_legend_handles_labels()[1] for ax in axarr], []) | |
fig.legend(handles, labels, loc=(0.05, 0.86), frameon=False) | |
fig.subplots_adjust(hspace=0) | |
fig.savefig("playlist_tempo.pdf", bbox_inches="tight") | |
def plot_playlist_valence(spotify): | |
"""Sketches Figure 5.""" | |
names, features = get_track_features_of_album( | |
"Thriller", "Michael Jackson", spotify | |
) | |
scores = dict() | |
for k, v in zip(names, features): | |
scores[k] = v["valence"] | |
min_score = min(scores.values()) | |
max_score = max(scores.values()) | |
for k, v in scores.items(): | |
scores[k] = scale(v, min_score, max_score, 0, 1) | |
playlist = fit_scores(scores) | |
sns.set_style("ticks") | |
sns.set_context("paper", font_scale=0.8) | |
colors = sns.color_palette("colorblind", 5) | |
fig, axarr = plt.subplots(2, 1, figsize=(1.5, 3), dpi=300) | |
x = np.linspace(0, 1, num=1000) | |
axarr[0].plot( | |
[scale(scores[song], 0, 1, min_score, max_score) for song in names], | |
np.arange(len(playlist)) + 1, | |
"o-", | |
color=colors[0], | |
label="Original Album", | |
) | |
axarr[1].plot( | |
[scale(narrative_curve(i), 0, 1, min_score, max_score) for i in x], | |
[scale(i, 0, 1, 1, len(playlist)) for i in x], | |
color=colors[1], | |
label="Narrative Template Curve", | |
) | |
axarr[1].plot( | |
[scale(scores[song], 0, 1, min_score, max_score) for song in playlist], | |
np.arange(len(playlist)) + 1, | |
"o-", | |
color=colors[3], | |
label="Fitted Playlist", | |
) | |
axarr[0].set_xticks([]) | |
axarr[0].set_ylim(0, len(playlist) + 1) | |
axarr[0].set_yticks(np.arange(len(playlist)) + 1) | |
axarr[0].set_yticklabels(names) | |
axarr[0].yaxis.set_label_position("right") | |
axarr[0].yaxis.tick_right() | |
axarr[0].invert_yaxis() | |
axarr[1].set_xticks([0.5, 0.7, 0.9]) | |
axarr[1].set_xlabel("Valence", labelpad=5) | |
axarr[1].set_ylim(0, len(playlist) + 1) | |
axarr[1].set_yticks(np.arange(len(playlist)) + 1) | |
axarr[1].set_yticklabels(playlist) | |
axarr[1].yaxis.set_label_position("right") | |
axarr[1].yaxis.tick_right() | |
axarr[1].invert_yaxis() | |
handles = sum([ax.get_legend_handles_labels()[0] for ax in axarr], []) | |
labels = sum([ax.get_legend_handles_labels()[1] for ax in axarr], []) | |
fig.legend(handles, labels, loc=(0.05, 0.86), frameon=False) | |
fig.subplots_adjust(hspace=0) | |
fig.savefig("playlist_valence.pdf", bbox_inches="tight") | |
def plot_playlist_zeta(spotify, zeta): | |
"""Sketches Figure 4.""" | |
names, features = get_track_features_of_album( | |
"Thriller", "Michael Jackson", spotify | |
) | |
scores = dict() | |
for k, v in zip(names, features): | |
scores[k] = zeta[k] | |
min_score = min(scores.values()) | |
max_score = max(scores.values()) | |
for k, v in scores.items(): | |
scores[k] = scale(v, min_score, max_score, 0, 1) | |
playlist = fit_scores(scores) | |
sns.set_style("ticks") | |
sns.set_context("paper", font_scale=0.8) | |
colors = sns.color_palette("colorblind", 5) | |
fig, axarr = plt.subplots(2, 1, figsize=(1.5, 3), dpi=300) | |
x = np.linspace(0, 1, num=1000) | |
axarr[0].plot( | |
[scale(scores[song], 0, 1, min_score, max_score) for song in names], | |
np.arange(len(playlist)) + 1, | |
"o-", | |
color=colors[0], | |
label="Original Album", | |
) | |
axarr[1].plot( | |
[scale(narrative_curve(i), 0, 1, min_score, max_score) for i in x], | |
[scale(i, 0, 1, 1, len(playlist)) for i in x], | |
color=colors[1], | |
label="Narrative Template Curve", | |
) | |
axarr[1].plot( | |
[scale(scores[song], 0, 1, min_score, max_score) for song in playlist], | |
np.arange(len(playlist)) + 1, | |
"o-", | |
color=colors[4], | |
label="Fitted Playlist", | |
) | |
axarr[0].set_xticks([]) | |
axarr[0].set_ylim(0, len(playlist) + 1) | |
axarr[0].set_yticks(np.arange(len(playlist)) + 1) | |
axarr[0].set_yticklabels(names) | |
axarr[0].yaxis.set_label_position("right") | |
axarr[0].yaxis.tick_right() | |
axarr[0].invert_yaxis() | |
axarr[1].set_xticks([0.15, 0.4, 0.65]) | |
axarr[1].set_xlabel("\u03B6", labelpad=5) | |
axarr[1].set_ylim(0, len(playlist) + 1) | |
axarr[1].set_yticks(np.arange(len(playlist)) + 1) | |
axarr[1].set_yticklabels(playlist) | |
axarr[1].yaxis.set_label_position("right") | |
axarr[1].yaxis.tick_right() | |
axarr[1].invert_yaxis() | |
handles = sum([ax.get_legend_handles_labels()[0] for ax in axarr], []) | |
labels = sum([ax.get_legend_handles_labels()[1] for ax in axarr], []) | |
fig.legend(handles, labels, loc=(0.05, 0.86), frameon=False) | |
fig.subplots_adjust(hspace=0) | |
fig.savefig("playlist_zeta.pdf", bbox_inches="tight") | |
def scale( | |
value: float, start_min: float, start_max: float, end_min: float, end_max: float | |
) -> float: | |
"""Returns the result of scaling value from the range | |
[start_min, start_max] to [end_min, end_max]. | |
""" | |
return end_min + (end_max - end_min) * (value - start_min) / (start_max - start_min) | |
def train_model(dataset_path, encoder_mode="train", num_epochs=50): | |
"""Trains the model for learning the predictive scalar | |
representation of songs. | |
""" | |
model = BidirectionalLSTMEncoder( | |
input_size=13, hidden_size=16, mlp_hidden_size=16, encoder_mode=encoder_mode | |
) | |
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | |
train_dataset = AlbumsDataset( | |
dataset_path, mode="train", pad_both_ends=False, feature_mask=None | |
) | |
eval_dataset = AlbumsDataset( | |
dataset_path, mode="valid", pad_both_ends=False, feature_mask=None | |
) | |
def evaluate(): | |
model.eval() | |
dataloader = torch.utils.data.DataLoader( | |
dataset=eval_dataset, batch_size=32, collate_fn=album_collate | |
) | |
all_losses = [] | |
for batch in iter(dataloader): | |
predictions, encodings, sequence_lengths = model(batch[0], batch[1]) | |
# print("encodings shape", encodings.shape) | |
encodings_list = [encodings[i, :l] for i, l in enumerate(sequence_lengths)] | |
predictions_list = [ | |
predictions[i, :l] for i, l in enumerate(sequence_lengths) | |
] | |
encodings_flat = torch.cat(encodings_list, dim=0).view(-1) | |
predictions_flat = torch.cat(predictions_list, dim=0).view(-1) | |
loss = contrastive_loss_function(predictions_flat, encodings_flat) | |
all_losses.append(loss) | |
all_losses = torch.cat(all_losses, dim=0) | |
model.train() | |
return all_losses.mean().item() | |
validation_loss = evaluate() | |
print("validation loss:", validation_loss) | |
for epoch in range(num_epochs): | |
print("epoch:", epoch) | |
dataloader = torch.utils.data.DataLoader( | |
dataset=train_dataset, batch_size=32, shuffle=True, collate_fn=album_collate | |
) | |
for i, batch in enumerate(iter(dataloader)): | |
predictions, encodings, sequence_lengths = model(batch[0], batch[1]) | |
encodings_list = [encodings[i, :l] for i, l in enumerate(sequence_lengths)] | |
predictions_list = [ | |
predictions[i, :l] for i, l in enumerate(sequence_lengths) | |
] | |
encodings_flat = torch.cat(encodings_list, dim=0).view(-1) | |
predictions_flat = torch.cat(predictions_list, dim=0).view(-1) | |
loss = contrastive_loss_function(predictions_flat, encodings_flat).mean() | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
validation_loss = evaluate() | |
print("validation loss:", validation_loss) | |
return model | |
if __name__ == "__main__": | |
main(parse_args()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment