-
-
Save erick016/898ebb309ec9cfaae314aaef8ff0e385 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Training replacement | |
d = 0 #from algorithm 2, used in computing loss | |
#alpha = .33 #learning rate | |
#def nt_grad(curr_d): | |
#return -1*alpha*(curr_d/BATCH_SIZE) | |
def compute_loss(model, x, y, training): | |
out = model(x, training=training) | |
loss = loss_object(y_true=y, y_pred=out) | |
#with tf.GradientTape() as tape: | |
return loss | |
def get_grad(model, x, y): | |
with tf.GradientTape() as tape: | |
loss = compute_loss(model, x, y, training=True) | |
return loss, tape.gradient(loss, model.trainable_variables) | |
def model_2_compute_loss(model, x, y, loss_value):#, training, sl, L1, L2, L3): | |
print("model_2: print x shape in compute loss f'n" + str(np.shape(x))) | |
print("model_2: print y shape in compute loss f'n" + str(np.shape(y))) | |
print("model_2: print loss_value shape in compute loss f'n" + str(np.shape(loss_value))) | |
out = model(x, training=True) | |
loss = loss_value #model_2_loss_function(sl, L1, L2, L3) | |
#return tf.convert_to_tensor(loss) <-What Jonathan saw | |
return loss | |
def model_2_get_grad(model, x, y, loss_value_tensor): #, sl, L1, L2, L3): | |
#count = 0 | |
print("model_2: print loss_value_tensor in get grad f'n" + str(np.shape(loss_value_tensor))) | |
with tf.GradientTape(persistent=True) as tape: | |
for count in range(BATCH_SIZE): | |
print( "loss value tensor at" + " " + str(count) + ": " + str(loss_value_tensor[count])) | |
loss = model_2_compute_loss(model, x, y, loss_value_tensor[count]) #, True, sl, L1, L2, L3) #training = True | |
#count = count + 1 | |
#returning only the first argument?? | |
#print(repr(model.trainable_variables)) | |
#print(repr(loss)) | |
#tape.gradient(loss, model.trainable_variables) <- I think this was for printouts | |
return loss, tape.gradient(loss, model.trainable_variables) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment