Last active
November 1, 2021 02:52
-
-
Save RensDimmendaal/f0d549219c1004dc146e5ed40de39bb9 to your computer and use it in GitHub Desktop.
Snippets for the data centric ai roman numerals competition
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import base64 | |
from pathlib import Path | |
from typing import Sequence, Union | |
import altair as alt | |
import numpy as np | |
import pandas as pd | |
import tensorflow as tf | |
import umap | |
from tensorflow.python.keras.preprocessing.image_dataset import ( | |
load_image as tf_load_image, | |
) | |
from tqdm import tqdm | |
ROMAN_NUMERALS = "i ii iii iv v vi vii viii ix x".split() | |
REVERSED_NUMERALS = {r: i for i, r in enumerate(ROMAN_NUMERALS, start=1)} | |
def load_img_32(fpath): | |
"""Loads an image as size 32x32""" | |
return ( | |
tf_load_image( | |
path=str(fpath), | |
image_size=(32, 32), | |
num_channels=3, | |
interpolation="bilinear", | |
smart_resize=False, | |
) | |
.numpy() | |
.astype(int) | |
) | |
def base64_encode_png(fpath): | |
with open(fpath, "rb") as f: | |
return "data:image/png;base64," + base64.b64encode(f.read()).decode() | |
def embed(image_paths: Sequence[Union[str, Path]]) -> np.ndarray: | |
# make model | |
base_model = tf.keras.applications.ResNet50( | |
input_shape=(32, 32, 3), | |
include_top=False, | |
weights="imagenet", | |
) | |
base_model = tf.keras.Model( | |
base_model.inputs, outputs=[base_model.get_layer("conv2_block3_out").output] | |
) | |
inputs = tf.keras.Input(shape=(32, 32, 3)) | |
x = tf.keras.applications.resnet.preprocess_input(inputs) | |
x = base_model(x) | |
x = tf.keras.layers.GlobalAveragePooling2D()(x) | |
model = tf.keras.Model(inputs, x) | |
embeddings = np.zeros((len(image_paths), 256)) | |
for idx, img_path in tqdm(enumerate(image_paths)): | |
img = load_img_32(img_path) | |
embeddings[idx, :] = model(img.reshape((1, 32, 32, 3))).numpy()[0] | |
return embeddings | |
def reduce_dimensionality(embeddings): | |
return umap.UMAP().fit_transform(embeddings) | |
def load_df(data_dir="data/raw"): | |
fpaths = list(Path(data_dir).glob("**/*.png")) | |
return ( | |
pd.DataFrame({"fpath": fpaths}) | |
.assign(base64_encoded_img32=lambda d: d["fpath"].apply(base64_encode_png)) | |
.assign(label=lambda d: d["fpath"].apply(lambda p: p.parent.name)) | |
.assign(subset=lambda d: d["fpath"].apply(lambda p: p.parent.parent.name)) | |
.assign( | |
arabic_label=lambda d: d["label"] | |
.apply(lambda r: REVERSED_NUMERALS[r]) | |
.replace(10, 0) | |
) | |
) | |
def altair_plot( | |
df, | |
x_axis, | |
y_axis, | |
color, | |
text_marker, | |
tooltip, | |
title, | |
img_col="base64_encoded_img32", | |
): | |
ddf = df[[x_axis, y_axis, color, text_marker, img_col] + tooltip] | |
result = ( | |
alt.Chart(ddf) | |
.mark_text(size=10, opacity=0.2) | |
.encode( | |
x=x_axis, | |
y=y_axis, | |
color=alt.Color(color, scale=alt.Scale(scheme="dark2")), | |
tooltip=tooltip, | |
text=text_marker, | |
) | |
.properties(title=title) | |
) | |
brush = alt.selection(type="interval") | |
ranked_img = ( | |
alt.Chart(ddf) | |
.mark_image(width=32, height=32) | |
.encode( | |
y=alt.Y("row_number:O", axis=None), | |
url=img_col, | |
) | |
.transform_window(row_number="row_number()") | |
.transform_filter(brush) | |
.transform_window(rank="rank(row_number)") | |
.transform_filter(alt.datum.rank < 20) | |
.properties(width=50, title="Img Selection") | |
) | |
return result.add_selection(brush) | ranked_img | |
if __name__ == "__main__": | |
df = load_df("./data/raw/") | |
embeds = embed(df["fpath"]) | |
df[["dim1", "dim2"]] = reduce_dimensionality(embeds) | |
fig = altair_plot( | |
df, | |
x_axis="dim1", | |
y_axis="dim2", | |
color="subset", | |
text_marker="arabic_label", | |
tooltip=[], # no tooltip | |
title="The standard Train/Validation split has style differences.", | |
) | |
fig.save("my_figure.html") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from PIL import Image | |
import streamlit as st | |
from streamlit_drawable_canvas import st_canvas | |
import matplotlib.pyplot as plt | |
from roman_numerals.io import load_df | |
from pathlib import Path | |
# --------------------------------------- | |
# -- Use this app to make i,ii,and iii -- | |
# --------------------------------------- | |
# Specify canvas parameters in application | |
stroke_width = 10 | |
stroke_color = "#FFFFFF" | |
bg_color = "#eee" | |
drawing_mode = "freedraw" | |
realtime_update = st.sidebar.checkbox("Update in realtime", True) | |
viii_fpaths = list(Path("./data/").glob("**/viii/*.png")) | |
save_dir = Path("./output/augmented") | |
n = len(viii_fpaths) | |
st.write( | |
f"# original={n},done={len(list(i for i in save_dir.glob('**/i/*.png') if '_' not in i.name))}/{n}" | |
) | |
save_button = st.button("save") | |
_ = st.button("next") | |
strategy_placholder = st.empty() | |
def get_next(load_dir, save_dir): | |
"""Find the next image, instruction, and savepath""" | |
# find the viii image that's not in save_dir yet | |
# always instruct to make it a vii for now | |
# save it as aug-vii-from-{original_name} | |
for img_path in load_dir: | |
if img_path.name not in (f.name for f in (save_dir / "iii/").iterdir()): | |
return ( | |
img_path, | |
"viii->iii", | |
(save_dir / "iii" / img_path.name), | |
) | |
elif img_path.name not in (f.name for f in (save_dir / "ii/").iterdir()): | |
return ( | |
(save_dir / "iii" / img_path.name), | |
"iii->ii", | |
(save_dir / "ii" / img_path.name), | |
) | |
elif img_path.name not in (f.name for f in (save_dir / "i/").iterdir()): | |
return ( | |
(save_dir / "ii" / img_path.name), | |
"ii->i", | |
(save_dir / "i" / img_path.name), | |
) | |
img_fpath, strategy, save_path = get_next(viii_fpaths, save_dir) | |
st.write(f"## {strategy}") | |
img = Image.open(img_fpath) | |
arr = np.asarray(img) if img else None | |
st.write(arr.shape) | |
# Create a canvas component | |
canvas_result = st_canvas( | |
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity | |
stroke_width=stroke_width, | |
stroke_color=stroke_color, | |
background_color=bg_color, | |
background_image=img, | |
update_streamlit=realtime_update, | |
height=arr.shape[0], | |
width=arr.shape[1], | |
drawing_mode=drawing_mode, | |
# key="canvas", | |
) | |
out = arr.copy() | |
if canvas_result.image_data is not None: | |
template = canvas_result.image_data.mean(axis=-1) | |
if template.max().max() > 0: | |
out[template > 0] = 255 | |
st.image(out) | |
if save_button: | |
assert img != out, "augment the data!" | |
plt.imsave(save_path, out, cmap="gray") | |
st.balloons() | |
# --------------------------------------- | |
# Use these functions to make iv,v,vi,vii | |
# --------------------------------------- | |
def combine(a, b): | |
"'Adds' two images" | |
return np.concatenate([a[..., np.newaxis], b[..., np.newaxis]], axis=-1).min( | |
axis=-1 | |
) | |
def remove(a, b): | |
"'Removes' image b from image a." | |
out = a.copy() | |
out[b < 1] = 1.0 | |
return out | |
def create_counterfactuals_from_viii(viii_fpath: Path, augmented_root: Path): | |
"""Given an image fpath of a viii, and the root folder of synthetically | |
created i, ii, and iii from that viii. This function returns images of | |
i, ii, iii, iv, v, vi, vii, and viii""" | |
fname = viii_fpath.name | |
i = plt.imread(augmented_root / "i" / fname)[:, :, 0] | |
ii = plt.imread(augmented_root / "ii" / fname)[:, :, 0] | |
iii = plt.imread(augmented_root / "iii" / fname)[:, :, 0] | |
viii = plt.imread(viii_fpath) | |
v = remove(viii, iii) | |
vi = combine(v, i) | |
vii = combine(v, ii) | |
iv = vi[:, ::-1] | |
return dict(i=i, ii=ii, iii=iii, iv=iv, v=v, vi=vi, vii=vii, viii=viii) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment