Skip to content

Instantly share code, notes, and snippets.

@itsuncheng
Created January 11, 2022 19:10
Show Gist options
  • Save itsuncheng/48e0fba62b38c14ca81d2dd73f16083e to your computer and use it in GitHub Desktop.
Save itsuncheng/48e0fba62b38c14ca81d2dd73f16083e to your computer and use it in GitHub Desktop.
import streamlit as st
st.set_page_config(page_title="Song Recommendation", layout="wide")
import pandas as pd
from sklearn.neighbors import NearestNeighbors
import plotly.express as px
import streamlit.components.v1 as components
@st.cache(allow_output_mutation=True)
def load_data():
df = pd.read_csv("data/filtered_track_df.csv")
df['genres'] = df.genres.apply(lambda x: [i[1:-1] for i in str(x)[1:-1].split(", ")])
exploded_track_df = df.explode("genres")
return exploded_track_df
genre_names = ['Dance Pop', 'Electronic', 'Electropop', 'Hip Hop', 'Jazz', 'K-pop', 'Latin', 'Pop', 'Pop Rap', 'R&B', 'Rock']
audio_feats = ["acousticness", "danceability", "energy", "instrumentalness", "valence", "tempo"]
exploded_track_df = load_data()
def n_neighbors_uri_audio(genre, start_year, end_year, test_feat):
genre = genre.lower()
genre_data = exploded_track_df[(exploded_track_df["genres"]==genre) & (exploded_track_df["release_year"]>=start_year) & (exploded_track_df["release_year"]<=end_year)]
genre_data = genre_data.sort_values(by='popularity', ascending=False)[:500]
neigh = NearestNeighbors()
neigh.fit(genre_data[audio_feats].to_numpy())
n_neighbors = neigh.kneighbors([test_feat], n_neighbors=len(genre_data), return_distance=False)[0]
uris = genre_data.iloc[n_neighbors]["uri"].tolist()
audios = genre_data.iloc[n_neighbors][audio_feats].to_numpy()
return uris, audios
title = "Song Recommendation Engine"
st.title(title)
st.write("First of all, welcome! This is the place where you can customize what you want to listen to based on genre and several key audio features. Try playing around with different settings and listen to the songs recommended by our system!")
st.markdown("##")
with st.container():
col1, col2,col3,col4 = st.columns((2,0.5,0.5,0.5))
with col3:
st.markdown("***Choose your genre:***")
genre = st.radio(
"",
genre_names, index=genre_names.index("Pop"))
with col1:
st.markdown("***Choose features to customize:***")
start_year, end_year = st.slider(
'Select the year range',
1990, 2019, (2015, 2019)
)
acousticness = st.slider(
'Acousticness',
0.0, 1.0, 0.5)
danceability = st.slider(
'Danceability',
0.0, 1.0, 0.5)
energy = st.slider(
'Energy',
0.0, 1.0, 0.5)
instrumentalness = st.slider(
'Instrumentalness',
0.0, 1.0, 0.0)
valence = st.slider(
'Valence',
0.0, 1.0, 0.45)
tempo = st.slider(
'Tempo',
0.0, 244.0, 118.0)
tracks_per_page = 6
test_feat = [acousticness, danceability, energy, instrumentalness, valence, tempo]
uris, audios = n_neighbors_uri_audio(genre, start_year, end_year, test_feat)
tracks = []
for uri in uris:
track = """<iframe src="https://open.spotify.com/embed/track/{}" width="260" height="380" frameborder="0" allowtransparency="true" allow="encrypted-media"></iframe>""".format(uri)
tracks.append(track)
if 'previous_inputs' not in st.session_state:
st.session_state['previous_inputs'] = [genre, start_year, end_year] + test_feat
current_inputs = [genre, start_year, end_year] + test_feat
if current_inputs != st.session_state['previous_inputs']:
if 'start_track_i' in st.session_state:
st.session_state['start_track_i'] = 0
st.session_state['previous_inputs'] = current_inputs
if 'start_track_i' not in st.session_state:
st.session_state['start_track_i'] = 0
with st.container():
col1, col2, col3 = st.columns([2,1,2])
if st.button("Recommend More Songs"):
if st.session_state['start_track_i'] < len(tracks):
st.session_state['start_track_i'] += tracks_per_page
current_tracks = tracks[st.session_state['start_track_i']: st.session_state['start_track_i'] + tracks_per_page]
current_audios = audios[st.session_state['start_track_i']: st.session_state['start_track_i'] + tracks_per_page]
if st.session_state['start_track_i'] < len(tracks):
for i, (track, audio) in enumerate(zip(current_tracks, current_audios)):
if i%2==0:
with col1:
components.html(
track,
height=400,
)
with st.expander("See more details"):
df = pd.DataFrame(dict(
r=audio[:5],
theta=audio_feats[:5]))
fig = px.line_polar(df, r='r', theta='theta', line_close=True)
fig.update_layout(height=400, width=340)
st.plotly_chart(fig)
else:
with col3:
components.html(
track,
height=400,
)
with st.expander("See more details"):
df = pd.DataFrame(dict(
r=audio[:5],
theta=audio_feats[:5]))
fig = px.line_polar(df, r='r', theta='theta', line_close=True)
fig.update_layout(height=400, width=340)
st.plotly_chart(fig)
else:
st.write("No songs left to recommend")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment