Skip to content

Instantly share code, notes, and snippets.

@cbur24
Last active September 22, 2020 07:29
Show Gist options
  • Save cbur24/9a52c14698b6a9324a62f5449972cf7f to your computer and use it in GitHub Desktop.
Save cbur24/9a52c14698b6a9324a62f5449972cf7f to your computer and use it in GitHub Desktop.
import xarray as xr
import numpy as np
import joblib
import dask.array as da
from datacube.utils.geometry import assign_crs
from dask_ml.wrappers import ParallelPostFit
def predict_xr(model,
input_xr,
chunk_size=None,
persist=True,
proba=False,
clean=False,
return_input=False):
"""
Using dask-ml ParallelPostfit(), runs the parallel
predict and predict_proba methods of sklearn
estimators. Useful for running predictions
on a larger-than-RAM datasets.
Last modified: September 2020
Parameters
----------
model : scikit-learn model or compatible object
Must have a .predict() method that takes numpy arrays.
input_xr : xarray.DataArray or xarray.Dataset.
Must have dimensions 'x' and 'y'
chunk_size : int
The dask chunk size to use on the flattened array. If this
is left as None, then the chunks size is inferred from the
.chunks method on the `input_xr`
persist : bool
If True, and proba=True, then 'input_xr' data will be
loaded into distributed memory. This will ensure data
is not loaded twice for the prediction of probabilities,
but this will only work if the data is not larger than RAM.
proba : bool
If True, predict probabilities
clean : bool
If True, remove Infs and NaNs from input and output arrays
return_input : bool
If True, then the data variables in the 'input_xr' dataset will
be appended to the output xarray dataset.
Returns
----------
output_xr : xarray.Dataset
An xarray.Dataset containing the prediction output from model
with input_xr as input, if proba=True then dataset will also contain
the prediciton probabilities. Has the same spatiotemporal structure
as input_xr.
"""
if chunk_size is None:
chunk_size=int(input_xr.chunks['x'][0])*int(input_xr.chunks['y'][0])
#convert model to dask predict
model=ParallelPostFit(model)
with joblib.parallel_backend('dask'):
x, y, crs = input_xr.x, input_xr.y, input_xr.geobox.crs
input_data = []
for var_name in input_xr.data_vars:
input_data.append(input_xr[var_name])
input_data_flattened = []
for data in input_data:
data = data.data.flatten().rechunk(chunk_size)
input_data_flattened.append(data)
# reshape for prediction
input_data_flattened = da.array(input_data_flattened).transpose()
if clean==True:
input_data_flattened = da.where(da.isfinite(input_data_flattened),
input_data_flattened, 0)
if (proba==True) & (persist==True):
#persisting data so we don't require loading all the data twice
input_data_flattened=input_data_flattened.persist()
#apply the classification
print(' predicting...')
out_class = model.predict(input_data_flattened)
# Mask out NaN or Inf values in results
if clean==True:
out_class = da.where(da.isfinite(out_class),out_class, 0)
# Reshape when writing out
out_class = out_class.reshape(len(y), len(x))
# stack back into xarray
output_xr = xr.DataArray(out_class, coords={
"x": x,
"y": y},
dims=["y", "x"])
output_xr = output_xr.to_dataset(name='Predictions')
if proba == True:
print(" probabilities...")
out_proba = model.predict_proba(input_data_flattened)
#convert to %
out_proba = da.max(out_proba, axis=1) * 100.0
if clean==True:
out_proba = da.where(da.isfinite(out_proba), out_proba, 0)
out_proba = out_proba.reshape(len(y), len(x))
out_proba = xr.DataArray(out_proba, coords={"x": x,"y": y}, dims=["y", "x"])
output_xr['Probabilities'] = out_proba
if return_input==True:
print(" input features...")
# unflatten the input_data_flattened array and append
# to the output_xr containin the predictions
arr = input_xr.to_array()
stacked = arr.stack(z=['x', 'y'])
# handle multivariable output
output_px_shape = ()
if len(input_data_flattened.shape[1:]):
output_px_shape = input_data_flattened.shape[1:]
output_features = input_data_flattened.reshape((len(stacked.z), *output_px_shape))
# set the stacked coordinate to match the input
output_features = xr.DataArray(output_features, coords={'z': stacked['z']},
dims=['z', *['output_dim_' + str(idx) for
idx in range(len(output_px_shape))]]).unstack()
#convert to dataset and rename arrays
output_features = output_features.to_dataset(dim='output_dim_0')
data_vars = list(input_xr.data_vars)
output_features = output_features.rename({i:j for i,j in zip(output_features.data_vars, data_vars)})
#merge with predictions
output_xr = xr.merge([output_xr, output_features], compat='override')
return assign_crs(output_xr, str(crs))
@cbur24
Copy link
Author

cbur24 commented Sep 17, 2020

To run this function:
predicted = predict_xr(model, features, proba=True, persist=True, clean=True).compute()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment