Last active
March 24, 2022 16:04
-
-
Save thomasjpfan/effd04ef0909f0c59dd8c95fa64bbfa3 to your computer and use it in GitHub Desktop.
Array API example for GMM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "91d24835-d13b-4ac1-b4f2-47925ddb4a2a", | |
"metadata": {}, | |
"source": [ | |
"# Array API example for GMM" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "a849ca35-90d9-4783-810d-18250bc021ee", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.datasets import make_blobs\n", | |
"from sklearn.mixture import GaussianMixture\n", | |
"\n", | |
"X_np, _ = make_blobs(random_state=0, n_samples=200_000, n_features=30)\n", | |
"gm_np = GaussianMixture(n_components=10, random_state=0, init_params=\"random\")\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8eefb285-86e8-4472-ac98-bd292e7e0546", | |
"metadata": {}, | |
"source": [ | |
"## Runtime for NumPy array" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "a1831332-4c8e-4dcf-abdc-f202e747dad1", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 30.6 s, sys: 57 s, total: 1min 27s\n", | |
"Wall time: 3.41 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"GaussianMixture(init_params='random', n_components=10, random_state=0)" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"gm_np.fit(X_np)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9f8e7614-3b48-4955-8c02-4883c91fd327", | |
"metadata": {}, | |
"source": [ | |
"## Runtime for CuPy array with array_api" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "ea84e268-4ca3-46fe-b5b8-cb5b90799bbe", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from sklearn import set_config\n", | |
"import cupy.array_api as xp\n", | |
"\n", | |
"set_config(array_api_dispatch=True)\n", | |
"\n", | |
"X_cu = xp.asarray(X_np, copy=True)\n", | |
"gm_cu = GaussianMixture(n_components=10, random_state=0, init_params=\"random\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "f828fcfc-ffc2-4ded-960f-d8151b624def", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 462 ms, sys: 136 ms, total: 598 ms\n", | |
"Wall time: 597 ms\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"GaussianMixture(init_params='random', n_components=10, random_state=0)" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"gm_cu.fit(X_cu)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "c4ab0e4f-e5d9-472f-964a-4ad1107e9df8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"cupy.array_api._array_object.Array" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"type(gm_cu.means_)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8cc882ff-65aa-44b3-82f6-5bade630fec9", | |
"metadata": {}, | |
"source": [ | |
"### Check that the means are the same" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "114a5700-2421-493c-99d8-3fffb6474356", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"\n", | |
"np_means_ = np.asarray(gm_np.means_)\n", | |
"\n", | |
"# Is there a bettery way to convert cupy.array_api to a np.ndarray?\n", | |
"cu_means_ = gm_cu.means_._array.get()\n", | |
"\n", | |
"np.testing.assert_allclose(np_means_, cu_means_)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "sk1-array_api (python3)", | |
"language": "python", | |
"name": "conda-env-sk1-array_api-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment