Last active June 6, 2022 22:23
Basic Streamlit widget for data labeling
from PIL import Image
import streamlit as st
import os
import pandas as pd
from datetime import datetime
def increase_idx(state):
if state.index < len(state.empty_indexes)-1:
state.index += 1
state.df_index = state.empty_indexes[state.index]
def decrease_idx(state):
if state.index > 0:
state.index -= 1
state.df_index = state.empty_indexes[state.index]
def set_idx(state):
if state.textinput:
state.index = int(state.textinput)
state.df_index = state.empty_indexes[state.index]
def save_resolved_label(state):
state.df.loc[state.df_index, 'label_fix'] = state.selectbox.split(
state.df.to_csv(state.path_to_target_csv, index=False)
def check_img_path(path):
if os.path.isfile(path):
return path
return './image/no_pict.jpeg'
def upload_csv():
uploaded_file = st.file_uploader('Upload CSV')
if uploaded_file is not None:
return pd.read_csv(uploaded_file)
st.warning('Please input a valid csv.')
def initialize_state(original_csv):
state = st.session_state
if 'initialized' not in state:
state['index'] = 0
timestamp_suffix ="%d%m%Y_%I%M%S%p")
state['path_to_target_csv'] = f"./fixed_data/labels_cleaned_{timestamp_suffix}.csv"
state['df'] = original_csv
if 'label_fix' not in state.df.columns:
state.df['label_fix'] = None
state['empty_indexes'] = state.df[state.df.label_fix.isna()].index
state['df_index'] = state.empty_indexes[state.index]
state.df.to_csv(state.path_to_target_csv, index=False)
state['initialized'] = True
return state
def generate_available_labels(state):
# TODO: parametrize columns for models
columns_to_labels = ['label_fix']+['label', 'model1', 'model2', 'model3']
row = state.df.iloc[state.df_index, :]
labels_available = list(row[columns_to_labels].dropna())
prefixes = columns_to_labels[len(columns_to_labels)-len(labels_available):]
labels_available = [prefixes[i]+': '+labels_available[i]
for i in range(len(prefixes))]
# workaround for a very strange bug:
# if list_available is exactly the same for two consequent images, label_fix doesn't get updated
if state.index % 2 == 0:
labels_available.append('unknown_resolution ')
return labels_available
def create_buttons(state):
col1, col2, col3, col4, col5 = st.columns([2, 1, 1, 1, 1])
with col1:
selectbox_placeholder = st.empty()
with col2:
st.button('Save current label, go to next',
on_click=save_resolved_label, args=((state,)))
with col3:
st.button('Back to prev img', on_click=decrease_idx, args=((state,)))
with col4:
st.button('Skip to next img', on_click=increase_idx, args=((state,)))
with col5:
st.text_input('go to Index', value=state.index,
on_change=set_idx, args=((state,)), key='textinput')
labels_available = generate_available_labels(state)
selectbox_placeholder.selectbox('Label', labels_available, index=0, \
on_change=save_resolved_label, args=((state,)), key='selectbox')
def display_img(state):
row = state.df.iloc[state.df_index, :]
original_label = row.label
fixed_label = row.label_fix
path = check_img_path(row.path)
filename = os.path.basename(path)
image =
st.subheader('original label {} - fixed label {}'.format(
original_label.upper(), str(fixed_label).upper()), anchor=None)
st.text('IMG {}'.format(filename))
st.image(image, use_column_width=True)
def create_ui(state):
if __name__ == "__main__":
original_csv = upload_csv()
state = initialize_state(original_csv)
