Skip to content

Instantly share code, notes, and snippets.

@erick016
Created September 18, 2021 23:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save erick016/898ebb309ec9cfaae314aaef8ff0e385 to your computer and use it in GitHub Desktop.
Save erick016/898ebb309ec9cfaae314aaef8ff0e385 to your computer and use it in GitHub Desktop.
#Training replacement
d = 0 #from algorithm 2, used in computing loss
#alpha = .33 #learning rate
#def nt_grad(curr_d):
#return -1*alpha*(curr_d/BATCH_SIZE)
def compute_loss(model, x, y, training):
out = model(x, training=training)
loss = loss_object(y_true=y, y_pred=out)
#with tf.GradientTape() as tape:
return loss
def get_grad(model, x, y):
with tf.GradientTape() as tape:
loss = compute_loss(model, x, y, training=True)
return loss, tape.gradient(loss, model.trainable_variables)
def model_2_compute_loss(model, x, y, loss_value):#, training, sl, L1, L2, L3):
print("model_2: print x shape in compute loss f'n" + str(np.shape(x)))
print("model_2: print y shape in compute loss f'n" + str(np.shape(y)))
print("model_2: print loss_value shape in compute loss f'n" + str(np.shape(loss_value)))
out = model(x, training=True)
loss = loss_value #model_2_loss_function(sl, L1, L2, L3)
#return tf.convert_to_tensor(loss) <-What Jonathan saw
return loss
def model_2_get_grad(model, x, y, loss_value_tensor): #, sl, L1, L2, L3):
#count = 0
print("model_2: print loss_value_tensor in get grad f'n" + str(np.shape(loss_value_tensor)))
with tf.GradientTape(persistent=True) as tape:
for count in range(BATCH_SIZE):
print( "loss value tensor at" + " " + str(count) + ": " + str(loss_value_tensor[count]))
loss = model_2_compute_loss(model, x, y, loss_value_tensor[count]) #, True, sl, L1, L2, L3) #training = True
#count = count + 1
#returning only the first argument??
#print(repr(model.trainable_variables))
#print(repr(loss))
#tape.gradient(loss, model.trainable_variables) <- I think this was for printouts
return loss, tape.gradient(loss, model.trainable_variables)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment