Created
May 17, 2022 08:59
-
-
Save PhilipVinc/6af1ea3b9dd0c7b59e24a0760e893b16 to your computer and use it in GitHub Desktop.
NetKet sum operator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Union, List, Optional | |
from netket.utils.types import DType | |
import functools | |
class SumOperator(nk.operator.AbstractOperator): | |
r"""This class implements the action of the _expect_kernel()-method of | |
ContinuousOperator for a sum of ContinuousOperator objects. | |
""" | |
def __init__( | |
self, | |
*operators: List, | |
coefficients: Union[float, List[float]] = 1.0, | |
dtype: Optional[DType] = None, | |
): | |
r""" | |
Returns the action of a sum of local operators. | |
Args: | |
operators: A list of ContinuousOperator objects | |
coefficients: A coefficient for each ContinuousOperator object | |
dtype: Data type of the matrix elements. Defaults to `np.float64` | |
""" | |
hil = [op.hilbert for op in operators] | |
if not all(_ == hil[0] for _ in hil): | |
raise NotImplementedError( | |
"Cannot add operators on different hilbert spaces" | |
) | |
self._ops = operators | |
self._coeff = jnp.asarray(coefficients) | |
if dtype is None: | |
dtype = functools.reduce( | |
lambda dt, op: jnp.promote_types(dt, op.dtype), operators, float | |
) | |
self._dtype = dtype | |
super().__init__(hil[0]) | |
self._is_hermitian = all([op.is_hermitian for op in operators]) | |
@property | |
def is_hermitian(self): | |
return self._is_hermitian | |
@property | |
def dtype(self) -> DType: | |
"""The dtype of the operator's matrix elements ⟨σ|Ô|σ'⟩.""" | |
return self._dtype | |
def __repr__(self): | |
return f"SumOperator({self.hilbert})" | |
@nk.vqs.get_local_kernel_arguments.dispatch | |
def get_local_kernel_arguments(vstate: nk.vqs.MCState, op: SumOperator): | |
sigma = vstate.samples | |
term_args = [nk.vqs.get_local_kernel_arguments(vstate, term)[1] for term in op._ops] | |
return sigma, (term_args, op._coeff) | |
def e_loc(term_kernel_fun, logpsi, pars, sigma, extra_args): | |
term_args, term_coeffs = extra_args | |
accum = 0.0 | |
for (coeff, kernel_fun, args) in zip(term_coeffs, term_kernel_fun, term_args): | |
ptargs = jax.tree_map(lambda x: (x.shape, x.dtype), args) | |
print(f"calling kernel {kernel_fun} for {coeff} with {sigma.shape} and {ptargs}") | |
accum = accum + coeff * kernel_fun(logpsi, pars, sigma, args) | |
return accum | |
@nk.vqs.get_local_kernel.dispatch | |
def get_local_kernel(vstate: nk.vqs.MCState, op: SumOperator): | |
term_kernel_fun = tuple(nk.vqs.get_local_kernel(vstate, term) for term in op._ops) | |
return nk.utils.HashablePartial(e_loc, term_kernel_fun) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment