Skip to content

Instantly share code, notes, and snippets.

@erip
Created April 19, 2021 11:44
Show Gist options
  • Save erip/83989a630932303a2e283bbf4cff7d79 to your computer and use it in GitHub Desktop.
Save erip/83989a630932303a2e283bbf4cff7d79 to your computer and use it in GitHub Desktop.
Streamlit image dataset clustering viewer
#!/usr/bin/env python3
import base64
from io import BytesIO
from typing import List
from pathlib import Path
from argparse import ArgumentParser
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import streamlit as st
from PIL import Image
from umap import UMAP
from skimage import io
from torchvision import models, transforms
from bokeh.plotting import figure, show
from bokeh.models import HoverTool, ColumnDataSource, CategoricalColorMapper
from bokeh.palettes import Spectral10
def setup_argparse():
parser = ArgumentParser('Data Diagnostics')
parser.add_argument('-m', '--model-path', required=False, help='Path to resnet50 pretrained model.')
parser.add_argument('-g', '--glob', help="The glob pattern to match", default='*')
return parser
@st.cache
def get_model(model_path=None):
model = models.resnet50()
if model_path:
model.load_state_dict(torch.load(model_path, map_location='cpu'))
new_model = nn.Sequential(*(list(model.children())[:-1]))
new_model.eval()
return new_model
def embeddable_image(data):
image = Image.fromarray(data).resize((128, 128), Image.BICUBIC)
buffer = BytesIO()
image.save(buffer, format='png')
for_encoding = buffer.getvalue()
return 'data:image/png;base64,' + base64.b64encode(for_encoding).decode()
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
args = setup_argparse().parse_args()
model = get_model(args.model_path)
images = io.imread_collection(args.glob)
preprocessed = torch.stack([preprocess(Image.fromarray(img)) for img in images])
with torch.no_grad():
embeddings = model(preprocessed).squeeze(-1).squeeze(-1)
reduced_dim = UMAP().fit_transform(embeddings)
df = pd.DataFrame(reduced_dim, columns=('x', 'y'))
df['image'] = list(map(lambda img: embeddable_image(img), images))
df['file'] = images.files
datasource = ColumnDataSource(df)
plot_figure = figure(
title='UMAP projection of image dataset',
plot_width=600,
plot_height=600,
tools=('pan, wheel_zoom, reset')
)
plot_figure.add_tools(HoverTool(tooltips="""
<div>
<div>
<img src='@image' style='float: left; margin: 5px 5px 5px 5px'/>
</div>
<div>
<span style='font-size: 16px; color: #224499'>File:</span>
<span style='font-size: 18px'>@file</span>
</div>
</div>
"""))
plot_figure.circle(
'x',
'y',
source=datasource,
line_alpha=0.6,
fill_alpha=0.6,
size=10
)
st.bokeh_chart(plot_figure, use_container_width=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment