Skip to content

Instantly share code, notes, and snippets.

@agastidukare
Last active April 1, 2020 02:03
Show Gist options
  • Save agastidukare/9551283d4206122826c8394ccf087788 to your computer and use it in GitHub Desktop.
Save agastidukare/9551283d4206122826c8394ccf087788 to your computer and use it in GitHub Desktop.
# loss function for calculating predictions and accuracy before pertubation
def loss(params, batch, test=0):
inputs, targets = batch
logits = predict(params, inputs)
preds = stax.logsoftmax(logits)
if(test==1):
print('Prediction Vector before softmax')
print(logits)
print("____________________________________________________________________________________")
print('Prediction Vector after softmax')
print(preds)
print("____________________________________________________________________________________")
return -(1/(preds.shape[0]))*np.sum(targets*preds)
# loss function for calculating gradients of loss w.r.t. input image
def lo(batch,params):
inputs, targets = batch
logits = predict(params, inputs)
preds = stax.logsoftmax(logits)
return -(1/(preds.shape[0]))*np.sum(targets*preds)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment