Last active
June 15, 2024 08:33
-
-
Save EdAbati/ff3bdc06bafeb92452b3740686cc8d7c to your computer and use it in GitHub Desktop.
sklearn-cuda-tests.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Test the `scikit-learn` compatibility with the [Array API standard](https://data-apis.org/array-api/latest/#) on CUDA GPUs" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"colab_type": "text", | |
"id": "view-in-github" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/EdAbati/ff3bdc06bafeb92452b3740686cc8d7c/sklearn-cuda-tests.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Run this notebook on Google Colab (or on a similar service) to test the Scikit-learn’s Array API support with CUDA GPUs.\n", | |
"\n", | |
"<div class=\"alert alert-block alert-warning\">\n", | |
"⚠️ Make sure to enable the GPU runtime before running the notebook.\n", | |
"</div>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "V_noc08vGyBs" | |
}, | |
"outputs": [], | |
"source": [ | |
"# If a GPU is enabled, this will print the GPU details\n", | |
"! nvidia-smi" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"1. Set required environment variables" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"\n", | |
"# Enable Array API support for SciPy: https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html#using-array-api-standard-support\n", | |
"os.environ[\"SCIPY_ARRAY_API\"] = \"1\"" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"2. Clone the branch you want to test" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "lwsP29ZoUMXG" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Add details of the branch you want to test\n", | |
"GITHUB_USERNAME, BRANCH_NAME = \"<username>:<branch_name>\".split(\":\", 1)\n", | |
"# or\n", | |
"# GITHUB_USERNAME = \"\"\n", | |
"# BRANCH_NAME = \"\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "IH6CzYKx8rur" | |
}, | |
"outputs": [], | |
"source": [ | |
"!git clone --single-branch -b $BRANCH_NAME https://github.com/$GITHUB_USERNAME/scikit-learn.git" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"3. Build `sklearn` from source, more details [here](https://scikit-learn.org/stable/developers/advanced_installation.html#building-from-source)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "0j3SKvFM9w9D" | |
}, | |
"outputs": [], | |
"source": [ | |
"! cd scikit-learn && pip install wheel numpy scipy cython meson-python ninja --quiet" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "LQ2wiY-iBY4L" | |
}, | |
"outputs": [], | |
"source": [ | |
"! cd scikit-learn && pip install --editable . --verbose --no-build-isolation --config-settings editable-verbose=true" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "UBPrWpqjBfH1" | |
}, | |
"outputs": [], | |
"source": [ | |
"! python -c \"import sklearn; sklearn.show_versions()\"" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"4. Install the extra dependencies required to enable the Array API support, more details [here](https://scikit-learn.org/stable/modules/array_api.html)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "fMc37beNIofE" | |
}, | |
"outputs": [], | |
"source": [ | |
"# (Optional) Check which version of CuPy is installed\n", | |
"! pip freeze | grep cupy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"background_save": true | |
}, | |
"id": "wpg-C_rSG2YR" | |
}, | |
"outputs": [], | |
"source": [ | |
"# If CuPy is not installed, follow the instructions here: https://docs.cupy.dev/en/stable/install.html\n", | |
"# For example, to install CuPy for CUDA 12.x, run the following command:\n", | |
"# ! pip install -U cupy-cuda12x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "6r2H6OKKI0m9" | |
}, | |
"outputs": [], | |
"source": [ | |
"! pip install array-api-compat" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "buf4_RZ79wJW" | |
}, | |
"source": [ | |
"5. Run all Array API tests" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "fERgtVhiHKVg" | |
}, | |
"outputs": [], | |
"source": [ | |
"! cd scikit-learn && pytest -k array_api -vl" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "J4uHntAy7V74" | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"accelerator": "GPU", | |
"colab": { | |
"authorship_tag": "ABX9TyM6uL9WE+XxL75a0i2T1yEM", | |
"gpuType": "T4", | |
"include_colab_link": true, | |
"provenance": [] | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Thanks for the suggestion :) just updated it
Could you also please change the last cell to run all array_api tests in sklearn by default, e.g.:
pytest -k array_api -vl
Naive users might not realize that there are not testing the things they are changing in their PR otherwise.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks @EdAbati!
To make it more convenient to paste the branch spec from a PR on github the first cell can be changed to:
and the second cell to: