Created
April 3, 2023 15:21
-
-
Save voatsap/6420798fbde1d3e773a6e95ec39778a4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import streamlit as st | |
import pandas as pd | |
from st_aggrid import AgGrid, GridUpdateMode | |
from st_aggrid.grid_options_builder import GridOptionsBuilder | |
import pickle | |
from st_on_hover_tabs import on_hover_tabs | |
import collections | |
st.set_page_config(layout="wide") | |
merchant = 'winemag' | |
# Load Dataframe | |
def load_wine_dataframe(merchant): | |
path_to_file = merchant+'_wine_dataframe_varieties_regions.txt' | |
with open(path_to_file,'rb') as file_object: | |
raw_data = file_object.read() | |
wine_dataframe = pickle.loads(raw_data) | |
return wine_dataframe | |
def wine_selector(): | |
wine_dataframe = load_wine_dataframe(merchant) | |
wd = wine_dataframe[['country','region_1','province','variety','title','normalized_descriptors']] | |
by_country = wd.groupby(['country'])['title'].count().sort_values(ascending=False).reset_index() | |
df_c = pd.DataFrame(by_country) | |
gd = GridOptionsBuilder.from_dataframe(df_c) | |
gd.configure_selection(selection_mode='multiple', use_checkbox=True) | |
gridoptions = gd.build() | |
grid_table_c = AgGrid(df_c, height=400, width=400, gridOptions=gridoptions,fit_columns_on_grid_load=True, | |
update_mode=GridUpdateMode.SELECTION_CHANGED) | |
selected_row = grid_table_c["selected_rows"] | |
if selected_row: | |
df_result = pd.DataFrame(selected_row) | |
print(df_result.head()) | |
st.dataframe(df_result['country']) | |
filtered_by_country = wd.loc[(wd['country'].isin(list(df_result['country'])))][['country','province','title','variety','normalized_descriptors']] | |
country_list = filtered_by_country.groupby(['country','province'])['title'].count().sort_values(ascending=False).reset_index() | |
gd = GridOptionsBuilder.from_dataframe(country_list) | |
gd.configure_selection(selection_mode='multiple', use_checkbox=True) | |
gridoptions = gd.build() | |
st.write("### Now choose some provinces from selected countries") | |
grid_table_r = AgGrid(country_list, height=400, width=400, gridOptions=gridoptions,fit_columns_on_grid_load=True, | |
update_mode=GridUpdateMode.SELECTION_CHANGED) | |
selected_row_province = grid_table_r["selected_rows"] | |
if selected_row_province: | |
df_result_p = pd.DataFrame(selected_row_province) | |
st.dataframe(df_result_p['province']) | |
filtered_by_province = filtered_by_country.loc[(filtered_by_country['province'].isin(list(df_result_p['province'])))][['country','province','title','variety','normalized_descriptors']] | |
province_list = filtered_by_province.groupby(['province','variety'])['title'].count().sort_values(ascending=False).reset_index() | |
gd = GridOptionsBuilder.from_dataframe(province_list) | |
gd.configure_selection(selection_mode='multiple', use_checkbox=True) | |
gridoptions = gd.build() | |
st.write("### Now select the varieties you preffer most:") | |
grid_table_v = AgGrid(province_list, height=400, width=400, gridOptions=gridoptions,fit_columns_on_grid_load=True, | |
update_mode=GridUpdateMode.SELECTION_CHANGED) | |
selected_row_variety = grid_table_v["selected_rows"] | |
df_result_v = pd.DataFrame(selected_row_variety) | |
filtered_by_variety = filtered_by_province.loc[(filtered_by_province['variety'].isin(list(df_result_v['variety'])))][['country','province','title','variety','normalized_descriptors']] | |
descriptor_mapping = pd.DataFrame() | |
descriptor_mapping = pd.read_csv('descriptor_mapping.csv') | |
level3 = descriptor_mapping['level_3'].to_list() | |
desc_common = pd.DataFrame() | |
for index, row in df_result_v.iterrows(): | |
# TODO: add region filter | |
query = filtered_by_variety.loc[(filtered_by_variety['variety'].isin([row['variety']]))][['normalized_descriptors']] | |
result_list = [] | |
for d_list in query['normalized_descriptors']: | |
result_list = result_list + d_list | |
res_list = [tag for tag in result_list if tag in level3 ] | |
counter = collections.Counter(res_list) | |
common_desc = str([(i, round((counter[i] / len(res_list) * 100.0),2)) for i, count in counter.most_common(20)]) | |
dict2 = {'province': [row['province']], 'variety': [row['variety']], 'common_desc': [common_desc]} | |
df2 = pd.DataFrame(dict2) | |
desc_common = pd.concat([desc_common,df2]) | |
# st.dataframe(grid_table_v["selected_rows"]) | |
st.write("## Tasting notes") | |
st.dataframe(desc_common,use_container_width=True) | |
### UI | |
st.markdown('<style>' + open('./style.css').read() + '</style>', unsafe_allow_html=True) | |
with st.sidebar: | |
tabs = on_hover_tabs(tabName=['My Profile', 'Suggested For Me', 'Explore'], | |
iconName=['dashboard', 'money', 'economy'], default_choice=0) | |
if tabs =='My Profile': | |
st.title("My Profile") | |
st.header("The wine from what countries do you like?") | |
wine_selector() | |
elif tabs == 'Suggested For Me': | |
st.title("Paper") | |
st.write('Name of option is {}'.format(tabs)) | |
elif tabs == 'Explore': | |
st.title("Tom") | |
st.write('Name of option is {}'.format(tabs)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment