Created
April 4, 2019 16:00
-
-
Save marcopeix/94aa0186ef4903c2825603ec34e33571 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def model(data, ix_to_char, char_to_ix, num_iterations = 35000, n_a = 50, dino_names = 7, vocab_size = 27): | |
# Retrieve n_x and n_y from vocab_size | |
n_x, n_y = vocab_size, vocab_size | |
# Initialize parameters | |
parameters = initialize_parameters(n_a, n_x, n_y) | |
# Initialize loss (this is required because we want to smooth our loss, don't worry about it) | |
loss = get_initial_loss(vocab_size, dino_names) | |
# Build list of all dinosaur names (training examples). | |
with open("dinos.txt") as f: | |
examples = f.readlines() | |
examples = [x.lower().strip() for x in examples] | |
# Shuffle list of all dinosaur names | |
np.random.seed(0) | |
np.random.shuffle(examples) | |
# Initialize the hidden state of your LSTM | |
a_prev = np.zeros((n_a, 1)) | |
# Optimization loop | |
for j in range(num_iterations): | |
# Use the hint above to define one training example (X,Y) | |
index = j % len(examples) | |
X = [None] + [char_to_ix[ch] for ch in examples[index]] | |
Y = X[1:] + [char_to_ix["\n"]] | |
# Perform one optimization step: Forward-prop -> Backward-prop -> Clip -> Update parameters | |
# Choose a learning rate of 0.01 | |
curr_loss, gradients, a_prev = optimize(X, Y, a_prev, parameters) | |
# Use a latency trick to keep the loss smooth. It happens here to accelerate the training. | |
loss = smooth(loss, curr_loss) | |
# Every 2000 Iteration, generate "n" characters thanks to sample() to check if the model is learning properly | |
if j % 2000 == 0: | |
print('Iteration: %d, Loss: %f' % (j, loss) + '\n') | |
# The number of dinosaur names to print | |
seed = 0 | |
for name in range(dino_names): | |
# Sample indices and print them | |
sampled_indices = sample(parameters, char_to_ix, seed) | |
print_sample(sampled_indices, ix_to_char) | |
seed += 1 # To get the same result for grading purposed, increment the seed by one. | |
print('\n') | |
return parameters |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment