Skip to content

Instantly share code, notes, and snippets.

@JiaweiZhuang
Created July 12, 2018 06:03
Show Gist options
  • Save JiaweiZhuang/b097f17bf85f9a0d34763d5cc9ce9b07 to your computer and use it in GitHub Desktop.
Save JiaweiZhuang/b097f17bf85f9a0d34763d5cc9ce9b07 to your computer and use it in GitHub Desktop.
Compare the performance of xarray's built-in interp() with xESMF
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import xarray as xr\n",
"import xesmf as xe"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# xESMF timing"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Mostly copied from http://xesmf.readthedocs.io/en/latest/Reuse_regridder.html"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.Dataset>\n",
"Dimensions: (x: 600, x_b: 601, y: 400, y_b: 401)\n",
"Coordinates:\n",
" lon (y, x) float64 -119.8 -119.4 -119.0 -118.6 -118.2 -117.8 -117.4 ...\n",
" lat (y, x) float64 -59.85 -59.85 -59.85 -59.85 -59.85 -59.85 -59.85 ...\n",
" lon_b (y_b, x_b) float64 -120.0 -119.6 -119.2 -118.8 -118.4 -118.0 ...\n",
" lat_b (y_b, x_b) float64 -60.0 -60.0 -60.0 -60.0 -60.0 -60.0 -60.0 ...\n",
"Dimensions without coordinates: x, x_b, y, y_b\n",
"Data variables:\n",
" *empty*"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds_in = xe.util.grid_2d(-120, 120, 0.4, # longitude range and resolution\n",
" -60, 60, 0.3) # latitude range and resolution\n",
"ds_in"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.Dataset>\n",
"Dimensions: (x: 400, x_b: 401, y: 300, y_b: 301)\n",
"Coordinates:\n",
" lon (y, x) float64 -119.7 -119.1 -118.5 -117.9 -117.3 -116.7 -116.1 ...\n",
" lat (y, x) float64 -59.8 -59.8 -59.8 -59.8 -59.8 -59.8 -59.8 -59.8 ...\n",
" lon_b (y_b, x_b) float64 -120.0 -119.4 -118.8 -118.2 -117.6 -117.0 ...\n",
" lat_b (y_b, x_b) float64 -60.0 -60.0 -60.0 -60.0 -60.0 -60.0 -60.0 ...\n",
"Dimensions without coordinates: x, x_b, y, y_b\n",
"Data variables:\n",
" *empty*"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds_out = xe.util.grid_2d(-120, 120, 0.6,\n",
" -60, 60, 0.4)\n",
"ds_out"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.Dataset>\n",
"Dimensions: (lev: 50, time: 10, x: 600, x_b: 601, y: 400, y_b: 401)\n",
"Coordinates:\n",
" lon (y, x) float64 -119.8 -119.4 -119.0 -118.6 -118.2 -117.8 -117.4 ...\n",
" lat (y, x) float64 -59.85 -59.85 -59.85 -59.85 -59.85 -59.85 -59.85 ...\n",
" lon_b (y_b, x_b) float64 -120.0 -119.6 -119.2 -118.8 -118.4 -118.0 ...\n",
" lat_b (y_b, x_b) float64 -60.0 -60.0 -60.0 -60.0 -60.0 -60.0 -60.0 ...\n",
" * time (time) int64 1 2 3 4 5 6 7 8 9 10\n",
" * lev (lev) int64 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 ...\n",
"Dimensions without coordinates: x, x_b, y, y_b\n",
"Data variables:\n",
" data2D (y, x) float64 1.872 1.869 1.866 1.863 1.86 1.857 1.855 1.852 ...\n",
" data4D (time, lev, y, x) float64 1.872 1.869 1.866 1.863 1.86 1.857 ..."
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds_in.coords['time'] = np.arange(1, 11)\n",
"ds_in.coords['lev'] = np.arange(1, 51)\n",
"ds_in['data2D'] = xe.data.wave_smooth(ds_in['lon'], ds_in['lat'])\n",
"ds_in['data4D'] = ds_in['time'] * ds_in['lev'] * ds_in['data2D']\n",
"ds_in"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Overwrite existing file: bilinear_400x600_300x400.nc \n",
" You can set reuse_weights=True to save computing time.\n",
"CPU times: user 5.7 s, sys: 409 ms, total: 6.11 s\n",
"Wall time: 6.23 s\n"
]
}
],
"source": [
"%%time\n",
"# slow but only needs to be done once\n",
"regridder = xe.Regridder(ds_in, ds_out, 'bilinear')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 417 ms, sys: 152 ms, total: 569 ms\n",
"Wall time: 585 ms\n"
]
}
],
"source": [
"%%time\n",
"# fast\n",
"dr_out = regridder(ds_in['data4D'])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray 'data4D' (time: 10, lev: 50, y: 300, x: 400)>\n",
"array([[[[ 1.871198, ..., 1.871198],\n",
" ...,\n",
" [ 1.871198, ..., 1.871198]],\n",
"\n",
" ...,\n",
"\n",
" [[ 93.559913, ..., 93.559913],\n",
" ...,\n",
" [ 93.559913, ..., 93.559913]]],\n",
"\n",
"\n",
" ...,\n",
"\n",
"\n",
" [[[ 18.711983, ..., 18.711983],\n",
" ...,\n",
" [ 18.711983, ..., 18.711983]],\n",
"\n",
" ...,\n",
"\n",
" [[935.599126, ..., 935.599126],\n",
" ...,\n",
" [935.599126, ..., 935.599126]]]])\n",
"Coordinates:\n",
" lon (y, x) float64 -119.7 -119.1 -118.5 -117.9 -117.3 -116.7 -116.1 ...\n",
" lat (y, x) float64 -59.8 -59.8 -59.8 -59.8 -59.8 -59.8 -59.8 -59.8 ...\n",
" * time (time) int64 1 2 3 4 5 6 7 8 9 10\n",
" * lev (lev) int64 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 ...\n",
"Dimensions without coordinates: y, x\n",
"Attributes:\n",
" regrid_method: bilinear"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dr_out # shape is correct"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# interp() timing"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# interp doesn't like 2D coordinates. Reduce to 1D. (https://github.com/pydata/xarray/issues/2281)\n",
"ds_in['lon'] = ds_in['lon'].isel(y=0)\n",
"ds_in['lat'] = ds_in['lat'].isel(x=0)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray 'data4D' (time: 10, lev: 50, lat: 400, lon: 600)>\n",
"array([[[[ 1.872343, ..., 1.872343],\n",
" ...,\n",
" [ 1.872343, ..., 1.872343]],\n",
"\n",
" ...,\n",
"\n",
" [[ 93.617126, ..., 93.617126],\n",
" ...,\n",
" [ 93.617126, ..., 93.617126]]],\n",
"\n",
"\n",
" ...,\n",
"\n",
"\n",
" [[[ 18.723425, ..., 18.723425],\n",
" ...,\n",
" [ 18.723425, ..., 18.723425]],\n",
"\n",
" ...,\n",
"\n",
" [[936.171263, ..., 936.171263],\n",
" ...,\n",
" [936.171263, ..., 936.171263]]]])\n",
"Coordinates:\n",
" * lon (lon) float64 -119.8 -119.4 -119.0 -118.6 -118.2 -117.8 -117.4 ...\n",
" * lat (lat) float64 -59.85 -59.55 -59.25 -58.95 -58.65 -58.35 -58.05 ...\n",
" * time (time) int64 1 2 3 4 5 6 7 8 9 10\n",
" * lev (lev) int64 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 ..."
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# make the dimension naming easier for interp() to understand\n",
"dr = ds_in['data4D']\n",
"dr = dr.rename({'x': 'lon', 'y': 'lat'})\n",
"dr"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# output grid as pure numpy arrays\n",
"lon_out = ds_out['lon'].isel(y=0).values\n",
"lat_out = ds_out['lat'].isel(x=0).values"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 8.76 s, sys: 2.55 s, total: 11.3 s\n",
"Wall time: 9.96 s\n"
]
}
],
"source": [
"%%time \n",
"# 16x slower than xESMF...\n",
"dr_out_interp = dr.interp(lon=lon_out, lat=lat_out)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.DataArray 'data4D' (time: 10, lev: 50, lat: 300, lon: 400)>\n",
"array([[[[ 1.871199, ..., 1.871199],\n",
" ...,\n",
" [ 1.871199, ..., 1.871199]],\n",
"\n",
" ...,\n",
"\n",
" [[ 93.559957, ..., 93.559957],\n",
" ...,\n",
" [ 93.559957, ..., 93.559957]]],\n",
"\n",
"\n",
" ...,\n",
"\n",
"\n",
" [[[ 18.711991, ..., 18.711991],\n",
" ...,\n",
" [ 18.711991, ..., 18.711991]],\n",
"\n",
" ...,\n",
"\n",
" [[935.599568, ..., 935.599568],\n",
" ...,\n",
" [935.599568, ..., 935.599568]]]])\n",
"Coordinates:\n",
" * time (time) int64 1 2 3 4 5 6 7 8 9 10\n",
" * lev (lev) int64 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 ...\n",
" * lon (lon) float64 -119.7 -119.1 -118.5 -117.9 -117.3 -116.7 -116.1 ...\n",
" * lat (lat) float64 -59.8 -59.4 -59.0 -58.6 -58.2 -57.8 -57.4 -57.0 ..."
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dr_out_interp # shape is correct"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Sanity check"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.collections.QuadMesh at 0x355211f28>"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x144 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, axes = plt.subplots(1, 2, figsize=[8, 2])\n",
"dr_out[0,0].plot(ax=axes[0])\n",
"dr_out_interp[0,0].plot(ax=axes[1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment