Created
January 24, 2022 22:51
-
-
Save thomasjpfan/45f8bd908f56f6ab5107c0bdde04b7e7 to your computer and use it in GitHub Desktop.
array_api with LDA
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "91d24835-d13b-4ac1-b4f2-47925ddb4a2a", | |
"metadata": {}, | |
"source": [ | |
"# Array API example for LinearDiscriminantAnalysis" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "a849ca35-90d9-4783-810d-18250bc021ee", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.datasets import make_classification\n", | |
"from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", | |
"\n", | |
"X_np, y_np = make_classification(random_state=0, n_samples=500_000, n_features=300)\n", | |
"lda_np = LinearDiscriminantAnalysis()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8eefb285-86e8-4472-ac98-bd292e7e0546", | |
"metadata": {}, | |
"source": [ | |
"## Fit runtime for NumPy array" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "a1831332-4c8e-4dcf-abdc-f202e747dad1", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1min 59s, sys: 33.9 s, total: 2min 33s\n", | |
"Wall time: 14.7 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"LinearDiscriminantAnalysis()" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"lda_np.fit(X_np, y_np)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9f8e7614-3b48-4955-8c02-4883c91fd327", | |
"metadata": {}, | |
"source": [ | |
"## Runtime for CuPy array with array_api" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "ea84e268-4ca3-46fe-b5b8-cb5b90799bbe", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/tmp/ipykernel_127754/3365037576.py:2: UserWarning: The numpy.array_api submodule is still experimental. See NEP 47.\n", | |
" import cupy.array_api as xp\n" | |
] | |
} | |
], | |
"source": [ | |
"from sklearn import set_config\n", | |
"import cupy.array_api as xp\n", | |
"\n", | |
"set_config(array_api_dispatch=True)\n", | |
"\n", | |
"X_cu = xp.asarray(X_np, copy=True)\n", | |
"y_cu = xp.asarray(y_np, copy=True)\n", | |
"lda_cu = LinearDiscriminantAnalysis()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "f828fcfc-ffc2-4ded-960f-d8151b624def", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 720 ms, sys: 364 ms, total: 1.08 s\n", | |
"Wall time: 1.08 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"LinearDiscriminantAnalysis()" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"lda_cu.fit(X_cu, y_cu)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "24b73a36-a142-48c0-967c-f2e40f14df83", | |
"metadata": {}, | |
"source": [ | |
"## Check coefs are the same" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "92674019-00ae-4cd7-96ab-1ada7f8dd0c0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"\n", | |
"np_coef_ = np.asarray(lda_np.coef_)\n", | |
"\n", | |
"# Is there a bettery way to convert cupy.array_api to a np.ndarray?\n", | |
"cu_coef_ = lda_cu.coef_._array.get()\n", | |
"\n", | |
"np.testing.assert_allclose(np_coef_, cu_coef_)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "sk1-array_api (python3)", | |
"language": "python", | |
"name": "conda-env-sk1-array_api-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment