Skip to content

Instantly share code, notes, and snippets.

@chriddyp
Created Jul 23, 2021
Embed
What would you like to do?
import redis
import io
import hashlib
import plotly
import pandas as pd
import json
def IDNotAllowedException(component_name, ids_class):
ids_message = ', '.join([
f'{component_name}.ids.{i}' for i in dir(ids_class) if not i.startswith('_')
])
return Exception(''.join[
'`id` is not an accepted argument.\n',
f'{component_name} is a composite component which means that ',
'it does not accept `id` as a property. ',
'Composite components define their own pattern-matching IDs ',
'internally as they create their own Python callbacks upon ',
'instantiation. \n',
'To access the ID, use `composite_id=your_id` instead of `id=your_id` and ',
f'access the ID with e.g. `{component_name}.ids.{available_ids[0]}(your_id)`.',
(
'This composite component is composed of several components. These IDs are available: {ids_message}'
if len(available_ids) else ''
)
])
def omit(remove_keys, a_dict, a_dict_name='', component_name='', ids_class=None):
if a_dict is None:
return {}
if a_dict_name:
for k in remove_keys:
if k in a_dict:
if k == 'id':
raise IDNotAllowedException(component_name, ids_class)
raise Exception(
f'{k}= is not an allowed key in {a_dict_name}.'
)
return {k: a_dict[k] for k in a_dict if k not in remove_keys}
def create_id_class(composite_component_name, subcomponent_names):
class ids:
pass
def create_id_function(name):
return lambda instance: ({
'instance': instance,
'component': composite_component_name,
'subcomponent': name
})
for name in subcomponent_names:
setattr(ids, name, create_id_function(name))
return ids
class composite_store:
# TODO - Fake Redis
r = redis.Redis(host='localhost', port=6379, db=0)
@staticmethod
def _hash(serialized_obj):
return hashlib.sha512(serialized_obj).hexdigest()
# TODO - Need to sign data? Otherwise malicious actors could try to
# "guess" which data exists on the server.
@staticmethod
def save(value):
if isinstance(value, pd.DataFrame):
buffer = io.BytesIO()
value.to_parquet(buffer, compression='gzip')
buffer.seek(0)
df_as_bytes = buffer.read()
hash_key = composite_store._hash(df_as_bytes)
type = 'pd.DataFrame'
serialized_value = df_as_bytes
else:
serialized_value = json.dumps(value, cls=plotly.utils.PlotlyJSONEncoder).encode('utf-8')
hash_key = composite_store._hash(serialized_value)
type = 'json-serialized'
composite_store.r.set(
f'_dash_composite_components_value_{hash_key}',
serialized_value
)
composite_store.r.set(
f'_dash_composite_components_type_{hash_key}',
type
)
return hash_key
@staticmethod
def load(hash_key):
data_type = composite_store.r.get(f'_dash_composite_components_type_{hash_key}')
serialized_value = composite_store.r.get(f'_dash_composite_components_value_{hash_key}')
try:
if data_type == b'pd.DataFrame':
value = pd.read_parquet(io.BytesIO(serialized_value))
else:
value = json.loads(serialized_value)
except Exception as e:
print(e)
print(f'ERROR LOADING {data_type} - {hash_key}')
raise e
print(f'LOADED {data_type} - {hash_key}')
print(value)
return value
import dash
from dash.dependencies import Input, Output, State, MATCH, ALL
import dash_table
import uuid
import dash_html_components as html
import dash_core_components as dcc
import composite
app = dash.Dash(__name__)
_operators = [
['ge ', '>='],
['le ', '<='],
['lt ', '<'],
['gt ', '>'],
['ne ', '!='],
['eq ', '='],
['contains '],
['datestartswith ']]
def _split_filter_part(filter_part):
for operator_type in _operators:
for operator in operator_type:
if operator in filter_part:
name_part, value_part = filter_part.split(operator, 1)
name = name_part[name_part.find('{') + 1: name_part.rfind('}')]
value_part = value_part.strip()
v0 = value_part[0]
if (v0 == value_part[-1] and v0 in ("'", '"', '`')):
value = value_part[1: -1].replace('\\' + v0, v0)
else:
try:
value = float(value_part)
except ValueError:
value = value_part
# word _operators need spaces after them in the filter string,
# but we don't want these later
return name, operator_type[0].strip(), value
return [None] * 3
# DataTableComposite
# DataTableBackend
# DataTableServerSide
class DataTableComposite(html.Div):
# TODO - Add module name too?
ids = composite.create_id_class(__qualname__, ['datatable', 'store'])
def __init__(self, df=None, composite_id=None, **data_table_kwargs):
"""
Params:
- data_table_kwargs
"""
if composite_id is None:
composite_id = str(uuid.uuid4())
if 'id' in data_table_kwargs:
raise IDNotAllowedException(__qualname__)
columns = []
columns_cast_to_string = []
for c in df.columns:
column = {'name': c, 'id': c}
dtype = pd.api.types.infer_dtype(df[c])
if dtype.startswith('mixed'):
columns_cast_to_string.append(c)
df[c] = df[c].astype(str)
if pd.api.types.is_numeric_dtype(df[c]):
column['type'] = 'numeric'
elif pd.api.types.is_string_dtype(df[c]):
column['type'] = 'text'
elif pd.api.types.is_datetime64_any_dtype(df[c]):
column['type'] = 'datetime'
else:
columns_cast_to_string.append(c)
df[c] = df[c].astype(str)
column['type'] = 'text'
columns.append(column)
if columns_cast_to_string:
print(f'Warning: Converted the following mixed-type columns to strings so that they can be saved in Redis or JSON: {", ".join(columns_cast_to_string)}')
# TODO - Check data length and skip callback
# if data is <5000 rows?
derived_kwargs = data_table_kwargs.copy()
store_data = {}
if df is None and 'data' in data_table_kwargs:
store_data['df'] = composite.composite_store.save(
pd.DataFrame(data_table_kwargs['data'])
)
elif df is not None and not 'data' in data_table_kwargs:
store_data['df'] = composite.composite_store.save(df)
elif df is not None and 'data' in data_table_kwargs:
raise Exception('The `df` argument cannot be supplied with the data argument - it\'s ambiguous.')
else:
raise Exception('No data supplied. Pass in a dataframe as `df=` or a list of dictionaries as `data=`')
if df is not None:
if 'columns' not in data_table_kwargs:
derived_kwargs['columns'] = columns
super().__init__([
dcc.Store(
data=store_data,
id=self.ids.store(composite_id)
),
dash_table.DataTable(
id=self.ids.datatable(composite_id),
page_current=0,
page_size=10,
page_action='custom',
filter_action='custom',
filter_query='',
sort_action='custom',
sort_mode='multi',
sort_by=[],
**derived_kwargs
)
])
def filter_df(df, filter):
filtering_expressions = filter.split(' && ')
# TODO - Case insensitive filtering
for filter_part in filtering_expressions:
col_name, operator, filter_value = _split_filter_part(filter_part)
if operator in ('eq', 'ne', 'lt', 'le', 'gt', 'ge'):
# these _operators match pandas series operator method names
df = df.loc[getattr(dff[col_name], operator)(filter_value)]
elif operator == 'contains':
df = df.loc[dff[col_name].str.contains(filter_value)]
elif operator == 'datestartswith':
# this is a simplification of the front-end filtering logic,
# only works with complete fields in standard format
df = df.loc[dff[col_name].str.startswith(filter_value)]
return df
def sort_df(df, sort_by):
if len(sort_by):
df = df.sort_values(
[col['column_id'] for col in sort_by],
ascending=[
col['direction'] == 'asc'
for col in sort_by
],
inplace=False
)
return df
def page_df(df, page_current, page_size):
return df.iloc[page_current * page_size: (page_current + 1) * page_size]
@app.callback(
Output(ids.datatable(MATCH), 'data'),
Input(ids.datatable(MATCH), 'page_current'),
Input(ids.datatable(MATCH), 'page_size'),
Input(ids.datatable(MATCH), 'sort_by'),
Input(ids.datatable(MATCH), 'filter_query'),
State(ids.store(MATCH), 'data')
)
def filter_sort_page(page_current, page_size, sort_by, filter, store):
df = composite.composite_store.load(store['df'])
df = DataTableComposite.filter_df(df, filter)
df = DataTableComposite.sort_df(df, sort_by)
df = DataTableComposite.page_df(df, page_current, page_size)
return df.to_dict('records')
if __name__ == '__main__':
import plotly.express as px
import numpy as np
import datetime as dt
import pandas as pd
N = 100 * 1000
df = pd.DataFrame({
'numeric': np.random.randn(N),
'date': [
(dt.datetime.now() + dt.timedelta(days = int(100 * np.random.randn(1)[0]))).replace(microsecond=0)
for _ in range(N)
],
'string': ['red', 'green', 'yellow', 'blue', 'orange'] * int(N / 5),
'mixed': ['rain', 'sun', 10, 3.1, dt.datetime.now().replace(microsecond=0)] * int(N / 5)
})
app.layout = html.Div([
DataTableComposite(
df,
composite_id='my-table'
),
html.Div(id='my-graphs')
])
# Recompute data
@app.callback(
Output('my-graphs', 'children'),
Input(DataTableComposite.ids.datatable('my-table'), 'filter_query'),
State(DataTableComposite.ids.store('my-table'), 'data')
)
def update_graph(filter_query):
# not strictly necessary to load df from redis in this example since `df` is
# defined in global scope, but this won't be the case in dynamic examples
# where `DataTableComposite` is returned from a calback
dff = composite.composite_store.load(store['df'])
dff = DataTableComposite.filter_df(df, filter_query)
return html.Div(
[html.Div([
html.B('describe'),
html.Pre(str(dff.describe()))
], style={'display': 'inline-block'})] +
[html.Div([html.B(c), html.Pre(
str(dff[c].value_counts(sort=True))
)], style={'display': 'inline-block'}) for c in dff.columns]
)
app.run_server(debug=True)
import dash
from dash.dependencies import Input, Output, State, MATCH, ALL
import dash_table
import uuid
import dash_html_components as html
import dash_core_components as dcc
import composite
app = dash.Dash(__name__)
class TabbedGraphComposite(html.Div):
_component_name = __qualname__
ids = composite.create_id_class(_component_name, ['table', 'tabs', 'graph', 'store', 'tabs_content'])
def __init__(self,
figure=None,
df=None,
composite_id=None,
table_kwargs=None,
tabs_kwargs=None,
tab_kwargs=None,
graph_kwargs=None,
storage='client'):
if composite_id is None:
composite_id = str(uuid.uuid4())
if ((tabs_kwargs and ('id' in tabs_kwargs)) or
(graph_kwargs and ('id' in graph_kwargs)) or
(table_kwargs and ('id' in table_kwargs))):
raise composite.IDNotAllowedException(__qualname__, self.ids)
if storage == 'redis':
# Store data in Redis so that it can be retrieved later on the fly
# when displaying the tab or downloading the data.
# Use the data's hash as the unique key -
# This will prevent multiple processes from writing duplicate data.
store_data = {}
for (name, obj) in [
('figure', figure),
('df', df), ('graph_kwargs', graph_kwargs),
('table_kwargs', table_kwargs),
('composite_id', composite_id),
]:
store_data[name] = composite.composite_store.save(obj)
super().__init__(
children=[
dcc.Store(
id=self.ids.store(composite_id),
data=store_data
) if storage == 'redis' else None,
dcc.Tabs(value='graph', id=self.ids.tabs(composite_id), children=[
dcc.Tab(
label='Graph',
value='graph',
children=dcc.Graph(
figure=figure,
id=self.ids.graph(composite_id),
**composite.omit(['figure', 'id'], graph_kwargs, 'graph_kwargs', self._component_name, self.ids)
) if storage != 'redis' else None
),
dcc.Tab(
label='Data',
value='data',
children=html.Div([
dcc.Download(),
dash_table.DataTable(
columns=[{'name': i, 'id': i} for i in df.columns],
data=df.to_dict('r'),
**composite.omit(['id'], table_kwargs, 'table_kwargs', self._component_name, self.ids)
)
]) if storage != 'redis' else None
)
], **composite.omit(['children', 'id'], tabs_kwargs, 'tabs_kwargs', self._component_name, self.ids)),
html.Div(id=self.ids.tabs_content(composite_id)) if storage == 'redis' else None
],
)
@app.callback(
Output(ids.tabs_content(MATCH), 'children'),
Input(ids.tabs(MATCH), 'value'),
State(ids.store(MATCH), 'data'),
)
def display_tab(tab, store):
composite_id = composite.composite_store.load(store['composite_id'])
if tab == 'graph':
figure = composite.composite_store.load(store['figure'])
graph_kwargs = composite.composite_store.load(store['graph_kwargs'])
return html.Div([
'Graph',
dcc.Graph(
figure=figure,
# id=TabbedGraphComposite.ids.graph(composite_id),
**composite.omit(['figure', 'id'], graph_kwargs, 'graph_kwargs', TabbedGraphComposite._component_name, TabbedGraphComposite.ids)
)
])
else:
df = composite.composite_store.load(store['df'])
table_kwargs = composite.composite_store.load(store['table_kwargs'])
return html.Div([
'DataTable',
dash_table.DataTable(
# id=TabbedGraphComposite.ids.table(composite_id),
columns=[{'name': i, 'id': i} for i in df.columns],
data=df.to_dict('r'),
**composite.omit(['id'], table_kwargs, 'table_kwargs', TabbedGraphComposite._component_name, TabbedGraphComposite.ids)
)
])
if __name__ == '__main__':
import json
import pandas as pd
import plotly.express as px
df = px.data.iris()
app.layout = html.Div([
TabbedGraphComposite(
figure=px.scatter(df, x=df.columns[0], y=df.columns[1]),
df=df,
storage='redis'
),
# DataTableComposite(df),
#
# DataTableComposite(df, composite_id='custom'),
# html.Div(id='custom-output')
])
# @app.callback(
# Output('custom-output', 'children'),
# Input(DataTableComposite.ids.datatable('custom'), 'data')
# )
# def update(data):
# return html.Pre(json.dumps(data, indent=2))
app.run_server(debug=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment