This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2021 The NetKet Authors - All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
filippovicentini in mpi4jax at cqslpc1 on mlir [$] via python-3.10.6 via 🐍 3.10.6 took 2s | |
➜ JAX_PLATFORMS="cpu" pytest --tb=short | |
=========================================================================== test session starts ============================================================================ | |
platform linux -- Python 3.10.6, pytest-7.2.0, pluggy-1.0.0 | |
MPI vendor: ('Open MPI', (4, 1, 2)) | |
MPI rank: 0 | |
MPI size: 1 | |
rootdir: /home/filippovicentini/Dropbox/Ricerca/Codes/Python/mpi4jax, configfile: pyproject.toml, testpaths: tests | |
plugins: cov-4.0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pip install jax jaxlib netket | |
"""Module for the common control flow utilities.""" | |
import os | |
from functools import partial | |
from typing import Callable, Optional, Sequence, Set | |
from jax import core |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from jax._src.scipy.sparse.linalg import (_vdot_real_tree, _identity, _normalize_matvec, _shapes, _sub, _add, _mul, _vdot_tree) | |
from functools import partial | |
import operator | |
import numpy as np | |
import jax.numpy as jnp | |
from jax import device_put | |
from jax import lax | |
from jax import scipy as jsp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
import pennylane as qml | |
phys_qubits = 2 | |
pars_q = np.random.rand(2) | |
def minimal_circ(params, prng_key=None): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def tree_log(tree, root, data, *, iter=None): | |
if tree is None: | |
return | |
elif isinstance(tree, list): | |
for (i, val) in enumerate(tree): | |
tree_log(val, f"{root}/{i}", data, iter=iter) | |
# handle namedtuples | |
elif isinstance(tree, list) and hasattr(tree, "_fields"): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import netket as nk | |
from netket.operator import AbstractOperator | |
import numpy as np | |
import jax.numpy as jnp | |
class ClassicalIsingOperator(AbstractOperator): | |
def __init__(self, hilbert, H0): | |
assert H0.shape == (hilbert.size, hilbert.size) | |
self.H0 = jnp.asarray(H0) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Union, List, Optional | |
from netket.utils.types import DType | |
import functools | |
class SumOperator(nk.operator.AbstractOperator): | |
r"""This class implements the action of the _expect_kernel()-method of | |
ContinuousOperator for a sum of ContinuousOperator objects. | |
""" | |
def __init__( |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
from typing import Any, Callable, Optional, Tuple, Union | |
import jax | |
from flax import linen as nn | |
from jax import numpy as jnp | |
from netket.hilbert import ContinuousHilbert | |
from netket.utils import mpi, wrap_afun |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from jax.config import config | |
config.update("jax_enable_x64", True) | |
from jax import random | |
from jax import numpy as jnp | |
import flax.linen as nn | |
key = random.PRNGKey(0) |
NewerOlder