Skip to content

Instantly share code, notes, and snippets.

View jvmncs's full-sized avatar
📖

jvmncs

📖
View GitHub Profile
@jvmncs
jvmncs / jax-poetry.sh
Created January 25, 2024 19:44
spinning up cuda-enabled jax in a poetry project (2024 jan)
#!/bin/bash
# Add jaxlib source with priority explicit
poetry source add jaxlib https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --priority explicit
# Add jaxlib package with specified version and extras
poetry add jaxlib~=0.4.23 --extras="cuda12.cudnn89" --source=jaxlib
# Add jax source with priority explicit
poetry source add jax https://storage.googleapis.com/jax-releases/jax_releases.html --priority explicit
# Add jax package with specified version and extras
@jvmncs
jvmncs / equinox_inaxes.py
Last active May 13, 2022 15:47
eqx.filter_vmap failing to respect in_axes kwarg
import equinox as eqx
import jax
import numpy as np
def func(x):
return x + x
v_func_jax = jax.vmap(func, in_axes=0)
v_func_eqx = eqx.filter_vmap(func, in_axes=0)
@jvmncs
jvmncs / paillier.py
Last active August 18, 2020 22:01
Paillier Aggregation in TensorFlow Federated
import tensorflow as tf
import tensorflow_federated as tff
from federated_aggregations import paillier
paillier_factory = paillier.local_paillier_executor_factory()
paillier_context = tff.framework.ExecutionContext(paillier_factory)
tff.framework.set_default_context(paillier_context)
# data from 5 clients
x = [np.array([i, i + 1], dtype=np.int32) for i in range(5)]
model_ptr.fit(x_train_ptr, y_train_ptr, epochs=2)
# ==> Train on 60000 samples
# Epoch 1/2
# 60000/60000 [==============================] - 2s 36us/sample - loss: 0.3008 - accuracy: 0.9129
# Epoch 2/2
# 60000/60000 [==============================] - 2s 32us/sample - loss: 0.1449 - accuracy: 0.9569
model_ptr = model.send(alice)
print(model_ptr)
# ==> (Wrapper)>[ObjectPointer | me:random_id1 -> alice:random_id2]
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
# Compile with optimizer, loss and metrics
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# Converting the data from numpy to tf.Tensor in order to have PySyft functionalities.
x_train, y_train = tf.convert_to_tensor(x_train), tf.convert_to_tensor(y_train)
x_test, y_test = tf.convert_to_tensor(x_test), tf.convert_to_tensor(y_test)
# Send data to Alice (for demonstration purposes)
x_train_ptr = x_train.send(alice)
x = tf.expand_dims(id[0], 0)
# Initialize the weight
w_init = tf.initializers.glorot_normal()
w = tf.Variable(w_init(shape=(2, 1), dtype=tf.float32)).send(alice)
z = tf.matmul(x, w)
# Manual differentiation & update
dzdx = tf.transpose(x)
w.assign_sub(dzdx)
y_ptr = x_ptr + x_ptr
y = tf.reshape(y_ptr, shape=[2, 2])
id = tf.constant([[1., 0.], [0., 1.]]).send(alice)
z = tf.matmul(y, id).get()
print(z)
# ==> tf.Tensor([[2. 4.]
# [6. 8.]], shape=(2, 2), dtype=float32)
alice = syft.VirtualWorker(hook, “alice”)
x = tf.constant([1., 2., 3., 4.])
x_ptr = x.send(alice)
print(x_ptr)
# ==> (Wrapper)>[PointerTensor | me:random_id1 -> alice:random_id2]