Created
January 6, 2021 14:19
-
-
Save bschifferer/78e2ca51d004ead898ea9d3d96879c6a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import os | |
# we can control how much memory to give tensorflow with this environment variable | |
# IMPORTANT: make sure you do this before you initialize TF's runtime, otherwise | |
# TF will have claimed all free GPU memory | |
os.environ['TF_MEMORY_ALLOCATION'] = "0.8" # fraction of free memory | |
from nvtabular.loader.tensorflow import KerasSequenceLoader | |
BATCH_SIZE = 1024*128 | |
LABEL_NAMES = [<label column names>] | |
NUMERIC_COLUMNS = [<numeric column names>] | |
CATEGORICAL_COLUMNS = [<categorical column names>] | |
train_dataset_tf = KerasSequenceLoader( | |
'./train/*.parquet', # you could also use a glob pattern | |
batch_size=BATCH_SIZE, | |
label_names=LABEL_NAMES, | |
cat_names=CATEGORICAL_COLUMNS, | |
cont_names=NUMERIC_COLUMNS, | |
engine='parquet', | |
shuffle=True, | |
buffer_size=0.06, # how many batches to load at once | |
parts_per_chunk=1 | |
) | |
# Define model structure | |
model = get_model() | |
model.fit(train_dataset_tf, epochs=10) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment