Skip to content

Instantly share code, notes, and snippets.

View rllin's full-sized avatar

Randall Lin rllin

View GitHub Profile
import os
from typing import List, Optional
import torch
from torch import distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record
from torch.utils.data import IterableDataset
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
from torchrec.datasets.random import RandomRecDataset
from torchrec.distributed import TrainPipelineSparseDist