Skip to content

Instantly share code, notes, and snippets.

@nathanmargaglio
Created March 6, 2020 16:28
Show Gist options
  • Save nathanmargaglio/804799fc1cb4a40f43dd60cb5555bed4 to your computer and use it in GitHub Desktop.
Save nathanmargaglio/804799fc1cb4a40f43dd60cb5555bed4 to your computer and use it in GitHub Desktop.
Dash Drawing App
from data_processing import ds_data, image_df, model
from common import numpy_to_base64
import numpy as np
import json
import time
from flask import Flask, request
import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output
from dash.exceptions import PreventUpdate
class App:
def __init__(self, server, name=None, url_base_pathname='/paint-app/'):
if name is None:
name = __name__
external_stylesheets = ['https://cdnjs.cloudflare.com/ajax/libs/materialize/1.0.0/css/materialize.min.css']
external_scripts = [
'https://cdnjs.cloudflare.com/ajax/libs/materialize/1.0.0/js/materialize.min.js',
'https://cdnjs.cloudflare.com/ajax/libs/fabric.js/3.5.0/fabric.js'
]
app = dash.Dash(
server=server,
name=name,
url_base_pathname=url_base_pathname,
external_scripts=external_scripts,
external_stylesheets=external_stylesheets
)
app.layout = html.Div(children=[
html.Div(className="card-panel grey darken-2 center-align", children=[
html.Canvas(id='canvas-image',
width=400,
height=100),
html.Div(children=[
html.A(id='send-image',
children=["Send"],
style={'margin': '10px'},
className="waves-effect waves-light btn-small grey"),
html.A(id='clear-image',
children=["Clear"],
style={'margin': '10px'},
className="waves-effect waves-light btn-small grey"),
])
]),
html.Div(className="card-panel grey darken-2 center-align", children=[
html.Img(id='result-image',
src=numpy_to_base64(np.ones((100, 400)))),
html.P(id='result-tag'),
])
])
@server.route('/post-image', methods=['POST'])
def post_image():
global posted_image
incoming_data = np.array(json.loads(request.data))
encoding = model.encode(incoming_data.reshape((1, 100, 400, 1)).astype(np.float32))
mus = encoding[0]
dists = ds_data.encodings.map(lambda d: np.linalg.norm(d - mus))
near_idx = np.argmin(np.array(dists))
result = ds_data.iloc[near_idx]
result_image = numpy_to_base64(result.image.reshape((100, 400)))
return json.dumps({
"result_image": result_image,
"meta": f'({result.name}): timestep: {result.timestep}, BR: {result.BR}, CHI: {result.CHI}'
})
self.app = app
def run_server(self, **kwargs):
self.app.run_server(**kwargs)
if __name__ == '__main__':
server = Flask(__name__)
app = App(server)
app.run_server(debug=True, host='0.0.0.0', port=3001)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment