Last active
August 6, 2023 16:35
-
-
Save vnegi10/d5ce1e0cf7a4a92b8f75d2da653d1b69 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
run_training(loss_change, | |
learn, | |
x_in, | |
y_in) | |
Run training epochs until the Δloss ≤ loss_change. | |
""" | |
function run_training(loss_change, | |
learn, | |
x_in, | |
y_in) | |
# Initialize Flux model | |
flux_model = Dense(1 => 1) | |
loss_initial = get_loss(flux_model, x_in, y_in) | |
all_losses = [loss_initial] | |
flux_model_new = nothing | |
num_epochs = 0 | |
while true | |
flux_model_new = update_model!(learn, flux_model, x_in, y_in) | |
loss_new = get_loss(flux_model_new, x_in, y_in) | |
num_epochs += 1 | |
push!(all_losses, loss_new) | |
if abs(loss_new - loss_initial) ≤ loss_change | |
break | |
else | |
loss_initial = loss_new | |
flux_model = flux_model_new | |
end | |
end | |
return flux_model_new, all_losses, num_epochs | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment