Skip to content

Instantly share code, notes, and snippets.

View agastidukare's full-sized avatar

Agasti Kishor Dukare agastidukare

View GitHub Profile
# This function calculates, loss, predictions and gradients
def covnet(t,params):
test_acc,target_class, predicted_class = accuracy(params, shape_as_image(test_images, test_labels))
test_loss = loss(params, shape_as_image(test_images, test_labels),test=t)
grads = grad(lo)(shape_as_image(test_images, test_labels),params)
if(t==1):
print('Test set loss, accuracy (%): ({:.2f}, {:.2f})'.format(test_loss, 100 * test_acc))
print('predicted_class,target_class', predicted_class,target_class)
return grads, test_acc
num_epochs = 1
key = random.PRNGKey(123)
_, init_params = init_random_params(key, (-1, 28, 28, 1))
opt_state = opt_init(init_params)
itercount = itertools.count()
for _ in range(num_batches):
opt_state= update(key, next(itercount), opt_state, shape_as_image(*next(batches)))
params = get_params(opt_state)
learning_rate = 0.14
opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
@jit
def update(_, i, opt_state, batch):
params = get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
init_random_params, predict = stax.serial(
stax.Conv(64, (7,7), padding='SAME'),
stax.Relu,
stax.Conv(32, (4, 4), padding='SAME'),
stax.Relu,
stax.MaxPool((3, 3)),
stax.Flatten,
stax.Dense(128),
stax.Relu,
stax.Dense(10),
batch_size = 128
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
rng = npr.RandomState(0)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
def accuracy(params, batch):
inputs, targets = batch
target_class = np.argmax(targets, axis=1)
predicted_class = np.argmax(predict(params, inputs), axis=1)
return np.mean(predicted_class == target_class), target_class, predicted_class
# loss function for calculating predictions and accuracy before pertubation
def loss(params, batch, test=0):
inputs, targets = batch
logits = predict(params, inputs)
preds = stax.logsoftmax(logits)
if(test==1):
print('Prediction Vector before softmax')
print(logits)
print("____________________________________________________________________________________")
print('Prediction Vector after softmax')
def predict(params, inputs):
activations = inputs
for w, b in params[:-1]:
outputs = np.dot(activations, w) + b
activations = np.tanh(outputs)
final_w, final_b = params[-1]
logits = np.dot(activations, final_w) + final_b
return logits - logsumexp(logits, axis=1, keepdims=True)
_DATA = "/tmp/"
def _download(url, filename):
"""Download a url to a file in the JAX data temp directory."""
if not path.exists(_DATA):
os.makedirs(_DATA)
out_file = path.join(_DATA, filename)
if not path.isfile(out_file):
urllib.request.urlretrieve(url, out_file)
print("downloaded {} to {}".format(url, _DATA))
import array
import gzip
import itertools
import numpy
import numpy.random as npr
import os
import struct
import time
from os import path
import urllib.request