Skip to content

Instantly share code, notes, and snippets.

@PhilipVinc
Created May 2, 2022 14:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save PhilipVinc/71d0ce0d5b4555e9659b30a30b839140 to your computer and use it in GitHub Desktop.
Save PhilipVinc/71d0ce0d5b4555e9659b30a30b839140 to your computer and use it in GitHub Desktop.
from functools import partial
from typing import Any, Callable, Optional, Tuple, Union
import jax
from flax import linen as nn
from jax import numpy as jnp
from netket.hilbert import ContinuousHilbert
from netket.utils import mpi, wrap_afun
from netket.utils.types import PyTree, PRNGKeyT
from netket.utils.deprecation import deprecated, warn_deprecation
from netket.utils import struct
@struct.dataclass
class MetropolisSmearedSamplerState(nk.sampler.MetropolisSamplerState):
epsilon : jnp.ndarray
@struct.dataclass
class MetropolisSmearedSampler(nk.sampler.MetropolisSampler):
def _init_state(sampler, machine, params, key):
key_state, key_rule = jax.random.split(key, 2)
rule_state = sampler.rule.init_state(sampler, machine, params, key_rule)
σ = jnp.zeros(
(sampler.n_chains_per_rank, sampler.hilbert.size), dtype=sampler.dtype
)
state = MetropolisSmearedSamplerState(σ=σ, rng=key_state, rule_state=rule_state, epsilon=0.0)
# If we don't reset the chain at every sampling iteration, then reset it
# now.
if not sampler.reset_chains:
key_state, rng = jax.random.split(key_state)
σ = sampler.rule.random_state(sampler, machine, params, state, rng)
state = state.replace(σ=σ, rng=key_state)
return state
def _sample_next(sampler, machine, parameters, state):
"""
Implementation of `sample_next` for subclasses of `MetropolisSampler`.
If you subclass `MetropolisSampler`, you should override this and not `sample_next`
itself, because `sample_next` contains some common logic.
"""
smearing_e = state.epsilon
def loop_body(i, s):
# 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel
s["key"], key1, key2 = jax.random.split(s["key"], 3)
σp, log_prob_correction = sampler.rule.transition(
sampler, machine, parameters, state, key1, s["σ"]
)
proposal_log_prob = sampler.machine_pow * machine.apply(parameters, σp).real
uniform = jax.random.uniform(key2, shape=(sampler.n_chains_per_rank,))
if log_prob_correction is not None:
do_accept = uniform < jnp.exp(
jnp.logaddexp(proposal_log_prob + smearing_e) - jnp.logaddexp(s["log_prob"], smearing_e) + log_prob_correction
)
else:
do_accept = uniform < jnp.exp(jnp.logaddexp(proposal_log_prob, smearing_e) - jnp.logaddexp(s["log_prob"], smearing_e))
# do_accept must match ndim of proposal and state (which is 2)
s["σ"] = jnp.where(do_accept.reshape(-1, 1), σp, s["σ"])
s["accepted"] += do_accept.sum()
s["log_prob"] = jax.numpy.where(
do_accept.reshape(-1), proposal_log_prob, s["log_prob"]
)
return s
new_rng, rng = jax.random.split(state.rng)
s = {
"key": rng,
"σ": state.σ,
"log_prob": sampler.machine_pow * machine.apply(parameters, state.σ).real,
# for logging
"accepted": state.n_accepted_proc,
}
s = jax.lax.fori_loop(0, sampler.n_sweeps, loop_body, s)
new_state = state.replace(
rng=new_rng,
σ=s["σ"],
n_accepted_proc=s["accepted"],
n_steps_proc=state.n_steps_proc
+ sampler.n_sweeps * sampler.n_chains_per_rank,
)
return new_state, new_state.σ
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment