Created
April 19, 2021 11:44
-
-
Save erip/83989a630932303a2e283bbf4cff7d79 to your computer and use it in GitHub Desktop.
Streamlit image dataset clustering viewer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import base64 | |
from io import BytesIO | |
from typing import List | |
from pathlib import Path | |
from argparse import ArgumentParser | |
import torch | |
import numpy as np | |
import pandas as pd | |
import torch.nn as nn | |
import streamlit as st | |
from PIL import Image | |
from umap import UMAP | |
from skimage import io | |
from torchvision import models, transforms | |
from bokeh.plotting import figure, show | |
from bokeh.models import HoverTool, ColumnDataSource, CategoricalColorMapper | |
from bokeh.palettes import Spectral10 | |
def setup_argparse(): | |
parser = ArgumentParser('Data Diagnostics') | |
parser.add_argument('-m', '--model-path', required=False, help='Path to resnet50 pretrained model.') | |
parser.add_argument('-g', '--glob', help="The glob pattern to match", default='*') | |
return parser | |
@st.cache | |
def get_model(model_path=None): | |
model = models.resnet50() | |
if model_path: | |
model.load_state_dict(torch.load(model_path, map_location='cpu')) | |
new_model = nn.Sequential(*(list(model.children())[:-1])) | |
new_model.eval() | |
return new_model | |
def embeddable_image(data): | |
image = Image.fromarray(data).resize((128, 128), Image.BICUBIC) | |
buffer = BytesIO() | |
image.save(buffer, format='png') | |
for_encoding = buffer.getvalue() | |
return 'data:image/png;base64,' + base64.b64encode(for_encoding).decode() | |
preprocess = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
args = setup_argparse().parse_args() | |
model = get_model(args.model_path) | |
images = io.imread_collection(args.glob) | |
preprocessed = torch.stack([preprocess(Image.fromarray(img)) for img in images]) | |
with torch.no_grad(): | |
embeddings = model(preprocessed).squeeze(-1).squeeze(-1) | |
reduced_dim = UMAP().fit_transform(embeddings) | |
df = pd.DataFrame(reduced_dim, columns=('x', 'y')) | |
df['image'] = list(map(lambda img: embeddable_image(img), images)) | |
df['file'] = images.files | |
datasource = ColumnDataSource(df) | |
plot_figure = figure( | |
title='UMAP projection of image dataset', | |
plot_width=600, | |
plot_height=600, | |
tools=('pan, wheel_zoom, reset') | |
) | |
plot_figure.add_tools(HoverTool(tooltips=""" | |
<div> | |
<div> | |
<img src='@image' style='float: left; margin: 5px 5px 5px 5px'/> | |
</div> | |
<div> | |
<span style='font-size: 16px; color: #224499'>File:</span> | |
<span style='font-size: 18px'>@file</span> | |
</div> | |
</div> | |
""")) | |
plot_figure.circle( | |
'x', | |
'y', | |
source=datasource, | |
line_alpha=0.6, | |
fill_alpha=0.6, | |
size=10 | |
) | |
st.bokeh_chart(plot_figure, use_container_width=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment