Skip to content

Instantly share code, notes, and snippets.

@bschifferer
Created January 6, 2021 14:19
Show Gist options
  • Save bschifferer/78e2ca51d004ead898ea9d3d96879c6a to your computer and use it in GitHub Desktop.
Save bschifferer/78e2ca51d004ead898ea9d3d96879c6a to your computer and use it in GitHub Desktop.
import tensorflow as tf
import os
# we can control how much memory to give tensorflow with this environment variable
# IMPORTANT: make sure you do this before you initialize TF's runtime, otherwise
# TF will have claimed all free GPU memory
os.environ['TF_MEMORY_ALLOCATION'] = "0.8" # fraction of free memory
from nvtabular.loader.tensorflow import KerasSequenceLoader
BATCH_SIZE = 1024*128
LABEL_NAMES = [<label column names>]
NUMERIC_COLUMNS = [<numeric column names>]
CATEGORICAL_COLUMNS = [<categorical column names>]
train_dataset_tf = KerasSequenceLoader(
'./train/*.parquet', # you could also use a glob pattern
batch_size=BATCH_SIZE,
label_names=LABEL_NAMES,
cat_names=CATEGORICAL_COLUMNS,
cont_names=NUMERIC_COLUMNS,
engine='parquet',
shuffle=True,
buffer_size=0.06, # how many batches to load at once
parts_per_chunk=1
)
# Define model structure
model = get_model()
model.fit(train_dataset_tf, epochs=10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment