Skip to content

Instantly share code, notes, and snippets.

@PhilipVinc
Created May 17, 2022 08:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save PhilipVinc/6af1ea3b9dd0c7b59e24a0760e893b16 to your computer and use it in GitHub Desktop.
Save PhilipVinc/6af1ea3b9dd0c7b59e24a0760e893b16 to your computer and use it in GitHub Desktop.
NetKet sum operator
from typing import Union, List, Optional
from netket.utils.types import DType
import functools
class SumOperator(nk.operator.AbstractOperator):
r"""This class implements the action of the _expect_kernel()-method of
ContinuousOperator for a sum of ContinuousOperator objects.
"""
def __init__(
self,
*operators: List,
coefficients: Union[float, List[float]] = 1.0,
dtype: Optional[DType] = None,
):
r"""
Returns the action of a sum of local operators.
Args:
operators: A list of ContinuousOperator objects
coefficients: A coefficient for each ContinuousOperator object
dtype: Data type of the matrix elements. Defaults to `np.float64`
"""
hil = [op.hilbert for op in operators]
if not all(_ == hil[0] for _ in hil):
raise NotImplementedError(
"Cannot add operators on different hilbert spaces"
)
self._ops = operators
self._coeff = jnp.asarray(coefficients)
if dtype is None:
dtype = functools.reduce(
lambda dt, op: jnp.promote_types(dt, op.dtype), operators, float
)
self._dtype = dtype
super().__init__(hil[0])
self._is_hermitian = all([op.is_hermitian for op in operators])
@property
def is_hermitian(self):
return self._is_hermitian
@property
def dtype(self) -> DType:
"""The dtype of the operator's matrix elements ⟨σ|Ô|σ'⟩."""
return self._dtype
def __repr__(self):
return f"SumOperator({self.hilbert})"
@nk.vqs.get_local_kernel_arguments.dispatch
def get_local_kernel_arguments(vstate: nk.vqs.MCState, op: SumOperator):
sigma = vstate.samples
term_args = [nk.vqs.get_local_kernel_arguments(vstate, term)[1] for term in op._ops]
return sigma, (term_args, op._coeff)
def e_loc(term_kernel_fun, logpsi, pars, sigma, extra_args):
term_args, term_coeffs = extra_args
accum = 0.0
for (coeff, kernel_fun, args) in zip(term_coeffs, term_kernel_fun, term_args):
ptargs = jax.tree_map(lambda x: (x.shape, x.dtype), args)
print(f"calling kernel {kernel_fun} for {coeff} with {sigma.shape} and {ptargs}")
accum = accum + coeff * kernel_fun(logpsi, pars, sigma, args)
return accum
@nk.vqs.get_local_kernel.dispatch
def get_local_kernel(vstate: nk.vqs.MCState, op: SumOperator):
term_kernel_fun = tuple(nk.vqs.get_local_kernel(vstate, term) for term in op._ops)
return nk.utils.HashablePartial(e_loc, term_kernel_fun)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment