Skip to content

Instantly share code, notes, and snippets.

# This file is largely inspired by and mostly follows the structure of
# ``fairscale.nn.FullyShardedDataParallel`` in
# https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/data_parallel/fully_sharded_data_parallel.py
from collections import OrderedDict
import contextlib
from enum import Enum, auto
import functools
import gc
from itertools import chain
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@hjm-aws
hjm-aws / reduce_scatter_coalesce_graph_change.py
Last active November 15, 2022 18:58
Example demonstrating graph variation when callbacks accumulated in bwd pass
import os
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
from fairscale.nn.wrap import enable_wrap, auto_wrap, default_auto_wrap_policy
from transformers import BertTokenizer, BertForMaskedLM, BertConfig
import functools
@hjm-aws
hjm-aws / test
Last active November 17, 2022 19:06
second version: i.e. my fixed version