Skip to content

Instantly share code, notes, and snippets.

@ogrisel
Last active June 30, 2023 10:38
Show Gist options
  • Save ogrisel/6a4304e1831051203a98118875ead2d4 to your computer and use it in GitHub Desktop.
Save ogrisel/6a4304e1831051203a98118875ead2d4 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "2e839fa7-a2ec-4b10-af93-8d698977b110",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"backend = \"pytorch\"\n",
"noop_compile = lambda f: f\n",
"\n",
"if backend == \"jax\":\n",
" import jax\n",
" jax.config.update(\"jax_enable_x64\", True)\n",
" xp = jax.numpy\n",
" func_compile = jax.jit\n",
"elif backend == \"dask\":\n",
" import dask.array as xp\n",
" func_compile = noop_compile\n",
"elif backend == \"pytorch\":\n",
" import torch\n",
" import array_api_compat\n",
" xp = array_api_compat.get_namespace(torch.zeros(1))\n",
" func_compile = torch.compile\n",
"\n",
"\n",
"def compute_one_step(step_idx, data, params, rate=0.9):\n",
" print(f\"Computing step {step_idx}\")\n",
" return rate * xp.mean(data, axis=0) + (1 - rate) * params\n",
"\n",
"\n",
"def metric_a(data, params):\n",
" return xp.sum(xp.abs(data.mean(axis=0) - params))\n",
"\n",
"\n",
"def metric_b(data, params):\n",
" return xp.linalg.norm(params)\n",
"\n",
"\n",
"stopping_criterion = metric_a"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "268706dd-f08e-4db0-ba07-c0ac3fbca631",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.1874, 0.0761, -0.9102, ..., -0.3720, 0.4877, -1.2558],\n",
" [ 0.5432, 0.6453, 0.7448, ..., 0.2203, 0.2694, -0.2880],\n",
" [ 1.1852, -0.3813, -0.4278, ..., -1.5407, 0.7995, 0.9147],\n",
" ...,\n",
" [-1.4059, -0.5850, 0.9484, ..., 1.8234, 0.5075, -0.5918],\n",
" [-0.0380, 0.5716, -0.3331, ..., 2.5488, 0.3956, 1.8654],\n",
" [ 1.3412, -0.9067, 0.1906, ..., 0.5824, -0.6630, 0.8663]],\n",
" dtype=torch.float64)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = xp.asarray(np.random.normal(size=(1000, 10)))\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c632c28c-de99-4e8b-86c8-34aee61a2b2f",
"metadata": {},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"import pandas as pd\n",
"\n",
"\n",
"@func_compile\n",
"def iterative_solver(data, params, tol=1e-4, maxiter=1_000):\n",
" record = defaultdict(list)\n",
" for iter_idx in range(maxiter):\n",
" params = compute_one_step(iter_idx, data, params)\n",
" record[\"iter\"].append(iter_idx)\n",
" record[\"a\"].append(float(metric_a(data, params)))\n",
" record[\"b\"].append(float(metric_b(data, params)))\n",
"\n",
" if stopping_criterion(data, params) < tol: # calls bool() implicitly\n",
" break\n",
" \n",
" return params, pd.DataFrame(record)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2fe48eeb-b70f-47a9-bc7d-13787b348725",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Computing step 0\n",
"Computing step 1\n",
"Computing step 1\n",
"Computing step 1\n"
]
},
{
"data": {
"text/plain": [
"tensor([-0.0469, 0.0287, 0.0177, 0.0226, 0.0090, -0.0191, 0.0326, -0.0173,\n",
" -0.0005, 0.0344], dtype=torch.float64)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"init_params = xp.zeros(shape=data.shape[1])\n",
"params, record = iterative_solver(data, init_params)\n",
"params"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "0ddde34f-5be2-4188-9ee9-208a210f2562",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"record.plot(x=\"iter\", y=[\"a\", \"b\"]);"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment