Skip to content

Instantly share code, notes, and snippets.

@fabianp
Last active May 31, 2022 20:37
Show Gist options
  • Save fabianp/6c3250512778e0612cbc49a60e32ff00 to your computer and use it in GitHub Desktop.
Save fabianp/6c3250512778e0612cbc49a60e32ff00 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"id": "a9fe6a29",
"metadata": {},
"outputs": [],
"source": [
"import copy\n",
"import jaxopt\n",
"\n",
"# for automatic differentiation\n",
"from jax import numpy as jnp\n",
"from jax import grad, hessian\n",
"from jax import jacfwd, jacrev\n",
"from jax import random\n",
"import jax\n",
"import numpy as np\n",
"from scipy.special import gamma as gamma_fct\n",
"\n",
"# enforce double precision\n",
"from jax.config import config\n",
"config.update(\"jax_enable_x64\", True)\n",
"\n",
"# seed for random data generation\n",
"key = random.PRNGKey(42)\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "b256f2f9",
"metadata": {},
"outputs": [],
"source": [
"\n",
"n_samples, n_features = 100, 200\n",
"A = random.normal(key, shape=(n_samples, n_features))\n",
"A /= A.max()\n",
"b = random.normal(key, shape=(n_samples,))\n",
"lmbda = 1e-4 * (jnp.linalg.norm(A, 2) ** 2)\n",
"\n",
"\n",
"def loss(x, A, b, lmbda):\n",
" A, b, lmbda\n",
" z = A @ x - b\n",
" return 0.5 * (jnp.vdot(z, z)) + 0.5 * lmbda * jnp.vdot(x, x)\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "12bbb25c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running experiments for dataset synthetic\n"
]
}
],
"source": [
"\n",
"max_iter = 300\n",
"all_max_iter = jnp.arange(1, max_iter, 1)\n",
"\n",
"dataset = 'synthetic'\n",
"print('Running experiments for dataset ', dataset)\n",
"\n",
"n_samples, n_features = A.shape\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "d4a11769",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"OptStep(params=DeviceArray([-5.56486650e-02, -3.30363170e-01, 1.76317885e-01,\n",
" -1.59527912e-01, -7.84478294e-02, 3.02712915e-01,\n",
" -5.47730120e-02, 2.25249034e-01, 2.49474351e-01,\n",
" -4.39778003e-02, 1.67724994e-01, 1.44724183e-01,\n",
" 4.15387431e-04, -9.87611300e-02, 1.85985813e-02,\n",
" -4.28141203e-01, -4.57414206e-02, -8.49005408e-02,\n",
" -8.92551810e-02, 5.64902954e-02, 9.73533821e-02,\n",
" -1.11933882e-01, 9.50161621e-02, 6.74602899e-02,\n",
" 1.24753848e-01, -6.82230733e-02, -1.63003820e-01,\n",
" -5.26133230e-02, 1.17648054e-01, -6.21795547e-02,\n",
" -8.82708372e-02, 1.59027257e-01, 1.85936780e-02,\n",
" -5.49654228e-02, -2.28894844e-01, 2.92985412e-01,\n",
" 1.83952343e-01, 1.58209258e-01, -1.10766676e-01,\n",
" -2.16695040e-01, 1.29207705e-01, -1.53319470e-01,\n",
" 1.00466661e-01, 1.26225154e-01, -3.38084431e-02,\n",
" 1.72545818e-02, 7.79419307e-02, 7.28775001e-02,\n",
" 1.70079966e-02, 2.19853321e-01, -5.42848920e-01,\n",
" -2.07430302e-01, -1.44251755e-01, 2.01036430e-01,\n",
" 1.37890430e-01, -5.35801019e-02, 1.35305391e-01,\n",
" 1.31798421e-01, 2.48309332e-01, 2.05080311e-01,\n",
" -3.45272564e-02, -2.86509744e-02, -3.72516348e-02,\n",
" -2.17809410e-01, -2.27468689e-01, -4.38619558e-04,\n",
" 1.81890883e-01, -5.51496365e-02, -1.43599552e-01,\n",
" 1.18041672e-01, 8.11086175e-02, 1.86604777e-01,\n",
" 4.03801019e-02, -4.95936820e-02, 4.50478812e-01,\n",
" -7.14686052e-02, -5.52294987e-02, 1.04372477e-01,\n",
" -3.40250920e-01, 2.88827982e-01, -2.63497403e-01,\n",
" -2.32445876e-01, -1.66722006e-01, -1.73271567e-01,\n",
" 1.22581456e-01, -1.16980820e-02, -1.61869616e-03,\n",
" -1.18780956e-01, -3.50540034e-01, 8.97913284e-02,\n",
" -7.61979897e-02, 7.29592123e-02, -4.32490231e-01,\n",
" 6.10167907e-03, 6.91155145e-02, 4.11303824e-01,\n",
" 8.68429544e-02, -4.46346022e-01, 2.26845073e-01,\n",
" -4.24012759e-02, 1.38030852e-01, 1.15063190e-01,\n",
" 5.10586070e-01, -7.14633172e-02, 5.30342517e-01,\n",
" 4.13048572e-01, 3.64264180e-01, 2.39869940e-01,\n",
" -4.42589741e-02, -1.74146967e-01, 3.30045026e-02,\n",
" 4.23266382e-01, -2.67721232e-01, 2.61825872e-01,\n",
" -2.39166354e-01, 5.31260754e-02, 1.22381049e-01,\n",
" -1.81057630e-01, 3.01673051e-02, 2.05792984e-01,\n",
" -4.43103912e-02, 2.59567109e-01, 1.24828531e-01,\n",
" 2.48344142e-02, 6.61292131e-02, -2.15595605e-01,\n",
" -2.17146328e-01, -4.32532035e-02, -1.84800925e-01,\n",
" 7.23164663e-02, 8.77612450e-02, 1.31270321e-02,\n",
" 3.74900658e-01, 1.48689905e-02, -7.54011853e-02,\n",
" -2.29019029e-01, 2.21091516e-01, 2.72291801e-01,\n",
" 6.98253976e-03, -3.35335329e-02, -2.38709301e-01,\n",
" 9.05404310e-02, 8.63905799e-02, 2.16739339e-01,\n",
" -3.42600976e-01, -1.68438518e-01, -1.76773410e-01,\n",
" 1.08067629e-01, -1.94609717e-02, -1.16744506e-01,\n",
" 1.02268451e-01, -1.56100896e-01, 1.04218029e-01,\n",
" -3.14483466e-01, -9.98113670e-03, 2.99466507e-01,\n",
" 9.39993831e-02, -8.74415484e-02, 3.34126679e-01,\n",
" 3.81101684e-01, -1.16455008e-01, 2.15668622e-02,\n",
" 2.13510393e-01, 3.03286885e-01, 2.54904244e-01,\n",
" 1.78367477e-01, -3.74130418e-02, -2.73458252e-01,\n",
" 4.08133013e-02, 1.35996689e-01, -1.31633515e-01,\n",
" -5.62645372e-02, -1.57828209e-01, 1.06801540e-01,\n",
" -3.81939883e-01, -1.62944996e-02, -1.32606331e-01,\n",
" -3.12490913e-01, -5.09885605e-01, -1.73578537e-01,\n",
" 1.14861704e-01, 2.58727189e-01, 8.06011109e-02,\n",
" 3.80615835e-02, 4.88992258e-01, -6.39860773e-02,\n",
" -3.77617846e-02, 2.52385531e-01, -1.06498009e-01,\n",
" 9.47125137e-02, -3.64161174e-01, 2.88317109e-02,\n",
" 9.08336005e-02, 2.46983136e-01, 5.80316211e-02,\n",
" 2.26372182e-01, -2.08512619e-01, 1.91416935e-01,\n",
" -1.38191562e-02, 3.12551018e-01], dtype=float64), state=ProxGradState(iter_num=DeviceArray(105, dtype=int64, weak_type=True), stepsize=DeviceArray(0.0625, dtype=float64, weak_type=True), error=DeviceArray(0.00086155, dtype=float64), aux=None, velocity=DeviceArray([-5.56365708e-02, -3.30370960e-01, 1.76324450e-01,\n",
" -1.59513769e-01, -7.84570317e-02, 3.02711873e-01,\n",
" -5.47660589e-02, 2.25243400e-01, 2.49482245e-01,\n",
" -4.39835217e-02, 1.67719389e-01, 1.44717301e-01,\n",
" 4.29480500e-04, -9.87580066e-02, 1.86057974e-02,\n",
" -4.28120465e-01, -4.57527805e-02, -8.49148170e-02,\n",
" -8.92561869e-02, 5.65000125e-02, 9.73655860e-02,\n",
" -1.11936646e-01, 9.50226306e-02, 6.74559041e-02,\n",
" 1.24752308e-01, -6.82214629e-02, -1.63014642e-01,\n",
" -5.25980008e-02, 1.17636468e-01, -6.21721361e-02,\n",
" -8.82840133e-02, 1.59023867e-01, 1.85811490e-02,\n",
" -5.49643385e-02, -2.28883662e-01, 2.92989066e-01,\n",
" 1.83948499e-01, 1.58231350e-01, -1.10761997e-01,\n",
" -2.16693963e-01, 1.29229235e-01, -1.53328191e-01,\n",
" 1.00460910e-01, 1.26238221e-01, -3.37998188e-02,\n",
" 1.72531762e-02, 7.79684852e-02, 7.28757459e-02,\n",
" 1.69893039e-02, 2.19860950e-01, -5.42834691e-01,\n",
" -2.07424208e-01, -1.44247503e-01, 2.01023907e-01,\n",
" 1.37899838e-01, -5.35589469e-02, 1.35297616e-01,\n",
" 1.31811190e-01, 2.48317556e-01, 2.05100632e-01,\n",
" -3.45215471e-02, -2.86365501e-02, -3.72467023e-02,\n",
" -2.17800236e-01, -2.27456097e-01, -4.22492155e-04,\n",
" 1.81887728e-01, -5.51479196e-02, -1.43605535e-01,\n",
" 1.18033057e-01, 8.11158435e-02, 1.86598680e-01,\n",
" 4.03832521e-02, -4.95991850e-02, 4.50467048e-01,\n",
" -7.14764433e-02, -5.52333379e-02, 1.04365858e-01,\n",
" -3.40253105e-01, 2.88818071e-01, -2.63496862e-01,\n",
" -2.32457231e-01, -1.66722145e-01, -1.73281612e-01,\n",
" 1.22590516e-01, -1.17013349e-02, -1.61774222e-03,\n",
" -1.18782141e-01, -3.50535704e-01, 8.98014510e-02,\n",
" -7.62146703e-02, 7.29366407e-02, -4.32499208e-01,\n",
" 6.10795595e-03, 6.91258407e-02, 4.11301176e-01,\n",
" 8.68362439e-02, -4.46333054e-01, 2.26840506e-01,\n",
" -4.24024401e-02, 1.38013444e-01, 1.15066948e-01,\n",
" 5.10591468e-01, -7.14560053e-02, 5.30322559e-01,\n",
" 4.13041697e-01, 3.64259546e-01, 2.39880462e-01,\n",
" -4.42563475e-02, -1.74149648e-01, 3.30168014e-02,\n",
" 4.23266403e-01, -2.67697013e-01, 2.61826372e-01,\n",
" -2.39150168e-01, 5.31110872e-02, 1.22383181e-01,\n",
" -1.81046194e-01, 3.01729755e-02, 2.05782080e-01,\n",
" -4.43180631e-02, 2.59578306e-01, 1.24833167e-01,\n",
" 2.48342648e-02, 6.61267853e-02, -2.15585951e-01,\n",
" -2.17135354e-01, -4.32690774e-02, -1.84796523e-01,\n",
" 7.23170858e-02, 8.77555028e-02, 1.31201431e-02,\n",
" 3.74914522e-01, 1.48794502e-02, -7.54039100e-02,\n",
" -2.29028960e-01, 2.21096161e-01, 2.72309764e-01,\n",
" 6.96390189e-03, -3.35404800e-02, -2.38686149e-01,\n",
" 9.05449307e-02, 8.63931951e-02, 2.16720633e-01,\n",
" -3.42589200e-01, -1.68447098e-01, -1.76778135e-01,\n",
" 1.08076862e-01, -1.94777350e-02, -1.16762877e-01,\n",
" 1.02269717e-01, -1.56103172e-01, 1.04242726e-01,\n",
" -3.14514302e-01, -9.97684888e-03, 2.99454758e-01,\n",
" 9.39988060e-02, -8.74642702e-02, 3.34122507e-01,\n",
" 3.81098341e-01, -1.16460167e-01, 2.15851953e-02,\n",
" 2.13511201e-01, 3.03302048e-01, 2.54892107e-01,\n",
" 1.78369235e-01, -3.74041682e-02, -2.73461976e-01,\n",
" 4.08019476e-02, 1.35984097e-01, -1.31634318e-01,\n",
" -5.62757131e-02, -1.57806639e-01, 1.06804030e-01,\n",
" -3.81937848e-01, -1.63021326e-02, -1.32606228e-01,\n",
" -3.12501254e-01, -5.09890110e-01, -1.73590024e-01,\n",
" 1.14881467e-01, 2.58734843e-01, 8.05958669e-02,\n",
" 3.80562044e-02, 4.88971275e-01, -6.39961953e-02,\n",
" -3.77588711e-02, 2.52373881e-01, -1.06489185e-01,\n",
" 9.47132596e-02, -3.64166697e-01, 2.88364831e-02,\n",
" 9.08045243e-02, 2.46979324e-01, 5.80289317e-02,\n",
" 2.26380337e-01, -2.08506862e-01, 1.91406167e-01,\n",
" -1.38343929e-02, 3.12547289e-01], dtype=float64), t=DeviceArray(54.49605521, dtype=float64, weak_type=True)))"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# this works\n",
"x0 = jnp.zeros(n_features)\n",
"solver = jaxopt.GradientDescent(loss)\n",
"solver.run(x0, A, b, lmbda)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "29b2a066",
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "run() got multiple values for argument 'init_stepsize'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_537956/1690247864.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mx0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0msolver\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjaxopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLBFGS\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0msolver\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlmbda\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/dev/jaxopt/jaxopt/_src/base.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, init_params, *args, **kwargs)\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[0mrun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdecorator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 215\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minit_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 216\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 217\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/dev/jaxopt/jaxopt/_src/implicit_diff.py\u001b[0m in \u001b[0;36mwrapped_solver_fun\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 249\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_signature_bind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msolver_fun_signature\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 250\u001b[0m \u001b[0mkeys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 251\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mmake_custom_vjp_solver_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msolver_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeys\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 252\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped_solver_fun\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
" \u001b[0;31m[... skipping hidden 5 frame]\u001b[0m\n",
"\u001b[0;32m~/dev/jaxopt/jaxopt/_src/implicit_diff.py\u001b[0m in \u001b[0;36msolver_fun_flat\u001b[0;34m(*flat_args)\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msolver_fun_flat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mflat_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_extract_kwargs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwarg_keys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflat_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 207\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msolver_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 208\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msolver_fun_fwd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mflat_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/dev/jaxopt/jaxopt/_src/base.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, init_params, *args, **kwargs)\u001b[0m\n\u001b[1;32m 175\u001b[0m \u001b[0mzero_step\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_zero_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minit_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 176\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 177\u001b[0;31m \u001b[0mopt_step\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minit_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 178\u001b[0m \u001b[0minit_val\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mopt_step\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 179\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/dev/jaxopt/jaxopt/_src/lbfgs.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, params, state, *args, **kwargs)\u001b[0m\n\u001b[1;32m 264\u001b[0m unroll=self.unroll)\n\u001b[1;32m 265\u001b[0m \u001b[0minit_stepsize\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstepsize\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mincrease_factor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 266\u001b[0;31m new_stepsize, ls_state = ls.run(init_stepsize=init_stepsize,\n\u001b[0m\u001b[1;32m 267\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 268\u001b[0m \u001b[0mdescent_direction\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdescent_direction\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: run() got multiple values for argument 'init_stepsize'"
]
}
],
"source": [
"# this doesn't works\n",
"x0 = jnp.zeros(n_features)\n",
"solver = jaxopt.LBFGS(loss)\n",
"solver.run(x0, A, b, lmbda)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e3475f6",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment