Skip to content

Instantly share code, notes, and snippets.

@krsnewwave
Created June 15, 2022 16:27
Show Gist options
  • Save krsnewwave/ae93cabb0f6d61b5c0f2192e91f84442 to your computer and use it in GitHub Desktop.
Save krsnewwave/ae93cabb0f6d61b5c0f2192e91f84442 to your computer and use it in GitHub Desktop.
train_dataset = nvt.Dataset([os.path.join(WORKING_DIR, "train.parquet")])
valid_dataset = nvt.Dataset([os.path.join(WORKING_DIR, "valid.parquet")])
# fit the workflow with our dataset
workflow.fit(train_dataset)
# define the dtypes
CATEGORICAL_COLUMNS = ["userId", "movieId"]
LABEL_COLUMNS = ["rating"]
dict_dtypes = {}
for col in CATEGORICAL_COLUMNS:
dict_dtypes[col] = np.int64
for col in LABEL_COLUMNS:
dict_dtypes[col] = np.float32
# transform data
workflow.transform(train_dataset).to_parquet(
output_path=os.path.join(WORKING_DIR, "train"),
shuffle=nvt.io.Shuffle.PER_PARTITION,
cats=["userId", "movieId", "genres"],
labels=["rating"],
dtypes=dict_dtypes,
)
# save the workflow
workflow.save(os.path.join(WORKING_DIR, "workflow"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment