-
-
Save erick016/fbf91402b5b743a9d3f1467d081c165b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
print("Reached Big L.") | |
for batch_idx in range(NUM_BATCHES): | |
#print("epoch_overall_loss_per_batch Index:" + str(epoch * NUM_BATCHES + batch_idx)) | |
print("batch_idx:" + str(batch_idx)) | |
#print("train_loss_per_batch_L1 Shape:" + str(np.shape(train_loss_per_batch_L1))) | |
#print("train_loss_per_batch_L2 Shape:" + str(np.shape(train_loss_per_batch_L2))) | |
#print("train_loss_per_batch_L3 Shape:" + str(np.shape(train_loss_per_batch_L3))) | |
#print("epoch_overall_loss_per_batch Shape:" + str(np.shape(train_loss_per_batch_L3))) | |
#print("===================================") | |
epoch_overall_loss_per_batch[epoch * NUM_BATCHES + batch_idx, :] = (small_lambda * tlpb1_copy[batch_idx] + | |
((1 - small_lambda) * (tlpb2_copy[batch_idx] + tlpb3_copy[batch_idx]))) #used to be [(epoch + 1) * batch_idx] in all the brackets | |
#print(np.shape((small_lambda * tlpb1_copy[batch_idx] + ((1 - small_lambda) * | |
#(tlpb2_copy[batch_idx] + tlpb3_copy[batch_idx]))))) | |
if batch_idx >= 99: | |
break #for testing with smaller sizes | |
if epoch == 0: | |
d = 1 | |
for batch_idx in range(BATCH_SIZE): | |
if batch_idx != 0: #because d is supposed to be the accumulation of the changes in the loss function | |
d = d + epoch_overall_loss_per_batch[epoch * NUM_BATCHES + batch_idx] #doing per batch because of foreach loop size(A2) | |
else: | |
for batch_idx in range(BATCH_SIZE): | |
if batch_idx != 0: | |
d = d + epoch_overall_loss_per_batch[epoch * NUM_BATCHES + batch_idx] | |
#v Call Training here v (commented out optimizers (trainable vars) for step 11) | |
#train and reset model_2_optimizer with new d. keep trainable parameters from last time. !or make a slot for d! | |
model_2_train_count = 0 | |
for image in noisified_images: #image, num_batches, rect. size | |
# is it batches or not? | |
model_2_train_count = model_2_train_count + 1 | |
for label in noisified_labels: | |
#(model, x, y, sl, L1, L2, L3) | |
#plus r so that it's still perturbed | |
current_worst_pert_image_rs_tensor = tf.convert_to_tensor((image + r).reshape(BATCH_SIZE,MNIST_IMG_SIZE_SQ,MNIST_IMG_SIZE_SQ,1)) | |
#should be the reshape's shape | |
#print( "image + r shape" + str(np.shape(current_worst_pert_image_rs))) | |
print( "epoch_loss_l1 shape" + str(np.shape(epoch_loss_L1))) | |
print( "epoch_loss_l2 shape" + str(np.shape(epoch_loss_L2))) | |
print( "epoch_loss_l3 shape" + str(np.shape(epoch_loss_L3))) | |
model_2_optimizer.set_changing_hypers(d) | |
loss_value = model_2_loss_function(small_lambda,tlpb1_copy[model_2_train_count],tlpb2_copy[model_2_train_count],tlpb3_copy[model_2_train_count]) | |
#print( "datatype" + str(type(loss_value)) + "actual value:" + loss_value) | |
print("Investigation of loss_value:") | |
print(print(str(tf.convert_to_tensor(loss_value)))) # + " ||| " + print(str(type(tf.convert_to_tensor(loss_value))))) | |
loss_value_tensor = tf.convert_to_tensor(loss_value) | |
print("trainables:") | |
print(model_2.trainable_variables) | |
#'''loss_value,''' grads | |
loss_and_grads = model_2_get_grad(model_2, current_worst_pert_image_rs_tensor, label, loss_value_tensor) | |
#,small_lambda,epoch_loss_L1,epoch_loss_L2,epoch_loss_L3) | |
#print("grads:" + grads) #can only concatenate str (not "tuple") to str | |
print("grads:") | |
print(loss_and_grads) | |
print(repr(loss_and_grads)) | |
print(isinstance(loss_and_grads,tf.Tensor)) | |
#grads_toTFVar = tf.convert_to_tensor(grads) | |
print(grads) | |
grads_toTFVar = tf.Variable(grads) | |
model_2_optimizer.apply_gradients(zip(grads, model_2.trainable_variables)) | |
train_loss_L_Overall.update_state(loss_value) | |
train_accuracy_L_Overall.update_state(label, model_2(current_worst_pert_image_rs_tensor, training=True)) | |
#^ Call Training here ^ | |
d = 0 # clear d for next epoch. | |
break #For testing just one epoch | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment