Last active
March 23, 2023 18:54
-
-
Save thomasjpfan/11950c8a961f5bfec452fc2e0fbc3ed8 to your computer and use it in GitHub Desktop.
Scikit-learn running on PyTorch using ArrayAPI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "91d24835-d13b-4ac1-b4f2-47925ddb4a2a", | |
"metadata": { | |
"tags": [] | |
}, | |
"source": [ | |
"# Array API example for LinearDiscriminantAnalysis\n", | |
"\n", | |
"### This benchmark is ran on a Nvidia GTX 3090 and a AMD 5950x." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "a849ca35-90d9-4783-810d-18250bc021ee", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.datasets import make_classification\n", | |
"from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", | |
"import numpy as np\n", | |
"\n", | |
"X_np, y_np = make_classification(random_state=0,\n", | |
" n_samples=500_000, n_features=300)\n", | |
"X_np, y_np = X_np.astype(np.float32), y_np.astype(np.float32)\n", | |
"lda_np = LinearDiscriminantAnalysis()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8eefb285-86e8-4472-ac98-bd292e7e0546", | |
"metadata": {}, | |
"source": [ | |
"## Fit runtime for NumPy array" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "a1831332-4c8e-4dcf-abdc-f202e747dad1", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1min 58s, sys: 38.2 s, total: 2min 36s\n", | |
"Wall time: 15.7 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"_ = lda_np.fit(X_np, y_np)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "fbbcd4c2-0622-452c-b1e8-d6e210b70821", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"103 ms ± 1.61 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"_ = lda_np.predict(X_np)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "d1049837-e4b0-4494-84f7-b5e13f08d979", | |
"metadata": {}, | |
"source": [ | |
"## Runtime for PyTorch Tensor on CPU" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "533eb684-e901-4f02-a69a-cde47f9194db", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import sklearn\n", | |
"\n", | |
"# Enable ArrayAPI dispatching\n", | |
"sklearn.set_config(array_api_dispatch=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "4abbfdc0-1462-46a7-b010-67d3541dba14", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"\n", | |
"X_torch_cpu = torch.asarray(X_np)\n", | |
"y_torch_cpu = torch.asarray(y_np)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "e26bb8a1-1351-44c2-bb44-134b503a6bdd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"lda_torch_cpu = LinearDiscriminantAnalysis()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "2874e1f6-b3d8-43df-87eb-c86ea00bc00f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2.09 s ± 29.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"lda_torch_cpu.fit(X_torch_cpu, y_torch_cpu)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "a84c01b8-31d1-41ee-a569-aafa17b978dc", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"44.6 ms ± 215 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"_ = lda_torch_cpu.predict(X_torch_cpu)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "63084e19-21cd-4069-8f71-1ecabd29081b", | |
"metadata": {}, | |
"source": [ | |
"### On CPU, PyTorch is more 7x faster during training and 2x faster during prediciton." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9f8e7614-3b48-4955-8c02-4883c91fd327", | |
"metadata": { | |
"tags": [] | |
}, | |
"source": [ | |
"## Runtime for PyTorch Tensor on CUDA" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "56eae76f-1837-48f4-a8f7-ba29c96b7d34", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"\n", | |
"X_torch_cuda = torch.asarray(X_np, device=\"cuda\")\n", | |
"y_torch_cuda = torch.asarray(y_np, device=\"cuda\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "ea84e268-4ca3-46fe-b5b8-cb5b90799bbe", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"lda_torch_cuda = LinearDiscriminantAnalysis()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "f828fcfc-ffc2-4ded-960f-d8151b624def", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"145 ms ± 1.34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"_ = lda_torch_cuda.fit(X_torch_cuda, y_torch_cuda)\n", | |
"torch.cuda.synchronize(device=\"cuda\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "ba0ba5c8-4c16-4627-8e01-f31141e43d2a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1.57 ms ± 817 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"_ = lda_torch_cuda.predict(X_torch_cuda)\n", | |
"torch.cuda.synchronize(device=\"cuda\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "612bd70d-3a90-41d0-b423-2bbb87b17628", | |
"metadata": {}, | |
"source": [ | |
"### On GPU, PyTorch is 100x faster during training and 60x faster during prediciton" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "3102ddf8-608e-4351-8db5-1cd94881631f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "c0980ae9-9c39-4c81-99eb-b7424464cc46", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "ba1b9b46-c558-4710-af1b-adbaf1aacba0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "24b73a36-a142-48c0-967c-f2e40f14df83", | |
"metadata": {}, | |
"source": [ | |
"## Check coefs are the same" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "92674019-00ae-4cd7-96ab-1ada7f8dd0c0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"\n", | |
"np_coef_ = np.asarray(lda_np.coef_)\n", | |
"\n", | |
"# Is there a bettery way to convert cupy.array_api to a np.ndarray?\n", | |
"cu_coef_ = lda_torch_cuda.coef_.cpu().numpy()\n", | |
"\n", | |
"np.testing.assert_allclose(np_coef_, cu_coef_, atol=1e-4)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "sk1-pytorch (python3)", | |
"language": "python", | |
"name": "conda-env-sk1-pytorch-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment