Skip to content

Instantly share code, notes, and snippets.

@PhilipVinc
Last active May 26, 2022 17:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save PhilipVinc/6e7e804db6523b4febc68394f96af5ed to your computer and use it in GitHub Desktop.
Save PhilipVinc/6e7e804db6523b4febc68394f96af5ed to your computer and use it in GitHub Desktop.
Implement a fast classical Ising hamiltonian
import netket as nk
from netket.operator import AbstractOperator
import numpy as np
import jax.numpy as jnp
class ClassicalIsingOperator(AbstractOperator):
def __init__(self, hilbert, H0):
assert H0.shape == (hilbert.size, hilbert.size)
self.H0 = jnp.asarray(H0)
@property
def dtype(self):
return self.H0.dtype
@property
def is_hermitian(self):
return True
def local_estimator(logpsi, pars, sigma, extra_args):
# this computes <x|H|psi>/<x|psi>
H0 = extra_args
# check that sigma has been reshaped to 2D
# sigma is (Nsamples, Nsites)
assert sigma.ndim == 2
return jnp.einsum('ij,jk,ik->i', sigma, H0, sigma)
@nk.vqs.get_local_kernel.dispatch
def get_local_kernel(vstate: nk.vqs.MCState, op: ClassicalIsingOperator):
return local_estimator
@nk.vqs.get_local_kernel_arguments.dispatch
def get_local_kernel_arguments(vstate: nk.vqs.MCState, op: ClassicalIsingOperator):
sigma = vstate.samples
# get the connected elements. Reshape the samples because that code only works
# if the input is a 2D matrix
extra_args = op.H0
return sigma, extra_args
hi = nk.hilbert.Spin(0.5, 100)
sa = nk.sampler.MetropolisLocal(hi)
vs = nk.vqs.MCState(sa, nk.models.RBM(), n_samples=5000)
ha = ClassicalIsingOperator(hi, np.random.rand(hi.size, hi.size))
vs.expect_and_grad(ha)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment