Skip to content

Instantly share code, notes, and snippets.

@yonesuke
Created May 18, 2022 08:39
Show Gist options
  • Save yonesuke/a2393772669a620cba3cc54ce892f5af to your computer and use it in GitHub Desktop.
Save yonesuke/a2393772669a620cba3cc54ce892f5af to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 148,
"metadata": {},
"outputs": [],
"source": [
"import jaxfss\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import optax\n",
"from softclip import SoftClip\n",
"from jax.config import config; config.update(\"jax_enable_x64\", True)\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib\n",
"matplotlib.rc('text', usetex=True)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"# initializing MLP\n",
"mlp = jaxfss.RationalMLP(features=[20, 20, 1])\n",
"mlp_params = mlp.init(jax.random.PRNGKey(0), jnp.array([[1]]))"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"# creating data\n",
"dataset = jaxfss.CriticalData.from_file(fname=\"ising.dat\")\n",
"train_data = dataset.train_data"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"# creating bijector for stablizing learning\n",
"bij_c1 = SoftClip(low=0.0)\n",
"bij_c2 = SoftClip(low=0.0)\n",
"bij_Tc = SoftClip(low=-1.0, high=1.0)"
]
},
{
"cell_type": "code",
"execution_count": 155,
"metadata": {},
"outputs": [],
"source": [
"# initial parameters\n",
"init_params = {\n",
" \"mlp\": mlp_params,\n",
" \"fss\": [0.0, 0.0] # p1, pc\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 156,
"metadata": {},
"outputs": [],
"source": [
"# helper function for bijectors\n",
"def bijectors(params):\n",
" p1, pc = params[\"fss\"]\n",
" c1 = bij_c1.forward(p1)\n",
" # c2 = bij_c2.forward(p2)\n",
" scaled_Tc = bij_Tc.forward(pc)\n",
" return [c1, scaled_Tc]\n",
"\n",
"# loss function for learning\n",
"def loss_fn(params):\n",
" c2 = 0.0\n",
" c1, scaled_Tc = bijectors(params)\n",
" Ls, Ts, As, vAs = train_data[\"system_size\"], train_data[\"temperature\"], train_data[\"observable\"], train_data[\"observable_var\"]\n",
" X = (Ts - scaled_Tc) * Ls ** c1\n",
" Y = As * Ls ** c2\n",
" E = vAs * Ls ** c2\n",
" return jaxfss.NLLLoss(mlp.apply(params[\"mlp\"], X), Y, E) + 0.5 * jnp.log(2.0 * jnp.pi)\n",
"\n",
"# fit function\n",
"def fit(loss_fn, optimizer, init_params, steps):\n",
" opt_state = optimizer.init(init_params)\n",
" @jax.jit\n",
" def update_fn(i, val):\n",
" params, opt_state, logs = val\n",
" loss, grad = jax.value_and_grad(loss_fn)(params)\n",
" updates, opt_state = optimizer.update(grad, opt_state, params)\n",
" params = optax.apply_updates(params, updates)\n",
" # losses = losses.at[i].set(loss)\n",
" # criticals = criticals.at[i].set(params[\"fss\"])\n",
" logs = logs.at[i].set([loss, *params[\"fss\"]])\n",
" return [params, opt_state, logs]\n",
" # losses = jnp.zeros(steps)\n",
" # criticals = jnp.zeros((steps, len(init_params[\"fss\"])))\n",
" logs = jnp.zeros((steps, len(init_params[\"fss\"]) + 1))\n",
" init_val = [init_params, opt_state, logs]\n",
" params, _, logs = jax.lax.fori_loop(0, steps, update_fn, init_val)\n",
" return params, logs"
]
},
{
"cell_type": "code",
"execution_count": 159,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.47 s, sys: 261 ms, total: 2.73 s\n",
"Wall time: 2.13 s\n",
"0.9907766077862741 0.44070345563338675\n"
]
}
],
"source": [
"lr = 10**-2\n",
"optimizer = optax.adam(learning_rate=lr)\n",
"steps = 10**4\n",
"# learn!!\n",
"# params, losses, fsses = jaxfss.fit(loss_fn, optimizer, init_params, steps)\n",
"%time params, logs = fit(loss_fn, optimizer, init_params, steps)\n",
"c1, scaled_Tc = bijectors(params)\n",
"Tc = dataset.bij_temperature.inverse(scaled_Tc)\n",
"print(c1, Tc)"
]
},
{
"cell_type": "code",
"execution_count": 170,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=[8, 6])\n",
"plt.rcParams[\"font.size\"] = 20\n",
"plt.xlim(0, steps)\n",
"plt.xlabel(r\"$\\textrm{iteration}$\")\n",
"# plt.plot(bij_c1.forward_fn(logs[:,1]), label=\"$c_{1}$\", lw=3)\n",
"plt.plot(logs[:,1], label=\"$p_{1}$\", lw=3)\n",
"# plt.plot(bij_c2.forward_fn(logs[:,2]), label=\"$c_{2}$\", lw=3)\n",
"# plt.plot(dataset.bij_temperature.inverse(bij_Tc.forward_fn(logs[:,3])), label=\"$\\mathrm{rescaled}\\ T_{c}$\", lw=3)\n",
"# plt.plot(bij_Tc.forward_fn(logs[:,2]), label=r\"$\\textrm{rescaled }T_{c}$\", lw=3)\n",
"plt.plot(logs[:,2], label=r\"$p_{c}$\", lw=3)\n",
"plt.legend()\n",
"\n",
"plt.savefig(\"ising_params.pdf\", bbox_inches='tight', transparent=True)"
]
},
{
"cell_type": "code",
"execution_count": 188,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x720 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=[8, 10])\n",
"plt.rcParams[\"font.size\"] = 20\n",
"\n",
"plt.subplot(2,1,1)\n",
"plt.title(\"$\\mathrm{(a)}$\", loc=\"left\")\n",
"nums = [64, 128, 256]\n",
"for num in nums:\n",
" idx = dataset.Ls == num\n",
" plt.scatter(dataset.Ts[idx], dataset.As[idx], label=rf\"$L={num}$\")\n",
"plt.legend()\n",
" \n",
"plt.subplot(2,1,2)\n",
"plt.title(\"$\\mathrm{(b)}$\", loc=\"left\")\n",
"for num in nums:\n",
" idx = dataset.Ls == num\n",
" Ls, Ts, As = train_data[\"system_size\"][idx], train_data[\"temperature\"][idx], train_data[\"observable\"][idx]\n",
" X = (Ts - scaled_Tc) * Ls ** c1\n",
" Y = As\n",
" plt.scatter(X, Y, label=rf\"$L={num}$\")\n",
"plt.legend()\n",
"\n",
"plt.savefig(\"ising_binder.pdf\", bbox_inches='tight', transparent=True)"
]
},
{
"cell_type": "code",
"execution_count": 165,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[ 2.98515077e+05, 1.00000000e-02, -1.00000000e-02],\n",
" [ 2.30951170e+05, 9.77260388e-03, -1.88504672e-02],\n",
" [ 1.80104530e+05, 4.33646888e-03, -2.79320548e-02],\n",
" ...,\n",
" [-5.26425070e+00, 5.26663112e-01, 1.40916808e-01],\n",
" [-5.26570592e+00, 5.26663547e-01, 1.40923429e-01],\n",
" [-5.26588023e+00, 5.26694239e-01, 1.40923888e-01]], dtype=float64)"
]
},
"execution_count": 165,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"logs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment