Last active
May 26, 2022 17:27
-
-
Save PhilipVinc/6e7e804db6523b4febc68394f96af5ed to your computer and use it in GitHub Desktop.
Implement a fast classical Ising hamiltonian
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import netket as nk | |
from netket.operator import AbstractOperator | |
import numpy as np | |
import jax.numpy as jnp | |
class ClassicalIsingOperator(AbstractOperator): | |
def __init__(self, hilbert, H0): | |
assert H0.shape == (hilbert.size, hilbert.size) | |
self.H0 = jnp.asarray(H0) | |
@property | |
def dtype(self): | |
return self.H0.dtype | |
@property | |
def is_hermitian(self): | |
return True | |
def local_estimator(logpsi, pars, sigma, extra_args): | |
# this computes <x|H|psi>/<x|psi> | |
H0 = extra_args | |
# check that sigma has been reshaped to 2D | |
# sigma is (Nsamples, Nsites) | |
assert sigma.ndim == 2 | |
return jnp.einsum('ij,jk,ik->i', sigma, H0, sigma) | |
@nk.vqs.get_local_kernel.dispatch | |
def get_local_kernel(vstate: nk.vqs.MCState, op: ClassicalIsingOperator): | |
return local_estimator | |
@nk.vqs.get_local_kernel_arguments.dispatch | |
def get_local_kernel_arguments(vstate: nk.vqs.MCState, op: ClassicalIsingOperator): | |
sigma = vstate.samples | |
# get the connected elements. Reshape the samples because that code only works | |
# if the input is a 2D matrix | |
extra_args = op.H0 | |
return sigma, extra_args | |
hi = nk.hilbert.Spin(0.5, 100) | |
sa = nk.sampler.MetropolisLocal(hi) | |
vs = nk.vqs.MCState(sa, nk.models.RBM(), n_samples=5000) | |
ha = ClassicalIsingOperator(hi, np.random.rand(hi.size, hi.size)) | |
vs.expect_and_grad(ha) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment