Basic Streamlit widget for data labeling
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from PIL import Image | |
import streamlit as st | |
import os | |
import pandas as pd | |
from datetime import datetime | |
def increase_idx(state): | |
if state.index < len(state.empty_indexes)-1: | |
state.index += 1 | |
state.df_index = state.empty_indexes[state.index] | |
def decrease_idx(state): | |
if state.index > 0: | |
state.index -= 1 | |
state.df_index = state.empty_indexes[state.index] | |
def set_idx(state): | |
if state.textinput: | |
state.index = int(state.textinput) | |
state.df_index = state.empty_indexes[state.index] | |
def save_resolved_label(state): | |
state.df.loc[state.df_index, 'label_fix'] = state.selectbox.split( | |
':')[-1].strip() | |
state.df.to_csv(state.path_to_target_csv, index=False) | |
increase_idx(state) | |
def check_img_path(path): | |
if os.path.isfile(path): | |
return path | |
else: | |
return './image/no_pict.jpeg' | |
def upload_csv(): | |
uploaded_file = st.file_uploader('Upload CSV') | |
if uploaded_file is not None: | |
return pd.read_csv(uploaded_file) | |
else: | |
st.warning('Please input a valid csv.') | |
st.stop() | |
def initialize_state(original_csv): | |
state = st.session_state | |
if 'initialized' not in state: | |
state['index'] = 0 | |
timestamp_suffix = datetime.now().strftime("%d%m%Y_%I%M%S%p") | |
state['path_to_target_csv'] = f"./fixed_data/labels_cleaned_{timestamp_suffix}.csv" | |
state['df'] = original_csv | |
if 'label_fix' not in state.df.columns: | |
state.df['label_fix'] = None | |
state['empty_indexes'] = state.df[state.df.label_fix.isna()].index | |
state['df_index'] = state.empty_indexes[state.index] | |
state.df.to_csv(state.path_to_target_csv, index=False) | |
state['initialized'] = True | |
return state | |
def generate_available_labels(state): | |
# TODO: parametrize columns for models | |
columns_to_labels = ['label_fix']+['label', 'model1', 'model2', 'model3'] | |
row = state.df.iloc[state.df_index, :] | |
labels_available = list(row[columns_to_labels].dropna()) | |
prefixes = columns_to_labels[len(columns_to_labels)-len(labels_available):] | |
labels_available = [prefixes[i]+': '+labels_available[i] | |
for i in range(len(prefixes))] | |
# workaround for a very strange bug: | |
# if list_available is exactly the same for two consequent images, label_fix doesn't get updated | |
if state.index % 2 == 0: | |
labels_available.append('unknown_resolution') | |
else: | |
labels_available.append('unknown_resolution ') | |
return labels_available | |
def create_buttons(state): | |
st.progress((state.index+1)/len(state.empty_indexes)) | |
col1, col2, col3, col4, col5 = st.columns([2, 1, 1, 1, 1]) | |
with col1: | |
selectbox_placeholder = st.empty() | |
with col2: | |
st.button('Save current label, go to next', | |
on_click=save_resolved_label, args=((state,))) | |
with col3: | |
st.button('Back to prev img', on_click=decrease_idx, args=((state,))) | |
with col4: | |
st.button('Skip to next img', on_click=increase_idx, args=((state,))) | |
with col5: | |
st.text_input('go to Index', value=state.index, | |
on_change=set_idx, args=((state,)), key='textinput') | |
labels_available = generate_available_labels(state) | |
selectbox_placeholder.selectbox('Label', labels_available, index=0, \ | |
on_change=save_resolved_label, args=((state,)), key='selectbox') | |
def display_img(state): | |
row = state.df.iloc[state.df_index, :] | |
original_label = row.label | |
fixed_label = row.label_fix | |
path = check_img_path(row.path) | |
filename = os.path.basename(path) | |
image = Image.open(path) | |
st.subheader('original label {} - fixed label {}'.format( | |
original_label.upper(), str(fixed_label).upper()), anchor=None) | |
st.text('IMG {}'.format(filename)) | |
st.image(image, use_column_width=True) | |
def create_ui(state): | |
create_buttons(state) | |
display_img(state) | |
if __name__ == "__main__": | |
original_csv = upload_csv() | |
state = initialize_state(original_csv) | |
create_ui(state) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment