Skip to content

Instantly share code, notes, and snippets.

@aaronspring
Last active April 27, 2024 19:05
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save aaronspring/36e112e992e36fba935f73404dbbd3cd to your computer and use it in GitHub Desktop.
Save aaronspring/36e112e992e36fba935f73404dbbd3cd to your computer and use it in GitHub Desktop.
vectorized `sklearn` with `xarray`
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# vectorized `sklearn` with `xarray`\n",
"\n",
"run a `sklearn` classifier on a grid (longitude/X, latitude/Y, lead_time, ...) all at once\n",
"\n",
"might be slow due to `vectorize=True`, but the code is short\n",
"\n",
"inspired by and based on https://renkulab.io/gitlab/lluis.palma/s2s-ai-challenge-bsc/-/blob/submission-ML_models/notebooks/S2S_ML_models.ipynb\n",
"\n",
"answers also https://discourse.pangeo.io/t/vectorized-sklearn/1444"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## import"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"import xarray as xr\n",
"xr.set_options(display_style='text')\n",
"\n",
"import numpy as np\n",
"\n",
"from sklearn.linear_model import LogisticRegression"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre>&lt;xarray.Dataset&gt;\n",
"Dimensions: (X: 5, Y: 5, lead_time: 2, week: 53, year: 20)\n",
"Coordinates:\n",
" * lead_time (lead_time) int64 1 2\n",
" * year (year) int64 2000 2001 2002 2003 2004 ... 2016 2017 2018 2019\n",
" * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n",
" * X (X) int64 0 1 2 3 4\n",
" * Y (Y) int64 0 1 2 3 4\n",
"Data variables:\n",
" t2m (lead_time, year, week, X, Y) float64 0.885 0.61 ... 0.1928\n",
" tp (lead_time, year, week, X, Y) float64 0.0597 0.7052 ... 0.3623\n",
" msl (lead_time, year, week, X, Y) float64 0.5728 0.8126 ... 0.2536</pre>"
],
"text/plain": [
"<xarray.Dataset>\n",
"Dimensions: (X: 5, Y: 5, lead_time: 2, week: 53, year: 20)\n",
"Coordinates:\n",
" * lead_time (lead_time) int64 1 2\n",
" * year (year) int64 2000 2001 2002 2003 2004 ... 2016 2017 2018 2019\n",
" * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n",
" * X (X) int64 0 1 2 3 4\n",
" * Y (Y) int64 0 1 2 3 4\n",
"Data variables:\n",
" t2m (lead_time, year, week, X, Y) float64 0.885 0.61 ... 0.1928\n",
" tp (lead_time, year, week, X, Y) float64 0.0597 0.7052 ... 0.3623\n",
" msl (lead_time, year, week, X, Y) float64 0.5728 0.8126 ... 0.2536"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# synethetic data: competition on 5x5 grid \n",
"# raw forecasts\n",
"X_train = xr.DataArray(np.random.rand(2,20,53,5,5,3),\n",
" dims=['lead_time','year','week','X','Y','variable'],\n",
" coords={'lead_time':[1,2],'year':range(2000,2020),'week':range(53), 'X':range(5), \"Y\":range(5), \"variable\":['t2m','tp','msl']}\n",
" ).to_dataset(dim='variable')\n",
"X_train"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre>&lt;xarray.Dataset&gt;\n",
"Dimensions: (X: 5, Y: 5, lead_time: 2, week: 53, year: 2)\n",
"Coordinates:\n",
" * lead_time (lead_time) int64 1 2\n",
" * year (year) int64 2018 2019\n",
" * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n",
" * X (X) int64 0 1 2 3 4\n",
" * Y (Y) int64 0 1 2 3 4\n",
"Data variables:\n",
" t2m (lead_time, year, week, X, Y) float64 0.8516 0.4321 ... 0.1928\n",
" tp (lead_time, year, week, X, Y) float64 0.9754 0.6478 ... 0.3623\n",
" msl (lead_time, year, week, X, Y) float64 0.9741 0.05569 ... 0.2536</pre>"
],
"text/plain": [
"<xarray.Dataset>\n",
"Dimensions: (X: 5, Y: 5, lead_time: 2, week: 53, year: 2)\n",
"Coordinates:\n",
" * lead_time (lead_time) int64 1 2\n",
" * year (year) int64 2018 2019\n",
" * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n",
" * X (X) int64 0 1 2 3 4\n",
" * Y (Y) int64 0 1 2 3 4\n",
"Data variables:\n",
" t2m (lead_time, year, week, X, Y) float64 0.8516 0.4321 ... 0.1928\n",
" tp (lead_time, year, week, X, Y) float64 0.9754 0.6478 ... 0.3623\n",
" msl (lead_time, year, week, X, Y) float64 0.9741 0.05569 ... 0.2536"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_test = X_train.isel(year=[-2,-1])\n",
"X_train = X_train.isel(year=slice(None,-2))\n",
"X_test"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre>&lt;xarray.Dataset&gt;\n",
"Dimensions: (X: 5, Y: 5, lead_time: 2, week: 53, year: 18)\n",
"Coordinates:\n",
" * lead_time (lead_time) int64 1 2\n",
" * year (year) int64 2000 2001 2002 2003 2004 ... 2014 2015 2016 2017\n",
" * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n",
" * X (X) int64 0 1 2 3 4\n",
" * Y (Y) int64 0 1 2 3 4\n",
"Data variables:\n",
" t2m (lead_time, year, week, X, Y) float64 2.0 1.0 0.0 ... 0.0 1.0 2.0\n",
" tp (lead_time, year, week, X, Y) float64 0.0 2.0 0.0 ... 2.0 0.0 2.0\n",
" msl (lead_time, year, week, X, Y) float64 1.0 2.0 1.0 ... 2.0 1.0 1.0</pre>"
],
"text/plain": [
"<xarray.Dataset>\n",
"Dimensions: (X: 5, Y: 5, lead_time: 2, week: 53, year: 18)\n",
"Coordinates:\n",
" * lead_time (lead_time) int64 1 2\n",
" * year (year) int64 2000 2001 2002 2003 2004 ... 2014 2015 2016 2017\n",
" * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n",
" * X (X) int64 0 1 2 3 4\n",
" * Y (Y) int64 0 1 2 3 4\n",
"Data variables:\n",
" t2m (lead_time, year, week, X, Y) float64 2.0 1.0 0.0 ... 0.0 1.0 2.0\n",
" tp (lead_time, year, week, X, Y) float64 0.0 2.0 0.0 ... 2.0 0.0 2.0\n",
" msl (lead_time, year, week, X, Y) float64 1.0 2.0 1.0 ... 2.0 1.0 1.0"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# categorized observations\n",
"y_train = xr.concat([\n",
" 0*xr.ones_like(X_train).where(X_train < 1/3, other=0),\n",
" 1*xr.ones_like(X_train).where((X_train > 1/3) & (X_train < 2/3), other=0),\n",
" 2*xr.ones_like(X_train).where(X_train > 2/3, other=0)\n",
"],'category').sum('category')\n",
"y_train"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## config"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"sample_dims = ['year','week'] # dimensions used as samples\n",
"features = ['t2m','tp','msl'] # variables used as features\n",
"target_var = 't2m' # var to predict\n",
"\n",
"# sklearn method\n",
"clf = LogisticRegression(penalty='l2',\n",
" solver='liblinear',\n",
" random_state=0,\n",
" multi_class='auto')\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## train"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def atomic_function_training_LR(X_train, y_train, clf):\n",
" feature_size=X_train.shape[-1]\n",
" sample_size=np.prod(X_train.shape[:-1])\n",
" # ensure samples are first dimensions\n",
" X_train = X_train.reshape(sample_size, feature_size) # sample sizes, feature sizes\n",
" y_train = y_train.reshape(sample_size)\n",
" try:\n",
" clf = clf.fit(X_train, y_train)\n",
" return clf\n",
" except:\n",
" return None"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 93 ms, sys: 2.35 ms, total: 95.4 ms\n",
"Wall time: 103 ms\n"
]
}
],
"source": [
"%%time\n",
"all_classifiers = xr.apply_ufunc(\n",
" atomic_function_training_LR,\n",
" X_train[features].to_array().transpose(...,'variable'), # transpose variable last\n",
" y_train[target_var],\n",
" clf,\n",
" input_core_dims=[sample_dims+['variable'], sample_dims, []], # add variable if needed\n",
" vectorize=True,\n",
" dask='parallelized',\n",
" output_dtypes=[object])\n",
"all_classifiers = all_classifiers.compute()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## predict"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1.54 s, sys: 29.9 ms, total: 1.57 s\n",
"Wall time: 1.68 s\n"
]
},
{
"data": {
"text/html": [
"<pre>&lt;xarray.DataArray (lead_time: 2, X: 5, Y: 5, week: 53, category: 3)&gt;\n",
"array([[[[[2.84430615e-03, 2.48019553e-01, 7.49136141e-01],\n",
" [1.59098678e-02, 4.24464371e-01, 5.59625761e-01],\n",
" [6.43132743e-01, 3.47622040e-01, 9.24521704e-03],\n",
" ...,\n",
" [2.88347752e-02, 2.97344648e-01, 6.73820577e-01],\n",
" [6.99026288e-01, 2.92805754e-01, 8.16795839e-03],\n",
" [6.84561350e-01, 3.09447865e-01, 5.99078411e-03]],\n",
"\n",
" [[3.14927916e-01, 5.73638100e-01, 1.11433984e-01],\n",
" [1.82388659e-03, 2.94081405e-01, 7.04094708e-01],\n",
" [3.03359651e-01, 5.03803615e-01, 1.92836733e-01],\n",
" ...,\n",
" [1.25277388e-03, 2.33315838e-01, 7.65431388e-01],\n",
" [9.91532734e-02, 5.16850748e-01, 3.83995978e-01],\n",
" [7.18120246e-01, 2.78537920e-01, 3.34183389e-03]],\n",
"\n",
" [[1.22389131e-02, 2.89980880e-01, 6.97780207e-01],\n",
" [7.48693273e-01, 2.49804366e-01, 1.50236071e-03],\n",
" [7.40388030e-01, 2.56002389e-01, 3.60958146e-03],\n",
" ...,\n",
"...\n",
" ...,\n",
" [1.59798246e-03, 2.40037927e-01, 7.58364091e-01],\n",
" [5.78522467e-03, 2.42150429e-01, 7.52064346e-01],\n",
" [1.93062439e-01, 4.20348759e-01, 3.86588802e-01]],\n",
"\n",
" [[4.88716204e-03, 2.38253500e-01, 7.56859338e-01],\n",
" [3.58386812e-01, 5.16826179e-01, 1.24787008e-01],\n",
" [6.64768310e-01, 3.30687396e-01, 4.54429314e-03],\n",
" ...,\n",
" [5.13090849e-02, 4.65301898e-01, 4.83389018e-01],\n",
" [6.93597202e-01, 3.05439592e-01, 9.63206169e-04],\n",
" [6.90890885e-01, 3.07015807e-01, 2.09330822e-03]],\n",
"\n",
" [[8.97167636e-03, 3.26590274e-01, 6.64438049e-01],\n",
" [7.34400789e-01, 2.62958955e-01, 2.64025675e-03],\n",
" [1.81597565e-03, 2.58011987e-01, 7.40172038e-01],\n",
" ...,\n",
" [4.70693953e-01, 4.48879066e-01, 8.04269809e-02],\n",
" [3.60043640e-03, 2.39024250e-01, 7.57375314e-01],\n",
" [6.55709623e-01, 3.37543667e-01, 6.74671012e-03]]]]])\n",
"Coordinates:\n",
" * lead_time (lead_time) int64 1 2\n",
" * X (X) int64 0 1 2 3 4\n",
" * Y (Y) int64 0 1 2 3 4\n",
" * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n",
" * category (category) float64 0.0 1.0 2.0</pre>"
],
"text/plain": [
"<xarray.DataArray (lead_time: 2, X: 5, Y: 5, week: 53, category: 3)>\n",
"array([[[[[2.84430615e-03, 2.48019553e-01, 7.49136141e-01],\n",
" [1.59098678e-02, 4.24464371e-01, 5.59625761e-01],\n",
" [6.43132743e-01, 3.47622040e-01, 9.24521704e-03],\n",
" ...,\n",
" [2.88347752e-02, 2.97344648e-01, 6.73820577e-01],\n",
" [6.99026288e-01, 2.92805754e-01, 8.16795839e-03],\n",
" [6.84561350e-01, 3.09447865e-01, 5.99078411e-03]],\n",
"\n",
" [[3.14927916e-01, 5.73638100e-01, 1.11433984e-01],\n",
" [1.82388659e-03, 2.94081405e-01, 7.04094708e-01],\n",
" [3.03359651e-01, 5.03803615e-01, 1.92836733e-01],\n",
" ...,\n",
" [1.25277388e-03, 2.33315838e-01, 7.65431388e-01],\n",
" [9.91532734e-02, 5.16850748e-01, 3.83995978e-01],\n",
" [7.18120246e-01, 2.78537920e-01, 3.34183389e-03]],\n",
"\n",
" [[1.22389131e-02, 2.89980880e-01, 6.97780207e-01],\n",
" [7.48693273e-01, 2.49804366e-01, 1.50236071e-03],\n",
" [7.40388030e-01, 2.56002389e-01, 3.60958146e-03],\n",
" ...,\n",
"...\n",
" ...,\n",
" [1.59798246e-03, 2.40037927e-01, 7.58364091e-01],\n",
" [5.78522467e-03, 2.42150429e-01, 7.52064346e-01],\n",
" [1.93062439e-01, 4.20348759e-01, 3.86588802e-01]],\n",
"\n",
" [[4.88716204e-03, 2.38253500e-01, 7.56859338e-01],\n",
" [3.58386812e-01, 5.16826179e-01, 1.24787008e-01],\n",
" [6.64768310e-01, 3.30687396e-01, 4.54429314e-03],\n",
" ...,\n",
" [5.13090849e-02, 4.65301898e-01, 4.83389018e-01],\n",
" [6.93597202e-01, 3.05439592e-01, 9.63206169e-04],\n",
" [6.90890885e-01, 3.07015807e-01, 2.09330822e-03]],\n",
"\n",
" [[8.97167636e-03, 3.26590274e-01, 6.64438049e-01],\n",
" [7.34400789e-01, 2.62958955e-01, 2.64025675e-03],\n",
" [1.81597565e-03, 2.58011987e-01, 7.40172038e-01],\n",
" ...,\n",
" [4.70693953e-01, 4.48879066e-01, 8.04269809e-02],\n",
" [3.60043640e-03, 2.39024250e-01, 7.57375314e-01],\n",
" [6.55709623e-01, 3.37543667e-01, 6.74671012e-03]]]]])\n",
"Coordinates:\n",
" * lead_time (lead_time) int64 1 2\n",
" * X (X) int64 0 1 2 3 4\n",
" * Y (Y) int64 0 1 2 3 4\n",
" * week (week) int64 0 1 2 3 4 5 6 7 8 9 ... 44 45 46 47 48 49 50 51 52\n",
" * category (category) float64 0.0 1.0 2.0"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"def atomic_function_prediction_lr(classifiers, X_test):\n",
" try:\n",
" sample_size = np.prod(X_test.shape[:-1])\n",
" feature_size = X_test.shape[-1]\n",
" if len(X_test.shape)!=2:\n",
" print('reshape')\n",
" X_test = X_test.reshape(sample_size,features_size)\n",
" prediction = classifiers.predict_proba(X_test)[0]\n",
" prediction = xr.DataArray(prediction,dims='category')\n",
" prediction = prediction.assign_coords(category=classifiers.classes_) # doesnt stick\n",
" return prediction\n",
" except Exception as e: # set climatology instead\n",
" print(type(e).__name__,e)\n",
" n_classes = len(classifiers.classes_)\n",
" return xr.DataArray(np.repeat([1/n_classes,n_classes]),dims='category') # adapt repeat\n",
"\n",
"predictions = xr.apply_ufunc(atomic_function_prediction_lr,\n",
" all_classifiers,\n",
" X_test[features].to_array().transpose(...,'variable'),\n",
" input_core_dims=[[], [\"year\",'variable']], # adapt year\n",
" vectorize=True,\n",
" dask='parallelized',\n",
" output_core_dims=[['category']] # new dim for predict_proba\n",
" ).compute()\n",
"\n",
"# manually add new coords\n",
"predictions = predictions.assign_coords(category=all_classifiers.isel({i:0 for i in all_classifiers.dims}).item().classes_)\n",
"predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "xr",
"language": "python",
"name": "xr"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment