Last active
September 22, 2020 07:29
-
-
Save cbur24/9a52c14698b6a9324a62f5449972cf7f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import xarray as xr | |
import numpy as np | |
import joblib | |
import dask.array as da | |
from datacube.utils.geometry import assign_crs | |
from dask_ml.wrappers import ParallelPostFit | |
def predict_xr(model, | |
input_xr, | |
chunk_size=None, | |
persist=True, | |
proba=False, | |
clean=False, | |
return_input=False): | |
""" | |
Using dask-ml ParallelPostfit(), runs the parallel | |
predict and predict_proba methods of sklearn | |
estimators. Useful for running predictions | |
on a larger-than-RAM datasets. | |
Last modified: September 2020 | |
Parameters | |
---------- | |
model : scikit-learn model or compatible object | |
Must have a .predict() method that takes numpy arrays. | |
input_xr : xarray.DataArray or xarray.Dataset. | |
Must have dimensions 'x' and 'y' | |
chunk_size : int | |
The dask chunk size to use on the flattened array. If this | |
is left as None, then the chunks size is inferred from the | |
.chunks method on the `input_xr` | |
persist : bool | |
If True, and proba=True, then 'input_xr' data will be | |
loaded into distributed memory. This will ensure data | |
is not loaded twice for the prediction of probabilities, | |
but this will only work if the data is not larger than RAM. | |
proba : bool | |
If True, predict probabilities | |
clean : bool | |
If True, remove Infs and NaNs from input and output arrays | |
return_input : bool | |
If True, then the data variables in the 'input_xr' dataset will | |
be appended to the output xarray dataset. | |
Returns | |
---------- | |
output_xr : xarray.Dataset | |
An xarray.Dataset containing the prediction output from model | |
with input_xr as input, if proba=True then dataset will also contain | |
the prediciton probabilities. Has the same spatiotemporal structure | |
as input_xr. | |
""" | |
if chunk_size is None: | |
chunk_size=int(input_xr.chunks['x'][0])*int(input_xr.chunks['y'][0]) | |
#convert model to dask predict | |
model=ParallelPostFit(model) | |
with joblib.parallel_backend('dask'): | |
x, y, crs = input_xr.x, input_xr.y, input_xr.geobox.crs | |
input_data = [] | |
for var_name in input_xr.data_vars: | |
input_data.append(input_xr[var_name]) | |
input_data_flattened = [] | |
for data in input_data: | |
data = data.data.flatten().rechunk(chunk_size) | |
input_data_flattened.append(data) | |
# reshape for prediction | |
input_data_flattened = da.array(input_data_flattened).transpose() | |
if clean==True: | |
input_data_flattened = da.where(da.isfinite(input_data_flattened), | |
input_data_flattened, 0) | |
if (proba==True) & (persist==True): | |
#persisting data so we don't require loading all the data twice | |
input_data_flattened=input_data_flattened.persist() | |
#apply the classification | |
print(' predicting...') | |
out_class = model.predict(input_data_flattened) | |
# Mask out NaN or Inf values in results | |
if clean==True: | |
out_class = da.where(da.isfinite(out_class),out_class, 0) | |
# Reshape when writing out | |
out_class = out_class.reshape(len(y), len(x)) | |
# stack back into xarray | |
output_xr = xr.DataArray(out_class, coords={ | |
"x": x, | |
"y": y}, | |
dims=["y", "x"]) | |
output_xr = output_xr.to_dataset(name='Predictions') | |
if proba == True: | |
print(" probabilities...") | |
out_proba = model.predict_proba(input_data_flattened) | |
#convert to % | |
out_proba = da.max(out_proba, axis=1) * 100.0 | |
if clean==True: | |
out_proba = da.where(da.isfinite(out_proba), out_proba, 0) | |
out_proba = out_proba.reshape(len(y), len(x)) | |
out_proba = xr.DataArray(out_proba, coords={"x": x,"y": y}, dims=["y", "x"]) | |
output_xr['Probabilities'] = out_proba | |
if return_input==True: | |
print(" input features...") | |
# unflatten the input_data_flattened array and append | |
# to the output_xr containin the predictions | |
arr = input_xr.to_array() | |
stacked = arr.stack(z=['x', 'y']) | |
# handle multivariable output | |
output_px_shape = () | |
if len(input_data_flattened.shape[1:]): | |
output_px_shape = input_data_flattened.shape[1:] | |
output_features = input_data_flattened.reshape((len(stacked.z), *output_px_shape)) | |
# set the stacked coordinate to match the input | |
output_features = xr.DataArray(output_features, coords={'z': stacked['z']}, | |
dims=['z', *['output_dim_' + str(idx) for | |
idx in range(len(output_px_shape))]]).unstack() | |
#convert to dataset and rename arrays | |
output_features = output_features.to_dataset(dim='output_dim_0') | |
data_vars = list(input_xr.data_vars) | |
output_features = output_features.rename({i:j for i,j in zip(output_features.data_vars, data_vars)}) | |
#merge with predictions | |
output_xr = xr.merge([output_xr, output_features], compat='override') | |
return assign_crs(output_xr, str(crs)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
To run this function:
predicted = predict_xr(model, features, proba=True, persist=True, clean=True).compute()