Skip to content

Instantly share code, notes, and snippets.

View PhilipVinc's full-sized avatar
🎯
Solving problems...

Filippo Vicentini PhilipVinc

🎯
Solving problems...
View GitHub Profile
# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
filippovicentini in mpi4jax at cqslpc1 on  mlir [$] via python-3.10.6 via 🐍 3.10.6 took 2s
➜ JAX_PLATFORMS="cpu" pytest --tb=short
=========================================================================== test session starts ============================================================================
platform linux -- Python 3.10.6, pytest-7.2.0, pluggy-1.0.0
MPI vendor: ('Open MPI', (4, 1, 2))
MPI rank: 0
MPI size: 1
rootdir: /home/filippovicentini/Dropbox/Ricerca/Codes/Python/mpi4jax, configfile: pyproject.toml, testpaths: tests
plugins: cov-4.0.0
@PhilipVinc
PhilipVinc / sparse.py
Created November 3, 2022 18:14
bug in jax
# pip install jax jaxlib netket
"""Module for the common control flow utilities."""
import os
from functools import partial
from typing import Callable, Optional, Sequence, Set
from jax import core
from jax._src.scipy.sparse.linalg import (_vdot_real_tree, _identity, _normalize_matvec, _shapes, _sub, _add, _mul, _vdot_tree)
from functools import partial
import operator
import numpy as np
import jax.numpy as jnp
from jax import device_put
from jax import lax
from jax import scipy as jsp
@PhilipVinc
PhilipVinc / grad.py
Last active October 25, 2022 15:10
crash pennylane
import jax
import jax.numpy as jnp
import numpy as np
import pennylane as qml
phys_qubits = 2
pars_q = np.random.rand(2)
def minimal_circ(params, prng_key=None):
@PhilipVinc
PhilipVinc / tree_log.py
Last active June 14, 2022 17:18
tree log
def tree_log(tree, root, data, *, iter=None):
if tree is None:
return
elif isinstance(tree, list):
for (i, val) in enumerate(tree):
tree_log(val, f"{root}/{i}", data, iter=iter)
# handle namedtuples
elif isinstance(tree, list) and hasattr(tree, "_fields"):
@PhilipVinc
PhilipVinc / classical_ising.py
Last active May 26, 2022 17:27
Implement a fast classical Ising hamiltonian
import netket as nk
from netket.operator import AbstractOperator
import numpy as np
import jax.numpy as jnp
class ClassicalIsingOperator(AbstractOperator):
def __init__(self, hilbert, H0):
assert H0.shape == (hilbert.size, hilbert.size)
self.H0 = jnp.asarray(H0)
@PhilipVinc
PhilipVinc / sumoperator.py
Created May 17, 2022 08:59
NetKet sum operator
from typing import Union, List, Optional
from netket.utils.types import DType
import functools
class SumOperator(nk.operator.AbstractOperator):
r"""This class implements the action of the _expect_kernel()-method of
ContinuousOperator for a sum of ContinuousOperator objects.
"""
def __init__(
from functools import partial
from typing import Any, Callable, Optional, Tuple, Union
import jax
from flax import linen as nn
from jax import numpy as jnp
from netket.hilbert import ContinuousHilbert
from netket.utils import mpi, wrap_afun
@PhilipVinc
PhilipVinc / complex-flax.py
Created November 29, 2021 09:47
complex-flax.py
from jax.config import config
config.update("jax_enable_x64", True)
from jax import random
from jax import numpy as jnp
import flax.linen as nn
key = random.PRNGKey(0)