Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save yf225/a6f05f88f92e71b37ef0a9757ec28b95 to your computer and use it in GitHub Desktop.
Save yf225/a6f05f88f92e71b37ef0a9757ec28b95 to your computer and use it in GitHub Desktop.
TRACED GRAPH
===== AFTER POST GRAD =====
/data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[8, 32]", primals_2: "f32[1024]", primals_3: "f32[32]", primals_4: "f32[64, 32]", primals_5: "f32[64]", primals_6, primals_7: "f32[4096]", primals_8: "f32[64]", primals_9: "f32[128, 64]", primals_10: "f32[128]", primals_11: "f32[16384]", primals_12: "f32[128]", primals_13: "f32[256, 128]", primals_14: "f32[256]"):
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
empty: "f32[2112]" = torch.ops.aten.empty.memory_format([2112], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False)
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
slice_1: "f32[1056]" = torch.ops.aten.slice.Tensor(empty, 0, 0, 1056)
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [1024, 32]); slice_1 = None
getitem: "f32[1024]" = split_with_sizes[0]
getitem_1: "f32[32]" = split_with_sizes[1]; split_with_sizes = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
_foreach_copy = torch.ops.aten._foreach_copy.default([getitem, getitem_1], [primals_2, primals_3]); getitem = getitem_1 = primals_2 = primals_3 = None
getitem_2: "f32[1024]" = _foreach_copy[0]
getitem_3: "f32[32]" = _foreach_copy[1]; _foreach_copy = None
# No stacktrace found for following nodes
slice_tensor: "f32[1056]" = torch.ops.aten.slice.Tensor(empty, 0, 0, 1056)
slice_scatter_default: "f32[1056]" = torch.ops.aten.slice_scatter.default(slice_tensor, getitem_2, 0, 0, 1024); slice_tensor = getitem_2 = None
slice_scatter_default_1: "f32[2112]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default, 0, 0, 1056); empty = slice_scatter_default = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
slice_3: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 0, 1056)
# No stacktrace found for following nodes
slice_tensor_1: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 0, 1056)
slice_scatter_default_2: "f32[1056]" = torch.ops.aten.slice_scatter.default(slice_tensor_1, getitem_3, 0, 1024, 1056); slice_tensor_1 = getitem_3 = None
slice_scatter_default_3: "f32[2112]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_1, slice_scatter_default_2, 0, 0, 1056); slice_scatter_default_1 = slice_scatter_default_2 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
slice_6: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 0, 1056); slice_scatter_default_3 = None
all_gather_into_tensor: "f32[2112]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_6, 2, '0'); slice_6 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
wait_tensor: "f32[2112]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
view_1: "f32[2, 1056]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]); wait_tensor = None
split_with_sizes_4 = torch.ops.aten.split_with_sizes.default(view_1, [1024, 32], 1); view_1 = None
getitem_10: "f32[2, 1024]" = split_with_sizes_4[0]
clone: "f32[2, 1024]" = torch.ops.aten.clone.default(getitem_10, memory_format = torch.contiguous_format); getitem_10 = None
view_2: "f32[2048]" = torch.ops.aten.reshape.default(clone, [2048]); clone = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided: "f32[64, 32]" = torch.ops.aten.as_strided.default(view_2, [64, 32], [32, 1], 0); view_2 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_13: "f32[2, 32]" = split_with_sizes_4[1]; split_with_sizes_4 = None
clone_1: "f32[2, 32]" = torch.ops.aten.clone.default(getitem_13, memory_format = torch.contiguous_format); getitem_13 = None
view_4: "f32[64]" = torch.ops.aten.reshape.default(clone_1, [64]); clone_1 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_1: "f32[64]" = torch.ops.aten.as_strided.default(view_4, [64], [1], 0); view_4 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
_foreach_copy_1 = torch.ops.aten._foreach_copy.default([primals_4, primals_5], [as_strided, as_strided_1]); primals_4 = primals_5 = as_strided = as_strided_1 = None
getitem_14: "f32[64, 32]" = _foreach_copy_1[0]
getitem_15: "f32[64]" = _foreach_copy_1[1]; _foreach_copy_1 = None
# File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
permute_1: "f32[32, 64]" = torch.ops.aten.permute.default(getitem_14, [1, 0]); getitem_14 = None
addmm: "f32[8, 64]" = torch.ops.aten.addmm.default(getitem_15, primals_1, permute_1); getitem_15 = permute_1 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
empty_1: "f32[8320]" = torch.ops.aten.empty.memory_format([8320], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False)
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
slice_7: "f32[4160]" = torch.ops.aten.slice.Tensor(empty_1, 0, 0, 4160)
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(slice_7, [4096, 64]); slice_7 = None
getitem_16: "f32[4096]" = split_with_sizes_6[0]
getitem_17: "f32[64]" = split_with_sizes_6[1]; split_with_sizes_6 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
_foreach_copy_2 = torch.ops.aten._foreach_copy.default([getitem_16, getitem_17], [primals_7, primals_8]); getitem_16 = getitem_17 = primals_7 = primals_8 = None
getitem_18: "f32[4096]" = _foreach_copy_2[0]
getitem_19: "f32[64]" = _foreach_copy_2[1]; _foreach_copy_2 = None
# No stacktrace found for following nodes
slice_tensor_2: "f32[4160]" = torch.ops.aten.slice.Tensor(empty_1, 0, 0, 4160)
slice_scatter_default_4: "f32[4160]" = torch.ops.aten.slice_scatter.default(slice_tensor_2, getitem_18, 0, 0, 4096); slice_tensor_2 = getitem_18 = None
slice_scatter_default_5: "f32[8320]" = torch.ops.aten.slice_scatter.default(empty_1, slice_scatter_default_4, 0, 0, 4160); empty_1 = slice_scatter_default_4 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
slice_9: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 0, 4160)
# No stacktrace found for following nodes
slice_tensor_3: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 0, 4160)
slice_scatter_default_6: "f32[4160]" = torch.ops.aten.slice_scatter.default(slice_tensor_3, getitem_19, 0, 4096, 4160); slice_tensor_3 = getitem_19 = None
slice_scatter_default_7: "f32[8320]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_5, slice_scatter_default_6, 0, 0, 4160); slice_scatter_default_5 = slice_scatter_default_6 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
slice_12: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_7, 0, 0, 4160); slice_scatter_default_7 = None
all_gather_into_tensor_1: "f32[8320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_12, 2, '0'); slice_12 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
wait_tensor_1: "f32[8320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
view_6: "f32[2, 4160]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]); wait_tensor_1 = None
split_with_sizes_10 = torch.ops.aten.split_with_sizes.default(view_6, [4096, 64], 1); view_6 = None
getitem_26: "f32[2, 4096]" = split_with_sizes_10[0]
clone_2: "f32[2, 4096]" = torch.ops.aten.clone.default(getitem_26, memory_format = torch.contiguous_format); getitem_26 = None
view_7: "f32[8192]" = torch.ops.aten.reshape.default(clone_2, [8192]); clone_2 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_2: "f32[128, 64]" = torch.ops.aten.as_strided.default(view_7, [128, 64], [64, 1], 0); view_7 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_29: "f32[2, 64]" = split_with_sizes_10[1]; split_with_sizes_10 = None
clone_3: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_29, memory_format = torch.contiguous_format); getitem_29 = None
view_9: "f32[128]" = torch.ops.aten.reshape.default(clone_3, [128]); clone_3 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_3: "f32[128]" = torch.ops.aten.as_strided.default(view_9, [128], [1], 0); view_9 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
_foreach_copy_3 = torch.ops.aten._foreach_copy.default([primals_9, primals_10], [as_strided_2, as_strided_3]); primals_10 = as_strided_2 = as_strided_3 = None
getitem_30: "f32[128, 64]" = _foreach_copy_3[0]
getitem_31: "f32[128]" = _foreach_copy_3[1]; _foreach_copy_3 = None
# File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
permute_3: "f32[64, 128]" = torch.ops.aten.permute.default(getitem_30, [1, 0]); getitem_30 = None
addmm_1: "f32[8, 128]" = torch.ops.aten.addmm.default(getitem_31, addmm, permute_3); getitem_31 = permute_3 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
empty_2: "f32[33024]" = torch.ops.aten.empty.memory_format([33024], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False)
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
slice_13: "f32[16512]" = torch.ops.aten.slice.Tensor(empty_2, 0, 0, 16512)
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
split_with_sizes_12 = torch.ops.aten.split_with_sizes.default(slice_13, [16384, 128]); slice_13 = None
getitem_32: "f32[16384]" = split_with_sizes_12[0]
getitem_33: "f32[128]" = split_with_sizes_12[1]; split_with_sizes_12 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
_foreach_copy_4 = torch.ops.aten._foreach_copy.default([getitem_32, getitem_33], [primals_11, primals_12]); getitem_32 = getitem_33 = primals_11 = primals_12 = None
getitem_34: "f32[16384]" = _foreach_copy_4[0]
getitem_35: "f32[128]" = _foreach_copy_4[1]; _foreach_copy_4 = None
# No stacktrace found for following nodes
slice_tensor_4: "f32[16512]" = torch.ops.aten.slice.Tensor(empty_2, 0, 0, 16512)
slice_scatter_default_8: "f32[16512]" = torch.ops.aten.slice_scatter.default(slice_tensor_4, getitem_34, 0, 0, 16384); slice_tensor_4 = getitem_34 = None
slice_scatter_default_9: "f32[33024]" = torch.ops.aten.slice_scatter.default(empty_2, slice_scatter_default_8, 0, 0, 16512); empty_2 = slice_scatter_default_8 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
slice_15: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 0, 16512)
# No stacktrace found for following nodes
slice_tensor_5: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 0, 16512)
slice_scatter_default_10: "f32[16512]" = torch.ops.aten.slice_scatter.default(slice_tensor_5, getitem_35, 0, 16384, 16512); slice_tensor_5 = getitem_35 = None
slice_scatter_default_11: "f32[33024]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_9, slice_scatter_default_10, 0, 0, 16512); slice_scatter_default_9 = slice_scatter_default_10 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
slice_18: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 0, 16512); slice_scatter_default_11 = None
all_gather_into_tensor_2: "f32[33024]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_18, 2, '0'); slice_18 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
wait_tensor_2: "f32[33024]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
view_11: "f32[2, 16512]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]); wait_tensor_2 = None
split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_11, [16384, 128], 1); view_11 = None
getitem_42: "f32[2, 16384]" = split_with_sizes_16[0]
clone_4: "f32[2, 16384]" = torch.ops.aten.clone.default(getitem_42, memory_format = torch.contiguous_format); getitem_42 = None
view_12: "f32[32768]" = torch.ops.aten.reshape.default(clone_4, [32768]); clone_4 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_4: "f32[256, 128]" = torch.ops.aten.as_strided.default(view_12, [256, 128], [128, 1], 0); view_12 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_45: "f32[2, 128]" = split_with_sizes_16[1]; split_with_sizes_16 = None
clone_5: "f32[2, 128]" = torch.ops.aten.clone.default(getitem_45, memory_format = torch.contiguous_format); getitem_45 = None
view_14: "f32[256]" = torch.ops.aten.reshape.default(clone_5, [256]); clone_5 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_5: "f32[256]" = torch.ops.aten.as_strided.default(view_14, [256], [1], 0); view_14 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
_foreach_copy_5 = torch.ops.aten._foreach_copy.default([primals_13, primals_14], [as_strided_4, as_strided_5]); primals_14 = as_strided_4 = as_strided_5 = None
getitem_46: "f32[256, 128]" = _foreach_copy_5[0]
getitem_47: "f32[256]" = _foreach_copy_5[1]; _foreach_copy_5 = None
# File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
permute_5: "f32[128, 256]" = torch.ops.aten.permute.default(getitem_46, [1, 0]); getitem_46 = None
addmm_2: "f32[8, 256]" = torch.ops.aten.addmm.default(getitem_47, addmm_1, permute_5); getitem_47 = permute_5 = None
return [addmm_2, primals_1, primals_9, primals_13, addmm, addmm_1]
TRACED GRAPH
===== AFTER POST GRAD =====
/data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[8, 32]", primals_2: "f32[1024]", primals_3: "f32[32]", primals_4: "f32[64, 32]", primals_5: "f32[64]", primals_6, primals_7: "f32[4096]", primals_8: "f32[64]", primals_9: "f32[128, 64]", primals_10: "f32[128]", primals_11: "f32[16384]", primals_12: "f32[128]", primals_13: "f32[256, 128]", primals_14: "f32[256]"):
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
empty: "f32[2112]" = torch.ops.aten.empty.memory_format([2112], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False)
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
slice_1: "f32[1056]" = torch.ops.aten.slice.Tensor(empty, 0, 1056, 2112)
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [1024, 32]); slice_1 = None
getitem: "f32[1024]" = split_with_sizes[0]
getitem_1: "f32[32]" = split_with_sizes[1]; split_with_sizes = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
_foreach_copy = torch.ops.aten._foreach_copy.default([getitem, getitem_1], [primals_2, primals_3]); getitem = getitem_1 = primals_2 = primals_3 = None
getitem_2: "f32[1024]" = _foreach_copy[0]
getitem_3: "f32[32]" = _foreach_copy[1]; _foreach_copy = None
# No stacktrace found for following nodes
slice_tensor: "f32[1056]" = torch.ops.aten.slice.Tensor(empty, 0, 1056, 2112)
slice_scatter_default: "f32[1056]" = torch.ops.aten.slice_scatter.default(slice_tensor, getitem_2, 0, 0, 1024); slice_tensor = getitem_2 = None
slice_scatter_default_1: "f32[2112]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default, 0, 1056, 2112); empty = slice_scatter_default = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
slice_3: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 1056, 2112)
# No stacktrace found for following nodes
slice_tensor_1: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 1056, 2112)
slice_scatter_default_2: "f32[1056]" = torch.ops.aten.slice_scatter.default(slice_tensor_1, getitem_3, 0, 1024, 1056); slice_tensor_1 = getitem_3 = None
slice_scatter_default_3: "f32[2112]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_1, slice_scatter_default_2, 0, 1056, 2112); slice_scatter_default_1 = slice_scatter_default_2 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
slice_6: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 1056, 2112); slice_scatter_default_3 = None
all_gather_into_tensor: "f32[2112]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_6, 2, '0'); slice_6 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
wait_tensor: "f32[2112]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
view_1: "f32[2, 1056]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]); wait_tensor = None
split_with_sizes_4 = torch.ops.aten.split_with_sizes.default(view_1, [1024, 32], 1); view_1 = None
getitem_10: "f32[2, 1024]" = split_with_sizes_4[0]
clone: "f32[2, 1024]" = torch.ops.aten.clone.default(getitem_10, memory_format = torch.contiguous_format); getitem_10 = None
view_2: "f32[2048]" = torch.ops.aten.reshape.default(clone, [2048]); clone = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided: "f32[64, 32]" = torch.ops.aten.as_strided.default(view_2, [64, 32], [32, 1], 0); view_2 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_13: "f32[2, 32]" = split_with_sizes_4[1]; split_with_sizes_4 = None
clone_1: "f32[2, 32]" = torch.ops.aten.clone.default(getitem_13, memory_format = torch.contiguous_format); getitem_13 = None
view_4: "f32[64]" = torch.ops.aten.reshape.default(clone_1, [64]); clone_1 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_1: "f32[64]" = torch.ops.aten.as_strided.default(view_4, [64], [1], 0); view_4 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
_foreach_copy_1 = torch.ops.aten._foreach_copy.default([primals_4, primals_5], [as_strided, as_strided_1]); primals_4 = primals_5 = as_strided = as_strided_1 = None
getitem_14: "f32[64, 32]" = _foreach_copy_1[0]
getitem_15: "f32[64]" = _foreach_copy_1[1]; _foreach_copy_1 = None
# File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
permute_1: "f32[32, 64]" = torch.ops.aten.permute.default(getitem_14, [1, 0]); getitem_14 = None
addmm: "f32[8, 64]" = torch.ops.aten.addmm.default(getitem_15, primals_1, permute_1); getitem_15 = permute_1 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
empty_1: "f32[8320]" = torch.ops.aten.empty.memory_format([8320], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False)
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
slice_7: "f32[4160]" = torch.ops.aten.slice.Tensor(empty_1, 0, 4160, 8320)
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(slice_7, [4096, 64]); slice_7 = None
getitem_16: "f32[4096]" = split_with_sizes_6[0]
getitem_17: "f32[64]" = split_with_sizes_6[1]; split_with_sizes_6 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
_foreach_copy_2 = torch.ops.aten._foreach_copy.default([getitem_16, getitem_17], [primals_7, primals_8]); getitem_16 = getitem_17 = primals_7 = primals_8 = None
getitem_18: "f32[4096]" = _foreach_copy_2[0]
getitem_19: "f32[64]" = _foreach_copy_2[1]; _foreach_copy_2 = None
# No stacktrace found for following nodes
slice_tensor_2: "f32[4160]" = torch.ops.aten.slice.Tensor(empty_1, 0, 4160, 8320)
slice_scatter_default_4: "f32[4160]" = torch.ops.aten.slice_scatter.default(slice_tensor_2, getitem_18, 0, 0, 4096); slice_tensor_2 = getitem_18 = None
slice_scatter_default_5: "f32[8320]" = torch.ops.aten.slice_scatter.default(empty_1, slice_scatter_default_4, 0, 4160, 8320); empty_1 = slice_scatter_default_4 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
slice_9: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4160, 8320)
# No stacktrace found for following nodes
slice_tensor_3: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4160, 8320)
slice_scatter_default_6: "f32[4160]" = torch.ops.aten.slice_scatter.default(slice_tensor_3, getitem_19, 0, 4096, 4160); slice_tensor_3 = getitem_19 = None
slice_scatter_default_7: "f32[8320]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_5, slice_scatter_default_6, 0, 4160, 8320); slice_scatter_default_5 = slice_scatter_default_6 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
slice_12: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_7, 0, 4160, 8320); slice_scatter_default_7 = None
all_gather_into_tensor_1: "f32[8320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_12, 2, '0'); slice_12 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
wait_tensor_1: "f32[8320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
view_6: "f32[2, 4160]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]); wait_tensor_1 = None
split_with_sizes_10 = torch.ops.aten.split_with_sizes.default(view_6, [4096, 64], 1); view_6 = None
getitem_26: "f32[2, 4096]" = split_with_sizes_10[0]
clone_2: "f32[2, 4096]" = torch.ops.aten.clone.default(getitem_26, memory_format = torch.contiguous_format); getitem_26 = None
view_7: "f32[8192]" = torch.ops.aten.reshape.default(clone_2, [8192]); clone_2 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_2: "f32[128, 64]" = torch.ops.aten.as_strided.default(view_7, [128, 64], [64, 1], 0); view_7 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_29: "f32[2, 64]" = split_with_sizes_10[1]; split_with_sizes_10 = None
clone_3: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_29, memory_format = torch.contiguous_format); getitem_29 = None
view_9: "f32[128]" = torch.ops.aten.reshape.default(clone_3, [128]); clone_3 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_3: "f32[128]" = torch.ops.aten.as_strided.default(view_9, [128], [1], 0); view_9 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
_foreach_copy_3 = torch.ops.aten._foreach_copy.default([primals_9, primals_10], [as_strided_2, as_strided_3]); primals_10 = as_strided_2 = as_strided_3 = None
getitem_30: "f32[128, 64]" = _foreach_copy_3[0]
getitem_31: "f32[128]" = _foreach_copy_3[1]; _foreach_copy_3 = None
# File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
permute_3: "f32[64, 128]" = torch.ops.aten.permute.default(getitem_30, [1, 0]); getitem_30 = None
addmm_1: "f32[8, 128]" = torch.ops.aten.addmm.default(getitem_31, addmm, permute_3); getitem_31 = permute_3 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
empty_2: "f32[33024]" = torch.ops.aten.empty.memory_format([33024], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False)
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
slice_13: "f32[16512]" = torch.ops.aten.slice.Tensor(empty_2, 0, 16512, 33024)
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
split_with_sizes_12 = torch.ops.aten.split_with_sizes.default(slice_13, [16384, 128]); slice_13 = None
getitem_32: "f32[16384]" = split_with_sizes_12[0]
getitem_33: "f32[128]" = split_with_sizes_12[1]; split_with_sizes_12 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
_foreach_copy_4 = torch.ops.aten._foreach_copy.default([getitem_32, getitem_33], [primals_11, primals_12]); getitem_32 = getitem_33 = primals_11 = primals_12 = None
getitem_34: "f32[16384]" = _foreach_copy_4[0]
getitem_35: "f32[128]" = _foreach_copy_4[1]; _foreach_copy_4 = None
# No stacktrace found for following nodes
slice_tensor_4: "f32[16512]" = torch.ops.aten.slice.Tensor(empty_2, 0, 16512, 33024)
slice_scatter_default_8: "f32[16512]" = torch.ops.aten.slice_scatter.default(slice_tensor_4, getitem_34, 0, 0, 16384); slice_tensor_4 = getitem_34 = None
slice_scatter_default_9: "f32[33024]" = torch.ops.aten.slice_scatter.default(empty_2, slice_scatter_default_8, 0, 16512, 33024); empty_2 = slice_scatter_default_8 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
slice_15: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 16512, 33024)
# No stacktrace found for following nodes
slice_tensor_5: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 16512, 33024)
slice_scatter_default_10: "f32[16512]" = torch.ops.aten.slice_scatter.default(slice_tensor_5, getitem_35, 0, 16384, 16512); slice_tensor_5 = getitem_35 = None
slice_scatter_default_11: "f32[33024]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_9, slice_scatter_default_10, 0, 16512, 33024); slice_scatter_default_9 = slice_scatter_default_10 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
slice_18: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 16512, 33024); slice_scatter_default_11 = None
all_gather_into_tensor_2: "f32[33024]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_18, 2, '0'); slice_18 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
wait_tensor_2: "f32[33024]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
view_11: "f32[2, 16512]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]); wait_tensor_2 = None
split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_11, [16384, 128], 1); view_11 = None
getitem_42: "f32[2, 16384]" = split_with_sizes_16[0]
clone_4: "f32[2, 16384]" = torch.ops.aten.clone.default(getitem_42, memory_format = torch.contiguous_format); getitem_42 = None
view_12: "f32[32768]" = torch.ops.aten.reshape.default(clone_4, [32768]); clone_4 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_4: "f32[256, 128]" = torch.ops.aten.as_strided.default(view_12, [256, 128], [128, 1], 0); view_12 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_45: "f32[2, 128]" = split_with_sizes_16[1]; split_with_sizes_16 = None
clone_5: "f32[2, 128]" = torch.ops.aten.clone.default(getitem_45, memory_format = torch.contiguous_format); getitem_45 = None
view_14: "f32[256]" = torch.ops.aten.reshape.default(clone_5, [256]); clone_5 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_5: "f32[256]" = torch.ops.aten.as_strided.default(view_14, [256], [1], 0); view_14 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
_foreach_copy_5 = torch.ops.aten._foreach_copy.default([primals_13, primals_14], [as_strided_4, as_strided_5]); primals_14 = as_strided_4 = as_strided_5 = None
getitem_46: "f32[256, 128]" = _foreach_copy_5[0]
getitem_47: "f32[256]" = _foreach_copy_5[1]; _foreach_copy_5 = None
# File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
permute_5: "f32[128, 256]" = torch.ops.aten.permute.default(getitem_46, [1, 0]); getitem_46 = None
addmm_2: "f32[8, 256]" = torch.ops.aten.addmm.default(getitem_47, addmm_1, permute_5); getitem_47 = permute_5 = None
return [addmm_2, primals_1, primals_9, primals_13, addmm, addmm_1]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment