Created
May 2, 2022 14:15
-
-
Save PhilipVinc/71d0ce0d5b4555e9659b30a30b839140 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
from typing import Any, Callable, Optional, Tuple, Union | |
import jax | |
from flax import linen as nn | |
from jax import numpy as jnp | |
from netket.hilbert import ContinuousHilbert | |
from netket.utils import mpi, wrap_afun | |
from netket.utils.types import PyTree, PRNGKeyT | |
from netket.utils.deprecation import deprecated, warn_deprecation | |
from netket.utils import struct | |
@struct.dataclass | |
class MetropolisSmearedSamplerState(nk.sampler.MetropolisSamplerState): | |
epsilon : jnp.ndarray | |
@struct.dataclass | |
class MetropolisSmearedSampler(nk.sampler.MetropolisSampler): | |
def _init_state(sampler, machine, params, key): | |
key_state, key_rule = jax.random.split(key, 2) | |
rule_state = sampler.rule.init_state(sampler, machine, params, key_rule) | |
σ = jnp.zeros( | |
(sampler.n_chains_per_rank, sampler.hilbert.size), dtype=sampler.dtype | |
) | |
state = MetropolisSmearedSamplerState(σ=σ, rng=key_state, rule_state=rule_state, epsilon=0.0) | |
# If we don't reset the chain at every sampling iteration, then reset it | |
# now. | |
if not sampler.reset_chains: | |
key_state, rng = jax.random.split(key_state) | |
σ = sampler.rule.random_state(sampler, machine, params, state, rng) | |
state = state.replace(σ=σ, rng=key_state) | |
return state | |
def _sample_next(sampler, machine, parameters, state): | |
""" | |
Implementation of `sample_next` for subclasses of `MetropolisSampler`. | |
If you subclass `MetropolisSampler`, you should override this and not `sample_next` | |
itself, because `sample_next` contains some common logic. | |
""" | |
smearing_e = state.epsilon | |
def loop_body(i, s): | |
# 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel | |
s["key"], key1, key2 = jax.random.split(s["key"], 3) | |
σp, log_prob_correction = sampler.rule.transition( | |
sampler, machine, parameters, state, key1, s["σ"] | |
) | |
proposal_log_prob = sampler.machine_pow * machine.apply(parameters, σp).real | |
uniform = jax.random.uniform(key2, shape=(sampler.n_chains_per_rank,)) | |
if log_prob_correction is not None: | |
do_accept = uniform < jnp.exp( | |
jnp.logaddexp(proposal_log_prob + smearing_e) - jnp.logaddexp(s["log_prob"], smearing_e) + log_prob_correction | |
) | |
else: | |
do_accept = uniform < jnp.exp(jnp.logaddexp(proposal_log_prob, smearing_e) - jnp.logaddexp(s["log_prob"], smearing_e)) | |
# do_accept must match ndim of proposal and state (which is 2) | |
s["σ"] = jnp.where(do_accept.reshape(-1, 1), σp, s["σ"]) | |
s["accepted"] += do_accept.sum() | |
s["log_prob"] = jax.numpy.where( | |
do_accept.reshape(-1), proposal_log_prob, s["log_prob"] | |
) | |
return s | |
new_rng, rng = jax.random.split(state.rng) | |
s = { | |
"key": rng, | |
"σ": state.σ, | |
"log_prob": sampler.machine_pow * machine.apply(parameters, state.σ).real, | |
# for logging | |
"accepted": state.n_accepted_proc, | |
} | |
s = jax.lax.fori_loop(0, sampler.n_sweeps, loop_body, s) | |
new_state = state.replace( | |
rng=new_rng, | |
σ=s["σ"], | |
n_accepted_proc=s["accepted"], | |
n_steps_proc=state.n_steps_proc | |
+ sampler.n_sweeps * sampler.n_chains_per_rank, | |
) | |
return new_state, new_state.σ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment