Skip to content

Instantly share code, notes, and snippets.

@krsnewwave
Last active June 15, 2022 16:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save krsnewwave/33acf8f8ea1fbd963cd462860c4168f1 to your computer and use it in GitHub Desktop.
Save krsnewwave/33acf8f8ea1fbd963cd462860c4168f1 to your computer and use it in GitHub Desktop.
from nvtabular.loader.torch import TorchAsyncItr, DLDataLoader
# define your categoricals, continuous variables, and labels
train_iter = TorchAsyncItr(
train_dataset,
batch_size=BATCH_SIZE,
cats=CATEGORICAL_COLUMNS + CATEGORICAL_MH_COLUMNS,
conts=NUMERIC_COLUMNS,
labels=["rating"],
)
train_loader = DLDataLoader(
train_iter, batch_size=None, collate_fn=lambda x: x, pin_memory=False, num_workers=0
)
# you can also use the workflow to get info about your data
# for example, if you have categoricals, you can get the vocabular and embedding sizes:
proc = nvt.Workflow.load(os.path.join(WORKING_DIR, "workflow"))
EMBEDDING_TABLE_SHAPES, MH_EMBEDDING_TABLE_SHAPES = nvt.ops.get_embedding_sizes(proc)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment