Skip to content

Instantly share code, notes, and snippets.

@rahuljantwal-8451
Created October 3, 2024 20:59
Show Gist options
  • Save rahuljantwal-8451/fb51340ee397ab2cfc8f4f35fa42dd3a to your computer and use it in GitHub Desktop.
Save rahuljantwal-8451/fb51340ee397ab2cfc8f4f35fa42dd3a to your computer and use it in GitHub Desktop.
Process Dataset using NVTabular
import argparse
import os
import nvtabular as nvt
from nvtabular import ops
from merlin.io import Dataset
from merlin.dag.ops.subgraph import Subgraph
def parse_args():
parser = argparse.ArgumentParser(description='Process the generated dataset using NVTabular')
parser.add_argument('--input_path', type=str, default='./data/simulated/source_dataset/*.parquet', help='Input dataset path')
parser.add_argument('--output_path', type=str, default='./data/simulated/processed_dataset', help='Output dataset path')
parser.add_argument('--nvt_workspace', type=str, default='./data/simulated/nvt_workspace', help='NVTabular workspace directory')
return parser.parse_args()
def compile_workflow(col1, col2, col3, nvt_workspace_dir) -> nvt.Workflow:
cols = [col1, col2, col3]
# Categorify operations for each column
user_cat = ops.Categorify(dtype="int32",
out_path=nvt_workspace_dir,
on_host=False,
cat_cache="host")
item_cat = ops.Categorify(dtype="int32",
out_path=nvt_workspace_dir,
on_host=False,
cat_cache="host")
# Handle missing values in all columns
handle_missing = (cols >> ops.FillMissing())
# Subgraph for user features (col1 and col2)
subgraph_user = Subgraph(
"user",
([col1] >> user_cat >> ops.TagAsUserID() >> ops.TagAsUserFeatures()) +
([col2] >> user_cat >> ops.TagAsUserFeatures())
)
# Subgraph for item features (col3)
subgraph_item = Subgraph(
"item",
([col3] >> item_cat >> ops.TagAsItemID() >> ops.TagAsItemFeatures())
)
outputs = handle_missing + subgraph_user + subgraph_item
return nvt.Workflow(outputs)
if __name__ == "__main__":
args = parse_args()
# Create a Dataset object
dataset = Dataset(args.input_path, engine='parquet')
# Compile the workflow
workflow = compile_workflow('col1', 'col2', 'col3', args.nvt_workspace)
# Fit and transform the dataset
workflow.fit_transform(dataset).to_parquet(output_path=args.output_path)
# Save the workflow
workflow.save(args.output_path)
# Load the processed dataset
processed_dataset = Dataset(args.output_path, engine='parquet')
print(f"Processed dataset saved to: {args.output_path}")
print(f"Workflow saved to: {args.output_path}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment