Created
January 29, 2020 22:37
-
-
Save zredlined/40e3f4ec7a7ee2b56d271d51b099fb11 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def validate_csv(text, field_count=4): | |
# Quick validation check to make sure input is sane | |
return len(text.split(", ")) == field_count | |
def generate_rec(model, start_string, num_chars, temperature, expected_field_count): | |
invalid_record_count = 0 | |
valid_record_count = 0 | |
# Convert our start string to numbers (vectorizing) | |
input_eval = [char2idx[s] for s in start_string] | |
input_eval = tf.expand_dims(input_eval, 0) | |
# Empty string to store our results | |
text_generated = "" | |
# When generating text, use batch size of 1 | |
model.reset_states() | |
for i in range(num_chars): | |
predictions = model(input_eval) | |
# remove the batch dimension | |
predictions = tf.squeeze(predictions, 0) | |
# using a categorical distribution to predict the word returned by the model | |
predictions = predictions / temperature | |
predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy() | |
# We pass the predicted word as the next input to the model | |
# along with the previous hidden state | |
input_eval = tf.expand_dims([predicted_id], 0) | |
text_generated += idx2char[predicted_id] | |
if idx2char[predicted_id] == '\n': | |
if validate_csv(text_generated, expected_field_count): | |
print(text_generated.strip()) | |
valid_record_count += 1 | |
else: | |
invalid_record_count += 1 | |
text_generated = "" | |
return {'valid': valid_record_count, 'invalid': invalid_record_count} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment