Skip to content

Instantly share code, notes, and snippets.

@erick016

erick016/bigL.py Secret

Created September 18, 2021 23:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save erick016/fbf91402b5b743a9d3f1467d081c165b to your computer and use it in GitHub Desktop.
Save erick016/fbf91402b5b743a9d3f1467d081c165b to your computer and use it in GitHub Desktop.
print("Reached Big L.")
for batch_idx in range(NUM_BATCHES):
#print("epoch_overall_loss_per_batch Index:" + str(epoch * NUM_BATCHES + batch_idx))
print("batch_idx:" + str(batch_idx))
#print("train_loss_per_batch_L1 Shape:" + str(np.shape(train_loss_per_batch_L1)))
#print("train_loss_per_batch_L2 Shape:" + str(np.shape(train_loss_per_batch_L2)))
#print("train_loss_per_batch_L3 Shape:" + str(np.shape(train_loss_per_batch_L3)))
#print("epoch_overall_loss_per_batch Shape:" + str(np.shape(train_loss_per_batch_L3)))
#print("===================================")
epoch_overall_loss_per_batch[epoch * NUM_BATCHES + batch_idx, :] = (small_lambda * tlpb1_copy[batch_idx] +
((1 - small_lambda) * (tlpb2_copy[batch_idx] + tlpb3_copy[batch_idx]))) #used to be [(epoch + 1) * batch_idx] in all the brackets
#print(np.shape((small_lambda * tlpb1_copy[batch_idx] + ((1 - small_lambda) *
#(tlpb2_copy[batch_idx] + tlpb3_copy[batch_idx])))))
if batch_idx >= 99:
break #for testing with smaller sizes
if epoch == 0:
d = 1
for batch_idx in range(BATCH_SIZE):
if batch_idx != 0: #because d is supposed to be the accumulation of the changes in the loss function
d = d + epoch_overall_loss_per_batch[epoch * NUM_BATCHES + batch_idx] #doing per batch because of foreach loop size(A2)
else:
for batch_idx in range(BATCH_SIZE):
if batch_idx != 0:
d = d + epoch_overall_loss_per_batch[epoch * NUM_BATCHES + batch_idx]
#v Call Training here v (commented out optimizers (trainable vars) for step 11)
#train and reset model_2_optimizer with new d. keep trainable parameters from last time. !or make a slot for d!
model_2_train_count = 0
for image in noisified_images: #image, num_batches, rect. size
# is it batches or not?
model_2_train_count = model_2_train_count + 1
for label in noisified_labels:
#(model, x, y, sl, L1, L2, L3)
#plus r so that it's still perturbed
current_worst_pert_image_rs_tensor = tf.convert_to_tensor((image + r).reshape(BATCH_SIZE,MNIST_IMG_SIZE_SQ,MNIST_IMG_SIZE_SQ,1))
#should be the reshape's shape
#print( "image + r shape" + str(np.shape(current_worst_pert_image_rs)))
print( "epoch_loss_l1 shape" + str(np.shape(epoch_loss_L1)))
print( "epoch_loss_l2 shape" + str(np.shape(epoch_loss_L2)))
print( "epoch_loss_l3 shape" + str(np.shape(epoch_loss_L3)))
model_2_optimizer.set_changing_hypers(d)
loss_value = model_2_loss_function(small_lambda,tlpb1_copy[model_2_train_count],tlpb2_copy[model_2_train_count],tlpb3_copy[model_2_train_count])
#print( "datatype" + str(type(loss_value)) + "actual value:" + loss_value)
print("Investigation of loss_value:")
print(print(str(tf.convert_to_tensor(loss_value)))) # + " ||| " + print(str(type(tf.convert_to_tensor(loss_value)))))
loss_value_tensor = tf.convert_to_tensor(loss_value)
print("trainables:")
print(model_2.trainable_variables)
#'''loss_value,''' grads
loss_and_grads = model_2_get_grad(model_2, current_worst_pert_image_rs_tensor, label, loss_value_tensor)
#,small_lambda,epoch_loss_L1,epoch_loss_L2,epoch_loss_L3)
#print("grads:" + grads) #can only concatenate str (not "tuple") to str
print("grads:")
print(loss_and_grads)
print(repr(loss_and_grads))
print(isinstance(loss_and_grads,tf.Tensor))
#grads_toTFVar = tf.convert_to_tensor(grads)
print(grads)
grads_toTFVar = tf.Variable(grads)
model_2_optimizer.apply_gradients(zip(grads, model_2.trainable_variables))
train_loss_L_Overall.update_state(loss_value)
train_accuracy_L_Overall.update_state(label, model_2(current_worst_pert_image_rs_tensor, training=True))
#^ Call Training here ^
d = 0 # clear d for next epoch.
break #For testing just one epoch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment