Skip to content

Instantly share code, notes, and snippets.

@zredlined
Created January 29, 2020 22:37
Show Gist options
  • Save zredlined/40e3f4ec7a7ee2b56d271d51b099fb11 to your computer and use it in GitHub Desktop.
Save zredlined/40e3f4ec7a7ee2b56d271d51b099fb11 to your computer and use it in GitHub Desktop.
def validate_csv(text, field_count=4):
# Quick validation check to make sure input is sane
return len(text.split(", ")) == field_count
def generate_rec(model, start_string, num_chars, temperature, expected_field_count):
invalid_record_count = 0
valid_record_count = 0
# Convert our start string to numbers (vectorizing)
input_eval = [char2idx[s] for s in start_string]
input_eval = tf.expand_dims(input_eval, 0)
# Empty string to store our results
text_generated = ""
# When generating text, use batch size of 1
model.reset_states()
for i in range(num_chars):
predictions = model(input_eval)
# remove the batch dimension
predictions = tf.squeeze(predictions, 0)
# using a categorical distribution to predict the word returned by the model
predictions = predictions / temperature
predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()
# We pass the predicted word as the next input to the model
# along with the previous hidden state
input_eval = tf.expand_dims([predicted_id], 0)
text_generated += idx2char[predicted_id]
if idx2char[predicted_id] == '\n':
if validate_csv(text_generated, expected_field_count):
print(text_generated.strip())
valid_record_count += 1
else:
invalid_record_count += 1
text_generated = ""
return {'valid': valid_record_count, 'invalid': invalid_record_count}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment