Created
October 3, 2024 20:59
-
-
Save rahuljantwal-8451/fb51340ee397ab2cfc8f4f35fa42dd3a to your computer and use it in GitHub Desktop.
Process Dataset using NVTabular
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import nvtabular as nvt | |
from nvtabular import ops | |
from merlin.io import Dataset | |
from merlin.dag.ops.subgraph import Subgraph | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Process the generated dataset using NVTabular') | |
parser.add_argument('--input_path', type=str, default='./data/simulated/source_dataset/*.parquet', help='Input dataset path') | |
parser.add_argument('--output_path', type=str, default='./data/simulated/processed_dataset', help='Output dataset path') | |
parser.add_argument('--nvt_workspace', type=str, default='./data/simulated/nvt_workspace', help='NVTabular workspace directory') | |
return parser.parse_args() | |
def compile_workflow(col1, col2, col3, nvt_workspace_dir) -> nvt.Workflow: | |
cols = [col1, col2, col3] | |
# Categorify operations for each column | |
user_cat = ops.Categorify(dtype="int32", | |
out_path=nvt_workspace_dir, | |
on_host=False, | |
cat_cache="host") | |
item_cat = ops.Categorify(dtype="int32", | |
out_path=nvt_workspace_dir, | |
on_host=False, | |
cat_cache="host") | |
# Handle missing values in all columns | |
handle_missing = (cols >> ops.FillMissing()) | |
# Subgraph for user features (col1 and col2) | |
subgraph_user = Subgraph( | |
"user", | |
([col1] >> user_cat >> ops.TagAsUserID() >> ops.TagAsUserFeatures()) + | |
([col2] >> user_cat >> ops.TagAsUserFeatures()) | |
) | |
# Subgraph for item features (col3) | |
subgraph_item = Subgraph( | |
"item", | |
([col3] >> item_cat >> ops.TagAsItemID() >> ops.TagAsItemFeatures()) | |
) | |
outputs = handle_missing + subgraph_user + subgraph_item | |
return nvt.Workflow(outputs) | |
if __name__ == "__main__": | |
args = parse_args() | |
# Create a Dataset object | |
dataset = Dataset(args.input_path, engine='parquet') | |
# Compile the workflow | |
workflow = compile_workflow('col1', 'col2', 'col3', args.nvt_workspace) | |
# Fit and transform the dataset | |
workflow.fit_transform(dataset).to_parquet(output_path=args.output_path) | |
# Save the workflow | |
workflow.save(args.output_path) | |
# Load the processed dataset | |
processed_dataset = Dataset(args.output_path, engine='parquet') | |
print(f"Processed dataset saved to: {args.output_path}") | |
print(f"Workflow saved to: {args.output_path}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment