Last active
May 18, 2020 16:11
-
-
Save zredlined/314566048b88573c80c015386af0f7af to your computer and use it in GitHub Desktop.
Optimal settings for training a synthetic data generation model on the UCI heart disease dataset from Kaggle
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
from gretel_synthetics.config import LocalConfig | |
# Create a config that we can use for both training and generating, with CPU-friendly settings | |
# The default values for ``max_chars`` and ``epochs`` are better suited for GPUs | |
config = LocalConfig( | |
max_lines=0, # read all lines (zero) | |
epochs=15, # 15-30 epochs for production | |
vocab_size=20000, # tokenizer model vocabulary size | |
character_coverage=1.0, # tokenizer model character coverage percent | |
gen_chars=0, # the maximum number of characters possible per-generated line of text | |
gen_lines=10000, # the number of generated text lines | |
rnn_units=256, # dimensionality of LSTM output space | |
batch_size=64, # batch size | |
buffer_size=1000, # buffer size to shuffle the dataset | |
dropout_rate=0.2, # fraction of the inputs to drop | |
dp=False, # let's use differential privacy | |
dp_learning_rate=0.015, # learning rate | |
dp_noise_multiplier=1.1, # control how much noise is added to gradients | |
dp_l2_norm_clip=1.0, # bound optimizer's sensitivity to individual training points | |
dp_microbatches=256, # split batches into minibatches for parallelism | |
checkpoint_dir=(Path.cwd() / 'checkpoints').as_posix(), | |
save_all_checkpoints=False, | |
input_data_path=annotated_file # filepath or S3 | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment