Last active
May 31, 2022 20:37
-
-
Save fabianp/6c3250512778e0612cbc49a60e32ff00 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "a9fe6a29", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import copy\n", | |
"import jaxopt\n", | |
"\n", | |
"# for automatic differentiation\n", | |
"from jax import numpy as jnp\n", | |
"from jax import grad, hessian\n", | |
"from jax import jacfwd, jacrev\n", | |
"from jax import random\n", | |
"import jax\n", | |
"import numpy as np\n", | |
"from scipy.special import gamma as gamma_fct\n", | |
"\n", | |
"# enforce double precision\n", | |
"from jax.config import config\n", | |
"config.update(\"jax_enable_x64\", True)\n", | |
"\n", | |
"# seed for random data generation\n", | |
"key = random.PRNGKey(42)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "b256f2f9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"n_samples, n_features = 100, 200\n", | |
"A = random.normal(key, shape=(n_samples, n_features))\n", | |
"A /= A.max()\n", | |
"b = random.normal(key, shape=(n_samples,))\n", | |
"lmbda = 1e-4 * (jnp.linalg.norm(A, 2) ** 2)\n", | |
"\n", | |
"\n", | |
"def loss(x, A, b, lmbda):\n", | |
" A, b, lmbda\n", | |
" z = A @ x - b\n", | |
" return 0.5 * (jnp.vdot(z, z)) + 0.5 * lmbda * jnp.vdot(x, x)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "12bbb25c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Running experiments for dataset synthetic\n" | |
] | |
} | |
], | |
"source": [ | |
"\n", | |
"max_iter = 300\n", | |
"all_max_iter = jnp.arange(1, max_iter, 1)\n", | |
"\n", | |
"dataset = 'synthetic'\n", | |
"print('Running experiments for dataset ', dataset)\n", | |
"\n", | |
"n_samples, n_features = A.shape\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "d4a11769", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"OptStep(params=DeviceArray([-5.56486650e-02, -3.30363170e-01, 1.76317885e-01,\n", | |
" -1.59527912e-01, -7.84478294e-02, 3.02712915e-01,\n", | |
" -5.47730120e-02, 2.25249034e-01, 2.49474351e-01,\n", | |
" -4.39778003e-02, 1.67724994e-01, 1.44724183e-01,\n", | |
" 4.15387431e-04, -9.87611300e-02, 1.85985813e-02,\n", | |
" -4.28141203e-01, -4.57414206e-02, -8.49005408e-02,\n", | |
" -8.92551810e-02, 5.64902954e-02, 9.73533821e-02,\n", | |
" -1.11933882e-01, 9.50161621e-02, 6.74602899e-02,\n", | |
" 1.24753848e-01, -6.82230733e-02, -1.63003820e-01,\n", | |
" -5.26133230e-02, 1.17648054e-01, -6.21795547e-02,\n", | |
" -8.82708372e-02, 1.59027257e-01, 1.85936780e-02,\n", | |
" -5.49654228e-02, -2.28894844e-01, 2.92985412e-01,\n", | |
" 1.83952343e-01, 1.58209258e-01, -1.10766676e-01,\n", | |
" -2.16695040e-01, 1.29207705e-01, -1.53319470e-01,\n", | |
" 1.00466661e-01, 1.26225154e-01, -3.38084431e-02,\n", | |
" 1.72545818e-02, 7.79419307e-02, 7.28775001e-02,\n", | |
" 1.70079966e-02, 2.19853321e-01, -5.42848920e-01,\n", | |
" -2.07430302e-01, -1.44251755e-01, 2.01036430e-01,\n", | |
" 1.37890430e-01, -5.35801019e-02, 1.35305391e-01,\n", | |
" 1.31798421e-01, 2.48309332e-01, 2.05080311e-01,\n", | |
" -3.45272564e-02, -2.86509744e-02, -3.72516348e-02,\n", | |
" -2.17809410e-01, -2.27468689e-01, -4.38619558e-04,\n", | |
" 1.81890883e-01, -5.51496365e-02, -1.43599552e-01,\n", | |
" 1.18041672e-01, 8.11086175e-02, 1.86604777e-01,\n", | |
" 4.03801019e-02, -4.95936820e-02, 4.50478812e-01,\n", | |
" -7.14686052e-02, -5.52294987e-02, 1.04372477e-01,\n", | |
" -3.40250920e-01, 2.88827982e-01, -2.63497403e-01,\n", | |
" -2.32445876e-01, -1.66722006e-01, -1.73271567e-01,\n", | |
" 1.22581456e-01, -1.16980820e-02, -1.61869616e-03,\n", | |
" -1.18780956e-01, -3.50540034e-01, 8.97913284e-02,\n", | |
" -7.61979897e-02, 7.29592123e-02, -4.32490231e-01,\n", | |
" 6.10167907e-03, 6.91155145e-02, 4.11303824e-01,\n", | |
" 8.68429544e-02, -4.46346022e-01, 2.26845073e-01,\n", | |
" -4.24012759e-02, 1.38030852e-01, 1.15063190e-01,\n", | |
" 5.10586070e-01, -7.14633172e-02, 5.30342517e-01,\n", | |
" 4.13048572e-01, 3.64264180e-01, 2.39869940e-01,\n", | |
" -4.42589741e-02, -1.74146967e-01, 3.30045026e-02,\n", | |
" 4.23266382e-01, -2.67721232e-01, 2.61825872e-01,\n", | |
" -2.39166354e-01, 5.31260754e-02, 1.22381049e-01,\n", | |
" -1.81057630e-01, 3.01673051e-02, 2.05792984e-01,\n", | |
" -4.43103912e-02, 2.59567109e-01, 1.24828531e-01,\n", | |
" 2.48344142e-02, 6.61292131e-02, -2.15595605e-01,\n", | |
" -2.17146328e-01, -4.32532035e-02, -1.84800925e-01,\n", | |
" 7.23164663e-02, 8.77612450e-02, 1.31270321e-02,\n", | |
" 3.74900658e-01, 1.48689905e-02, -7.54011853e-02,\n", | |
" -2.29019029e-01, 2.21091516e-01, 2.72291801e-01,\n", | |
" 6.98253976e-03, -3.35335329e-02, -2.38709301e-01,\n", | |
" 9.05404310e-02, 8.63905799e-02, 2.16739339e-01,\n", | |
" -3.42600976e-01, -1.68438518e-01, -1.76773410e-01,\n", | |
" 1.08067629e-01, -1.94609717e-02, -1.16744506e-01,\n", | |
" 1.02268451e-01, -1.56100896e-01, 1.04218029e-01,\n", | |
" -3.14483466e-01, -9.98113670e-03, 2.99466507e-01,\n", | |
" 9.39993831e-02, -8.74415484e-02, 3.34126679e-01,\n", | |
" 3.81101684e-01, -1.16455008e-01, 2.15668622e-02,\n", | |
" 2.13510393e-01, 3.03286885e-01, 2.54904244e-01,\n", | |
" 1.78367477e-01, -3.74130418e-02, -2.73458252e-01,\n", | |
" 4.08133013e-02, 1.35996689e-01, -1.31633515e-01,\n", | |
" -5.62645372e-02, -1.57828209e-01, 1.06801540e-01,\n", | |
" -3.81939883e-01, -1.62944996e-02, -1.32606331e-01,\n", | |
" -3.12490913e-01, -5.09885605e-01, -1.73578537e-01,\n", | |
" 1.14861704e-01, 2.58727189e-01, 8.06011109e-02,\n", | |
" 3.80615835e-02, 4.88992258e-01, -6.39860773e-02,\n", | |
" -3.77617846e-02, 2.52385531e-01, -1.06498009e-01,\n", | |
" 9.47125137e-02, -3.64161174e-01, 2.88317109e-02,\n", | |
" 9.08336005e-02, 2.46983136e-01, 5.80316211e-02,\n", | |
" 2.26372182e-01, -2.08512619e-01, 1.91416935e-01,\n", | |
" -1.38191562e-02, 3.12551018e-01], dtype=float64), state=ProxGradState(iter_num=DeviceArray(105, dtype=int64, weak_type=True), stepsize=DeviceArray(0.0625, dtype=float64, weak_type=True), error=DeviceArray(0.00086155, dtype=float64), aux=None, velocity=DeviceArray([-5.56365708e-02, -3.30370960e-01, 1.76324450e-01,\n", | |
" -1.59513769e-01, -7.84570317e-02, 3.02711873e-01,\n", | |
" -5.47660589e-02, 2.25243400e-01, 2.49482245e-01,\n", | |
" -4.39835217e-02, 1.67719389e-01, 1.44717301e-01,\n", | |
" 4.29480500e-04, -9.87580066e-02, 1.86057974e-02,\n", | |
" -4.28120465e-01, -4.57527805e-02, -8.49148170e-02,\n", | |
" -8.92561869e-02, 5.65000125e-02, 9.73655860e-02,\n", | |
" -1.11936646e-01, 9.50226306e-02, 6.74559041e-02,\n", | |
" 1.24752308e-01, -6.82214629e-02, -1.63014642e-01,\n", | |
" -5.25980008e-02, 1.17636468e-01, -6.21721361e-02,\n", | |
" -8.82840133e-02, 1.59023867e-01, 1.85811490e-02,\n", | |
" -5.49643385e-02, -2.28883662e-01, 2.92989066e-01,\n", | |
" 1.83948499e-01, 1.58231350e-01, -1.10761997e-01,\n", | |
" -2.16693963e-01, 1.29229235e-01, -1.53328191e-01,\n", | |
" 1.00460910e-01, 1.26238221e-01, -3.37998188e-02,\n", | |
" 1.72531762e-02, 7.79684852e-02, 7.28757459e-02,\n", | |
" 1.69893039e-02, 2.19860950e-01, -5.42834691e-01,\n", | |
" -2.07424208e-01, -1.44247503e-01, 2.01023907e-01,\n", | |
" 1.37899838e-01, -5.35589469e-02, 1.35297616e-01,\n", | |
" 1.31811190e-01, 2.48317556e-01, 2.05100632e-01,\n", | |
" -3.45215471e-02, -2.86365501e-02, -3.72467023e-02,\n", | |
" -2.17800236e-01, -2.27456097e-01, -4.22492155e-04,\n", | |
" 1.81887728e-01, -5.51479196e-02, -1.43605535e-01,\n", | |
" 1.18033057e-01, 8.11158435e-02, 1.86598680e-01,\n", | |
" 4.03832521e-02, -4.95991850e-02, 4.50467048e-01,\n", | |
" -7.14764433e-02, -5.52333379e-02, 1.04365858e-01,\n", | |
" -3.40253105e-01, 2.88818071e-01, -2.63496862e-01,\n", | |
" -2.32457231e-01, -1.66722145e-01, -1.73281612e-01,\n", | |
" 1.22590516e-01, -1.17013349e-02, -1.61774222e-03,\n", | |
" -1.18782141e-01, -3.50535704e-01, 8.98014510e-02,\n", | |
" -7.62146703e-02, 7.29366407e-02, -4.32499208e-01,\n", | |
" 6.10795595e-03, 6.91258407e-02, 4.11301176e-01,\n", | |
" 8.68362439e-02, -4.46333054e-01, 2.26840506e-01,\n", | |
" -4.24024401e-02, 1.38013444e-01, 1.15066948e-01,\n", | |
" 5.10591468e-01, -7.14560053e-02, 5.30322559e-01,\n", | |
" 4.13041697e-01, 3.64259546e-01, 2.39880462e-01,\n", | |
" -4.42563475e-02, -1.74149648e-01, 3.30168014e-02,\n", | |
" 4.23266403e-01, -2.67697013e-01, 2.61826372e-01,\n", | |
" -2.39150168e-01, 5.31110872e-02, 1.22383181e-01,\n", | |
" -1.81046194e-01, 3.01729755e-02, 2.05782080e-01,\n", | |
" -4.43180631e-02, 2.59578306e-01, 1.24833167e-01,\n", | |
" 2.48342648e-02, 6.61267853e-02, -2.15585951e-01,\n", | |
" -2.17135354e-01, -4.32690774e-02, -1.84796523e-01,\n", | |
" 7.23170858e-02, 8.77555028e-02, 1.31201431e-02,\n", | |
" 3.74914522e-01, 1.48794502e-02, -7.54039100e-02,\n", | |
" -2.29028960e-01, 2.21096161e-01, 2.72309764e-01,\n", | |
" 6.96390189e-03, -3.35404800e-02, -2.38686149e-01,\n", | |
" 9.05449307e-02, 8.63931951e-02, 2.16720633e-01,\n", | |
" -3.42589200e-01, -1.68447098e-01, -1.76778135e-01,\n", | |
" 1.08076862e-01, -1.94777350e-02, -1.16762877e-01,\n", | |
" 1.02269717e-01, -1.56103172e-01, 1.04242726e-01,\n", | |
" -3.14514302e-01, -9.97684888e-03, 2.99454758e-01,\n", | |
" 9.39988060e-02, -8.74642702e-02, 3.34122507e-01,\n", | |
" 3.81098341e-01, -1.16460167e-01, 2.15851953e-02,\n", | |
" 2.13511201e-01, 3.03302048e-01, 2.54892107e-01,\n", | |
" 1.78369235e-01, -3.74041682e-02, -2.73461976e-01,\n", | |
" 4.08019476e-02, 1.35984097e-01, -1.31634318e-01,\n", | |
" -5.62757131e-02, -1.57806639e-01, 1.06804030e-01,\n", | |
" -3.81937848e-01, -1.63021326e-02, -1.32606228e-01,\n", | |
" -3.12501254e-01, -5.09890110e-01, -1.73590024e-01,\n", | |
" 1.14881467e-01, 2.58734843e-01, 8.05958669e-02,\n", | |
" 3.80562044e-02, 4.88971275e-01, -6.39961953e-02,\n", | |
" -3.77588711e-02, 2.52373881e-01, -1.06489185e-01,\n", | |
" 9.47132596e-02, -3.64166697e-01, 2.88364831e-02,\n", | |
" 9.08045243e-02, 2.46979324e-01, 5.80289317e-02,\n", | |
" 2.26380337e-01, -2.08506862e-01, 1.91406167e-01,\n", | |
" -1.38343929e-02, 3.12547289e-01], dtype=float64), t=DeviceArray(54.49605521, dtype=float64, weak_type=True)))" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# this works\n", | |
"x0 = jnp.zeros(n_features)\n", | |
"solver = jaxopt.GradientDescent(loss)\n", | |
"solver.run(x0, A, b, lmbda)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "29b2a066", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "TypeError", | |
"evalue": "run() got multiple values for argument 'init_stepsize'", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m/tmp/ipykernel_537956/1690247864.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mx0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0msolver\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjaxopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLBFGS\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0msolver\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlmbda\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m~/dev/jaxopt/jaxopt/_src/base.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, init_params, *args, **kwargs)\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[0mrun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdecorator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 215\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minit_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 216\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 217\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/dev/jaxopt/jaxopt/_src/implicit_diff.py\u001b[0m in \u001b[0;36mwrapped_solver_fun\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 249\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_signature_bind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msolver_fun_signature\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 250\u001b[0m \u001b[0mkeys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 251\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mmake_custom_vjp_solver_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msolver_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeys\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 252\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped_solver_fun\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
" \u001b[0;31m[... skipping hidden 5 frame]\u001b[0m\n", | |
"\u001b[0;32m~/dev/jaxopt/jaxopt/_src/implicit_diff.py\u001b[0m in \u001b[0;36msolver_fun_flat\u001b[0;34m(*flat_args)\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msolver_fun_flat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mflat_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_extract_kwargs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwarg_keys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflat_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 207\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msolver_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 208\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msolver_fun_fwd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mflat_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/dev/jaxopt/jaxopt/_src/base.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, init_params, *args, **kwargs)\u001b[0m\n\u001b[1;32m 175\u001b[0m \u001b[0mzero_step\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_zero_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minit_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 176\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 177\u001b[0;31m \u001b[0mopt_step\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minit_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 178\u001b[0m \u001b[0minit_val\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mopt_step\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 179\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/dev/jaxopt/jaxopt/_src/lbfgs.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, params, state, *args, **kwargs)\u001b[0m\n\u001b[1;32m 264\u001b[0m unroll=self.unroll)\n\u001b[1;32m 265\u001b[0m \u001b[0minit_stepsize\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstepsize\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mincrease_factor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 266\u001b[0;31m new_stepsize, ls_state = ls.run(init_stepsize=init_stepsize,\n\u001b[0m\u001b[1;32m 267\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 268\u001b[0m \u001b[0mdescent_direction\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdescent_direction\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mTypeError\u001b[0m: run() got multiple values for argument 'init_stepsize'" | |
] | |
} | |
], | |
"source": [ | |
"# this doesn't works\n", | |
"x0 = jnp.zeros(n_features)\n", | |
"solver = jaxopt.LBFGS(loss)\n", | |
"solver.run(x0, A, b, lmbda)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "0e3475f6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment