Skip to content

Instantly share code, notes, and snippets.

@ezyang
Created October 30, 2022 17:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ezyang/fd6cfbf4cd8956f1a6e0933aa86c2de6 to your computer and use it in GitHub Desktop.
Save ezyang/fd6cfbf4cd8956f1a6e0933aa86c2de6 to your computer and use it in GitHub Desktop.
Sweep logs for symbolic-shapes (TORCHDYNAMO_DYNAMIC_SHAPES=0) - Sun Oct 30 06:22:09 PDT 2022
This file has been truncated, but you can view the full file.
Running BERT_pytorch...
cuda train BERT_pytorch PASS
Running Background_Matting...
cuda train Background_Matting PASS
WARNING:root:DALLE2_pytorch failed to load
Eager model failed to run
Traceback (most recent call last):
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/common.py", line 903, in validate_model
self.model_iter_fn(model, example_inputs)
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/torchbench.py", line 337, in forward_and_backward_pass
self.grad_scaler.scale(loss).backward()
File "/data/users/ezyang/pytorch-tmp/torch/_tensor.py", line 488, in backward
torch.autograd.backward(
File "/data/users/ezyang/pytorch-tmp/torch/autograd/__init__.py", line 197, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/common.py", line 1746, in main
device, name, model, example_inputs, batch_size = runner.load_model(
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/torchbench.py", line 282, in load_model
self.validate_model(model, example_inputs)
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/common.py", line 905, in validate_model
raise NotImplementedError("Eager model failed to run")
NotImplementedError: Eager model failed to run
Running LearningToPaint...
cuda train LearningToPaint PASS
Running Super_SloMo...
cuda train Super_SloMo PASS
Running alexnet...
cuda train alexnet PASS
Running attention_is_all_you_need_pytorch...
cuda train attention_is_all_you_need_pytorch PASS
Running dcgan...
cuda train dcgan PASS
Running demucs...
[2022-10-30 06:27:40,119] torch._dynamo.optimizations.training: [WARNING] Unable to use Aot Autograd because of presence of LSTM
cuda train demucs PASS
Running densenet121...
cuda train densenet121 PASS
WARNING:root:detectron2_fcos_r_50_fpn failed to load
FCOS train is not supported by upstream detectron2. See GH Issue: https://github.com/facebookresearch/detectron2/issues/4369.
Traceback (most recent call last):
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/common.py", line 1746, in main
device, name, model, example_inputs, batch_size = runner.load_model(
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/torchbench.py", line 251, in load_model
benchmark = benchmark_cls(
File "/data/users/ezyang/benchmark/torchbenchmark/util/model.py", line 18, in __call__
obj = type.__call__(cls, *args, **kwargs)
File "/data/users/ezyang/benchmark/torchbenchmark/models/detectron2_fcos_r_50_fpn/__init__.py", line 15, in __init__
super().__init__(variant="COCO-Detection/fcos_R_50_FPN_1x.py", test=test, device=device,
File "/data/users/ezyang/benchmark/torchbenchmark/util/framework/detectron2/model_factory.py", line 100, in __init__
loader = self.setup_train(cfg, args)
File "/data/users/ezyang/benchmark/torchbenchmark/util/framework/detectron2/model_factory.py", line 110, in setup_train
raise NotImplementedError("FCOS train is not supported by upstream detectron2. " \
NotImplementedError: FCOS train is not supported by upstream detectron2. See GH Issue: https://github.com/facebookresearch/detectron2/issues/4369.
WARNING:root:detectron2_maskrcnn_r_50_c4 failed to load
Eager model failed to run
Traceback (most recent call last):
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/common.py", line 903, in validate_model
self.model_iter_fn(model, example_inputs)
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/torchbench.py", line 336, in forward_and_backward_pass
loss = self.compute_loss(pred)
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/torchbench.py", line 326, in compute_loss
return reduce_to_scalar_loss(pred)
File "/data/users/ezyang/pytorch-tmp/torch/_dynamo/testing.py", line 87, in reduce_to_scalar_loss
return sum([reduce_to_scalar_loss(x) for x in out]) / len(out)
File "/data/users/ezyang/pytorch-tmp/torch/_dynamo/testing.py", line 87, in <listcomp>
return sum([reduce_to_scalar_loss(x) for x in out]) / len(out)
File "/data/users/ezyang/pytorch-tmp/torch/_dynamo/testing.py", line 97, in reduce_to_scalar_loss
return sum([reduce_to_scalar_loss(value) for value in out.values()]) / len(
File "/data/users/ezyang/pytorch-tmp/torch/_dynamo/testing.py", line 97, in <listcomp>
return sum([reduce_to_scalar_loss(value) for value in out.values()]) / len(
File "/data/users/ezyang/pytorch-tmp/torch/_dynamo/testing.py", line 102, in reduce_to_scalar_loss
raise NotImplementedError("Don't know how to reduce", type(out))
NotImplementedError: ("Don't know how to reduce", <class 'detectron2.structures.instances.Instances'>)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/common.py", line 1746, in main
device, name, model, example_inputs, batch_size = runner.load_model(
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/torchbench.py", line 282, in load_model
self.validate_model(model, example_inputs)
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/common.py", line 905, in validate_model
raise NotImplementedError("Eager model failed to run")
NotImplementedError: Eager model failed to run
Running dlrm...
incomplete graph:
class joint_forward_backward(torch.nn.Module):
def forward(self, orig_primals, orig_tangents):
orig_primals_1: f32[512, 512], [512, 1], orig_primals_2: f32[512], [1], orig_primals_3: f32[64, 512], [512, 1], orig_primals_4: f32[64], [1], orig_primals_5: f32[1000000, 64], [64, 1], orig_primals_6: f32[1000000, 64], [64, 1], orig_primals_7: f32[1000000, 64], [64, 1], orig_primals_8: f32[1000000, 64], [64, 1], orig_primals_9: f32[1000000, 64], [64, 1], orig_primals_10: f32[1000000, 64], [64, 1], orig_primals_11: f32[1000000, 64], [64, 1], orig_primals_12: f32[1000000, 64], [64, 1], orig_primals_13: f32[1024, 100], [100, 1], orig_primals_14: f32[1024], [1], orig_primals_15: f32[1024, 1024], [1024, 1], orig_primals_16: f32[1024], [1], orig_primals_17: f32[1024, 1024], [1024, 1], orig_primals_18: f32[1024], [1], orig_primals_19: f32[1, 1024], [1024, 1], orig_primals_20: f32[1], [1], orig_primals_21: f32[2, 512], [512, 1], orig_primals_22: i64[s2, 2], [2, 1], orig_primals_23: i64[s3], [1], orig_primals_24: i64[s3], [1], orig_primals_25: i64[s3], [1], orig_primals_26: i64[s3], [1], orig_primals_27: i64[s3], [1], orig_primals_28: i64[s3], [1], orig_primals_29: i64[s3], [1], orig_primals_30: i64[s3], [1], orig_tangents_1: f32[2, 1], [1, 1], = fx_pytree.tree_flatten_spec([orig_primals, orig_tangents], self._in_spec)
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
t: f32[512, 512], [1, 512] = torch.ops.aten.t.default(orig_primals_1); orig_primals_1 = None
addmm: f32[2, 512], [512, 1] = torch.ops.aten.addmm.default(orig_primals_2, orig_primals_21, t); orig_primals_2 = t = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
relu: f32[2, 512], [512, 1] = torch.ops.aten.relu.default(addmm); addmm = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
t_1: f32[512, 64], [1, 512] = torch.ops.aten.t.default(orig_primals_3); orig_primals_3 = None
addmm_1: f32[2, 64], [64, 1] = torch.ops.aten.addmm.default(orig_primals_4, relu, t_1); orig_primals_4 = relu = t_1 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
relu_1: f32[2, 64], [64, 1] = torch.ops.aten.relu.default(addmm_1); addmm_1 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:288, code: sparse_offset_group_batch = lS_o[k]
select: i64[2], [1] = torch.ops.aten.select.int(orig_primals_22, 0, 0)
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:295, code: V = E(sparse_index_group_batch, sparse_offset_group_batch)
_embedding_bag = torch.ops.aten._embedding_bag.default(orig_primals_5, orig_primals_23, select, False, 0, True, None); orig_primals_5 = orig_primals_23 = select = None
getitem: f32[2, 64], [64, 1] = _embedding_bag[0]
getitem_1: i64[s3], [1] = _embedding_bag[1]
getitem_2: i64[2], [1] = _embedding_bag[2]
getitem_3: i64[0], [1] = _embedding_bag[3]; _embedding_bag = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:288, code: sparse_offset_group_batch = lS_o[k]
select_1: i64[2], [1] = torch.ops.aten.select.int(orig_primals_22, 0, 1)
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:295, code: V = E(sparse_index_group_batch, sparse_offset_group_batch)
_embedding_bag_1 = torch.ops.aten._embedding_bag.default(orig_primals_6, orig_primals_24, select_1, False, 0, True, None); orig_primals_6 = orig_primals_24 = select_1 = None
getitem_4: f32[2, 64], [64, 1] = _embedding_bag_1[0]
getitem_5: i64[s3], [1] = _embedding_bag_1[1]
getitem_6: i64[2], [1] = _embedding_bag_1[2]
getitem_7: i64[0], [1] = _embedding_bag_1[3]; _embedding_bag_1 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:288, code: sparse_offset_group_batch = lS_o[k]
select_2: i64[2], [1] = torch.ops.aten.select.int(orig_primals_22, 0, 2)
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:295, code: V = E(sparse_index_group_batch, sparse_offset_group_batch)
_embedding_bag_2 = torch.ops.aten._embedding_bag.default(orig_primals_7, orig_primals_25, select_2, False, 0, True, None); orig_primals_7 = orig_primals_25 = select_2 = None
getitem_8: f32[2, 64], [64, 1] = _embedding_bag_2[0]
getitem_9: i64[s3], [1] = _embedding_bag_2[1]
getitem_10: i64[2], [1] = _embedding_bag_2[2]
getitem_11: i64[0], [1] = _embedding_bag_2[3]; _embedding_bag_2 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:288, code: sparse_offset_group_batch = lS_o[k]
select_3: i64[2], [1] = torch.ops.aten.select.int(orig_primals_22, 0, 3)
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:295, code: V = E(sparse_index_group_batch, sparse_offset_group_batch)
_embedding_bag_3 = torch.ops.aten._embedding_bag.default(orig_primals_8, orig_primals_26, select_3, False, 0, True, None); orig_primals_8 = orig_primals_26 = select_3 = None
getitem_12: f32[2, 64], [64, 1] = _embedding_bag_3[0]
getitem_13: i64[s3], [1] = _embedding_bag_3[1]
getitem_14: i64[2], [1] = _embedding_bag_3[2]
getitem_15: i64[0], [1] = _embedding_bag_3[3]; _embedding_bag_3 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:288, code: sparse_offset_group_batch = lS_o[k]
select_4: i64[2], [1] = torch.ops.aten.select.int(orig_primals_22, 0, 4)
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:295, code: V = E(sparse_index_group_batch, sparse_offset_group_batch)
_embedding_bag_4 = torch.ops.aten._embedding_bag.default(orig_primals_9, orig_primals_27, select_4, False, 0, True, None); orig_primals_9 = orig_primals_27 = select_4 = None
getitem_16: f32[2, 64], [64, 1] = _embedding_bag_4[0]
getitem_17: i64[s3], [1] = _embedding_bag_4[1]
getitem_18: i64[2], [1] = _embedding_bag_4[2]
getitem_19: i64[0], [1] = _embedding_bag_4[3]; _embedding_bag_4 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:288, code: sparse_offset_group_batch = lS_o[k]
select_5: i64[2], [1] = torch.ops.aten.select.int(orig_primals_22, 0, 5)
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:295, code: V = E(sparse_index_group_batch, sparse_offset_group_batch)
_embedding_bag_5 = torch.ops.aten._embedding_bag.default(orig_primals_10, orig_primals_28, select_5, False, 0, True, None); orig_primals_10 = orig_primals_28 = select_5 = None
getitem_20: f32[2, 64], [64, 1] = _embedding_bag_5[0]
getitem_21: i64[s3], [1] = _embedding_bag_5[1]
getitem_22: i64[2], [1] = _embedding_bag_5[2]
getitem_23: i64[0], [1] = _embedding_bag_5[3]; _embedding_bag_5 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:288, code: sparse_offset_group_batch = lS_o[k]
select_6: i64[2], [1] = torch.ops.aten.select.int(orig_primals_22, 0, 6)
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:295, code: V = E(sparse_index_group_batch, sparse_offset_group_batch)
_embedding_bag_6 = torch.ops.aten._embedding_bag.default(orig_primals_11, orig_primals_29, select_6, False, 0, True, None); orig_primals_11 = orig_primals_29 = select_6 = None
getitem_24: f32[2, 64], [64, 1] = _embedding_bag_6[0]
getitem_25: i64[s3], [1] = _embedding_bag_6[1]
getitem_26: i64[2], [1] = _embedding_bag_6[2]
getitem_27: i64[0], [1] = _embedding_bag_6[3]; _embedding_bag_6 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:288, code: sparse_offset_group_batch = lS_o[k]
select_7: i64[2], [1] = torch.ops.aten.select.int(orig_primals_22, 0, 7); orig_primals_22 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:295, code: V = E(sparse_index_group_batch, sparse_offset_group_batch)
_embedding_bag_7 = torch.ops.aten._embedding_bag.default(orig_primals_12, orig_primals_30, select_7, False, 0, True, None); orig_primals_12 = select_7 = None
getitem_28: f32[2, 64], [64, 1] = _embedding_bag_7[0]
getitem_29: i64[s3], [1] = _embedding_bag_7[1]
getitem_30: i64[2], [1] = _embedding_bag_7[2]
getitem_31: i64[0], [1] = _embedding_bag_7[3]; _embedding_bag_7 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:306, code: T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
cat: f32[2, 576], [576, 1] = torch.ops.aten.cat.default([relu_1, getitem, getitem_4, getitem_8, getitem_12, getitem_16, getitem_20, getitem_24, getitem_28], 1); getitem = getitem_4 = getitem_8 = getitem_12 = getitem_16 = getitem_20 = getitem_24 = getitem_28 = None
view: f32[2, 9, 64], [576, 64, 1] = torch.ops.aten.view.default(cat, [2, -1, 64]); cat = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:308, code: Z = torch.bmm(T, torch.transpose(T, 1, 2))
transpose: f32[2, 64, 9], [576, 1, 64] = torch.ops.aten.transpose.int(view, 1, 2)
bmm: f32[2, 9, 9], [81, 9, 1] = torch.ops.aten.bmm.default(view, transpose)
# No stacktrace found for following nodes
_tensor_constant0 = self._tensor_constant0
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:319, code: li = torch.tensor([i for i in range(ni) for j in range(i + offset)], device=x.device)
lift_fresh_copy: i64[36], [1] = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
# No stacktrace found for following nodes
_tensor_constant1 = self._tensor_constant1
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:320, code: lj = torch.tensor([j for i in range(nj) for j in range(i + offset)], device=x.device)
lift_fresh_copy_1: i64[36], [1] = torch.ops.aten.lift_fresh_copy.default(_tensor_constant1); _tensor_constant1 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:321, code: Zflat = Z[:, li, lj]
slice_1: f32[2, 9, 9], [81, 9, 1] = torch.ops.aten.slice.Tensor(bmm, 0, 0, 9223372036854775807); bmm = None
index: f32[2, 36], [36, 1] = torch.ops.aten.index.Tensor(slice_1, [None, lift_fresh_copy, lift_fresh_copy_1])
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:323, code: R = torch.cat([x] + [Zflat], dim=1)
cat_1: f32[2, 100], [100, 1] = torch.ops.aten.cat.default([relu_1, index], 1); relu_1 = index = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
t_2: f32[100, 1024], [1, 100] = torch.ops.aten.t.default(orig_primals_13); orig_primals_13 = None
addmm_2: f32[2, 1024], [1024, 1] = torch.ops.aten.addmm.default(orig_primals_14, cat_1, t_2); orig_primals_14 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
relu_2: f32[2, 1024], [1024, 1] = torch.ops.aten.relu.default(addmm_2); addmm_2 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
t_3: f32[1024, 1024], [1, 1024] = torch.ops.aten.t.default(orig_primals_15); orig_primals_15 = None
addmm_3: f32[2, 1024], [1024, 1] = torch.ops.aten.addmm.default(orig_primals_16, relu_2, t_3); orig_primals_16 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
relu_3: f32[2, 1024], [1024, 1] = torch.ops.aten.relu.default(addmm_3); addmm_3 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
t_4: f32[1024, 1024], [1, 1024] = torch.ops.aten.t.default(orig_primals_17); orig_primals_17 = None
addmm_4: f32[2, 1024], [1024, 1] = torch.ops.aten.addmm.default(orig_primals_18, relu_3, t_4); orig_primals_18 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
relu_4: f32[2, 1024], [1024, 1] = torch.ops.aten.relu.default(addmm_4); addmm_4 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
t_5: f32[1024, 1], [1, 1024] = torch.ops.aten.t.default(orig_primals_19); orig_primals_19 = None
addmm_5: f32[2, 1], [1, 1] = torch.ops.aten.addmm.default(orig_primals_20, relu_4, t_5); orig_primals_20 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
relu_5: f32[2, 1], [1, 1] = torch.ops.aten.relu.default(addmm_5); addmm_5 = None
# No stacktrace found for following nodes
is_same_size = torch.ops.aten.is_same_size.default(relu_5, orig_tangents_1)
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
threshold_backward: f32[2, 1], [1, 1] = torch.ops.aten.threshold_backward.default(orig_tangents_1, relu_5, 0); orig_tangents_1 = relu_5 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
t_6: f32[1, 1024], [1024, 1] = torch.ops.aten.t.default(t_5); t_5 = None
mm: f32[2, 1024], [1024, 1] = torch.ops.aten.mm.default(threshold_backward, t_6); t_6 = None
t_7: f32[1, 2], [1, 1] = torch.ops.aten.t.default(threshold_backward)
mm_1: f32[1, 1024], [1024, 1] = torch.ops.aten.mm.default(t_7, relu_4); t_7 = None
t_8: f32[1024, 1], [1, 1024] = torch.ops.aten.t.default(mm_1); mm_1 = None
sum_1: f32[1, 1], [0, 1] = torch.ops.aten.sum.dim_IntList(threshold_backward, [0], True); threshold_backward = None
view_1: f32[1], [0] = torch.ops.aten.view.default(sum_1, [1]); sum_1 = None
t_9: f32[1, 1024], [1024, 1] = torch.ops.aten.t.default(t_8); t_8 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
threshold_backward_1: f32[2, 1024], [1024, 1] = torch.ops.aten.threshold_backward.default(mm, relu_4, 0); mm = relu_4 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
t_10: f32[1024, 1024], [1024, 1] = torch.ops.aten.t.default(t_4); t_4 = None
mm_2: f32[2, 1024], [1024, 1] = torch.ops.aten.mm.default(threshold_backward_1, t_10); t_10 = None
t_11: f32[1024, 2], [1, 1024] = torch.ops.aten.t.default(threshold_backward_1)
mm_3: f32[1024, 1024], [1024, 1] = torch.ops.aten.mm.default(t_11, relu_3); t_11 = None
t_12: f32[1024, 1024], [1, 1024] = torch.ops.aten.t.default(mm_3); mm_3 = None
sum_2: f32[1, 1024], [0, 1] = torch.ops.aten.sum.dim_IntList(threshold_backward_1, [0], True); threshold_backward_1 = None
view_2: f32[1024], [1] = torch.ops.aten.view.default(sum_2, [1024]); sum_2 = None
t_13: f32[1024, 1024], [1024, 1] = torch.ops.aten.t.default(t_12); t_12 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
threshold_backward_2: f32[2, 1024], [1024, 1] = torch.ops.aten.threshold_backward.default(mm_2, relu_3, 0); mm_2 = relu_3 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
t_14: f32[1024, 1024], [1024, 1] = torch.ops.aten.t.default(t_3); t_3 = None
mm_4: f32[2, 1024], [1024, 1] = torch.ops.aten.mm.default(threshold_backward_2, t_14); t_14 = None
t_15: f32[1024, 2], [1, 1024] = torch.ops.aten.t.default(threshold_backward_2)
mm_5: f32[1024, 1024], [1024, 1] = torch.ops.aten.mm.default(t_15, relu_2); t_15 = None
t_16: f32[1024, 1024], [1, 1024] = torch.ops.aten.t.default(mm_5); mm_5 = None
sum_3: f32[1, 1024], [0, 1] = torch.ops.aten.sum.dim_IntList(threshold_backward_2, [0], True); threshold_backward_2 = None
view_3: f32[1024], [1] = torch.ops.aten.view.default(sum_3, [1024]); sum_3 = None
t_17: f32[1024, 1024], [1024, 1] = torch.ops.aten.t.default(t_16); t_16 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
threshold_backward_3: f32[2, 1024], [1024, 1] = torch.ops.aten.threshold_backward.default(mm_4, relu_2, 0); mm_4 = relu_2 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:274, code: return layers(x)
t_18: f32[1024, 100], [100, 1] = torch.ops.aten.t.default(t_2); t_2 = None
mm_6: f32[2, 100], [100, 1] = torch.ops.aten.mm.default(threshold_backward_3, t_18); t_18 = None
t_19: f32[1024, 2], [1, 1024] = torch.ops.aten.t.default(threshold_backward_3)
mm_7: f32[1024, 100], [100, 1] = torch.ops.aten.mm.default(t_19, cat_1); t_19 = cat_1 = None
t_20: f32[100, 1024], [1, 100] = torch.ops.aten.t.default(mm_7); mm_7 = None
sum_4: f32[1, 1024], [0, 1] = torch.ops.aten.sum.dim_IntList(threshold_backward_3, [0], True); threshold_backward_3 = None
view_4: f32[1024], [1] = torch.ops.aten.view.default(sum_4, [1024]); sum_4 = None
t_21: f32[1024, 100], [100, 1] = torch.ops.aten.t.default(t_20); t_20 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:323, code: R = torch.cat([x] + [Zflat], dim=1)
slice_2: f32[2, 64], [100, 1] = torch.ops.aten.slice.Tensor(mm_6, 1, 0, 64)
slice_3: f32[2, 36], [100, 1] = torch.ops.aten.slice.Tensor(mm_6, 1, 64, 100); mm_6 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:321, code: Zflat = Z[:, li, lj]
sym_size: Sym(2) = torch.ops.aten.sym_size(slice_1, 0); slice_1 = None
sym_size_1: Sym(9) = torch.ops.aten.sym_size(view, 1)
new_zeros: f32[2, 9, 9], [81, 9, 1] = torch.ops.aten.new_zeros.default(slice_3, [sym_size, sym_size_1, sym_size_1], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0)); sym_size = None
index_put: f32[2, 9, 9], [81, 9, 1] = torch.ops.aten.index_put.default(new_zeros, [None, lift_fresh_copy, lift_fresh_copy_1], slice_3, True); new_zeros = lift_fresh_copy = lift_fresh_copy_1 = slice_3 = None
sym_size_2: Sym(2) = torch.ops.aten.sym_size(orig_primals_21, 0); orig_primals_21 = None
slice_backward: f32[2, 9, 9], [81, 9, 1] = torch.ops.aten.slice_backward.default(index_put, [sym_size_2, sym_size_1, sym_size_1], 0, 0, 9223372036854775807, 1); index_put = sym_size_1 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:308, code: Z = torch.bmm(T, torch.transpose(T, 1, 2))
transpose_1: f32[2, 64, 9], [576, 1, 64] = torch.ops.aten.transpose.int(view, 1, 2); view = None
bmm_1: f32[2, 64, 9], [576, 9, 1] = torch.ops.aten.bmm.default(transpose_1, slice_backward); transpose_1 = None
transpose_2: f32[2, 9, 64], [576, 64, 1] = torch.ops.aten.transpose.int(transpose, 1, 2); transpose = None
bmm_2: f32[2, 9, 64], [576, 64, 1] = torch.ops.aten.bmm.default(slice_backward, transpose_2); slice_backward = transpose_2 = None
transpose_3: f32[2, 9, 64], [576, 1, 9] = torch.ops.aten.transpose.int(bmm_1, 1, 2); bmm_1 = None
# Gradient addition node due to multiple use of tensor around:, File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:308, code: Z = torch.bmm(T, torch.transpose(T, 1, 2))
add: f32[2, 9, 64], [576, 64, 1] = torch.ops.aten.add.Tensor(bmm_2, transpose_3); bmm_2 = transpose_3 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:306, code: T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
view_5: f32[2, 576], [576, 1] = torch.ops.aten.view.default(add, [sym_size_2, 576]); add = sym_size_2 = None
slice_4: f32[2, 64], [576, 1] = torch.ops.aten.slice.Tensor(view_5, 1, 0, 64)
slice_5: f32[2, 64], [576, 1] = torch.ops.aten.slice.Tensor(view_5, 1, 64, 128)
slice_6: f32[2, 64], [576, 1] = torch.ops.aten.slice.Tensor(view_5, 1, 128, 192)
slice_7: f32[2, 64], [576, 1] = torch.ops.aten.slice.Tensor(view_5, 1, 192, 256)
slice_8: f32[2, 64], [576, 1] = torch.ops.aten.slice.Tensor(view_5, 1, 256, 320)
slice_9: f32[2, 64], [576, 1] = torch.ops.aten.slice.Tensor(view_5, 1, 320, 384)
slice_10: f32[2, 64], [576, 1] = torch.ops.aten.slice.Tensor(view_5, 1, 384, 448)
slice_11: f32[2, 64], [576, 1] = torch.ops.aten.slice.Tensor(view_5, 1, 448, 512)
slice_12: f32[2, 64], [576, 1] = torch.ops.aten.slice.Tensor(view_5, 1, 512, 576); view_5 = None
# Gradient addition node due to multiple use of tensor around:, File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:306, code: T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
add_1: f32[2, 64], [64, 1] = torch.ops.aten.add.Tensor(slice_2, slice_4); slice_2 = slice_4 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:295, code: V = E(sparse_index_group_batch, sparse_offset_group_batch)
index_select: f32[s3, 64], [64, 1] = torch.ops.aten.index_select.default(slice_12, 0, getitem_29)
sym_size_3: Sym(s3) = torch.ops.aten.sym_size(orig_primals_30, 0)
# No stacktrace found for following nodes
floordiv: Sym(s3) = sym_size_3 // 1; sym_size_3 = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:295, code: V = E(sparse_index_group_batch, sparse_offset_group_batch)
view_6: i64[1, s3], [s3, 1] = torch.ops.aten.view.default(orig_primals_30, [1, floordiv]); orig_primals_30 = floordiv = None
sym_numel: Sym(s3) = torch.ops.aten.sym_numel(getitem_29); getitem_29 = None
sym_size_4: Sym(64) = torch.ops.aten.sym_size(slice_12, 1); slice_12 = None
# No stacktrace found for following nodes
mul: Sym(64*s3) = sym_numel * sym_size_4; sym_numel = None
floordiv_1: Sym(s3) = mul // sym_size_4; mul = None
# File: /data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py:295, code: V = E(sparse_index_group_batch, sparse_offset_group_batch)
view_7: f32[s3, 64], [64, 1] = torch.ops.aten.view.default(index_select, [floordiv_1, sym_size_4]); index_select = floordiv_1 = None
_sparse_coo_tensor_with_dims_and_tensors: f32[1000000, 64], [0, 0] = torch.ops.aten._sparse_coo_tensor_with_dims_and_tensors.default(1, 1, [1000000, sym_size_4], view_6, view_7, dtype = torch.float32, layout = torch.sparse_coo, device = device(type='cuda', index=0), pin_memory = None); sym_size_4 = view_6 = view_7 = None
ERROR:common:Cannot access storage of SparseTensorImpl
Traceback (most recent call last):
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/common.py", line 1041, in check_accuracy
new_result = optimized_model_iter_fn(model, example_inputs)
File "/data/users/ezyang/pytorch-tmp/torch/_dynamo/eval_frame.py", line 157, in _fn
return fn(*args, **kwargs)
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/common.py", line 945, in run_n_iterations
self.model_iter_fn(mod, inputs, collect_outputs=False)
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/torchbench.py", line 332, in forward_and_backward_pass
cloned_inputs = clone_inputs(inputs)
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/torchbench.py", line 335, in <graph break in forward_and_backward_pass>
pred = mod(*cloned_inputs)
File "/data/users/ezyang/pytorch-tmp/torch/nn/modules/module.py", line 1423, in _call_impl
return forward_call(*input, **kwargs)
File "/data/users/ezyang/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py", line 336, in forward
def forward(self, dense_x, lS_o, lS_i):
File "/data/users/ezyang/pytorch-tmp/torch/_dynamo/eval_frame.py", line 157, in _fn
return fn(*args, **kwargs)
File "/data/users/ezyang/pytorch-tmp/functorch/_src/aot_autograd.py", line 893, in forward
return compiled_f(
File "/data/users/ezyang/pytorch-tmp/functorch/_src/aot_autograd.py", line 880, in new_func
compiled_fn = create_aot_dispatcher_function(
File "/data/users/ezyang/pytorch-tmp/functorch/_src/aot_autograd.py", line 600, in create_aot_dispatcher_function
aot_dispatch_autograd(flat_fn, fake_flat_tensor_args, aot_config)
File "/data/users/ezyang/pytorch-tmp/functorch/_src/aot_autograd.py", line 434, in aot_dispatch_autograd
fx_g = make_fx(joint_forward_backward, aot_config.decompositions)(*joint_inputs)
File "/data/users/ezyang/pytorch-tmp/torch/fx/experimental/proxy_tensor.py", line 671, in wrapped
t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs))
File "/data/users/ezyang/pytorch-tmp/torch/_dynamo/eval_frame.py", line 157, in _fn
return fn(*args, **kwargs)
File "/data/users/ezyang/pytorch-tmp/torch/fx/experimental/proxy_tensor.py", line 422, in dispatch_trace
graph = tracer.trace(root, concrete_args)
File "/data/users/ezyang/pytorch-tmp/torch/_dynamo/eval_frame.py", line 157, in _fn
return fn(*args, **kwargs)
File "/data/users/ezyang/pytorch-tmp/torch/fx/_symbolic_trace.py", line 739, in trace
(self.create_arg(fn(*args)),),
File "/data/users/ezyang/pytorch-tmp/torch/fx/_symbolic_trace.py", line 614, in flatten_fn
tree_out = root_fn(*tree_args)
File "/data/users/ezyang/pytorch-tmp/torch/fx/experimental/proxy_tensor.py", line 439, in wrapped
out = f(*tensors)
File "/data/users/ezyang/pytorch-tmp/functorch/_src/aot_autograd.py", line 180, in joint_forward_backward
backward_out = torch.autograd.grad(
File "/data/users/ezyang/pytorch-tmp/torch/autograd/__init__.py", line 300, in grad
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
NotImplementedError: Cannot access storage of SparseTensorImpl
TorchDynamo optimized model failed to run because of following error
cuda train dlrm FAIL
/data/users/ezyang/pytorch-tmp/torch/utils/tensorboard/__init__.py:4: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
if not hasattr(tensorboard, "__version__") or LooseVersion(
/home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/gym/core.py:317: DeprecationWarning: WARN: Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.
deprecation(
Running drq...
cuda train drq PASS
Running fastNLP_Bert...
cuda train fastNLP_Bert PASS
Running functorch_dp_cifar10...
cuda train functorch_dp_cifar10 PASS
Running functorch_maml_omniglot...
cuda train functorch_maml_omniglot PASS
Running hf_Albert...
cuda train hf_Albert PASS
Running hf_Bart...
cuda train hf_Bart PASS
Running hf_Bert...
cuda train hf_Bert PASS
Running hf_BigBird...
[2022-10-30 06:36:13,648] torch._dynamo.optimizations.training: [WARNING] Unable to use Aot Autograd because of presence of mutation
[2022-10-30 06:36:16,197] torch._dynamo.optimizations.training: [WARNING] Unable to use Aot Autograd because of presence of mutation
[2022-10-30 06:36:18,691] torch._dynamo.optimizations.training: [WARNING] Unable to use Aot Autograd because of presence of mutation
[2022-10-30 06:36:20,979] torch._dynamo.optimizations.training: [WARNING] Unable to use Aot Autograd because of presence of mutation
[2022-10-30 06:36:23,510] torch._dynamo.optimizations.training: [WARNING] Unable to use Aot Autograd because of presence of mutation
[2022-10-30 06:36:25,824] torch._dynamo.optimizations.training: [WARNING] Unable to use Aot Autograd because of presence of mutation
[2022-10-30 06:36:28,390] torch._dynamo.optimizations.training: [WARNING] Unable to use Aot Autograd because of presence of mutation
[2022-10-30 06:36:30,734] torch._dynamo.optimizations.training: [WARNING] Unable to use Aot Autograd because of presence of mutation
[2022-10-30 06:36:33,075] torch._dynamo.optimizations.training: [WARNING] Unable to use Aot Autograd because of presence of mutation
[2022-10-30 06:36:35,647] torch._dynamo.optimizations.training: [WARNING] Unable to use Aot Autograd because of presence of mutation
[2022-10-30 06:36:37,965] torch._dynamo.optimizations.training: [WARNING] Unable to use Aot Autograd because of presence of mutation
[2022-10-30 06:36:40,288] torch._dynamo.optimizations.training: [WARNING] Unable to use Aot Autograd because of presence of mutation
[2022-10-30 06:37:00,601] torch._dynamo.utils: [ERROR] RMSE (res-fp64): 0.00419, (ref-fp64): 0.00000 and shape=torch.Size([2, 768])
[2022-10-30 06:37:00,601] torch._dynamo.utils: [ERROR] Accuracy failed for key name bert.embeddings.token_type_embeddings.weight.grad
cuda train hf_BigBird FAIL
Running hf_DistilBert...
cuda train hf_DistilBert PASS
Running hf_GPT2...
cuda train hf_GPT2 PASS
Running hf_GPT2_large...
Traceback (most recent call last):
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/torchbench.py", line 349, in <module>
main(TorchBenchmarkRunner(), original_dir)
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/common.py", line 1775, in main
runner.run_one_model(
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/common.py", line 768, in inner
return fn(self, *args, **kwargs)
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/common.py", line 1213, in run_one_model
status = self.check_accuracy(
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/common.py", line 1023, in check_accuracy
correct_rerun_result = self.run_n_iterations(
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/common.py", line 946, in run_n_iterations
return self.model_iter_fn(mod, inputs, collect_outputs=True)
File "/data/users/ezyang/pytorch-tmp/benchmarks/dynamo/torchbench.py", line 335, in forward_and_backward_pass
pred = mod(*cloned_inputs)
File "/data/users/ezyang/pytorch-tmp/torch/nn/modules/module.py", line 1423, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1048, in forward
transformer_outputs = self.transformer(
File "/data/users/ezyang/pytorch-tmp/torch/nn/modules/module.py", line 1423, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 891, in forward
outputs = block(
File "/data/users/ezyang/pytorch-tmp/torch/nn/modules/module.py", line 1423, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 391, in forward
attn_outputs = self.attn(
File "/data/users/ezyang/pytorch-tmp/torch/nn/modules/module.py", line 1423, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 332, in forward
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
File "/home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 192, in _attn
attn_weights = attn_weights / (value.size(-1) ** 0.5)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 40.00 MiB (GPU 0; 39.59 GiB total capacity; 36.48 GiB already allocated; 14.44 MiB free; 37.57 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
cuda train hf_GPT2_large FAIL
Running hf_Longformer...
[2022-10-30 06:39:01,644] torch._dynamo.variables.builtin: [WARNING] incorrect arg count <bound method BuiltinVariable._call_min_max of BuiltinVariable(max)> missing a required argument: 'b'
[2022-10-30 06:39:02,132] torch._dynamo.variables.builtin: [WARNING] incorrect arg count <bound method BuiltinVariable._call_min_max of BuiltinVariable(max)> missing a required argument: 'b'
incomplete graph:
class joint_forward_backward(torch.nn.Module):
def forward(self, orig_primals, orig_tangents):
orig_primals_1: f32[768, 768], [768, 1], orig_primals_2: f32[768], [1], orig_primals_3: f32[768, 768], [768, 1], orig_primals_4: f32[768], [1], orig_primals_5: f32[768, 768], [768, 1], orig_primals_6: f32[768], [1], orig_primals_7: f32[768, 768], [768, 1], orig_primals_8: f32[768], [1], orig_primals_9: f32[768], [1], orig_primals_10: f32[768], [1], orig_primals_11: f32[3072, 768], [768, 1], orig_primals_12: f32[3072], [1], orig_primals_13: f32[768, 3072], [3072, 1], orig_primals_14: f32[768], [1], orig_primals_15: f32[768], [1], orig_primals_16: f32[768], [1], orig_primals_17: f32[768, 768], [768, 1], orig_primals_18: f32[768], [1], orig_primals_19: f32[768, 768], [768, 1], orig_primals_20: f32[768], [1], orig_primals_21: f32[768, 768], [768, 1], orig_primals_22: f32[768], [1], orig_primals_23: f32[768, 768], [768, 1], orig_primals_24: f32[768], [1], orig_primals_25: f32[768], [1], orig_primals_26: f32[768], [1], orig_primals_27: f32[3072, 768], [768, 1], orig_primals_28: f32[3072], [1], orig_primals_29: f32[768, 3072], [3072, 1], orig_primals_30: f32[768], [1], orig_primals_31: f32[768], [1], orig_primals_32: f32[768], [1], orig_primals_33: f32[768, 768], [768, 1], orig_primals_34: f32[768], [1], orig_primals_35: f32[768, 768], [768, 1], orig_primals_36: f32[768], [1], orig_primals_37: f32[768, 768], [768, 1], orig_primals_38: f32[768], [1], orig_primals_39: f32[768, 768], [768, 1], orig_primals_40: f32[768], [1], orig_primals_41: f32[768], [1], orig_primals_42: f32[768], [1], orig_primals_43: f32[3072, 768], [768, 1], orig_primals_44: f32[3072], [1], orig_primals_45: f32[768, 3072], [3072, 1], orig_primals_46: f32[768], [1], orig_primals_47: f32[768], [1], orig_primals_48: f32[768], [1], orig_primals_49: f32[768, 768], [768, 1], orig_primals_50: f32[768], [1], orig_primals_51: f32[768, 768], [768, 1], orig_primals_52: f32[768], [1], orig_primals_53: f32[768, 768], [768, 1], orig_primals_54: f32[768], [1], orig_primals_55: f32[768, 768], [768, 1], orig_primals_56: f32[768], [1], orig_primals_57: f32[768], [1], orig_primals_58: f32[768], [1], orig_primals_59: f32[3072, 768], [768, 1], orig_primals_60: f32[3072], [1], orig_primals_61: f32[768, 3072], [3072, 1], orig_primals_62: f32[768], [1], orig_primals_63: f32[768], [1], orig_primals_64: f32[768], [1], orig_primals_65: f32[768, 768], [768, 1], orig_primals_66: f32[768], [1], orig_primals_67: f32[768, 768], [768, 1], orig_primals_68: f32[768], [1], orig_primals_69: f32[768, 768], [768, 1], orig_primals_70: f32[768], [1], orig_primals_71: f32[768, 768], [768, 1], orig_primals_72: f32[768], [1], orig_primals_73: f32[768], [1], orig_primals_74: f32[768], [1], orig_primals_75: f32[3072, 768], [768, 1], orig_primals_76: f32[3072], [1], orig_primals_77: f32[768, 3072], [3072, 1], orig_primals_78: f32[768], [1], orig_primals_79: f32[768], [1], orig_primals_80: f32[768], [1], orig_primals_81: f32[768, 768], [768, 1], orig_primals_82: f32[768], [1], orig_primals_83: f32[768, 768], [768, 1], orig_primals_84: f32[768], [1], orig_primals_85: f32[768, 768], [768, 1], orig_primals_86: f32[768], [1], orig_primals_87: f32[768, 768], [768, 1], orig_primals_88: f32[768], [1], orig_primals_89: f32[768], [1], orig_primals_90: f32[768], [1], orig_primals_91: f32[3072, 768], [768, 1], orig_primals_92: f32[3072], [1], orig_primals_93: f32[768, 3072], [3072, 1], orig_primals_94: f32[768], [1], orig_primals_95: f32[768], [1], orig_primals_96: f32[768], [1], orig_primals_97: f32[768, 768], [768, 1], orig_primals_98: f32[768], [1], orig_primals_99: f32[768, 768], [768, 1], orig_primals_100: f32[768], [1], orig_primals_101: f32[768, 768], [768, 1], orig_primals_102: f32[768], [1], orig_primals_103: f32[768, 768], [768, 1], orig_primals_104: f32[768], [1], orig_primals_105: f32[768], [1], orig_primals_106: f32[768], [1], orig_primals_107: f32[3072, 768], [768, 1], orig_primals_108: f32[3072], [1], orig_primals_109: f32[768, 3072], [3072, 1], orig_primals_110: f32[768], [1], orig_primals_111: f32[768], [1], orig_primals_112: f32[768], [1], orig_primals_113: f32[768, 768], [768, 1], orig_primals_114: f32[768], [1], orig_primals_115: f32[768, 768], [768, 1], orig_primals_116: f32[768], [1], orig_primals_117: f32[768, 768], [768, 1], orig_primals_118: f32[768], [1], orig_primals_119: f32[768, 768], [768, 1], orig_primals_120: f32[768], [1], orig_primals_121: f32[768], [1], orig_primals_122: f32[768], [1], orig_primals_123: f32[3072, 768], [768, 1], orig_primals_124: f32[3072], [1], orig_primals_125: f32[768, 3072], [3072, 1], orig_primals_126: f32[768], [1], orig_primals_127: f32[768], [1], orig_primals_128: f32[768], [1], orig_primals_129: f32[768, 768], [768, 1], orig_primals_130: f32[768], [1], orig_primals_131: f32[768, 768], [768, 1], orig_primals_132: f32[768], [1], orig_primals_133: f32[768, 768], [768, 1], orig_primals_134: f32[768], [1], orig_primals_135: f32[768, 768], [768, 1], orig_primals_136: f32[768], [1], orig_primals_137: f32[768], [1], orig_primals_138: f32[768], [1], orig_primals_139: f32[3072, 768], [768, 1], orig_primals_140: f32[3072], [1], orig_primals_141: f32[768, 3072], [3072, 1], orig_primals_142: f32[768], [1], orig_primals_143: f32[768], [1], orig_primals_144: f32[768], [1], orig_primals_145: f32[768, 768], [768, 1], orig_primals_146: f32[768], [1], orig_primals_147: f32[768, 768], [768, 1], orig_primals_148: f32[768], [1], orig_primals_149: f32[768, 768], [768, 1], orig_primals_150: f32[768], [1], orig_primals_151: f32[768, 768], [768, 1], orig_primals_152: f32[768], [1], orig_primals_153: f32[768], [1], orig_primals_154: f32[768], [1], orig_primals_155: f32[3072, 768], [768, 1], orig_primals_156: f32[3072], [1], orig_primals_157: f32[768, 3072], [3072, 1], orig_primals_158: f32[768], [1], orig_primals_159: f32[768], [1], orig_primals_160: f32[768], [1], orig_primals_161: f32[768, 768], [768, 1], orig_primals_162: f32[768], [1], orig_primals_163: f32[768, 768], [768, 1], orig_primals_164: f32[768], [1], orig_primals_165: f32[768, 768], [768, 1], orig_primals_166: f32[768], [1], orig_primals_167: f32[768, 768], [768, 1], orig_primals_168: f32[768], [1], orig_primals_169: f32[768], [1], orig_primals_170: f32[768], [1], orig_primals_171: f32[3072, 768], [768, 1], orig_primals_172: f32[3072], [1], orig_primals_173: f32[768, 3072], [3072, 1], orig_primals_174: f32[768], [1], orig_primals_175: f32[768], [1], orig_primals_176: f32[768], [1], orig_primals_177: f32[768, 768], [768, 1], orig_primals_178: f32[768], [1], orig_primals_179: f32[768, 768], [768, 1], orig_primals_180: f32[768], [1], orig_primals_181: f32[768, 768], [768, 1], orig_primals_182: f32[768], [1], orig_primals_183: f32[768, 768], [768, 1], orig_primals_184: f32[768], [1], orig_primals_185: f32[768], [1], orig_primals_186: f32[768], [1], orig_primals_187: f32[3072, 768], [768, 1], orig_primals_188: f32[3072], [1], orig_primals_189: f32[768, 3072], [3072, 1], orig_primals_190: f32[768], [1], orig_primals_191: f32[768], [1], orig_primals_192: f32[768], [1], orig_primals_193: f32[2, 1024, 768], [786432, 768, 1], orig_primals_194: f32[2, 1024], [1024, 1], orig_primals_195: b8[2, 1024], [1024, 1], orig_tangents_1: f32[2, 1024, 768], [786432, 768, 1], = fx_pytree.tree_flatten_spec([orig_primals, orig_tangents], self._in_spec)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:562, code: hidden_states = hidden_states.transpose(0, 1)
transpose: f32[1024, 2, 768], [768, 786432, 1] = torch.ops.aten.transpose.int(orig_primals_193, 0, 1)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
t: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_1); orig_primals_1 = None
clone: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose, memory_format = torch.contiguous_format)
sym_size: Sym(1024) = torch.ops.aten.sym_size(orig_primals_193, 1)
sym_size_1: Sym(2) = torch.ops.aten.sym_size(orig_primals_193, 0)
# No stacktrace found for following nodes
mul: Sym(2048) = sym_size * sym_size_1
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
sym_size_2: Sym(768) = torch.ops.aten.sym_size(orig_primals_193, 2)
view: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone, [mul, sym_size_2]); clone = mul = None
mm: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view, t); view = t = None
view_1: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm, [sym_size, sym_size_1, 768]); mm = None
add: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_1, orig_primals_2); view_1 = orig_primals_2 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
t_1: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_3); orig_primals_3 = None
clone_1: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose, memory_format = torch.contiguous_format)
# No stacktrace found for following nodes
mul_1: Sym(2048) = sym_size * sym_size_1
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
view_2: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_1, [mul_1, sym_size_2]); clone_1 = mul_1 = None
mm_1: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_2, t_1); view_2 = t_1 = None
view_3: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_1, [sym_size, sym_size_1, 768]); mm_1 = None
add_1: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_3, orig_primals_4); view_3 = orig_primals_4 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
t_2: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_5); orig_primals_5 = None
clone_2: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose, memory_format = torch.contiguous_format); transpose = None
# No stacktrace found for following nodes
mul_2: Sym(2048) = sym_size * sym_size_1
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
view_4: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_2, [mul_2, sym_size_2]); clone_2 = mul_2 = sym_size_2 = None
mm_2: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_4, t_2); view_4 = t_2 = None
view_5: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_2, [sym_size, sym_size_1, 768]); mm_2 = sym_size = sym_size_1 = None
add_2: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_5, orig_primals_6); view_5 = orig_primals_6 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:575, code: query_vectors /= math.sqrt(self.head_dim)
div: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.div.Tensor(add, 8.0); add = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:577, code: query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_6: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div, [1024, 2, 12, 64])
transpose_1: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_6, 0, 1); view_6 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:578, code: key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_7: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_1, [1024, 2, 12, 64]); add_1 = None
transpose_2: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_7, 0, 1); view_7 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_3: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_1, 1, 2); transpose_1 = None
view_8: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_3, [24, 1024, 64]); transpose_3 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_4: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_2, 1, 2); transpose_2 = None
view_9: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_4, [24, 1024, 64]); transpose_4 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_10: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_8, [24, 2, 512, 64]); view_8 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_10, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_10 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_11: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_9, [24, 2, 512, 64]); view_9 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_1: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_11, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_11 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided, 4); as_strided = None
permute: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze, [0, 1, 2, 4, 3]); unsqueeze = None
unsqueeze_1: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_1, 4); as_strided_1 = None
permute_1: f32[24, 3, 1, 512, 64], [64, 393216, 0, 1536, 1] = torch.ops.aten.permute.default(unsqueeze_1, [0, 1, 4, 2, 3]); unsqueeze_1 = None
permute_2: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute, [0, 1, 2, 4, 3]); permute = None
view_12: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div, [1024, 2, 12, 64]); div = None
transpose_5: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_12, 0, 1); view_12 = None
transpose_6: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_5, 1, 2); transpose_5 = None
view_13: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_6, [24, 1024, 64]); transpose_6 = None
view_14: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_13, [24, 2, 512, 64]); view_13 = None
as_strided_2: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_14, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_14 = None
unsqueeze_2: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_2, 4); as_strided_2 = None
permute_3: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze_2, [0, 1, 2, 4, 3]); unsqueeze_2 = None
permute_4: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute_3, [0, 1, 2, 4, 3]); permute_3 = None
clone_3: f32[24, 3, 512, 64, 1], [98304, 32768, 64, 1, 1] = torch.ops.aten.clone.default(permute_4, memory_format = torch.contiguous_format); permute_4 = None
_unsafe_view: f32[72, 512, 64], [32768, 64, 1] = torch.ops.aten._unsafe_view.default(clone_3, [72, 512, 64]); clone_3 = None
permute_5: f32[24, 3, 64, 512, 1], [64, 393216, 1, 1536, 0] = torch.ops.aten.permute.default(permute_1, [0, 1, 4, 3, 2]); permute_1 = None
clone_4: f32[24, 3, 64, 512, 1], [98304, 32768, 512, 1, 1] = torch.ops.aten.clone.default(permute_5, memory_format = torch.contiguous_format); permute_5 = None
_unsafe_view_1: f32[72, 64, 512], [32768, 512, 1] = torch.ops.aten._unsafe_view.default(clone_4, [72, 64, 512]); clone_4 = None
bmm: f32[72, 512, 512], [262144, 512, 1] = torch.ops.aten.bmm.default(_unsafe_view, _unsafe_view_1); _unsafe_view = _unsafe_view_1 = None
view_15: f32[24, 3, 512, 1, 512], [786432, 262144, 512, 512, 1] = torch.ops.aten.view.default(bmm, [24, 3, 512, 1, 512]); bmm = None
permute_6: f32[24, 3, 512, 512, 1], [786432, 262144, 512, 1, 512] = torch.ops.aten.permute.default(view_15, [0, 1, 2, 4, 3]); view_15 = None
view_16: f32[24, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(permute_6, [24, 3, 512, 512]); permute_6 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd: f32[24, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_16, [0, 0, 0, 1], 0.0); view_16 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_17: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd, [24, 3, 512, 513]); constant_pad_nd = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_17, [24, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_1: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_17, 0, 0, 9223372036854775807)
slice_2: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807); slice_1 = None
slice_3: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_2, 2, 0, 256); slice_2 = None
slice_4: f32[24, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_3, 3, 0, 257); slice_3 = None
slice_5: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty, 0, 0, 9223372036854775807)
slice_6: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_5, 1, 0, -1); slice_5 = None
slice_7: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_6, 2, 0, 9223372036854775807); slice_6 = None
slice_8: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_7, 3, 256, 9223372036854775807); slice_7 = None
slice_9: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty, 0, 0, 9223372036854775807)
slice_10: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_9, 1, 0, -1); slice_9 = None
slice_11: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_10, 2, 0, 9223372036854775807); slice_10 = None
slice_12: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_11, 3, 256, 9223372036854775807); slice_11 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_13: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_17, 0, 0, 9223372036854775807)
select: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_13, 1, -1); slice_13 = None
slice_14: f32[24, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select, 1, 256, 9223372036854775807); select = None
slice_15: f32[24, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_14, 2, 0, 257); slice_14 = None
slice_16: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty, 0, 0, 9223372036854775807)
select_1: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_16, 1, -1); slice_16 = None
slice_17: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_1, 1, 0, 9223372036854775807); select_1 = None
slice_18: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_17, 2, 256, 9223372036854775807); slice_17 = None
slice_19: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty, 0, 0, 9223372036854775807)
slice_20: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_19, 1, 0, -1)
slice_21: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_20, 2, 0, 9223372036854775807)
slice_scatter: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_21, slice_4, 3, 256, 9223372036854775807); slice_21 = slice_4 = None
slice_scatter_1: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_20, slice_scatter, 2, 0, 9223372036854775807); slice_20 = slice_scatter = None
slice_scatter_2: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_19, slice_scatter_1, 1, 0, -1); slice_19 = slice_scatter_1 = None
slice_scatter_3: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty, slice_scatter_2, 0, 0, 9223372036854775807); slice_scatter_2 = None
slice_22: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_3, 0, 0, 9223372036854775807)
select_2: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_22, 1, -1); slice_22 = None
slice_23: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_2, 1, 0, 9223372036854775807); select_2 = None
slice_24: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_23, 2, 256, 9223372036854775807); slice_23 = None
slice_25: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty, 0, 0, 9223372036854775807)
select_3: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_25, 1, -1); slice_25 = None
slice_26: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_3, 1, 0, 9223372036854775807); select_3 = None
slice_27: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_26, 2, 256, 9223372036854775807); slice_26 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_28: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_17, 0, 0, 9223372036854775807)
slice_29: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_28, 1, 0, 9223372036854775807); slice_28 = None
slice_30: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_29, 2, -257, -1); slice_29 = None
slice_31: f32[24, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_30, 3, 257, 9223372036854775807); slice_30 = None
slice_32: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty, 0, 0, 9223372036854775807)
slice_33: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_32, 1, 1, 9223372036854775807); slice_32 = None
slice_34: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_33, 2, 0, 9223372036854775807); slice_33 = None
slice_35: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_34, 3, 0, 256); slice_34 = None
slice_36: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_3, 0, 0, 9223372036854775807)
select_4: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_36, 1, -1)
slice_37: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_4, 1, 0, 9223372036854775807)
slice_scatter_4: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_37, slice_15, 2, 256, 9223372036854775807); slice_37 = slice_15 = None
slice_scatter_5: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_4, slice_scatter_4, 1, 0, 9223372036854775807); select_4 = slice_scatter_4 = None
select_scatter: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_36, slice_scatter_5, 1, -1); slice_36 = slice_scatter_5 = None
slice_scatter_6: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_3, select_scatter, 0, 0, 9223372036854775807); slice_scatter_3 = select_scatter = None
slice_38: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_6, 0, 0, 9223372036854775807)
slice_39: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_38, 1, 1, 9223372036854775807); slice_38 = None
slice_40: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_39, 2, 0, 9223372036854775807); slice_39 = None
slice_41: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_40, 3, 0, 256); slice_40 = None
slice_42: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty, 0, 0, 9223372036854775807)
slice_43: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_42, 1, 1, 9223372036854775807); slice_42 = None
slice_44: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_43, 2, 0, 9223372036854775807); slice_43 = None
slice_45: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_44, 3, 0, 256); slice_44 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_46: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_17, 0, 0, 9223372036854775807); view_17 = None
select_5: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_46, 1, 0); slice_46 = None
slice_47: f32[24, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_5, 1, 0, 255); select_5 = None
slice_48: f32[24, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_47, 2, -255, 9223372036854775807); slice_47 = None
slice_49: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty, 0, 0, 9223372036854775807)
select_6: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_49, 1, 0); slice_49 = None
slice_50: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_6, 1, 1, 256); select_6 = None
slice_51: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_50, 2, 1, 256); slice_50 = None
slice_52: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_6, 0, 0, 9223372036854775807)
slice_53: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_52, 1, 1, 9223372036854775807)
slice_54: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_53, 2, 0, 9223372036854775807)
slice_scatter_7: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_54, slice_31, 3, 0, 256); slice_54 = slice_31 = None
slice_scatter_8: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_53, slice_scatter_7, 2, 0, 9223372036854775807); slice_53 = slice_scatter_7 = None
slice_scatter_9: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_52, slice_scatter_8, 1, 1, 9223372036854775807); slice_52 = slice_scatter_8 = None
slice_scatter_10: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_6, slice_scatter_9, 0, 0, 9223372036854775807); slice_scatter_6 = slice_scatter_9 = None
slice_55: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_10, 0, 0, 9223372036854775807)
select_7: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_55, 1, 0); slice_55 = None
slice_56: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_7, 1, 1, 256); select_7 = None
slice_57: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_56, 2, 1, 256); slice_56 = None
slice_58: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty, 0, 0, 9223372036854775807)
select_8: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_58, 1, 0); slice_58 = None
slice_59: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_8, 1, 1, 256); select_8 = None
slice_60: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_59, 2, 1, 256); slice_59 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_18: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty, [2, 12, 1024, 513])
transpose_7: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_18, 2, 1); view_18 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_61: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_10, 0, 0, 9223372036854775807)
select_9: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_61, 1, 0)
slice_62: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_9, 1, 1, 256)
slice_scatter_11: f32[24, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_62, slice_48, 2, 1, 256); slice_62 = slice_48 = None
slice_scatter_12: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_9, slice_scatter_11, 1, 1, 256); select_9 = slice_scatter_11 = None
select_scatter_1: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_61, slice_scatter_12, 1, 0); slice_61 = slice_scatter_12 = None
slice_scatter_13: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_10, select_scatter_1, 0, 0, 9223372036854775807); slice_scatter_10 = select_scatter_1 = None
view_19: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_13, [2, 12, 1024, 513])
transpose_8: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_19, 2, 1); view_19 = None
new_ones: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_8, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones); new_ones = None
flip: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril, [0]); tril = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_3: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip, 0); flip = None
slice_63: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_3, 1, 0, 9223372036854775807); unsqueeze_3 = None
unsqueeze_4: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_63, 2); slice_63 = None
slice_64: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_4, 3, 0, 9223372036854775807); unsqueeze_4 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_1: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_64, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_65: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_8, 0, 0, 9223372036854775807)
slice_66: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_65, 1, 0, 256); slice_65 = None
slice_67: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_66, 2, 0, 9223372036854775807); slice_66 = None
slice_68: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_67, 3, 0, 257); slice_67 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_64, [2, 256, 12, 257]); slice_64 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand, 1); expand = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_20: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_13, [2, 12, 1024, 513])
transpose_9: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_20, 2, 1); view_20 = None
slice_69: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_9, 0, 0, 9223372036854775807); transpose_9 = None
slice_70: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_69, 1, 0, 256); slice_69 = None
slice_71: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_70, 2, 0, 9223372036854775807); slice_70 = None
slice_72: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_71, 3, 0, 257); slice_71 = None
masked_fill: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_72, eq, -inf); slice_72 = eq = None
view_21: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty, [2, 12, 1024, 513])
transpose_10: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_21, 2, 1); view_21 = None
slice_73: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_10, 0, 0, 9223372036854775807); transpose_10 = None
slice_74: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_73, 1, 0, 256); slice_73 = None
slice_75: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_74, 2, 0, 9223372036854775807); slice_74 = None
slice_76: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_75, 3, 0, 257); slice_75 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
view_22: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty, [2, 12, 1024, 513])
transpose_11: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_22, 2, 1); view_22 = None
slice_77: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_8, 0, 0, 9223372036854775807); transpose_8 = None
slice_78: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_77, 1, -256, 9223372036854775807); slice_77 = None
slice_79: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_78, 2, 0, 9223372036854775807); slice_78 = None
slice_80: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_79, 3, -257, 9223372036854775807); slice_79 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_1: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_1, [2, 256, 12, 257]); flip_1 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_1: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand_1, 1); expand_1 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_23: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_13, [2, 12, 1024, 513]); slice_scatter_13 = None
transpose_12: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_23, 2, 1); view_23 = None
slice_81: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_12, 0, 0, 9223372036854775807)
slice_82: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_81, 1, 0, 256)
slice_83: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_82, 2, 0, 9223372036854775807)
slice_scatter_14: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_83, masked_fill, 3, 0, 257); slice_83 = masked_fill = None
slice_scatter_15: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_82, slice_scatter_14, 2, 0, 9223372036854775807); slice_82 = slice_scatter_14 = None
slice_scatter_16: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_81, slice_scatter_15, 1, 0, 256); slice_81 = slice_scatter_15 = None
slice_scatter_17: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_12, slice_scatter_16, 0, 0, 9223372036854775807); transpose_12 = slice_scatter_16 = None
transpose_13: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_17, 2, 1); slice_scatter_17 = None
view_24: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_13, [24, 4, 256, 513]); transpose_13 = None
view_25: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_24, [2, 12, 1024, 513])
transpose_14: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_25, 2, 1); view_25 = None
slice_84: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_14, 0, 0, 9223372036854775807); transpose_14 = None
slice_85: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_84, 1, -256, 9223372036854775807); slice_84 = None
slice_86: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_85, 2, 0, 9223372036854775807); slice_85 = None
slice_87: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_86, 3, -257, 9223372036854775807); slice_86 = None
masked_fill_1: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_87, eq_1, -inf); slice_87 = eq_1 = None
view_26: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty, [2, 12, 1024, 513])
transpose_15: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_26, 2, 1); view_26 = None
slice_88: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_15, 0, 0, 9223372036854775807); transpose_15 = None
slice_89: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_88, 1, -256, 9223372036854775807); slice_88 = None
slice_90: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_89, 2, 0, 9223372036854775807); slice_89 = None
slice_91: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_90, 3, -257, 9223372036854775807); slice_90 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
ne: b8[2, 1024], [1024, 1] = torch.ops.aten.ne.Scalar(orig_primals_194, 0)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:585, code: remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]
slice_92: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(ne, 0, 0, 9223372036854775807); ne = None
slice_93: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_92, 1, 0, 9223372036854775807); slice_92 = None
unsqueeze_5: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_93, 2); slice_93 = None
unsqueeze_6: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_5, 3); unsqueeze_5 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:588, code: float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
_to_copy: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten._to_copy.default(unsqueeze_6, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))
masked_fill_2: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.masked_fill.Scalar(_to_copy, unsqueeze_6, -10000.0); _to_copy = unsqueeze_6 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:593, code: float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
new_ones_1: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.new_ones.default(masked_fill_2, [2, 1024, 1, 1], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_16: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(new_ones_1, 1, 2); new_ones_1 = None
view_27: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_16, [2, 1024, 1]); transpose_16 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_17: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(masked_fill_2, 1, 2); masked_fill_2 = None
view_28: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_17, [2, 1024, 1]); transpose_17 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_29: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_27, [2, 2, 512, 1]); view_27 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_3: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_29, [2, 3, 512, 1], [1024, 256, 1, 1]); view_29 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_30: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_28, [2, 2, 512, 1]); view_28 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_4: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_30, [2, 3, 512, 1], [1024, 256, 1, 1]); view_30 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze_7: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_3, 4); as_strided_3 = None
permute_7: f32[2, 3, 512, 1, 1], [1024, 256, 1, 0, 1] = torch.ops.aten.permute.default(unsqueeze_7, [0, 1, 2, 4, 3]); unsqueeze_7 = None
unsqueeze_8: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_4, 4); as_strided_4 = None
permute_8: f32[2, 3, 1, 512, 1], [1024, 256, 0, 1, 1] = torch.ops.aten.permute.default(unsqueeze_8, [0, 1, 4, 2, 3]); unsqueeze_8 = None
mul_3: f32[2, 3, 512, 512, 1], [786432, 262144, 512, 1, 1] = torch.ops.aten.mul.Tensor(permute_7, permute_8); permute_7 = permute_8 = None
view_31: f32[2, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(mul_3, [2, 3, 512, 512]); mul_3 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd_1: f32[2, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_31, [0, 0, 0, 1], 0.0); view_31 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_32: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd_1, [2, 3, 512, 513]); constant_pad_nd_1 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty_1: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_32, [2, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_94: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_32, 0, 0, 9223372036854775807)
slice_95: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_94, 1, 0, 9223372036854775807); slice_94 = None
slice_96: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_95, 2, 0, 256); slice_95 = None
slice_97: f32[2, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_96, 3, 0, 257); slice_96 = None
slice_98: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_1, 0, 0, 9223372036854775807)
slice_99: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_98, 1, 0, -1); slice_98 = None
slice_100: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_99, 2, 0, 9223372036854775807); slice_99 = None
slice_101: f32[2, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_100, 3, 256, 9223372036854775807); slice_100 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_102: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_32, 0, 0, 9223372036854775807)
select_10: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_102, 1, -1); slice_102 = None
slice_103: f32[2, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_10, 1, 256, 9223372036854775807); select_10 = None
slice_104: f32[2, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_103, 2, 0, 257); slice_103 = None
slice_105: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_1, 0, 0, 9223372036854775807)
select_11: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_105, 1, -1); slice_105 = None
slice_106: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_11, 1, 0, 9223372036854775807); select_11 = None
slice_107: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_106, 2, 256, 9223372036854775807); slice_106 = None
slice_108: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_1, 0, 0, 9223372036854775807)
slice_109: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_108, 1, 0, -1)
slice_110: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_109, 2, 0, 9223372036854775807)
slice_scatter_18: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_110, slice_97, 3, 256, 9223372036854775807); slice_110 = slice_97 = None
slice_scatter_19: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_109, slice_scatter_18, 2, 0, 9223372036854775807); slice_109 = slice_scatter_18 = None
slice_scatter_20: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_108, slice_scatter_19, 1, 0, -1); slice_108 = slice_scatter_19 = None
slice_scatter_21: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty_1, slice_scatter_20, 0, 0, 9223372036854775807); slice_scatter_20 = None
slice_111: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_21, 0, 0, 9223372036854775807)
select_12: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_111, 1, -1); slice_111 = None
slice_112: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_12, 1, 0, 9223372036854775807); select_12 = None
slice_113: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_112, 2, 256, 9223372036854775807); slice_112 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_114: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_32, 0, 0, 9223372036854775807)
slice_115: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_114, 1, 0, 9223372036854775807); slice_114 = None
slice_116: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_115, 2, -257, -1); slice_115 = None
slice_117: f32[2, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_116, 3, 257, 9223372036854775807); slice_116 = None
slice_118: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_1, 0, 0, 9223372036854775807)
slice_119: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_118, 1, 1, 9223372036854775807); slice_118 = None
slice_120: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_119, 2, 0, 9223372036854775807); slice_119 = None
slice_121: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_120, 3, 0, 256); slice_120 = None
slice_122: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_21, 0, 0, 9223372036854775807)
select_13: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_122, 1, -1)
slice_123: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_13, 1, 0, 9223372036854775807)
slice_scatter_22: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_123, slice_104, 2, 256, 9223372036854775807); slice_123 = slice_104 = None
slice_scatter_23: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_13, slice_scatter_22, 1, 0, 9223372036854775807); select_13 = slice_scatter_22 = None
select_scatter_2: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_122, slice_scatter_23, 1, -1); slice_122 = slice_scatter_23 = None
slice_scatter_24: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_21, select_scatter_2, 0, 0, 9223372036854775807); slice_scatter_21 = select_scatter_2 = None
slice_124: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_24, 0, 0, 9223372036854775807)
slice_125: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_124, 1, 1, 9223372036854775807); slice_124 = None
slice_126: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_125, 2, 0, 9223372036854775807); slice_125 = None
slice_127: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_126, 3, 0, 256); slice_126 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_128: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_32, 0, 0, 9223372036854775807); view_32 = None
select_14: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_128, 1, 0); slice_128 = None
slice_129: f32[2, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_14, 1, 0, 255); select_14 = None
slice_130: f32[2, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_129, 2, -255, 9223372036854775807); slice_129 = None
slice_131: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_1, 0, 0, 9223372036854775807)
select_15: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_131, 1, 0); slice_131 = None
slice_132: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_15, 1, 1, 256); select_15 = None
slice_133: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_132, 2, 1, 256); slice_132 = None
slice_134: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_24, 0, 0, 9223372036854775807)
slice_135: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_134, 1, 1, 9223372036854775807)
slice_136: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_135, 2, 0, 9223372036854775807)
slice_scatter_25: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_136, slice_117, 3, 0, 256); slice_136 = slice_117 = None
slice_scatter_26: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_135, slice_scatter_25, 2, 0, 9223372036854775807); slice_135 = slice_scatter_25 = None
slice_scatter_27: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_134, slice_scatter_26, 1, 1, 9223372036854775807); slice_134 = slice_scatter_26 = None
slice_scatter_28: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_24, slice_scatter_27, 0, 0, 9223372036854775807); slice_scatter_24 = slice_scatter_27 = None
slice_137: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_28, 0, 0, 9223372036854775807)
select_16: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_137, 1, 0); slice_137 = None
slice_138: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_16, 1, 1, 256); select_16 = None
slice_139: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_138, 2, 1, 256); slice_138 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_33: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_1, [2, 1, 1024, 513]); new_empty_1 = None
transpose_18: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_33, 2, 1); view_33 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_140: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_28, 0, 0, 9223372036854775807)
select_17: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_140, 1, 0)
slice_141: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_17, 1, 1, 256)
slice_scatter_29: f32[2, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_141, slice_130, 2, 1, 256); slice_141 = slice_130 = None
slice_scatter_30: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_17, slice_scatter_29, 1, 1, 256); select_17 = slice_scatter_29 = None
select_scatter_3: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_140, slice_scatter_30, 1, 0); slice_140 = slice_scatter_30 = None
slice_scatter_31: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_28, select_scatter_3, 0, 0, 9223372036854775807); slice_scatter_28 = select_scatter_3 = None
view_34: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_31, [2, 1, 1024, 513])
transpose_19: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_34, 2, 1); view_34 = None
new_ones_2: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_19, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril_1: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones_2); new_ones_2 = None
flip_2: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril_1, [0]); tril_1 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_9: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip_2, 0); flip_2 = None
slice_142: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_9, 1, 0, 9223372036854775807); unsqueeze_9 = None
unsqueeze_10: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_142, 2); slice_142 = None
slice_143: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_10, 3, 0, 9223372036854775807); unsqueeze_10 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_3: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_143, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_144: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_19, 0, 0, 9223372036854775807)
slice_145: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_144, 1, 0, 256); slice_144 = None
slice_146: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_145, 2, 0, 9223372036854775807); slice_145 = None
slice_147: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_146, 3, 0, 257); slice_146 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand_2: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_143, [2, 256, 1, 257]); slice_143 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_2: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_2, 1); expand_2 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_35: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_31, [2, 1, 1024, 513])
transpose_20: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_35, 2, 1); view_35 = None
slice_148: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_20, 0, 0, 9223372036854775807); transpose_20 = None
slice_149: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_148, 1, 0, 256); slice_148 = None
slice_150: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_149, 2, 0, 9223372036854775807); slice_149 = None
slice_151: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_150, 3, 0, 257); slice_150 = None
masked_fill_3: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_151, eq_2, -inf); slice_151 = eq_2 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
slice_152: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_19, 0, 0, 9223372036854775807); transpose_19 = None
slice_153: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_152, 1, -256, 9223372036854775807); slice_152 = None
slice_154: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_153, 2, 0, 9223372036854775807); slice_153 = None
slice_155: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_154, 3, -257, 9223372036854775807); slice_154 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_3: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_3, [2, 256, 1, 257]); flip_3 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_3: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_3, 1); expand_3 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_36: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_31, [2, 1, 1024, 513]); slice_scatter_31 = None
transpose_21: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_36, 2, 1); view_36 = None
slice_156: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_21, 0, 0, 9223372036854775807)
slice_157: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_156, 1, 0, 256)
slice_158: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_157, 2, 0, 9223372036854775807)
slice_scatter_32: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_158, masked_fill_3, 3, 0, 257); slice_158 = masked_fill_3 = None
slice_scatter_33: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_157, slice_scatter_32, 2, 0, 9223372036854775807); slice_157 = slice_scatter_32 = None
slice_scatter_34: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_156, slice_scatter_33, 1, 0, 256); slice_156 = slice_scatter_33 = None
slice_scatter_35: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_21, slice_scatter_34, 0, 0, 9223372036854775807); transpose_21 = slice_scatter_34 = None
transpose_22: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_35, 2, 1); slice_scatter_35 = None
view_37: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_22, [2, 4, 256, 513]); transpose_22 = None
view_38: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_37, [2, 1, 1024, 513])
transpose_23: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_38, 2, 1); view_38 = None
slice_159: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_23, 0, 0, 9223372036854775807); transpose_23 = None
slice_160: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_159, 1, -256, 9223372036854775807); slice_159 = None
slice_161: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_160, 2, 0, 9223372036854775807); slice_160 = None
slice_162: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_161, 3, -257, 9223372036854775807); slice_161 = None
masked_fill_4: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_162, eq_3, -inf); slice_162 = eq_3 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:597, code: attn_scores += diagonal_mask
view_39: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty, [2, 12, 1024, 513])
transpose_24: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_39, 2, 1); view_39 = None
view_40: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_24, [2, 12, 1024, 513]); view_24 = None
transpose_25: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_40, 2, 1); view_40 = None
slice_163: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_25, 0, 0, 9223372036854775807)
slice_164: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_163, 1, -256, 9223372036854775807)
slice_165: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_164, 2, 0, 9223372036854775807)
slice_scatter_36: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_165, masked_fill_1, 3, -257, 9223372036854775807); slice_165 = masked_fill_1 = None
slice_scatter_37: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_164, slice_scatter_36, 2, 0, 9223372036854775807); slice_164 = slice_scatter_36 = None
slice_scatter_38: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_163, slice_scatter_37, 1, -256, 9223372036854775807); slice_163 = slice_scatter_37 = None
slice_scatter_39: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_25, slice_scatter_38, 0, 0, 9223372036854775807); transpose_25 = slice_scatter_38 = None
transpose_26: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_39, 2, 1); slice_scatter_39 = None
view_41: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_26, [24, 4, 256, 513]); transpose_26 = None
view_42: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_41, [2, 12, 1024, 513]); view_41 = None
transpose_27: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_42, 2, 1); view_42 = None
view_43: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_37, [2, 1, 1024, 513]); view_37 = None
transpose_28: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_43, 2, 1); view_43 = None
slice_166: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_28, 0, 0, 9223372036854775807)
slice_167: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_166, 1, -256, 9223372036854775807)
slice_168: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_167, 2, 0, 9223372036854775807)
slice_scatter_40: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_168, masked_fill_4, 3, -257, 9223372036854775807); slice_168 = masked_fill_4 = None
slice_scatter_41: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_167, slice_scatter_40, 2, 0, 9223372036854775807); slice_167 = slice_scatter_40 = None
slice_scatter_42: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_166, slice_scatter_41, 1, -256, 9223372036854775807); slice_166 = slice_scatter_41 = None
slice_scatter_43: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_28, slice_scatter_42, 0, 0, 9223372036854775807); transpose_28 = slice_scatter_42 = None
transpose_29: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_43, 2, 1); slice_scatter_43 = None
view_44: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_29, [2, 4, 256, 513]); transpose_29 = None
view_45: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_44, [2, 1, 1024, 513]); view_44 = None
transpose_30: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_45, 2, 1); view_45 = None
add_3: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.add.Tensor(transpose_27, transpose_30); transpose_27 = transpose_30 = None
view_46: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty, [2, 12, 1024, 513]); new_empty = None
transpose_31: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_46, 2, 1); view_46 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:635, code: attn_probs = nn.functional.softmax(
_softmax: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten._softmax.default(add_3, -1, False); add_3 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:646, code: attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
slice_169: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(orig_primals_195, 0, 0, 9223372036854775807)
slice_170: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_169, 1, 0, 9223372036854775807); slice_169 = None
unsqueeze_11: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_170, 2); slice_170 = None
unsqueeze_12: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_11, 3); unsqueeze_11 = None
masked_fill_5: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten.masked_fill.Scalar(_softmax, unsqueeze_12, 0.0); _softmax = unsqueeze_12 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:655, code: value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_47: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_2, [1024, 2, 12, 64]); add_2 = None
transpose_32: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_47, 0, 1); view_47 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:883, code: chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
transpose_33: f32[2, 12, 1024, 513], [6303744, 513, 6156, 1] = torch.ops.aten.transpose.int(masked_fill_5, 1, 2); masked_fill_5 = None
clone_5: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.clone.default(transpose_33, memory_format = torch.contiguous_format); transpose_33 = None
_unsafe_view_2: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten._unsafe_view.default(clone_5, [24, 4, 256, 513]); clone_5 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:888, code: value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_34: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_32, 1, 2); transpose_32 = None
view_48: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_34, [24, 1024, 64]); transpose_34 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:891, code: padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1)
constant_pad_nd_2: f32[24, 1536, 64], [98304, 64, 1] = torch.ops.aten.constant_pad_nd.default(view_48, [0, 0, 256, 256], -1.0); view_48 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:902, code: chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride)
as_strided_5: f32[24, 4, 768, 64], [98304, 16384, 64, 1] = torch.ops.aten.as_strided.default(constant_pad_nd_2, [24, 4, 768, 64], [98304, 16384, 64, 1]); constant_pad_nd_2 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:755, code: chunked_hidden_states = nn.functional.pad(
constant_pad_nd_3: f32[24, 4, 256, 770], [788480, 197120, 770, 1] = torch.ops.aten.constant_pad_nd.default(_unsafe_view_2, [0, 257], 0.0); _unsafe_view_2 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:758, code: chunked_hidden_states = chunked_hidden_states.view(
view_49: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.view.default(constant_pad_nd_3, [24, 4, -1]); constant_pad_nd_3 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:761, code: chunked_hidden_states = chunked_hidden_states[
slice_171: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(view_49, 0, 0, 9223372036854775807); view_49 = None
slice_172: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_171, 1, 0, 9223372036854775807); slice_171 = None
slice_173: f32[24, 4, 196864], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_172, 2, 0, -256); slice_172 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:764, code: chunked_hidden_states = chunked_hidden_states.view(
view_50: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.view.default(slice_173, [24, 4, 256, 769]); slice_173 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:767, code: chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
slice_174: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(view_50, 0, 0, 9223372036854775807); view_50 = None
slice_175: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_174, 1, 0, 9223372036854775807)
slice_176: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_175, 2, 0, 9223372036854775807); slice_175 = None
slice_177: f32[24, 4, 256, 768], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_176, 3, 0, -1); slice_176 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:906, code: context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
unsqueeze_13: f32[24, 4, 256, 768, 1], [788480, 197120, 769, 1, 0] = torch.ops.aten.unsqueeze.default(slice_177, 4)
permute_9: f32[24, 4, 256, 1, 768], [788480, 197120, 769, 0, 1] = torch.ops.aten.permute.default(unsqueeze_13, [0, 1, 2, 4, 3]); unsqueeze_13 = None
unsqueeze_14: f32[24, 4, 768, 64, 1], [98304, 16384, 64, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_5, 4); as_strided_5 = None
permute_10: f32[24, 4, 1, 64, 768], [98304, 16384, 0, 1, 64] = torch.ops.aten.permute.default(unsqueeze_14, [0, 1, 4, 3, 2]); unsqueeze_14 = None
permute_11: f32[24, 4, 256, 768, 1], [788480, 197120, 769, 1, 0] = torch.ops.aten.permute.default(permute_9, [0, 1, 2, 4, 3]); permute_9 = None
sym_size_3: Sym(24) = torch.ops.aten.sym_size(slice_174, 0); slice_174 = None
# No stacktrace found for following nodes
mul_4: Sym(96) = sym_size_3 * 4
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:906, code: context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
sym_size_4: Sym(768) = torch.ops.aten.sym_size(slice_177, 3); slice_177 = None
view_51: f32[96, 256, 768], [197120, 769, 1] = torch.ops.aten.view.default(permute_11, [mul_4, 256, sym_size_4]); permute_11 = None
permute_12: f32[24, 4, 768, 64, 1], [98304, 16384, 64, 1, 0] = torch.ops.aten.permute.default(permute_10, [0, 1, 4, 3, 2]); permute_10 = None
clone_6: f32[24, 4, 768, 64, 1], [196608, 49152, 64, 1, 1] = torch.ops.aten.clone.default(permute_12, memory_format = torch.contiguous_format); permute_12 = None
_unsafe_view_3: f32[96, 768, 64], [49152, 64, 1] = torch.ops.aten._unsafe_view.default(clone_6, [mul_4, sym_size_4, 64]); clone_6 = mul_4 = sym_size_4 = None
bmm_1: f32[96, 256, 64], [16384, 64, 1] = torch.ops.aten.bmm.default(view_51, _unsafe_view_3); view_51 = _unsafe_view_3 = None
view_52: f32[24, 4, 256, 1, 64], [65536, 16384, 64, 64, 1] = torch.ops.aten.view.default(bmm_1, [sym_size_3, 4, 256, 1, 64]); bmm_1 = None
permute_13: f32[24, 4, 256, 64, 1], [65536, 16384, 64, 1, 64] = torch.ops.aten.permute.default(view_52, [0, 1, 2, 4, 3])
sym_size_5: Sym(4) = torch.ops.aten.sym_size(view_52, 1); view_52 = None
view_53: f32[24, 4, 256, 64], [65536, 16384, 64, 1] = torch.ops.aten.view.default(permute_13, [sym_size_3, sym_size_5, 256, 64]); permute_13 = sym_size_3 = sym_size_5 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:907, code: return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
view_54: f32[2, 12, 1024, 64], [786432, 65536, 64, 1] = torch.ops.aten.view.default(view_53, [2, 12, 1024, 64]); view_53 = None
transpose_35: f32[2, 1024, 12, 64], [786432, 64, 65536, 1] = torch.ops.aten.transpose.int(view_54, 1, 2)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:674, code: attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
transpose_36: f32[1024, 2, 12, 64], [64, 786432, 65536, 1] = torch.ops.aten.transpose.int(transpose_35, 0, 1); transpose_35 = None
clone_7: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.clone.default(transpose_36, memory_format = torch.contiguous_format); transpose_36 = None
_unsafe_view_4: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten._unsafe_view.default(clone_7, [1024, 2, 768]); clone_7 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:703, code: outputs = (attn_output.transpose(0, 1),)
transpose_37: f32[2, 1024, 768], [768, 1536, 1] = torch.ops.aten.transpose.int(_unsafe_view_4, 0, 1)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1116, code: hidden_states = self.dense(hidden_states)
t_3: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_7); orig_primals_7 = None
clone_8: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.clone.default(transpose_37, memory_format = torch.contiguous_format); transpose_37 = None
sym_size_6: Sym(1024) = torch.ops.aten.sym_size(view_54, 2); view_54 = None
# No stacktrace found for following nodes
mul_5: Sym(2048) = 2 * sym_size_6
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1116, code: hidden_states = self.dense(hidden_states)
sym_size_7: Sym(768) = torch.ops.aten.sym_size(_unsafe_view_4, 2); _unsafe_view_4 = None
view_55: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_8, [mul_5, sym_size_7]); clone_8 = mul_5 = sym_size_7 = None
mm_3: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_55, t_3); view_55 = t_3 = None
view_56: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.view.default(mm_3, [2, sym_size_6, 768]); mm_3 = sym_size_6 = None
add_4: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(view_56, orig_primals_8); orig_primals_8 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1118, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
add_5: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(add_4, orig_primals_193); add_4 = orig_primals_193 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1118, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
native_layer_norm = torch.ops.aten.native_layer_norm.default(add_5, [768], orig_primals_9, orig_primals_10, 1e-05); add_5 = orig_primals_9 = orig_primals_10 = None
getitem: f32[2, 1024, 768], [786432, 768, 1] = native_layer_norm[0]
getitem_1: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm[1]
getitem_2: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm[2]; native_layer_norm = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1182, code: hidden_states = self.dense(hidden_states)
t_4: f32[768, 3072], [1, 768] = torch.ops.aten.t.default(orig_primals_11); orig_primals_11 = None
sym_size_8: Sym(1024) = torch.ops.aten.sym_size(view_56, 1); view_56 = None
# No stacktrace found for following nodes
mul_6: Sym(2048) = 2 * sym_size_8
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1182, code: hidden_states = self.dense(hidden_states)
view_57: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(getitem, [mul_6, 768]); mul_6 = None
addmm: f32[2048, 3072], [3072, 1] = torch.ops.aten.addmm.default(orig_primals_12, view_57, t_4); orig_primals_12 = view_57 = t_4 = None
view_58: f32[2, 1024, 3072], [3145728, 3072, 1] = torch.ops.aten.view.default(addmm, [2, sym_size_8, 3072]); addmm = sym_size_8 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/activations.py:56, code: return self.act(input)
gelu: f32[2, 1024, 3072], [3145728, 3072, 1] = torch.ops.aten.gelu.default(view_58)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1196, code: hidden_states = self.dense(hidden_states)
t_5: f32[3072, 768], [1, 3072] = torch.ops.aten.t.default(orig_primals_13); orig_primals_13 = None
sym_size_9: Sym(1024) = torch.ops.aten.sym_size(view_58, 1); view_58 = None
# No stacktrace found for following nodes
mul_7: Sym(2048) = 2 * sym_size_9
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1196, code: hidden_states = self.dense(hidden_states)
view_59: f32[2048, 3072], [3072, 1] = torch.ops.aten.view.default(gelu, [mul_7, 3072]); gelu = mul_7 = None
addmm_1: f32[2048, 768], [768, 1] = torch.ops.aten.addmm.default(orig_primals_14, view_59, t_5); orig_primals_14 = view_59 = t_5 = None
view_60: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.view.default(addmm_1, [2, sym_size_9, 768]); addmm_1 = sym_size_9 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1198, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
add_6: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(view_60, getitem); getitem = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1198, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
native_layer_norm_1 = torch.ops.aten.native_layer_norm.default(add_6, [768], orig_primals_15, orig_primals_16, 1e-05); add_6 = orig_primals_15 = orig_primals_16 = None
getitem_3: f32[2, 1024, 768], [786432, 768, 1] = native_layer_norm_1[0]
getitem_4: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_1[1]
getitem_5: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_1[2]; native_layer_norm_1 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:562, code: hidden_states = hidden_states.transpose(0, 1)
transpose_38: f32[1024, 2, 768], [768, 786432, 1] = torch.ops.aten.transpose.int(getitem_3, 0, 1)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
t_6: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_17); orig_primals_17 = None
clone_9: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_38, memory_format = torch.contiguous_format)
sym_size_10: Sym(1024) = torch.ops.aten.sym_size(view_60, 1); view_60 = None
# No stacktrace found for following nodes
mul_8: Sym(2048) = sym_size_10 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
view_61: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_9, [mul_8, 768]); clone_9 = mul_8 = None
mm_4: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_61, t_6); view_61 = t_6 = None
view_62: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_4, [sym_size_10, 2, 768]); mm_4 = None
add_7: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_62, orig_primals_18); view_62 = orig_primals_18 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
t_7: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_19); orig_primals_19 = None
clone_10: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_38, memory_format = torch.contiguous_format)
# No stacktrace found for following nodes
mul_9: Sym(2048) = sym_size_10 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
view_63: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_10, [mul_9, 768]); clone_10 = mul_9 = None
mm_5: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_63, t_7); view_63 = t_7 = None
view_64: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_5, [sym_size_10, 2, 768]); mm_5 = None
add_8: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_64, orig_primals_20); view_64 = orig_primals_20 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
t_8: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_21); orig_primals_21 = None
clone_11: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_38, memory_format = torch.contiguous_format); transpose_38 = None
# No stacktrace found for following nodes
mul_10: Sym(2048) = sym_size_10 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
view_65: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_11, [mul_10, 768]); clone_11 = mul_10 = None
mm_6: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_65, t_8); view_65 = t_8 = None
view_66: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_6, [sym_size_10, 2, 768]); mm_6 = sym_size_10 = None
add_9: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_66, orig_primals_22); view_66 = orig_primals_22 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:575, code: query_vectors /= math.sqrt(self.head_dim)
div_1: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.div.Tensor(add_7, 8.0); add_7 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:577, code: query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_67: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div_1, [1024, 2, 12, 64])
transpose_39: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_67, 0, 1); view_67 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:578, code: key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_68: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_8, [1024, 2, 12, 64]); add_8 = None
transpose_40: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_68, 0, 1); view_68 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_41: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_39, 1, 2); transpose_39 = None
view_69: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_41, [24, 1024, 64]); transpose_41 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_42: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_40, 1, 2); transpose_40 = None
view_70: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_42, [24, 1024, 64]); transpose_42 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_71: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_69, [24, 2, 512, 64]); view_69 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_6: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_71, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_71 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_72: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_70, [24, 2, 512, 64]); view_70 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_7: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_72, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_72 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze_15: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_6, 4); as_strided_6 = None
permute_14: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze_15, [0, 1, 2, 4, 3]); unsqueeze_15 = None
unsqueeze_16: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_7, 4); as_strided_7 = None
permute_15: f32[24, 3, 1, 512, 64], [64, 393216, 0, 1536, 1] = torch.ops.aten.permute.default(unsqueeze_16, [0, 1, 4, 2, 3]); unsqueeze_16 = None
permute_16: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute_14, [0, 1, 2, 4, 3]); permute_14 = None
view_73: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div_1, [1024, 2, 12, 64]); div_1 = None
transpose_43: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_73, 0, 1); view_73 = None
transpose_44: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_43, 1, 2); transpose_43 = None
view_74: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_44, [24, 1024, 64]); transpose_44 = None
view_75: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_74, [24, 2, 512, 64]); view_74 = None
as_strided_8: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_75, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_75 = None
unsqueeze_17: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_8, 4); as_strided_8 = None
permute_17: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze_17, [0, 1, 2, 4, 3]); unsqueeze_17 = None
permute_18: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute_17, [0, 1, 2, 4, 3]); permute_17 = None
clone_12: f32[24, 3, 512, 64, 1], [98304, 32768, 64, 1, 1] = torch.ops.aten.clone.default(permute_18, memory_format = torch.contiguous_format); permute_18 = None
_unsafe_view_5: f32[72, 512, 64], [32768, 64, 1] = torch.ops.aten._unsafe_view.default(clone_12, [72, 512, 64]); clone_12 = None
permute_19: f32[24, 3, 64, 512, 1], [64, 393216, 1, 1536, 0] = torch.ops.aten.permute.default(permute_15, [0, 1, 4, 3, 2]); permute_15 = None
clone_13: f32[24, 3, 64, 512, 1], [98304, 32768, 512, 1, 1] = torch.ops.aten.clone.default(permute_19, memory_format = torch.contiguous_format); permute_19 = None
_unsafe_view_6: f32[72, 64, 512], [32768, 512, 1] = torch.ops.aten._unsafe_view.default(clone_13, [72, 64, 512]); clone_13 = None
bmm_2: f32[72, 512, 512], [262144, 512, 1] = torch.ops.aten.bmm.default(_unsafe_view_5, _unsafe_view_6); _unsafe_view_5 = _unsafe_view_6 = None
view_76: f32[24, 3, 512, 1, 512], [786432, 262144, 512, 512, 1] = torch.ops.aten.view.default(bmm_2, [24, 3, 512, 1, 512]); bmm_2 = None
permute_20: f32[24, 3, 512, 512, 1], [786432, 262144, 512, 1, 512] = torch.ops.aten.permute.default(view_76, [0, 1, 2, 4, 3]); view_76 = None
view_77: f32[24, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(permute_20, [24, 3, 512, 512]); permute_20 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd_4: f32[24, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_77, [0, 0, 0, 1], 0.0); view_77 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_78: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd_4, [24, 3, 512, 513]); constant_pad_nd_4 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty_2: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_78, [24, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_178: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_78, 0, 0, 9223372036854775807)
slice_179: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_178, 1, 0, 9223372036854775807); slice_178 = None
slice_180: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_179, 2, 0, 256); slice_179 = None
slice_181: f32[24, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_180, 3, 0, 257); slice_180 = None
slice_182: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_2, 0, 0, 9223372036854775807)
slice_183: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_182, 1, 0, -1); slice_182 = None
slice_184: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_183, 2, 0, 9223372036854775807); slice_183 = None
slice_185: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_184, 3, 256, 9223372036854775807); slice_184 = None
slice_186: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_2, 0, 0, 9223372036854775807)
slice_187: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_186, 1, 0, -1); slice_186 = None
slice_188: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_187, 2, 0, 9223372036854775807); slice_187 = None
slice_189: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_188, 3, 256, 9223372036854775807); slice_188 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_190: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_78, 0, 0, 9223372036854775807)
select_18: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_190, 1, -1); slice_190 = None
slice_191: f32[24, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_18, 1, 256, 9223372036854775807); select_18 = None
slice_192: f32[24, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_191, 2, 0, 257); slice_191 = None
slice_193: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_2, 0, 0, 9223372036854775807)
select_19: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_193, 1, -1); slice_193 = None
slice_194: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_19, 1, 0, 9223372036854775807); select_19 = None
slice_195: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_194, 2, 256, 9223372036854775807); slice_194 = None
slice_196: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_2, 0, 0, 9223372036854775807)
slice_197: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_196, 1, 0, -1)
slice_198: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_197, 2, 0, 9223372036854775807)
slice_scatter_44: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_198, slice_181, 3, 256, 9223372036854775807); slice_198 = slice_181 = None
slice_scatter_45: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_197, slice_scatter_44, 2, 0, 9223372036854775807); slice_197 = slice_scatter_44 = None
slice_scatter_46: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_196, slice_scatter_45, 1, 0, -1); slice_196 = slice_scatter_45 = None
slice_scatter_47: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty_2, slice_scatter_46, 0, 0, 9223372036854775807); slice_scatter_46 = None
slice_199: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_47, 0, 0, 9223372036854775807)
select_20: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_199, 1, -1); slice_199 = None
slice_200: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_20, 1, 0, 9223372036854775807); select_20 = None
slice_201: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_200, 2, 256, 9223372036854775807); slice_200 = None
slice_202: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_2, 0, 0, 9223372036854775807)
select_21: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_202, 1, -1); slice_202 = None
slice_203: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_21, 1, 0, 9223372036854775807); select_21 = None
slice_204: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_203, 2, 256, 9223372036854775807); slice_203 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_205: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_78, 0, 0, 9223372036854775807)
slice_206: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_205, 1, 0, 9223372036854775807); slice_205 = None
slice_207: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_206, 2, -257, -1); slice_206 = None
slice_208: f32[24, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_207, 3, 257, 9223372036854775807); slice_207 = None
slice_209: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_2, 0, 0, 9223372036854775807)
slice_210: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_209, 1, 1, 9223372036854775807); slice_209 = None
slice_211: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_210, 2, 0, 9223372036854775807); slice_210 = None
slice_212: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_211, 3, 0, 256); slice_211 = None
slice_213: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_47, 0, 0, 9223372036854775807)
select_22: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_213, 1, -1)
slice_214: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_22, 1, 0, 9223372036854775807)
slice_scatter_48: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_214, slice_192, 2, 256, 9223372036854775807); slice_214 = slice_192 = None
slice_scatter_49: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_22, slice_scatter_48, 1, 0, 9223372036854775807); select_22 = slice_scatter_48 = None
select_scatter_4: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_213, slice_scatter_49, 1, -1); slice_213 = slice_scatter_49 = None
slice_scatter_50: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_47, select_scatter_4, 0, 0, 9223372036854775807); slice_scatter_47 = select_scatter_4 = None
slice_215: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_50, 0, 0, 9223372036854775807)
slice_216: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_215, 1, 1, 9223372036854775807); slice_215 = None
slice_217: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_216, 2, 0, 9223372036854775807); slice_216 = None
slice_218: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_217, 3, 0, 256); slice_217 = None
slice_219: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_2, 0, 0, 9223372036854775807)
slice_220: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_219, 1, 1, 9223372036854775807); slice_219 = None
slice_221: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_220, 2, 0, 9223372036854775807); slice_220 = None
slice_222: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_221, 3, 0, 256); slice_221 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_223: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_78, 0, 0, 9223372036854775807); view_78 = None
select_23: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_223, 1, 0); slice_223 = None
slice_224: f32[24, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_23, 1, 0, 255); select_23 = None
slice_225: f32[24, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_224, 2, -255, 9223372036854775807); slice_224 = None
slice_226: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_2, 0, 0, 9223372036854775807)
select_24: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_226, 1, 0); slice_226 = None
slice_227: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_24, 1, 1, 256); select_24 = None
slice_228: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_227, 2, 1, 256); slice_227 = None
slice_229: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_50, 0, 0, 9223372036854775807)
slice_230: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_229, 1, 1, 9223372036854775807)
slice_231: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_230, 2, 0, 9223372036854775807)
slice_scatter_51: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_231, slice_208, 3, 0, 256); slice_231 = slice_208 = None
slice_scatter_52: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_230, slice_scatter_51, 2, 0, 9223372036854775807); slice_230 = slice_scatter_51 = None
slice_scatter_53: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_229, slice_scatter_52, 1, 1, 9223372036854775807); slice_229 = slice_scatter_52 = None
slice_scatter_54: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_50, slice_scatter_53, 0, 0, 9223372036854775807); slice_scatter_50 = slice_scatter_53 = None
slice_232: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_54, 0, 0, 9223372036854775807)
select_25: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_232, 1, 0); slice_232 = None
slice_233: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_25, 1, 1, 256); select_25 = None
slice_234: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_233, 2, 1, 256); slice_233 = None
slice_235: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_2, 0, 0, 9223372036854775807)
select_26: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_235, 1, 0); slice_235 = None
slice_236: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_26, 1, 1, 256); select_26 = None
slice_237: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_236, 2, 1, 256); slice_236 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_79: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_2, [2, 12, 1024, 513])
transpose_45: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_79, 2, 1); view_79 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_238: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_54, 0, 0, 9223372036854775807)
select_27: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_238, 1, 0)
slice_239: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_27, 1, 1, 256)
slice_scatter_55: f32[24, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_239, slice_225, 2, 1, 256); slice_239 = slice_225 = None
slice_scatter_56: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_27, slice_scatter_55, 1, 1, 256); select_27 = slice_scatter_55 = None
select_scatter_5: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_238, slice_scatter_56, 1, 0); slice_238 = slice_scatter_56 = None
slice_scatter_57: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_54, select_scatter_5, 0, 0, 9223372036854775807); slice_scatter_54 = select_scatter_5 = None
view_80: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_57, [2, 12, 1024, 513])
transpose_46: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_80, 2, 1); view_80 = None
new_ones_3: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_46, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril_2: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones_3); new_ones_3 = None
flip_4: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril_2, [0]); tril_2 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_18: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip_4, 0); flip_4 = None
slice_240: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_18, 1, 0, 9223372036854775807); unsqueeze_18 = None
unsqueeze_19: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_240, 2); slice_240 = None
slice_241: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_19, 3, 0, 9223372036854775807); unsqueeze_19 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_5: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_241, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_242: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_46, 0, 0, 9223372036854775807)
slice_243: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_242, 1, 0, 256); slice_242 = None
slice_244: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_243, 2, 0, 9223372036854775807); slice_243 = None
slice_245: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_244, 3, 0, 257); slice_244 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand_4: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_241, [2, 256, 12, 257]); slice_241 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_4: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand_4, 1); expand_4 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_81: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_57, [2, 12, 1024, 513])
transpose_47: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_81, 2, 1); view_81 = None
slice_246: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_47, 0, 0, 9223372036854775807); transpose_47 = None
slice_247: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_246, 1, 0, 256); slice_246 = None
slice_248: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_247, 2, 0, 9223372036854775807); slice_247 = None
slice_249: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_248, 3, 0, 257); slice_248 = None
masked_fill_6: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_249, eq_4, -inf); slice_249 = eq_4 = None
view_82: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_2, [2, 12, 1024, 513])
transpose_48: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_82, 2, 1); view_82 = None
slice_250: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_48, 0, 0, 9223372036854775807); transpose_48 = None
slice_251: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_250, 1, 0, 256); slice_250 = None
slice_252: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_251, 2, 0, 9223372036854775807); slice_251 = None
slice_253: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_252, 3, 0, 257); slice_252 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
view_83: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_2, [2, 12, 1024, 513])
transpose_49: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_83, 2, 1); view_83 = None
slice_254: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_46, 0, 0, 9223372036854775807); transpose_46 = None
slice_255: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_254, 1, -256, 9223372036854775807); slice_254 = None
slice_256: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_255, 2, 0, 9223372036854775807); slice_255 = None
slice_257: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_256, 3, -257, 9223372036854775807); slice_256 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_5: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_5, [2, 256, 12, 257]); flip_5 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_5: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand_5, 1); expand_5 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_84: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_57, [2, 12, 1024, 513]); slice_scatter_57 = None
transpose_50: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_84, 2, 1); view_84 = None
slice_258: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_50, 0, 0, 9223372036854775807)
slice_259: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_258, 1, 0, 256)
slice_260: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_259, 2, 0, 9223372036854775807)
slice_scatter_58: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_260, masked_fill_6, 3, 0, 257); slice_260 = masked_fill_6 = None
slice_scatter_59: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_259, slice_scatter_58, 2, 0, 9223372036854775807); slice_259 = slice_scatter_58 = None
slice_scatter_60: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_258, slice_scatter_59, 1, 0, 256); slice_258 = slice_scatter_59 = None
slice_scatter_61: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_50, slice_scatter_60, 0, 0, 9223372036854775807); transpose_50 = slice_scatter_60 = None
transpose_51: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_61, 2, 1); slice_scatter_61 = None
view_85: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_51, [24, 4, 256, 513]); transpose_51 = None
view_86: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_85, [2, 12, 1024, 513])
transpose_52: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_86, 2, 1); view_86 = None
slice_261: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_52, 0, 0, 9223372036854775807); transpose_52 = None
slice_262: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_261, 1, -256, 9223372036854775807); slice_261 = None
slice_263: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_262, 2, 0, 9223372036854775807); slice_262 = None
slice_264: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_263, 3, -257, 9223372036854775807); slice_263 = None
masked_fill_7: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_264, eq_5, -inf); slice_264 = eq_5 = None
view_87: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_2, [2, 12, 1024, 513])
transpose_53: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_87, 2, 1); view_87 = None
slice_265: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_53, 0, 0, 9223372036854775807); transpose_53 = None
slice_266: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_265, 1, -256, 9223372036854775807); slice_265 = None
slice_267: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_266, 2, 0, 9223372036854775807); slice_266 = None
slice_268: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_267, 3, -257, 9223372036854775807); slice_267 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
ne_1: b8[2, 1024], [1024, 1] = torch.ops.aten.ne.Scalar(orig_primals_194, 0)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:585, code: remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]
slice_269: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(ne_1, 0, 0, 9223372036854775807); ne_1 = None
slice_270: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_269, 1, 0, 9223372036854775807); slice_269 = None
unsqueeze_20: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_270, 2); slice_270 = None
unsqueeze_21: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_20, 3); unsqueeze_20 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:588, code: float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
_to_copy_1: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten._to_copy.default(unsqueeze_21, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))
masked_fill_8: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.masked_fill.Scalar(_to_copy_1, unsqueeze_21, -10000.0); _to_copy_1 = unsqueeze_21 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:593, code: float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
new_ones_4: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.new_ones.default(masked_fill_8, [2, 1024, 1, 1], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_54: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(new_ones_4, 1, 2); new_ones_4 = None
view_88: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_54, [2, 1024, 1]); transpose_54 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_55: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(masked_fill_8, 1, 2); masked_fill_8 = None
view_89: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_55, [2, 1024, 1]); transpose_55 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_90: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_88, [2, 2, 512, 1]); view_88 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_9: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_90, [2, 3, 512, 1], [1024, 256, 1, 1]); view_90 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_91: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_89, [2, 2, 512, 1]); view_89 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_10: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_91, [2, 3, 512, 1], [1024, 256, 1, 1]); view_91 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze_22: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_9, 4); as_strided_9 = None
permute_21: f32[2, 3, 512, 1, 1], [1024, 256, 1, 0, 1] = torch.ops.aten.permute.default(unsqueeze_22, [0, 1, 2, 4, 3]); unsqueeze_22 = None
unsqueeze_23: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_10, 4); as_strided_10 = None
permute_22: f32[2, 3, 1, 512, 1], [1024, 256, 0, 1, 1] = torch.ops.aten.permute.default(unsqueeze_23, [0, 1, 4, 2, 3]); unsqueeze_23 = None
mul_11: f32[2, 3, 512, 512, 1], [786432, 262144, 512, 1, 1] = torch.ops.aten.mul.Tensor(permute_21, permute_22); permute_21 = permute_22 = None
view_92: f32[2, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(mul_11, [2, 3, 512, 512]); mul_11 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd_5: f32[2, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_92, [0, 0, 0, 1], 0.0); view_92 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_93: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd_5, [2, 3, 512, 513]); constant_pad_nd_5 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty_3: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_93, [2, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_271: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_93, 0, 0, 9223372036854775807)
slice_272: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_271, 1, 0, 9223372036854775807); slice_271 = None
slice_273: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_272, 2, 0, 256); slice_272 = None
slice_274: f32[2, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_273, 3, 0, 257); slice_273 = None
slice_275: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_3, 0, 0, 9223372036854775807)
slice_276: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_275, 1, 0, -1); slice_275 = None
slice_277: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_276, 2, 0, 9223372036854775807); slice_276 = None
slice_278: f32[2, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_277, 3, 256, 9223372036854775807); slice_277 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_279: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_93, 0, 0, 9223372036854775807)
select_28: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_279, 1, -1); slice_279 = None
slice_280: f32[2, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_28, 1, 256, 9223372036854775807); select_28 = None
slice_281: f32[2, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_280, 2, 0, 257); slice_280 = None
slice_282: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_3, 0, 0, 9223372036854775807)
select_29: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_282, 1, -1); slice_282 = None
slice_283: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_29, 1, 0, 9223372036854775807); select_29 = None
slice_284: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_283, 2, 256, 9223372036854775807); slice_283 = None
slice_285: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_3, 0, 0, 9223372036854775807)
slice_286: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_285, 1, 0, -1)
slice_287: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_286, 2, 0, 9223372036854775807)
slice_scatter_62: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_287, slice_274, 3, 256, 9223372036854775807); slice_287 = slice_274 = None
slice_scatter_63: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_286, slice_scatter_62, 2, 0, 9223372036854775807); slice_286 = slice_scatter_62 = None
slice_scatter_64: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_285, slice_scatter_63, 1, 0, -1); slice_285 = slice_scatter_63 = None
slice_scatter_65: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty_3, slice_scatter_64, 0, 0, 9223372036854775807); slice_scatter_64 = None
slice_288: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_65, 0, 0, 9223372036854775807)
select_30: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_288, 1, -1); slice_288 = None
slice_289: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_30, 1, 0, 9223372036854775807); select_30 = None
slice_290: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_289, 2, 256, 9223372036854775807); slice_289 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_291: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_93, 0, 0, 9223372036854775807)
slice_292: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_291, 1, 0, 9223372036854775807); slice_291 = None
slice_293: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_292, 2, -257, -1); slice_292 = None
slice_294: f32[2, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_293, 3, 257, 9223372036854775807); slice_293 = None
slice_295: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_3, 0, 0, 9223372036854775807)
slice_296: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_295, 1, 1, 9223372036854775807); slice_295 = None
slice_297: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_296, 2, 0, 9223372036854775807); slice_296 = None
slice_298: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_297, 3, 0, 256); slice_297 = None
slice_299: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_65, 0, 0, 9223372036854775807)
select_31: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_299, 1, -1)
slice_300: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_31, 1, 0, 9223372036854775807)
slice_scatter_66: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_300, slice_281, 2, 256, 9223372036854775807); slice_300 = slice_281 = None
slice_scatter_67: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_31, slice_scatter_66, 1, 0, 9223372036854775807); select_31 = slice_scatter_66 = None
select_scatter_6: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_299, slice_scatter_67, 1, -1); slice_299 = slice_scatter_67 = None
slice_scatter_68: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_65, select_scatter_6, 0, 0, 9223372036854775807); slice_scatter_65 = select_scatter_6 = None
slice_301: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_68, 0, 0, 9223372036854775807)
slice_302: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_301, 1, 1, 9223372036854775807); slice_301 = None
slice_303: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_302, 2, 0, 9223372036854775807); slice_302 = None
slice_304: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_303, 3, 0, 256); slice_303 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_305: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_93, 0, 0, 9223372036854775807); view_93 = None
select_32: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_305, 1, 0); slice_305 = None
slice_306: f32[2, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_32, 1, 0, 255); select_32 = None
slice_307: f32[2, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_306, 2, -255, 9223372036854775807); slice_306 = None
slice_308: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_3, 0, 0, 9223372036854775807)
select_33: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_308, 1, 0); slice_308 = None
slice_309: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_33, 1, 1, 256); select_33 = None
slice_310: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_309, 2, 1, 256); slice_309 = None
slice_311: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_68, 0, 0, 9223372036854775807)
slice_312: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_311, 1, 1, 9223372036854775807)
slice_313: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_312, 2, 0, 9223372036854775807)
slice_scatter_69: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_313, slice_294, 3, 0, 256); slice_313 = slice_294 = None
slice_scatter_70: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_312, slice_scatter_69, 2, 0, 9223372036854775807); slice_312 = slice_scatter_69 = None
slice_scatter_71: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_311, slice_scatter_70, 1, 1, 9223372036854775807); slice_311 = slice_scatter_70 = None
slice_scatter_72: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_68, slice_scatter_71, 0, 0, 9223372036854775807); slice_scatter_68 = slice_scatter_71 = None
slice_314: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_72, 0, 0, 9223372036854775807)
select_34: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_314, 1, 0); slice_314 = None
slice_315: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_34, 1, 1, 256); select_34 = None
slice_316: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_315, 2, 1, 256); slice_315 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_94: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_3, [2, 1, 1024, 513]); new_empty_3 = None
transpose_56: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_94, 2, 1); view_94 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_317: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_72, 0, 0, 9223372036854775807)
select_35: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_317, 1, 0)
slice_318: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_35, 1, 1, 256)
slice_scatter_73: f32[2, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_318, slice_307, 2, 1, 256); slice_318 = slice_307 = None
slice_scatter_74: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_35, slice_scatter_73, 1, 1, 256); select_35 = slice_scatter_73 = None
select_scatter_7: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_317, slice_scatter_74, 1, 0); slice_317 = slice_scatter_74 = None
slice_scatter_75: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_72, select_scatter_7, 0, 0, 9223372036854775807); slice_scatter_72 = select_scatter_7 = None
view_95: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_75, [2, 1, 1024, 513])
transpose_57: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_95, 2, 1); view_95 = None
new_ones_5: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_57, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril_3: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones_5); new_ones_5 = None
flip_6: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril_3, [0]); tril_3 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_24: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip_6, 0); flip_6 = None
slice_319: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_24, 1, 0, 9223372036854775807); unsqueeze_24 = None
unsqueeze_25: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_319, 2); slice_319 = None
slice_320: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_25, 3, 0, 9223372036854775807); unsqueeze_25 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_7: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_320, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_321: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_57, 0, 0, 9223372036854775807)
slice_322: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_321, 1, 0, 256); slice_321 = None
slice_323: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_322, 2, 0, 9223372036854775807); slice_322 = None
slice_324: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_323, 3, 0, 257); slice_323 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand_6: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_320, [2, 256, 1, 257]); slice_320 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_6: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_6, 1); expand_6 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_96: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_75, [2, 1, 1024, 513])
transpose_58: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_96, 2, 1); view_96 = None
slice_325: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_58, 0, 0, 9223372036854775807); transpose_58 = None
slice_326: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_325, 1, 0, 256); slice_325 = None
slice_327: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_326, 2, 0, 9223372036854775807); slice_326 = None
slice_328: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_327, 3, 0, 257); slice_327 = None
masked_fill_9: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_328, eq_6, -inf); slice_328 = eq_6 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
slice_329: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_57, 0, 0, 9223372036854775807); transpose_57 = None
slice_330: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_329, 1, -256, 9223372036854775807); slice_329 = None
slice_331: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_330, 2, 0, 9223372036854775807); slice_330 = None
slice_332: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_331, 3, -257, 9223372036854775807); slice_331 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_7: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_7, [2, 256, 1, 257]); flip_7 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_7: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_7, 1); expand_7 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_97: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_75, [2, 1, 1024, 513]); slice_scatter_75 = None
transpose_59: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_97, 2, 1); view_97 = None
slice_333: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_59, 0, 0, 9223372036854775807)
slice_334: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_333, 1, 0, 256)
slice_335: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_334, 2, 0, 9223372036854775807)
slice_scatter_76: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_335, masked_fill_9, 3, 0, 257); slice_335 = masked_fill_9 = None
slice_scatter_77: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_334, slice_scatter_76, 2, 0, 9223372036854775807); slice_334 = slice_scatter_76 = None
slice_scatter_78: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_333, slice_scatter_77, 1, 0, 256); slice_333 = slice_scatter_77 = None
slice_scatter_79: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_59, slice_scatter_78, 0, 0, 9223372036854775807); transpose_59 = slice_scatter_78 = None
transpose_60: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_79, 2, 1); slice_scatter_79 = None
view_98: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_60, [2, 4, 256, 513]); transpose_60 = None
view_99: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_98, [2, 1, 1024, 513])
transpose_61: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_99, 2, 1); view_99 = None
slice_336: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_61, 0, 0, 9223372036854775807); transpose_61 = None
slice_337: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_336, 1, -256, 9223372036854775807); slice_336 = None
slice_338: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_337, 2, 0, 9223372036854775807); slice_337 = None
slice_339: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_338, 3, -257, 9223372036854775807); slice_338 = None
masked_fill_10: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_339, eq_7, -inf); slice_339 = eq_7 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:597, code: attn_scores += diagonal_mask
view_100: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_2, [2, 12, 1024, 513])
transpose_62: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_100, 2, 1); view_100 = None
view_101: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_85, [2, 12, 1024, 513]); view_85 = None
transpose_63: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_101, 2, 1); view_101 = None
slice_340: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_63, 0, 0, 9223372036854775807)
slice_341: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_340, 1, -256, 9223372036854775807)
slice_342: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_341, 2, 0, 9223372036854775807)
slice_scatter_80: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_342, masked_fill_7, 3, -257, 9223372036854775807); slice_342 = masked_fill_7 = None
slice_scatter_81: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_341, slice_scatter_80, 2, 0, 9223372036854775807); slice_341 = slice_scatter_80 = None
slice_scatter_82: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_340, slice_scatter_81, 1, -256, 9223372036854775807); slice_340 = slice_scatter_81 = None
slice_scatter_83: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_63, slice_scatter_82, 0, 0, 9223372036854775807); transpose_63 = slice_scatter_82 = None
transpose_64: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_83, 2, 1); slice_scatter_83 = None
view_102: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_64, [24, 4, 256, 513]); transpose_64 = None
view_103: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_102, [2, 12, 1024, 513]); view_102 = None
transpose_65: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_103, 2, 1); view_103 = None
view_104: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_98, [2, 1, 1024, 513]); view_98 = None
transpose_66: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_104, 2, 1); view_104 = None
slice_343: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_66, 0, 0, 9223372036854775807)
slice_344: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_343, 1, -256, 9223372036854775807)
slice_345: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_344, 2, 0, 9223372036854775807)
slice_scatter_84: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_345, masked_fill_10, 3, -257, 9223372036854775807); slice_345 = masked_fill_10 = None
slice_scatter_85: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_344, slice_scatter_84, 2, 0, 9223372036854775807); slice_344 = slice_scatter_84 = None
slice_scatter_86: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_343, slice_scatter_85, 1, -256, 9223372036854775807); slice_343 = slice_scatter_85 = None
slice_scatter_87: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_66, slice_scatter_86, 0, 0, 9223372036854775807); transpose_66 = slice_scatter_86 = None
transpose_67: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_87, 2, 1); slice_scatter_87 = None
view_105: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_67, [2, 4, 256, 513]); transpose_67 = None
view_106: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_105, [2, 1, 1024, 513]); view_105 = None
transpose_68: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_106, 2, 1); view_106 = None
add_10: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.add.Tensor(transpose_65, transpose_68); transpose_65 = transpose_68 = None
view_107: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_2, [2, 12, 1024, 513]); new_empty_2 = None
transpose_69: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_107, 2, 1); view_107 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:635, code: attn_probs = nn.functional.softmax(
_softmax_1: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten._softmax.default(add_10, -1, False); add_10 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:646, code: attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
slice_346: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(orig_primals_195, 0, 0, 9223372036854775807)
slice_347: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_346, 1, 0, 9223372036854775807); slice_346 = None
unsqueeze_26: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_347, 2); slice_347 = None
unsqueeze_27: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_26, 3); unsqueeze_26 = None
masked_fill_11: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten.masked_fill.Scalar(_softmax_1, unsqueeze_27, 0.0); _softmax_1 = unsqueeze_27 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:655, code: value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_108: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_9, [1024, 2, 12, 64]); add_9 = None
transpose_70: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_108, 0, 1); view_108 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:883, code: chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
transpose_71: f32[2, 12, 1024, 513], [6303744, 513, 6156, 1] = torch.ops.aten.transpose.int(masked_fill_11, 1, 2); masked_fill_11 = None
clone_14: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.clone.default(transpose_71, memory_format = torch.contiguous_format); transpose_71 = None
_unsafe_view_7: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten._unsafe_view.default(clone_14, [24, 4, 256, 513]); clone_14 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:888, code: value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_72: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_70, 1, 2); transpose_70 = None
view_109: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_72, [24, 1024, 64]); transpose_72 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:891, code: padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1)
constant_pad_nd_6: f32[24, 1536, 64], [98304, 64, 1] = torch.ops.aten.constant_pad_nd.default(view_109, [0, 0, 256, 256], -1.0); view_109 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:902, code: chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride)
as_strided_11: f32[24, 4, 768, 64], [98304, 16384, 64, 1] = torch.ops.aten.as_strided.default(constant_pad_nd_6, [24, 4, 768, 64], [98304, 16384, 64, 1]); constant_pad_nd_6 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:755, code: chunked_hidden_states = nn.functional.pad(
constant_pad_nd_7: f32[24, 4, 256, 770], [788480, 197120, 770, 1] = torch.ops.aten.constant_pad_nd.default(_unsafe_view_7, [0, 257], 0.0); _unsafe_view_7 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:758, code: chunked_hidden_states = chunked_hidden_states.view(
view_110: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.view.default(constant_pad_nd_7, [24, 4, -1]); constant_pad_nd_7 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:761, code: chunked_hidden_states = chunked_hidden_states[
slice_348: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(view_110, 0, 0, 9223372036854775807); view_110 = None
slice_349: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_348, 1, 0, 9223372036854775807); slice_348 = None
slice_350: f32[24, 4, 196864], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_349, 2, 0, -256); slice_349 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:764, code: chunked_hidden_states = chunked_hidden_states.view(
view_111: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.view.default(slice_350, [24, 4, 256, 769]); slice_350 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:767, code: chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
slice_351: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(view_111, 0, 0, 9223372036854775807); view_111 = None
slice_352: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_351, 1, 0, 9223372036854775807)
slice_353: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_352, 2, 0, 9223372036854775807); slice_352 = None
slice_354: f32[24, 4, 256, 768], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_353, 3, 0, -1); slice_353 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:906, code: context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
unsqueeze_28: f32[24, 4, 256, 768, 1], [788480, 197120, 769, 1, 0] = torch.ops.aten.unsqueeze.default(slice_354, 4)
permute_23: f32[24, 4, 256, 1, 768], [788480, 197120, 769, 0, 1] = torch.ops.aten.permute.default(unsqueeze_28, [0, 1, 2, 4, 3]); unsqueeze_28 = None
unsqueeze_29: f32[24, 4, 768, 64, 1], [98304, 16384, 64, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_11, 4); as_strided_11 = None
permute_24: f32[24, 4, 1, 64, 768], [98304, 16384, 0, 1, 64] = torch.ops.aten.permute.default(unsqueeze_29, [0, 1, 4, 3, 2]); unsqueeze_29 = None
permute_25: f32[24, 4, 256, 768, 1], [788480, 197120, 769, 1, 0] = torch.ops.aten.permute.default(permute_23, [0, 1, 2, 4, 3]); permute_23 = None
sym_size_11: Sym(24) = torch.ops.aten.sym_size(slice_351, 0); slice_351 = None
# No stacktrace found for following nodes
mul_12: Sym(96) = sym_size_11 * 4
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:906, code: context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
sym_size_12: Sym(768) = torch.ops.aten.sym_size(slice_354, 3); slice_354 = None
view_112: f32[96, 256, 768], [197120, 769, 1] = torch.ops.aten.view.default(permute_25, [mul_12, 256, sym_size_12]); permute_25 = None
permute_26: f32[24, 4, 768, 64, 1], [98304, 16384, 64, 1, 0] = torch.ops.aten.permute.default(permute_24, [0, 1, 4, 3, 2]); permute_24 = None
clone_15: f32[24, 4, 768, 64, 1], [196608, 49152, 64, 1, 1] = torch.ops.aten.clone.default(permute_26, memory_format = torch.contiguous_format); permute_26 = None
_unsafe_view_8: f32[96, 768, 64], [49152, 64, 1] = torch.ops.aten._unsafe_view.default(clone_15, [mul_12, sym_size_12, 64]); clone_15 = mul_12 = sym_size_12 = None
bmm_3: f32[96, 256, 64], [16384, 64, 1] = torch.ops.aten.bmm.default(view_112, _unsafe_view_8); view_112 = _unsafe_view_8 = None
view_113: f32[24, 4, 256, 1, 64], [65536, 16384, 64, 64, 1] = torch.ops.aten.view.default(bmm_3, [sym_size_11, 4, 256, 1, 64]); bmm_3 = None
permute_27: f32[24, 4, 256, 64, 1], [65536, 16384, 64, 1, 64] = torch.ops.aten.permute.default(view_113, [0, 1, 2, 4, 3])
sym_size_13: Sym(4) = torch.ops.aten.sym_size(view_113, 1); view_113 = None
view_114: f32[24, 4, 256, 64], [65536, 16384, 64, 1] = torch.ops.aten.view.default(permute_27, [sym_size_11, sym_size_13, 256, 64]); permute_27 = sym_size_11 = sym_size_13 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:907, code: return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
view_115: f32[2, 12, 1024, 64], [786432, 65536, 64, 1] = torch.ops.aten.view.default(view_114, [2, 12, 1024, 64]); view_114 = None
transpose_73: f32[2, 1024, 12, 64], [786432, 64, 65536, 1] = torch.ops.aten.transpose.int(view_115, 1, 2)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:674, code: attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
transpose_74: f32[1024, 2, 12, 64], [64, 786432, 65536, 1] = torch.ops.aten.transpose.int(transpose_73, 0, 1); transpose_73 = None
clone_16: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.clone.default(transpose_74, memory_format = torch.contiguous_format); transpose_74 = None
_unsafe_view_9: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten._unsafe_view.default(clone_16, [1024, 2, 768]); clone_16 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:703, code: outputs = (attn_output.transpose(0, 1),)
transpose_75: f32[2, 1024, 768], [768, 1536, 1] = torch.ops.aten.transpose.int(_unsafe_view_9, 0, 1)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1116, code: hidden_states = self.dense(hidden_states)
t_9: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_23); orig_primals_23 = None
clone_17: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.clone.default(transpose_75, memory_format = torch.contiguous_format); transpose_75 = None
sym_size_14: Sym(1024) = torch.ops.aten.sym_size(view_115, 2); view_115 = None
# No stacktrace found for following nodes
mul_13: Sym(2048) = 2 * sym_size_14
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1116, code: hidden_states = self.dense(hidden_states)
sym_size_15: Sym(768) = torch.ops.aten.sym_size(_unsafe_view_9, 2); _unsafe_view_9 = None
view_116: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_17, [mul_13, sym_size_15]); clone_17 = mul_13 = sym_size_15 = None
mm_7: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_116, t_9); view_116 = t_9 = None
view_117: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.view.default(mm_7, [2, sym_size_14, 768]); mm_7 = sym_size_14 = None
add_11: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(view_117, orig_primals_24); orig_primals_24 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1118, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
add_12: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(add_11, getitem_3); add_11 = getitem_3 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1118, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
native_layer_norm_2 = torch.ops.aten.native_layer_norm.default(add_12, [768], orig_primals_25, orig_primals_26, 1e-05); add_12 = orig_primals_25 = orig_primals_26 = None
getitem_6: f32[2, 1024, 768], [786432, 768, 1] = native_layer_norm_2[0]
getitem_7: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_2[1]
getitem_8: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_2[2]; native_layer_norm_2 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1182, code: hidden_states = self.dense(hidden_states)
t_10: f32[768, 3072], [1, 768] = torch.ops.aten.t.default(orig_primals_27); orig_primals_27 = None
sym_size_16: Sym(1024) = torch.ops.aten.sym_size(view_117, 1); view_117 = None
# No stacktrace found for following nodes
mul_14: Sym(2048) = 2 * sym_size_16
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1182, code: hidden_states = self.dense(hidden_states)
view_118: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(getitem_6, [mul_14, 768]); mul_14 = None
addmm_2: f32[2048, 3072], [3072, 1] = torch.ops.aten.addmm.default(orig_primals_28, view_118, t_10); orig_primals_28 = view_118 = t_10 = None
view_119: f32[2, 1024, 3072], [3145728, 3072, 1] = torch.ops.aten.view.default(addmm_2, [2, sym_size_16, 3072]); addmm_2 = sym_size_16 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/activations.py:56, code: return self.act(input)
gelu_1: f32[2, 1024, 3072], [3145728, 3072, 1] = torch.ops.aten.gelu.default(view_119)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1196, code: hidden_states = self.dense(hidden_states)
t_11: f32[3072, 768], [1, 3072] = torch.ops.aten.t.default(orig_primals_29); orig_primals_29 = None
sym_size_17: Sym(1024) = torch.ops.aten.sym_size(view_119, 1); view_119 = None
# No stacktrace found for following nodes
mul_15: Sym(2048) = 2 * sym_size_17
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1196, code: hidden_states = self.dense(hidden_states)
view_120: f32[2048, 3072], [3072, 1] = torch.ops.aten.view.default(gelu_1, [mul_15, 3072]); gelu_1 = mul_15 = None
addmm_3: f32[2048, 768], [768, 1] = torch.ops.aten.addmm.default(orig_primals_30, view_120, t_11); orig_primals_30 = view_120 = t_11 = None
view_121: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.view.default(addmm_3, [2, sym_size_17, 768]); addmm_3 = sym_size_17 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1198, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
add_13: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(view_121, getitem_6); getitem_6 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1198, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
native_layer_norm_3 = torch.ops.aten.native_layer_norm.default(add_13, [768], orig_primals_31, orig_primals_32, 1e-05); add_13 = orig_primals_31 = orig_primals_32 = None
getitem_9: f32[2, 1024, 768], [786432, 768, 1] = native_layer_norm_3[0]
getitem_10: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_3[1]
getitem_11: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_3[2]; native_layer_norm_3 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:562, code: hidden_states = hidden_states.transpose(0, 1)
transpose_76: f32[1024, 2, 768], [768, 786432, 1] = torch.ops.aten.transpose.int(getitem_9, 0, 1)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
t_12: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_33); orig_primals_33 = None
clone_18: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_76, memory_format = torch.contiguous_format)
sym_size_18: Sym(1024) = torch.ops.aten.sym_size(view_121, 1); view_121 = None
# No stacktrace found for following nodes
mul_16: Sym(2048) = sym_size_18 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
view_122: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_18, [mul_16, 768]); clone_18 = mul_16 = None
mm_8: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_122, t_12); view_122 = t_12 = None
view_123: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_8, [sym_size_18, 2, 768]); mm_8 = None
add_14: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_123, orig_primals_34); view_123 = orig_primals_34 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
t_13: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_35); orig_primals_35 = None
clone_19: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_76, memory_format = torch.contiguous_format)
# No stacktrace found for following nodes
mul_17: Sym(2048) = sym_size_18 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
view_124: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_19, [mul_17, 768]); clone_19 = mul_17 = None
mm_9: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_124, t_13); view_124 = t_13 = None
view_125: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_9, [sym_size_18, 2, 768]); mm_9 = None
add_15: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_125, orig_primals_36); view_125 = orig_primals_36 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
t_14: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_37); orig_primals_37 = None
clone_20: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_76, memory_format = torch.contiguous_format); transpose_76 = None
# No stacktrace found for following nodes
mul_18: Sym(2048) = sym_size_18 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
view_126: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_20, [mul_18, 768]); clone_20 = mul_18 = None
mm_10: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_126, t_14); view_126 = t_14 = None
view_127: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_10, [sym_size_18, 2, 768]); mm_10 = sym_size_18 = None
add_16: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_127, orig_primals_38); view_127 = orig_primals_38 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:575, code: query_vectors /= math.sqrt(self.head_dim)
div_2: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.div.Tensor(add_14, 8.0); add_14 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:577, code: query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_128: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div_2, [1024, 2, 12, 64])
transpose_77: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_128, 0, 1); view_128 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:578, code: key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_129: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_15, [1024, 2, 12, 64]); add_15 = None
transpose_78: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_129, 0, 1); view_129 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_79: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_77, 1, 2); transpose_77 = None
view_130: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_79, [24, 1024, 64]); transpose_79 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_80: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_78, 1, 2); transpose_78 = None
view_131: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_80, [24, 1024, 64]); transpose_80 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_132: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_130, [24, 2, 512, 64]); view_130 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_12: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_132, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_132 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_133: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_131, [24, 2, 512, 64]); view_131 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_13: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_133, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_133 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze_30: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_12, 4); as_strided_12 = None
permute_28: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze_30, [0, 1, 2, 4, 3]); unsqueeze_30 = None
unsqueeze_31: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_13, 4); as_strided_13 = None
permute_29: f32[24, 3, 1, 512, 64], [64, 393216, 0, 1536, 1] = torch.ops.aten.permute.default(unsqueeze_31, [0, 1, 4, 2, 3]); unsqueeze_31 = None
permute_30: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute_28, [0, 1, 2, 4, 3]); permute_28 = None
view_134: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div_2, [1024, 2, 12, 64]); div_2 = None
transpose_81: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_134, 0, 1); view_134 = None
transpose_82: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_81, 1, 2); transpose_81 = None
view_135: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_82, [24, 1024, 64]); transpose_82 = None
view_136: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_135, [24, 2, 512, 64]); view_135 = None
as_strided_14: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_136, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_136 = None
unsqueeze_32: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_14, 4); as_strided_14 = None
permute_31: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze_32, [0, 1, 2, 4, 3]); unsqueeze_32 = None
permute_32: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute_31, [0, 1, 2, 4, 3]); permute_31 = None
clone_21: f32[24, 3, 512, 64, 1], [98304, 32768, 64, 1, 1] = torch.ops.aten.clone.default(permute_32, memory_format = torch.contiguous_format); permute_32 = None
_unsafe_view_10: f32[72, 512, 64], [32768, 64, 1] = torch.ops.aten._unsafe_view.default(clone_21, [72, 512, 64]); clone_21 = None
permute_33: f32[24, 3, 64, 512, 1], [64, 393216, 1, 1536, 0] = torch.ops.aten.permute.default(permute_29, [0, 1, 4, 3, 2]); permute_29 = None
clone_22: f32[24, 3, 64, 512, 1], [98304, 32768, 512, 1, 1] = torch.ops.aten.clone.default(permute_33, memory_format = torch.contiguous_format); permute_33 = None
_unsafe_view_11: f32[72, 64, 512], [32768, 512, 1] = torch.ops.aten._unsafe_view.default(clone_22, [72, 64, 512]); clone_22 = None
bmm_4: f32[72, 512, 512], [262144, 512, 1] = torch.ops.aten.bmm.default(_unsafe_view_10, _unsafe_view_11); _unsafe_view_10 = _unsafe_view_11 = None
view_137: f32[24, 3, 512, 1, 512], [786432, 262144, 512, 512, 1] = torch.ops.aten.view.default(bmm_4, [24, 3, 512, 1, 512]); bmm_4 = None
permute_34: f32[24, 3, 512, 512, 1], [786432, 262144, 512, 1, 512] = torch.ops.aten.permute.default(view_137, [0, 1, 2, 4, 3]); view_137 = None
view_138: f32[24, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(permute_34, [24, 3, 512, 512]); permute_34 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd_8: f32[24, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_138, [0, 0, 0, 1], 0.0); view_138 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_139: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd_8, [24, 3, 512, 513]); constant_pad_nd_8 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty_4: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_139, [24, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_355: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_139, 0, 0, 9223372036854775807)
slice_356: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_355, 1, 0, 9223372036854775807); slice_355 = None
slice_357: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_356, 2, 0, 256); slice_356 = None
slice_358: f32[24, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_357, 3, 0, 257); slice_357 = None
slice_359: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_4, 0, 0, 9223372036854775807)
slice_360: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_359, 1, 0, -1); slice_359 = None
slice_361: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_360, 2, 0, 9223372036854775807); slice_360 = None
slice_362: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_361, 3, 256, 9223372036854775807); slice_361 = None
slice_363: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_4, 0, 0, 9223372036854775807)
slice_364: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_363, 1, 0, -1); slice_363 = None
slice_365: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_364, 2, 0, 9223372036854775807); slice_364 = None
slice_366: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_365, 3, 256, 9223372036854775807); slice_365 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_367: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_139, 0, 0, 9223372036854775807)
select_36: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_367, 1, -1); slice_367 = None
slice_368: f32[24, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_36, 1, 256, 9223372036854775807); select_36 = None
slice_369: f32[24, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_368, 2, 0, 257); slice_368 = None
slice_370: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_4, 0, 0, 9223372036854775807)
select_37: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_370, 1, -1); slice_370 = None
slice_371: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_37, 1, 0, 9223372036854775807); select_37 = None
slice_372: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_371, 2, 256, 9223372036854775807); slice_371 = None
slice_373: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_4, 0, 0, 9223372036854775807)
slice_374: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_373, 1, 0, -1)
slice_375: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_374, 2, 0, 9223372036854775807)
slice_scatter_88: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_375, slice_358, 3, 256, 9223372036854775807); slice_375 = slice_358 = None
slice_scatter_89: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_374, slice_scatter_88, 2, 0, 9223372036854775807); slice_374 = slice_scatter_88 = None
slice_scatter_90: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_373, slice_scatter_89, 1, 0, -1); slice_373 = slice_scatter_89 = None
slice_scatter_91: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty_4, slice_scatter_90, 0, 0, 9223372036854775807); slice_scatter_90 = None
slice_376: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_91, 0, 0, 9223372036854775807)
select_38: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_376, 1, -1); slice_376 = None
slice_377: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_38, 1, 0, 9223372036854775807); select_38 = None
slice_378: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_377, 2, 256, 9223372036854775807); slice_377 = None
slice_379: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_4, 0, 0, 9223372036854775807)
select_39: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_379, 1, -1); slice_379 = None
slice_380: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_39, 1, 0, 9223372036854775807); select_39 = None
slice_381: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_380, 2, 256, 9223372036854775807); slice_380 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_382: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_139, 0, 0, 9223372036854775807)
slice_383: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_382, 1, 0, 9223372036854775807); slice_382 = None
slice_384: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_383, 2, -257, -1); slice_383 = None
slice_385: f32[24, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_384, 3, 257, 9223372036854775807); slice_384 = None
slice_386: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_4, 0, 0, 9223372036854775807)
slice_387: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_386, 1, 1, 9223372036854775807); slice_386 = None
slice_388: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_387, 2, 0, 9223372036854775807); slice_387 = None
slice_389: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_388, 3, 0, 256); slice_388 = None
slice_390: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_91, 0, 0, 9223372036854775807)
select_40: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_390, 1, -1)
slice_391: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_40, 1, 0, 9223372036854775807)
slice_scatter_92: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_391, slice_369, 2, 256, 9223372036854775807); slice_391 = slice_369 = None
slice_scatter_93: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_40, slice_scatter_92, 1, 0, 9223372036854775807); select_40 = slice_scatter_92 = None
select_scatter_8: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_390, slice_scatter_93, 1, -1); slice_390 = slice_scatter_93 = None
slice_scatter_94: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_91, select_scatter_8, 0, 0, 9223372036854775807); slice_scatter_91 = select_scatter_8 = None
slice_392: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_94, 0, 0, 9223372036854775807)
slice_393: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_392, 1, 1, 9223372036854775807); slice_392 = None
slice_394: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_393, 2, 0, 9223372036854775807); slice_393 = None
slice_395: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_394, 3, 0, 256); slice_394 = None
slice_396: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_4, 0, 0, 9223372036854775807)
slice_397: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_396, 1, 1, 9223372036854775807); slice_396 = None
slice_398: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_397, 2, 0, 9223372036854775807); slice_397 = None
slice_399: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_398, 3, 0, 256); slice_398 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_400: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_139, 0, 0, 9223372036854775807); view_139 = None
select_41: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_400, 1, 0); slice_400 = None
slice_401: f32[24, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_41, 1, 0, 255); select_41 = None
slice_402: f32[24, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_401, 2, -255, 9223372036854775807); slice_401 = None
slice_403: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_4, 0, 0, 9223372036854775807)
select_42: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_403, 1, 0); slice_403 = None
slice_404: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_42, 1, 1, 256); select_42 = None
slice_405: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_404, 2, 1, 256); slice_404 = None
slice_406: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_94, 0, 0, 9223372036854775807)
slice_407: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_406, 1, 1, 9223372036854775807)
slice_408: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_407, 2, 0, 9223372036854775807)
slice_scatter_95: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_408, slice_385, 3, 0, 256); slice_408 = slice_385 = None
slice_scatter_96: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_407, slice_scatter_95, 2, 0, 9223372036854775807); slice_407 = slice_scatter_95 = None
slice_scatter_97: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_406, slice_scatter_96, 1, 1, 9223372036854775807); slice_406 = slice_scatter_96 = None
slice_scatter_98: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_94, slice_scatter_97, 0, 0, 9223372036854775807); slice_scatter_94 = slice_scatter_97 = None
slice_409: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_98, 0, 0, 9223372036854775807)
select_43: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_409, 1, 0); slice_409 = None
slice_410: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_43, 1, 1, 256); select_43 = None
slice_411: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_410, 2, 1, 256); slice_410 = None
slice_412: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_4, 0, 0, 9223372036854775807)
select_44: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_412, 1, 0); slice_412 = None
slice_413: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_44, 1, 1, 256); select_44 = None
slice_414: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_413, 2, 1, 256); slice_413 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_140: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_4, [2, 12, 1024, 513])
transpose_83: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_140, 2, 1); view_140 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_415: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_98, 0, 0, 9223372036854775807)
select_45: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_415, 1, 0)
slice_416: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_45, 1, 1, 256)
slice_scatter_99: f32[24, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_416, slice_402, 2, 1, 256); slice_416 = slice_402 = None
slice_scatter_100: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_45, slice_scatter_99, 1, 1, 256); select_45 = slice_scatter_99 = None
select_scatter_9: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_415, slice_scatter_100, 1, 0); slice_415 = slice_scatter_100 = None
slice_scatter_101: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_98, select_scatter_9, 0, 0, 9223372036854775807); slice_scatter_98 = select_scatter_9 = None
view_141: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_101, [2, 12, 1024, 513])
transpose_84: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_141, 2, 1); view_141 = None
new_ones_6: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_84, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril_4: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones_6); new_ones_6 = None
flip_8: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril_4, [0]); tril_4 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_33: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip_8, 0); flip_8 = None
slice_417: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_33, 1, 0, 9223372036854775807); unsqueeze_33 = None
unsqueeze_34: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_417, 2); slice_417 = None
slice_418: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_34, 3, 0, 9223372036854775807); unsqueeze_34 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_9: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_418, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_419: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_84, 0, 0, 9223372036854775807)
slice_420: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_419, 1, 0, 256); slice_419 = None
slice_421: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_420, 2, 0, 9223372036854775807); slice_420 = None
slice_422: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_421, 3, 0, 257); slice_421 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand_8: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_418, [2, 256, 12, 257]); slice_418 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_8: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand_8, 1); expand_8 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_142: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_101, [2, 12, 1024, 513])
transpose_85: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_142, 2, 1); view_142 = None
slice_423: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_85, 0, 0, 9223372036854775807); transpose_85 = None
slice_424: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_423, 1, 0, 256); slice_423 = None
slice_425: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_424, 2, 0, 9223372036854775807); slice_424 = None
slice_426: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_425, 3, 0, 257); slice_425 = None
masked_fill_12: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_426, eq_8, -inf); slice_426 = eq_8 = None
view_143: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_4, [2, 12, 1024, 513])
transpose_86: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_143, 2, 1); view_143 = None
slice_427: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_86, 0, 0, 9223372036854775807); transpose_86 = None
slice_428: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_427, 1, 0, 256); slice_427 = None
slice_429: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_428, 2, 0, 9223372036854775807); slice_428 = None
slice_430: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_429, 3, 0, 257); slice_429 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
view_144: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_4, [2, 12, 1024, 513])
transpose_87: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_144, 2, 1); view_144 = None
slice_431: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_84, 0, 0, 9223372036854775807); transpose_84 = None
slice_432: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_431, 1, -256, 9223372036854775807); slice_431 = None
slice_433: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_432, 2, 0, 9223372036854775807); slice_432 = None
slice_434: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_433, 3, -257, 9223372036854775807); slice_433 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_9: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_9, [2, 256, 12, 257]); flip_9 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_9: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand_9, 1); expand_9 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_145: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_101, [2, 12, 1024, 513]); slice_scatter_101 = None
transpose_88: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_145, 2, 1); view_145 = None
slice_435: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_88, 0, 0, 9223372036854775807)
slice_436: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_435, 1, 0, 256)
slice_437: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_436, 2, 0, 9223372036854775807)
slice_scatter_102: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_437, masked_fill_12, 3, 0, 257); slice_437 = masked_fill_12 = None
slice_scatter_103: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_436, slice_scatter_102, 2, 0, 9223372036854775807); slice_436 = slice_scatter_102 = None
slice_scatter_104: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_435, slice_scatter_103, 1, 0, 256); slice_435 = slice_scatter_103 = None
slice_scatter_105: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_88, slice_scatter_104, 0, 0, 9223372036854775807); transpose_88 = slice_scatter_104 = None
transpose_89: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_105, 2, 1); slice_scatter_105 = None
view_146: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_89, [24, 4, 256, 513]); transpose_89 = None
view_147: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_146, [2, 12, 1024, 513])
transpose_90: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_147, 2, 1); view_147 = None
slice_438: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_90, 0, 0, 9223372036854775807); transpose_90 = None
slice_439: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_438, 1, -256, 9223372036854775807); slice_438 = None
slice_440: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_439, 2, 0, 9223372036854775807); slice_439 = None
slice_441: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_440, 3, -257, 9223372036854775807); slice_440 = None
masked_fill_13: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_441, eq_9, -inf); slice_441 = eq_9 = None
view_148: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_4, [2, 12, 1024, 513])
transpose_91: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_148, 2, 1); view_148 = None
slice_442: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_91, 0, 0, 9223372036854775807); transpose_91 = None
slice_443: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_442, 1, -256, 9223372036854775807); slice_442 = None
slice_444: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_443, 2, 0, 9223372036854775807); slice_443 = None
slice_445: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_444, 3, -257, 9223372036854775807); slice_444 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
ne_2: b8[2, 1024], [1024, 1] = torch.ops.aten.ne.Scalar(orig_primals_194, 0)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:585, code: remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]
slice_446: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(ne_2, 0, 0, 9223372036854775807); ne_2 = None
slice_447: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_446, 1, 0, 9223372036854775807); slice_446 = None
unsqueeze_35: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_447, 2); slice_447 = None
unsqueeze_36: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_35, 3); unsqueeze_35 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:588, code: float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
_to_copy_2: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten._to_copy.default(unsqueeze_36, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))
masked_fill_14: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.masked_fill.Scalar(_to_copy_2, unsqueeze_36, -10000.0); _to_copy_2 = unsqueeze_36 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:593, code: float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
new_ones_7: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.new_ones.default(masked_fill_14, [2, 1024, 1, 1], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_92: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(new_ones_7, 1, 2); new_ones_7 = None
view_149: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_92, [2, 1024, 1]); transpose_92 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_93: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(masked_fill_14, 1, 2); masked_fill_14 = None
view_150: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_93, [2, 1024, 1]); transpose_93 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_151: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_149, [2, 2, 512, 1]); view_149 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_15: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_151, [2, 3, 512, 1], [1024, 256, 1, 1]); view_151 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_152: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_150, [2, 2, 512, 1]); view_150 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_16: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_152, [2, 3, 512, 1], [1024, 256, 1, 1]); view_152 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze_37: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_15, 4); as_strided_15 = None
permute_35: f32[2, 3, 512, 1, 1], [1024, 256, 1, 0, 1] = torch.ops.aten.permute.default(unsqueeze_37, [0, 1, 2, 4, 3]); unsqueeze_37 = None
unsqueeze_38: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_16, 4); as_strided_16 = None
permute_36: f32[2, 3, 1, 512, 1], [1024, 256, 0, 1, 1] = torch.ops.aten.permute.default(unsqueeze_38, [0, 1, 4, 2, 3]); unsqueeze_38 = None
mul_19: f32[2, 3, 512, 512, 1], [786432, 262144, 512, 1, 1] = torch.ops.aten.mul.Tensor(permute_35, permute_36); permute_35 = permute_36 = None
view_153: f32[2, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(mul_19, [2, 3, 512, 512]); mul_19 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd_9: f32[2, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_153, [0, 0, 0, 1], 0.0); view_153 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_154: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd_9, [2, 3, 512, 513]); constant_pad_nd_9 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty_5: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_154, [2, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_448: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_154, 0, 0, 9223372036854775807)
slice_449: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_448, 1, 0, 9223372036854775807); slice_448 = None
slice_450: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_449, 2, 0, 256); slice_449 = None
slice_451: f32[2, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_450, 3, 0, 257); slice_450 = None
slice_452: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_5, 0, 0, 9223372036854775807)
slice_453: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_452, 1, 0, -1); slice_452 = None
slice_454: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_453, 2, 0, 9223372036854775807); slice_453 = None
slice_455: f32[2, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_454, 3, 256, 9223372036854775807); slice_454 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_456: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_154, 0, 0, 9223372036854775807)
select_46: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_456, 1, -1); slice_456 = None
slice_457: f32[2, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_46, 1, 256, 9223372036854775807); select_46 = None
slice_458: f32[2, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_457, 2, 0, 257); slice_457 = None
slice_459: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_5, 0, 0, 9223372036854775807)
select_47: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_459, 1, -1); slice_459 = None
slice_460: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_47, 1, 0, 9223372036854775807); select_47 = None
slice_461: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_460, 2, 256, 9223372036854775807); slice_460 = None
slice_462: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_5, 0, 0, 9223372036854775807)
slice_463: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_462, 1, 0, -1)
slice_464: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_463, 2, 0, 9223372036854775807)
slice_scatter_106: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_464, slice_451, 3, 256, 9223372036854775807); slice_464 = slice_451 = None
slice_scatter_107: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_463, slice_scatter_106, 2, 0, 9223372036854775807); slice_463 = slice_scatter_106 = None
slice_scatter_108: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_462, slice_scatter_107, 1, 0, -1); slice_462 = slice_scatter_107 = None
slice_scatter_109: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty_5, slice_scatter_108, 0, 0, 9223372036854775807); slice_scatter_108 = None
slice_465: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_109, 0, 0, 9223372036854775807)
select_48: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_465, 1, -1); slice_465 = None
slice_466: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_48, 1, 0, 9223372036854775807); select_48 = None
slice_467: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_466, 2, 256, 9223372036854775807); slice_466 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_468: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_154, 0, 0, 9223372036854775807)
slice_469: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_468, 1, 0, 9223372036854775807); slice_468 = None
slice_470: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_469, 2, -257, -1); slice_469 = None
slice_471: f32[2, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_470, 3, 257, 9223372036854775807); slice_470 = None
slice_472: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_5, 0, 0, 9223372036854775807)
slice_473: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_472, 1, 1, 9223372036854775807); slice_472 = None
slice_474: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_473, 2, 0, 9223372036854775807); slice_473 = None
slice_475: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_474, 3, 0, 256); slice_474 = None
slice_476: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_109, 0, 0, 9223372036854775807)
select_49: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_476, 1, -1)
slice_477: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_49, 1, 0, 9223372036854775807)
slice_scatter_110: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_477, slice_458, 2, 256, 9223372036854775807); slice_477 = slice_458 = None
slice_scatter_111: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_49, slice_scatter_110, 1, 0, 9223372036854775807); select_49 = slice_scatter_110 = None
select_scatter_10: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_476, slice_scatter_111, 1, -1); slice_476 = slice_scatter_111 = None
slice_scatter_112: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_109, select_scatter_10, 0, 0, 9223372036854775807); slice_scatter_109 = select_scatter_10 = None
slice_478: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_112, 0, 0, 9223372036854775807)
slice_479: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_478, 1, 1, 9223372036854775807); slice_478 = None
slice_480: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_479, 2, 0, 9223372036854775807); slice_479 = None
slice_481: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_480, 3, 0, 256); slice_480 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_482: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_154, 0, 0, 9223372036854775807); view_154 = None
select_50: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_482, 1, 0); slice_482 = None
slice_483: f32[2, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_50, 1, 0, 255); select_50 = None
slice_484: f32[2, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_483, 2, -255, 9223372036854775807); slice_483 = None
slice_485: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_5, 0, 0, 9223372036854775807)
select_51: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_485, 1, 0); slice_485 = None
slice_486: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_51, 1, 1, 256); select_51 = None
slice_487: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_486, 2, 1, 256); slice_486 = None
slice_488: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_112, 0, 0, 9223372036854775807)
slice_489: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_488, 1, 1, 9223372036854775807)
slice_490: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_489, 2, 0, 9223372036854775807)
slice_scatter_113: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_490, slice_471, 3, 0, 256); slice_490 = slice_471 = None
slice_scatter_114: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_489, slice_scatter_113, 2, 0, 9223372036854775807); slice_489 = slice_scatter_113 = None
slice_scatter_115: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_488, slice_scatter_114, 1, 1, 9223372036854775807); slice_488 = slice_scatter_114 = None
slice_scatter_116: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_112, slice_scatter_115, 0, 0, 9223372036854775807); slice_scatter_112 = slice_scatter_115 = None
slice_491: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_116, 0, 0, 9223372036854775807)
select_52: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_491, 1, 0); slice_491 = None
slice_492: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_52, 1, 1, 256); select_52 = None
slice_493: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_492, 2, 1, 256); slice_492 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_155: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_5, [2, 1, 1024, 513]); new_empty_5 = None
transpose_94: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_155, 2, 1); view_155 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_494: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_116, 0, 0, 9223372036854775807)
select_53: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_494, 1, 0)
slice_495: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_53, 1, 1, 256)
slice_scatter_117: f32[2, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_495, slice_484, 2, 1, 256); slice_495 = slice_484 = None
slice_scatter_118: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_53, slice_scatter_117, 1, 1, 256); select_53 = slice_scatter_117 = None
select_scatter_11: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_494, slice_scatter_118, 1, 0); slice_494 = slice_scatter_118 = None
slice_scatter_119: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_116, select_scatter_11, 0, 0, 9223372036854775807); slice_scatter_116 = select_scatter_11 = None
view_156: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_119, [2, 1, 1024, 513])
transpose_95: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_156, 2, 1); view_156 = None
new_ones_8: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_95, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril_5: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones_8); new_ones_8 = None
flip_10: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril_5, [0]); tril_5 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_39: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip_10, 0); flip_10 = None
slice_496: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_39, 1, 0, 9223372036854775807); unsqueeze_39 = None
unsqueeze_40: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_496, 2); slice_496 = None
slice_497: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_40, 3, 0, 9223372036854775807); unsqueeze_40 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_11: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_497, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_498: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_95, 0, 0, 9223372036854775807)
slice_499: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_498, 1, 0, 256); slice_498 = None
slice_500: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_499, 2, 0, 9223372036854775807); slice_499 = None
slice_501: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_500, 3, 0, 257); slice_500 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand_10: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_497, [2, 256, 1, 257]); slice_497 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_10: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_10, 1); expand_10 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_157: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_119, [2, 1, 1024, 513])
transpose_96: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_157, 2, 1); view_157 = None
slice_502: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_96, 0, 0, 9223372036854775807); transpose_96 = None
slice_503: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_502, 1, 0, 256); slice_502 = None
slice_504: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_503, 2, 0, 9223372036854775807); slice_503 = None
slice_505: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_504, 3, 0, 257); slice_504 = None
masked_fill_15: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_505, eq_10, -inf); slice_505 = eq_10 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
slice_506: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_95, 0, 0, 9223372036854775807); transpose_95 = None
slice_507: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_506, 1, -256, 9223372036854775807); slice_506 = None
slice_508: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_507, 2, 0, 9223372036854775807); slice_507 = None
slice_509: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_508, 3, -257, 9223372036854775807); slice_508 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_11: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_11, [2, 256, 1, 257]); flip_11 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_11: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_11, 1); expand_11 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_158: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_119, [2, 1, 1024, 513]); slice_scatter_119 = None
transpose_97: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_158, 2, 1); view_158 = None
slice_510: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_97, 0, 0, 9223372036854775807)
slice_511: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_510, 1, 0, 256)
slice_512: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_511, 2, 0, 9223372036854775807)
slice_scatter_120: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_512, masked_fill_15, 3, 0, 257); slice_512 = masked_fill_15 = None
slice_scatter_121: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_511, slice_scatter_120, 2, 0, 9223372036854775807); slice_511 = slice_scatter_120 = None
slice_scatter_122: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_510, slice_scatter_121, 1, 0, 256); slice_510 = slice_scatter_121 = None
slice_scatter_123: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_97, slice_scatter_122, 0, 0, 9223372036854775807); transpose_97 = slice_scatter_122 = None
transpose_98: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_123, 2, 1); slice_scatter_123 = None
view_159: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_98, [2, 4, 256, 513]); transpose_98 = None
view_160: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_159, [2, 1, 1024, 513])
transpose_99: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_160, 2, 1); view_160 = None
slice_513: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_99, 0, 0, 9223372036854775807); transpose_99 = None
slice_514: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_513, 1, -256, 9223372036854775807); slice_513 = None
slice_515: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_514, 2, 0, 9223372036854775807); slice_514 = None
slice_516: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_515, 3, -257, 9223372036854775807); slice_515 = None
masked_fill_16: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_516, eq_11, -inf); slice_516 = eq_11 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:597, code: attn_scores += diagonal_mask
view_161: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_4, [2, 12, 1024, 513])
transpose_100: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_161, 2, 1); view_161 = None
view_162: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_146, [2, 12, 1024, 513]); view_146 = None
transpose_101: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_162, 2, 1); view_162 = None
slice_517: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_101, 0, 0, 9223372036854775807)
slice_518: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_517, 1, -256, 9223372036854775807)
slice_519: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_518, 2, 0, 9223372036854775807)
slice_scatter_124: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_519, masked_fill_13, 3, -257, 9223372036854775807); slice_519 = masked_fill_13 = None
slice_scatter_125: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_518, slice_scatter_124, 2, 0, 9223372036854775807); slice_518 = slice_scatter_124 = None
slice_scatter_126: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_517, slice_scatter_125, 1, -256, 9223372036854775807); slice_517 = slice_scatter_125 = None
slice_scatter_127: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_101, slice_scatter_126, 0, 0, 9223372036854775807); transpose_101 = slice_scatter_126 = None
transpose_102: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_127, 2, 1); slice_scatter_127 = None
view_163: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_102, [24, 4, 256, 513]); transpose_102 = None
view_164: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_163, [2, 12, 1024, 513]); view_163 = None
transpose_103: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_164, 2, 1); view_164 = None
view_165: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_159, [2, 1, 1024, 513]); view_159 = None
transpose_104: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_165, 2, 1); view_165 = None
slice_520: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_104, 0, 0, 9223372036854775807)
slice_521: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_520, 1, -256, 9223372036854775807)
slice_522: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_521, 2, 0, 9223372036854775807)
slice_scatter_128: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_522, masked_fill_16, 3, -257, 9223372036854775807); slice_522 = masked_fill_16 = None
slice_scatter_129: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_521, slice_scatter_128, 2, 0, 9223372036854775807); slice_521 = slice_scatter_128 = None
slice_scatter_130: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_520, slice_scatter_129, 1, -256, 9223372036854775807); slice_520 = slice_scatter_129 = None
slice_scatter_131: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_104, slice_scatter_130, 0, 0, 9223372036854775807); transpose_104 = slice_scatter_130 = None
transpose_105: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_131, 2, 1); slice_scatter_131 = None
view_166: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_105, [2, 4, 256, 513]); transpose_105 = None
view_167: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_166, [2, 1, 1024, 513]); view_166 = None
transpose_106: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_167, 2, 1); view_167 = None
add_17: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.add.Tensor(transpose_103, transpose_106); transpose_103 = transpose_106 = None
view_168: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_4, [2, 12, 1024, 513]); new_empty_4 = None
transpose_107: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_168, 2, 1); view_168 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:635, code: attn_probs = nn.functional.softmax(
_softmax_2: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten._softmax.default(add_17, -1, False); add_17 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:646, code: attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
slice_523: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(orig_primals_195, 0, 0, 9223372036854775807)
slice_524: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_523, 1, 0, 9223372036854775807); slice_523 = None
unsqueeze_41: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_524, 2); slice_524 = None
unsqueeze_42: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_41, 3); unsqueeze_41 = None
masked_fill_17: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten.masked_fill.Scalar(_softmax_2, unsqueeze_42, 0.0); _softmax_2 = unsqueeze_42 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:655, code: value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_169: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_16, [1024, 2, 12, 64]); add_16 = None
transpose_108: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_169, 0, 1); view_169 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:883, code: chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
transpose_109: f32[2, 12, 1024, 513], [6303744, 513, 6156, 1] = torch.ops.aten.transpose.int(masked_fill_17, 1, 2); masked_fill_17 = None
clone_23: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.clone.default(transpose_109, memory_format = torch.contiguous_format); transpose_109 = None
_unsafe_view_12: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten._unsafe_view.default(clone_23, [24, 4, 256, 513]); clone_23 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:888, code: value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_110: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_108, 1, 2); transpose_108 = None
view_170: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_110, [24, 1024, 64]); transpose_110 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:891, code: padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1)
constant_pad_nd_10: f32[24, 1536, 64], [98304, 64, 1] = torch.ops.aten.constant_pad_nd.default(view_170, [0, 0, 256, 256], -1.0); view_170 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:902, code: chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride)
as_strided_17: f32[24, 4, 768, 64], [98304, 16384, 64, 1] = torch.ops.aten.as_strided.default(constant_pad_nd_10, [24, 4, 768, 64], [98304, 16384, 64, 1]); constant_pad_nd_10 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:755, code: chunked_hidden_states = nn.functional.pad(
constant_pad_nd_11: f32[24, 4, 256, 770], [788480, 197120, 770, 1] = torch.ops.aten.constant_pad_nd.default(_unsafe_view_12, [0, 257], 0.0); _unsafe_view_12 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:758, code: chunked_hidden_states = chunked_hidden_states.view(
view_171: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.view.default(constant_pad_nd_11, [24, 4, -1]); constant_pad_nd_11 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:761, code: chunked_hidden_states = chunked_hidden_states[
slice_525: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(view_171, 0, 0, 9223372036854775807); view_171 = None
slice_526: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_525, 1, 0, 9223372036854775807); slice_525 = None
slice_527: f32[24, 4, 196864], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_526, 2, 0, -256); slice_526 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:764, code: chunked_hidden_states = chunked_hidden_states.view(
view_172: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.view.default(slice_527, [24, 4, 256, 769]); slice_527 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:767, code: chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
slice_528: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(view_172, 0, 0, 9223372036854775807); view_172 = None
slice_529: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_528, 1, 0, 9223372036854775807)
slice_530: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_529, 2, 0, 9223372036854775807); slice_529 = None
slice_531: f32[24, 4, 256, 768], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_530, 3, 0, -1); slice_530 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:906, code: context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
unsqueeze_43: f32[24, 4, 256, 768, 1], [788480, 197120, 769, 1, 0] = torch.ops.aten.unsqueeze.default(slice_531, 4)
permute_37: f32[24, 4, 256, 1, 768], [788480, 197120, 769, 0, 1] = torch.ops.aten.permute.default(unsqueeze_43, [0, 1, 2, 4, 3]); unsqueeze_43 = None
unsqueeze_44: f32[24, 4, 768, 64, 1], [98304, 16384, 64, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_17, 4); as_strided_17 = None
permute_38: f32[24, 4, 1, 64, 768], [98304, 16384, 0, 1, 64] = torch.ops.aten.permute.default(unsqueeze_44, [0, 1, 4, 3, 2]); unsqueeze_44 = None
permute_39: f32[24, 4, 256, 768, 1], [788480, 197120, 769, 1, 0] = torch.ops.aten.permute.default(permute_37, [0, 1, 2, 4, 3]); permute_37 = None
sym_size_19: Sym(24) = torch.ops.aten.sym_size(slice_528, 0); slice_528 = None
# No stacktrace found for following nodes
mul_20: Sym(96) = sym_size_19 * 4
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:906, code: context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
sym_size_20: Sym(768) = torch.ops.aten.sym_size(slice_531, 3); slice_531 = None
view_173: f32[96, 256, 768], [197120, 769, 1] = torch.ops.aten.view.default(permute_39, [mul_20, 256, sym_size_20]); permute_39 = None
permute_40: f32[24, 4, 768, 64, 1], [98304, 16384, 64, 1, 0] = torch.ops.aten.permute.default(permute_38, [0, 1, 4, 3, 2]); permute_38 = None
clone_24: f32[24, 4, 768, 64, 1], [196608, 49152, 64, 1, 1] = torch.ops.aten.clone.default(permute_40, memory_format = torch.contiguous_format); permute_40 = None
_unsafe_view_13: f32[96, 768, 64], [49152, 64, 1] = torch.ops.aten._unsafe_view.default(clone_24, [mul_20, sym_size_20, 64]); clone_24 = mul_20 = sym_size_20 = None
bmm_5: f32[96, 256, 64], [16384, 64, 1] = torch.ops.aten.bmm.default(view_173, _unsafe_view_13); view_173 = _unsafe_view_13 = None
view_174: f32[24, 4, 256, 1, 64], [65536, 16384, 64, 64, 1] = torch.ops.aten.view.default(bmm_5, [sym_size_19, 4, 256, 1, 64]); bmm_5 = None
permute_41: f32[24, 4, 256, 64, 1], [65536, 16384, 64, 1, 64] = torch.ops.aten.permute.default(view_174, [0, 1, 2, 4, 3])
sym_size_21: Sym(4) = torch.ops.aten.sym_size(view_174, 1); view_174 = None
view_175: f32[24, 4, 256, 64], [65536, 16384, 64, 1] = torch.ops.aten.view.default(permute_41, [sym_size_19, sym_size_21, 256, 64]); permute_41 = sym_size_19 = sym_size_21 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:907, code: return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
view_176: f32[2, 12, 1024, 64], [786432, 65536, 64, 1] = torch.ops.aten.view.default(view_175, [2, 12, 1024, 64]); view_175 = None
transpose_111: f32[2, 1024, 12, 64], [786432, 64, 65536, 1] = torch.ops.aten.transpose.int(view_176, 1, 2)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:674, code: attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
transpose_112: f32[1024, 2, 12, 64], [64, 786432, 65536, 1] = torch.ops.aten.transpose.int(transpose_111, 0, 1); transpose_111 = None
clone_25: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.clone.default(transpose_112, memory_format = torch.contiguous_format); transpose_112 = None
_unsafe_view_14: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten._unsafe_view.default(clone_25, [1024, 2, 768]); clone_25 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:703, code: outputs = (attn_output.transpose(0, 1),)
transpose_113: f32[2, 1024, 768], [768, 1536, 1] = torch.ops.aten.transpose.int(_unsafe_view_14, 0, 1)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1116, code: hidden_states = self.dense(hidden_states)
t_15: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_39); orig_primals_39 = None
clone_26: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.clone.default(transpose_113, memory_format = torch.contiguous_format); transpose_113 = None
sym_size_22: Sym(1024) = torch.ops.aten.sym_size(view_176, 2); view_176 = None
# No stacktrace found for following nodes
mul_21: Sym(2048) = 2 * sym_size_22
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1116, code: hidden_states = self.dense(hidden_states)
sym_size_23: Sym(768) = torch.ops.aten.sym_size(_unsafe_view_14, 2); _unsafe_view_14 = None
view_177: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_26, [mul_21, sym_size_23]); clone_26 = mul_21 = sym_size_23 = None
mm_11: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_177, t_15); view_177 = t_15 = None
view_178: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.view.default(mm_11, [2, sym_size_22, 768]); mm_11 = sym_size_22 = None
add_18: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(view_178, orig_primals_40); orig_primals_40 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1118, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
add_19: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(add_18, getitem_9); add_18 = getitem_9 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1118, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
native_layer_norm_4 = torch.ops.aten.native_layer_norm.default(add_19, [768], orig_primals_41, orig_primals_42, 1e-05); add_19 = orig_primals_41 = orig_primals_42 = None
getitem_12: f32[2, 1024, 768], [786432, 768, 1] = native_layer_norm_4[0]
getitem_13: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_4[1]
getitem_14: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_4[2]; native_layer_norm_4 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1182, code: hidden_states = self.dense(hidden_states)
t_16: f32[768, 3072], [1, 768] = torch.ops.aten.t.default(orig_primals_43); orig_primals_43 = None
sym_size_24: Sym(1024) = torch.ops.aten.sym_size(view_178, 1); view_178 = None
# No stacktrace found for following nodes
mul_22: Sym(2048) = 2 * sym_size_24
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1182, code: hidden_states = self.dense(hidden_states)
view_179: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(getitem_12, [mul_22, 768]); mul_22 = None
addmm_4: f32[2048, 3072], [3072, 1] = torch.ops.aten.addmm.default(orig_primals_44, view_179, t_16); orig_primals_44 = view_179 = t_16 = None
view_180: f32[2, 1024, 3072], [3145728, 3072, 1] = torch.ops.aten.view.default(addmm_4, [2, sym_size_24, 3072]); addmm_4 = sym_size_24 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/activations.py:56, code: return self.act(input)
gelu_2: f32[2, 1024, 3072], [3145728, 3072, 1] = torch.ops.aten.gelu.default(view_180)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1196, code: hidden_states = self.dense(hidden_states)
t_17: f32[3072, 768], [1, 3072] = torch.ops.aten.t.default(orig_primals_45); orig_primals_45 = None
sym_size_25: Sym(1024) = torch.ops.aten.sym_size(view_180, 1); view_180 = None
# No stacktrace found for following nodes
mul_23: Sym(2048) = 2 * sym_size_25
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1196, code: hidden_states = self.dense(hidden_states)
view_181: f32[2048, 3072], [3072, 1] = torch.ops.aten.view.default(gelu_2, [mul_23, 3072]); gelu_2 = mul_23 = None
addmm_5: f32[2048, 768], [768, 1] = torch.ops.aten.addmm.default(orig_primals_46, view_181, t_17); orig_primals_46 = view_181 = t_17 = None
view_182: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.view.default(addmm_5, [2, sym_size_25, 768]); addmm_5 = sym_size_25 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1198, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
add_20: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(view_182, getitem_12); getitem_12 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1198, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
native_layer_norm_5 = torch.ops.aten.native_layer_norm.default(add_20, [768], orig_primals_47, orig_primals_48, 1e-05); add_20 = orig_primals_47 = orig_primals_48 = None
getitem_15: f32[2, 1024, 768], [786432, 768, 1] = native_layer_norm_5[0]
getitem_16: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_5[1]
getitem_17: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_5[2]; native_layer_norm_5 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:562, code: hidden_states = hidden_states.transpose(0, 1)
transpose_114: f32[1024, 2, 768], [768, 786432, 1] = torch.ops.aten.transpose.int(getitem_15, 0, 1)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
t_18: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_49); orig_primals_49 = None
clone_27: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_114, memory_format = torch.contiguous_format)
sym_size_26: Sym(1024) = torch.ops.aten.sym_size(view_182, 1); view_182 = None
# No stacktrace found for following nodes
mul_24: Sym(2048) = sym_size_26 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
view_183: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_27, [mul_24, 768]); clone_27 = mul_24 = None
mm_12: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_183, t_18); view_183 = t_18 = None
view_184: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_12, [sym_size_26, 2, 768]); mm_12 = None
add_21: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_184, orig_primals_50); view_184 = orig_primals_50 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
t_19: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_51); orig_primals_51 = None
clone_28: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_114, memory_format = torch.contiguous_format)
# No stacktrace found for following nodes
mul_25: Sym(2048) = sym_size_26 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
view_185: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_28, [mul_25, 768]); clone_28 = mul_25 = None
mm_13: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_185, t_19); view_185 = t_19 = None
view_186: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_13, [sym_size_26, 2, 768]); mm_13 = None
add_22: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_186, orig_primals_52); view_186 = orig_primals_52 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
t_20: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_53); orig_primals_53 = None
clone_29: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_114, memory_format = torch.contiguous_format); transpose_114 = None
# No stacktrace found for following nodes
mul_26: Sym(2048) = sym_size_26 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
view_187: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_29, [mul_26, 768]); clone_29 = mul_26 = None
mm_14: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_187, t_20); view_187 = t_20 = None
view_188: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_14, [sym_size_26, 2, 768]); mm_14 = sym_size_26 = None
add_23: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_188, orig_primals_54); view_188 = orig_primals_54 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:575, code: query_vectors /= math.sqrt(self.head_dim)
div_3: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.div.Tensor(add_21, 8.0); add_21 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:577, code: query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_189: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div_3, [1024, 2, 12, 64])
transpose_115: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_189, 0, 1); view_189 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:578, code: key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_190: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_22, [1024, 2, 12, 64]); add_22 = None
transpose_116: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_190, 0, 1); view_190 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_117: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_115, 1, 2); transpose_115 = None
view_191: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_117, [24, 1024, 64]); transpose_117 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_118: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_116, 1, 2); transpose_116 = None
view_192: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_118, [24, 1024, 64]); transpose_118 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_193: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_191, [24, 2, 512, 64]); view_191 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_18: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_193, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_193 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_194: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_192, [24, 2, 512, 64]); view_192 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_19: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_194, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_194 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze_45: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_18, 4); as_strided_18 = None
permute_42: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze_45, [0, 1, 2, 4, 3]); unsqueeze_45 = None
unsqueeze_46: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_19, 4); as_strided_19 = None
permute_43: f32[24, 3, 1, 512, 64], [64, 393216, 0, 1536, 1] = torch.ops.aten.permute.default(unsqueeze_46, [0, 1, 4, 2, 3]); unsqueeze_46 = None
permute_44: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute_42, [0, 1, 2, 4, 3]); permute_42 = None
view_195: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div_3, [1024, 2, 12, 64]); div_3 = None
transpose_119: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_195, 0, 1); view_195 = None
transpose_120: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_119, 1, 2); transpose_119 = None
view_196: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_120, [24, 1024, 64]); transpose_120 = None
view_197: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_196, [24, 2, 512, 64]); view_196 = None
as_strided_20: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_197, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_197 = None
unsqueeze_47: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_20, 4); as_strided_20 = None
permute_45: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze_47, [0, 1, 2, 4, 3]); unsqueeze_47 = None
permute_46: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute_45, [0, 1, 2, 4, 3]); permute_45 = None
clone_30: f32[24, 3, 512, 64, 1], [98304, 32768, 64, 1, 1] = torch.ops.aten.clone.default(permute_46, memory_format = torch.contiguous_format); permute_46 = None
_unsafe_view_15: f32[72, 512, 64], [32768, 64, 1] = torch.ops.aten._unsafe_view.default(clone_30, [72, 512, 64]); clone_30 = None
permute_47: f32[24, 3, 64, 512, 1], [64, 393216, 1, 1536, 0] = torch.ops.aten.permute.default(permute_43, [0, 1, 4, 3, 2]); permute_43 = None
clone_31: f32[24, 3, 64, 512, 1], [98304, 32768, 512, 1, 1] = torch.ops.aten.clone.default(permute_47, memory_format = torch.contiguous_format); permute_47 = None
_unsafe_view_16: f32[72, 64, 512], [32768, 512, 1] = torch.ops.aten._unsafe_view.default(clone_31, [72, 64, 512]); clone_31 = None
bmm_6: f32[72, 512, 512], [262144, 512, 1] = torch.ops.aten.bmm.default(_unsafe_view_15, _unsafe_view_16); _unsafe_view_15 = _unsafe_view_16 = None
view_198: f32[24, 3, 512, 1, 512], [786432, 262144, 512, 512, 1] = torch.ops.aten.view.default(bmm_6, [24, 3, 512, 1, 512]); bmm_6 = None
permute_48: f32[24, 3, 512, 512, 1], [786432, 262144, 512, 1, 512] = torch.ops.aten.permute.default(view_198, [0, 1, 2, 4, 3]); view_198 = None
view_199: f32[24, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(permute_48, [24, 3, 512, 512]); permute_48 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd_12: f32[24, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_199, [0, 0, 0, 1], 0.0); view_199 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_200: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd_12, [24, 3, 512, 513]); constant_pad_nd_12 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty_6: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_200, [24, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_532: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_200, 0, 0, 9223372036854775807)
slice_533: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_532, 1, 0, 9223372036854775807); slice_532 = None
slice_534: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_533, 2, 0, 256); slice_533 = None
slice_535: f32[24, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_534, 3, 0, 257); slice_534 = None
slice_536: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_6, 0, 0, 9223372036854775807)
slice_537: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_536, 1, 0, -1); slice_536 = None
slice_538: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_537, 2, 0, 9223372036854775807); slice_537 = None
slice_539: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_538, 3, 256, 9223372036854775807); slice_538 = None
slice_540: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_6, 0, 0, 9223372036854775807)
slice_541: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_540, 1, 0, -1); slice_540 = None
slice_542: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_541, 2, 0, 9223372036854775807); slice_541 = None
slice_543: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_542, 3, 256, 9223372036854775807); slice_542 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_544: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_200, 0, 0, 9223372036854775807)
select_54: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_544, 1, -1); slice_544 = None
slice_545: f32[24, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_54, 1, 256, 9223372036854775807); select_54 = None
slice_546: f32[24, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_545, 2, 0, 257); slice_545 = None
slice_547: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_6, 0, 0, 9223372036854775807)
select_55: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_547, 1, -1); slice_547 = None
slice_548: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_55, 1, 0, 9223372036854775807); select_55 = None
slice_549: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_548, 2, 256, 9223372036854775807); slice_548 = None
slice_550: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_6, 0, 0, 9223372036854775807)
slice_551: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_550, 1, 0, -1)
slice_552: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_551, 2, 0, 9223372036854775807)
slice_scatter_132: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_552, slice_535, 3, 256, 9223372036854775807); slice_552 = slice_535 = None
slice_scatter_133: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_551, slice_scatter_132, 2, 0, 9223372036854775807); slice_551 = slice_scatter_132 = None
slice_scatter_134: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_550, slice_scatter_133, 1, 0, -1); slice_550 = slice_scatter_133 = None
slice_scatter_135: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty_6, slice_scatter_134, 0, 0, 9223372036854775807); slice_scatter_134 = None
slice_553: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_135, 0, 0, 9223372036854775807)
select_56: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_553, 1, -1); slice_553 = None
slice_554: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_56, 1, 0, 9223372036854775807); select_56 = None
slice_555: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_554, 2, 256, 9223372036854775807); slice_554 = None
slice_556: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_6, 0, 0, 9223372036854775807)
select_57: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_556, 1, -1); slice_556 = None
slice_557: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_57, 1, 0, 9223372036854775807); select_57 = None
slice_558: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_557, 2, 256, 9223372036854775807); slice_557 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_559: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_200, 0, 0, 9223372036854775807)
slice_560: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_559, 1, 0, 9223372036854775807); slice_559 = None
slice_561: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_560, 2, -257, -1); slice_560 = None
slice_562: f32[24, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_561, 3, 257, 9223372036854775807); slice_561 = None
slice_563: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_6, 0, 0, 9223372036854775807)
slice_564: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_563, 1, 1, 9223372036854775807); slice_563 = None
slice_565: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_564, 2, 0, 9223372036854775807); slice_564 = None
slice_566: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_565, 3, 0, 256); slice_565 = None
slice_567: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_135, 0, 0, 9223372036854775807)
select_58: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_567, 1, -1)
slice_568: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_58, 1, 0, 9223372036854775807)
slice_scatter_136: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_568, slice_546, 2, 256, 9223372036854775807); slice_568 = slice_546 = None
slice_scatter_137: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_58, slice_scatter_136, 1, 0, 9223372036854775807); select_58 = slice_scatter_136 = None
select_scatter_12: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_567, slice_scatter_137, 1, -1); slice_567 = slice_scatter_137 = None
slice_scatter_138: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_135, select_scatter_12, 0, 0, 9223372036854775807); slice_scatter_135 = select_scatter_12 = None
slice_569: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_138, 0, 0, 9223372036854775807)
slice_570: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_569, 1, 1, 9223372036854775807); slice_569 = None
slice_571: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_570, 2, 0, 9223372036854775807); slice_570 = None
slice_572: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_571, 3, 0, 256); slice_571 = None
slice_573: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_6, 0, 0, 9223372036854775807)
slice_574: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_573, 1, 1, 9223372036854775807); slice_573 = None
slice_575: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_574, 2, 0, 9223372036854775807); slice_574 = None
slice_576: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_575, 3, 0, 256); slice_575 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_577: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_200, 0, 0, 9223372036854775807); view_200 = None
select_59: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_577, 1, 0); slice_577 = None
slice_578: f32[24, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_59, 1, 0, 255); select_59 = None
slice_579: f32[24, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_578, 2, -255, 9223372036854775807); slice_578 = None
slice_580: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_6, 0, 0, 9223372036854775807)
select_60: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_580, 1, 0); slice_580 = None
slice_581: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_60, 1, 1, 256); select_60 = None
slice_582: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_581, 2, 1, 256); slice_581 = None
slice_583: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_138, 0, 0, 9223372036854775807)
slice_584: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_583, 1, 1, 9223372036854775807)
slice_585: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_584, 2, 0, 9223372036854775807)
slice_scatter_139: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_585, slice_562, 3, 0, 256); slice_585 = slice_562 = None
slice_scatter_140: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_584, slice_scatter_139, 2, 0, 9223372036854775807); slice_584 = slice_scatter_139 = None
slice_scatter_141: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_583, slice_scatter_140, 1, 1, 9223372036854775807); slice_583 = slice_scatter_140 = None
slice_scatter_142: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_138, slice_scatter_141, 0, 0, 9223372036854775807); slice_scatter_138 = slice_scatter_141 = None
slice_586: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_142, 0, 0, 9223372036854775807)
select_61: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_586, 1, 0); slice_586 = None
slice_587: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_61, 1, 1, 256); select_61 = None
slice_588: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_587, 2, 1, 256); slice_587 = None
slice_589: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_6, 0, 0, 9223372036854775807)
select_62: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_589, 1, 0); slice_589 = None
slice_590: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_62, 1, 1, 256); select_62 = None
slice_591: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_590, 2, 1, 256); slice_590 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_201: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_6, [2, 12, 1024, 513])
transpose_121: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_201, 2, 1); view_201 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_592: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_142, 0, 0, 9223372036854775807)
select_63: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_592, 1, 0)
slice_593: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_63, 1, 1, 256)
slice_scatter_143: f32[24, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_593, slice_579, 2, 1, 256); slice_593 = slice_579 = None
slice_scatter_144: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_63, slice_scatter_143, 1, 1, 256); select_63 = slice_scatter_143 = None
select_scatter_13: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_592, slice_scatter_144, 1, 0); slice_592 = slice_scatter_144 = None
slice_scatter_145: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_142, select_scatter_13, 0, 0, 9223372036854775807); slice_scatter_142 = select_scatter_13 = None
view_202: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_145, [2, 12, 1024, 513])
transpose_122: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_202, 2, 1); view_202 = None
new_ones_9: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_122, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril_6: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones_9); new_ones_9 = None
flip_12: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril_6, [0]); tril_6 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_48: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip_12, 0); flip_12 = None
slice_594: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_48, 1, 0, 9223372036854775807); unsqueeze_48 = None
unsqueeze_49: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_594, 2); slice_594 = None
slice_595: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_49, 3, 0, 9223372036854775807); unsqueeze_49 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_13: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_595, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_596: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_122, 0, 0, 9223372036854775807)
slice_597: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_596, 1, 0, 256); slice_596 = None
slice_598: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_597, 2, 0, 9223372036854775807); slice_597 = None
slice_599: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_598, 3, 0, 257); slice_598 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand_12: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_595, [2, 256, 12, 257]); slice_595 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_12: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand_12, 1); expand_12 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_203: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_145, [2, 12, 1024, 513])
transpose_123: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_203, 2, 1); view_203 = None
slice_600: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_123, 0, 0, 9223372036854775807); transpose_123 = None
slice_601: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_600, 1, 0, 256); slice_600 = None
slice_602: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_601, 2, 0, 9223372036854775807); slice_601 = None
slice_603: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_602, 3, 0, 257); slice_602 = None
masked_fill_18: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_603, eq_12, -inf); slice_603 = eq_12 = None
view_204: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_6, [2, 12, 1024, 513])
transpose_124: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_204, 2, 1); view_204 = None
slice_604: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_124, 0, 0, 9223372036854775807); transpose_124 = None
slice_605: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_604, 1, 0, 256); slice_604 = None
slice_606: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_605, 2, 0, 9223372036854775807); slice_605 = None
slice_607: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_606, 3, 0, 257); slice_606 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
view_205: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_6, [2, 12, 1024, 513])
transpose_125: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_205, 2, 1); view_205 = None
slice_608: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_122, 0, 0, 9223372036854775807); transpose_122 = None
slice_609: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_608, 1, -256, 9223372036854775807); slice_608 = None
slice_610: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_609, 2, 0, 9223372036854775807); slice_609 = None
slice_611: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_610, 3, -257, 9223372036854775807); slice_610 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_13: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_13, [2, 256, 12, 257]); flip_13 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_13: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand_13, 1); expand_13 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_206: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_145, [2, 12, 1024, 513]); slice_scatter_145 = None
transpose_126: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_206, 2, 1); view_206 = None
slice_612: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_126, 0, 0, 9223372036854775807)
slice_613: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_612, 1, 0, 256)
slice_614: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_613, 2, 0, 9223372036854775807)
slice_scatter_146: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_614, masked_fill_18, 3, 0, 257); slice_614 = masked_fill_18 = None
slice_scatter_147: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_613, slice_scatter_146, 2, 0, 9223372036854775807); slice_613 = slice_scatter_146 = None
slice_scatter_148: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_612, slice_scatter_147, 1, 0, 256); slice_612 = slice_scatter_147 = None
slice_scatter_149: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_126, slice_scatter_148, 0, 0, 9223372036854775807); transpose_126 = slice_scatter_148 = None
transpose_127: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_149, 2, 1); slice_scatter_149 = None
view_207: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_127, [24, 4, 256, 513]); transpose_127 = None
view_208: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_207, [2, 12, 1024, 513])
transpose_128: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_208, 2, 1); view_208 = None
slice_615: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_128, 0, 0, 9223372036854775807); transpose_128 = None
slice_616: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_615, 1, -256, 9223372036854775807); slice_615 = None
slice_617: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_616, 2, 0, 9223372036854775807); slice_616 = None
slice_618: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_617, 3, -257, 9223372036854775807); slice_617 = None
masked_fill_19: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_618, eq_13, -inf); slice_618 = eq_13 = None
view_209: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_6, [2, 12, 1024, 513])
transpose_129: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_209, 2, 1); view_209 = None
slice_619: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_129, 0, 0, 9223372036854775807); transpose_129 = None
slice_620: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_619, 1, -256, 9223372036854775807); slice_619 = None
slice_621: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_620, 2, 0, 9223372036854775807); slice_620 = None
slice_622: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_621, 3, -257, 9223372036854775807); slice_621 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
ne_3: b8[2, 1024], [1024, 1] = torch.ops.aten.ne.Scalar(orig_primals_194, 0)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:585, code: remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]
slice_623: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(ne_3, 0, 0, 9223372036854775807); ne_3 = None
slice_624: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_623, 1, 0, 9223372036854775807); slice_623 = None
unsqueeze_50: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_624, 2); slice_624 = None
unsqueeze_51: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_50, 3); unsqueeze_50 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:588, code: float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
_to_copy_3: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten._to_copy.default(unsqueeze_51, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))
masked_fill_20: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.masked_fill.Scalar(_to_copy_3, unsqueeze_51, -10000.0); _to_copy_3 = unsqueeze_51 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:593, code: float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
new_ones_10: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.new_ones.default(masked_fill_20, [2, 1024, 1, 1], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_130: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(new_ones_10, 1, 2); new_ones_10 = None
view_210: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_130, [2, 1024, 1]); transpose_130 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_131: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(masked_fill_20, 1, 2); masked_fill_20 = None
view_211: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_131, [2, 1024, 1]); transpose_131 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_212: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_210, [2, 2, 512, 1]); view_210 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_21: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_212, [2, 3, 512, 1], [1024, 256, 1, 1]); view_212 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_213: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_211, [2, 2, 512, 1]); view_211 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_22: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_213, [2, 3, 512, 1], [1024, 256, 1, 1]); view_213 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze_52: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_21, 4); as_strided_21 = None
permute_49: f32[2, 3, 512, 1, 1], [1024, 256, 1, 0, 1] = torch.ops.aten.permute.default(unsqueeze_52, [0, 1, 2, 4, 3]); unsqueeze_52 = None
unsqueeze_53: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_22, 4); as_strided_22 = None
permute_50: f32[2, 3, 1, 512, 1], [1024, 256, 0, 1, 1] = torch.ops.aten.permute.default(unsqueeze_53, [0, 1, 4, 2, 3]); unsqueeze_53 = None
mul_27: f32[2, 3, 512, 512, 1], [786432, 262144, 512, 1, 1] = torch.ops.aten.mul.Tensor(permute_49, permute_50); permute_49 = permute_50 = None
view_214: f32[2, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(mul_27, [2, 3, 512, 512]); mul_27 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd_13: f32[2, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_214, [0, 0, 0, 1], 0.0); view_214 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_215: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd_13, [2, 3, 512, 513]); constant_pad_nd_13 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty_7: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_215, [2, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_625: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_215, 0, 0, 9223372036854775807)
slice_626: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_625, 1, 0, 9223372036854775807); slice_625 = None
slice_627: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_626, 2, 0, 256); slice_626 = None
slice_628: f32[2, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_627, 3, 0, 257); slice_627 = None
slice_629: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_7, 0, 0, 9223372036854775807)
slice_630: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_629, 1, 0, -1); slice_629 = None
slice_631: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_630, 2, 0, 9223372036854775807); slice_630 = None
slice_632: f32[2, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_631, 3, 256, 9223372036854775807); slice_631 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_633: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_215, 0, 0, 9223372036854775807)
select_64: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_633, 1, -1); slice_633 = None
slice_634: f32[2, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_64, 1, 256, 9223372036854775807); select_64 = None
slice_635: f32[2, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_634, 2, 0, 257); slice_634 = None
slice_636: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_7, 0, 0, 9223372036854775807)
select_65: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_636, 1, -1); slice_636 = None
slice_637: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_65, 1, 0, 9223372036854775807); select_65 = None
slice_638: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_637, 2, 256, 9223372036854775807); slice_637 = None
slice_639: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_7, 0, 0, 9223372036854775807)
slice_640: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_639, 1, 0, -1)
slice_641: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_640, 2, 0, 9223372036854775807)
slice_scatter_150: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_641, slice_628, 3, 256, 9223372036854775807); slice_641 = slice_628 = None
slice_scatter_151: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_640, slice_scatter_150, 2, 0, 9223372036854775807); slice_640 = slice_scatter_150 = None
slice_scatter_152: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_639, slice_scatter_151, 1, 0, -1); slice_639 = slice_scatter_151 = None
slice_scatter_153: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty_7, slice_scatter_152, 0, 0, 9223372036854775807); slice_scatter_152 = None
slice_642: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_153, 0, 0, 9223372036854775807)
select_66: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_642, 1, -1); slice_642 = None
slice_643: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_66, 1, 0, 9223372036854775807); select_66 = None
slice_644: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_643, 2, 256, 9223372036854775807); slice_643 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_645: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_215, 0, 0, 9223372036854775807)
slice_646: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_645, 1, 0, 9223372036854775807); slice_645 = None
slice_647: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_646, 2, -257, -1); slice_646 = None
slice_648: f32[2, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_647, 3, 257, 9223372036854775807); slice_647 = None
slice_649: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_7, 0, 0, 9223372036854775807)
slice_650: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_649, 1, 1, 9223372036854775807); slice_649 = None
slice_651: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_650, 2, 0, 9223372036854775807); slice_650 = None
slice_652: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_651, 3, 0, 256); slice_651 = None
slice_653: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_153, 0, 0, 9223372036854775807)
select_67: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_653, 1, -1)
slice_654: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_67, 1, 0, 9223372036854775807)
slice_scatter_154: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_654, slice_635, 2, 256, 9223372036854775807); slice_654 = slice_635 = None
slice_scatter_155: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_67, slice_scatter_154, 1, 0, 9223372036854775807); select_67 = slice_scatter_154 = None
select_scatter_14: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_653, slice_scatter_155, 1, -1); slice_653 = slice_scatter_155 = None
slice_scatter_156: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_153, select_scatter_14, 0, 0, 9223372036854775807); slice_scatter_153 = select_scatter_14 = None
slice_655: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_156, 0, 0, 9223372036854775807)
slice_656: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_655, 1, 1, 9223372036854775807); slice_655 = None
slice_657: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_656, 2, 0, 9223372036854775807); slice_656 = None
slice_658: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_657, 3, 0, 256); slice_657 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_659: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_215, 0, 0, 9223372036854775807); view_215 = None
select_68: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_659, 1, 0); slice_659 = None
slice_660: f32[2, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_68, 1, 0, 255); select_68 = None
slice_661: f32[2, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_660, 2, -255, 9223372036854775807); slice_660 = None
slice_662: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_7, 0, 0, 9223372036854775807)
select_69: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_662, 1, 0); slice_662 = None
slice_663: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_69, 1, 1, 256); select_69 = None
slice_664: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_663, 2, 1, 256); slice_663 = None
slice_665: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_156, 0, 0, 9223372036854775807)
slice_666: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_665, 1, 1, 9223372036854775807)
slice_667: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_666, 2, 0, 9223372036854775807)
slice_scatter_157: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_667, slice_648, 3, 0, 256); slice_667 = slice_648 = None
slice_scatter_158: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_666, slice_scatter_157, 2, 0, 9223372036854775807); slice_666 = slice_scatter_157 = None
slice_scatter_159: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_665, slice_scatter_158, 1, 1, 9223372036854775807); slice_665 = slice_scatter_158 = None
slice_scatter_160: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_156, slice_scatter_159, 0, 0, 9223372036854775807); slice_scatter_156 = slice_scatter_159 = None
slice_668: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_160, 0, 0, 9223372036854775807)
select_70: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_668, 1, 0); slice_668 = None
slice_669: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_70, 1, 1, 256); select_70 = None
slice_670: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_669, 2, 1, 256); slice_669 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_216: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_7, [2, 1, 1024, 513]); new_empty_7 = None
transpose_132: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_216, 2, 1); view_216 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_671: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_160, 0, 0, 9223372036854775807)
select_71: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_671, 1, 0)
slice_672: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_71, 1, 1, 256)
slice_scatter_161: f32[2, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_672, slice_661, 2, 1, 256); slice_672 = slice_661 = None
slice_scatter_162: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_71, slice_scatter_161, 1, 1, 256); select_71 = slice_scatter_161 = None
select_scatter_15: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_671, slice_scatter_162, 1, 0); slice_671 = slice_scatter_162 = None
slice_scatter_163: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_160, select_scatter_15, 0, 0, 9223372036854775807); slice_scatter_160 = select_scatter_15 = None
view_217: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_163, [2, 1, 1024, 513])
transpose_133: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_217, 2, 1); view_217 = None
new_ones_11: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_133, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril_7: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones_11); new_ones_11 = None
flip_14: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril_7, [0]); tril_7 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_54: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip_14, 0); flip_14 = None
slice_673: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_54, 1, 0, 9223372036854775807); unsqueeze_54 = None
unsqueeze_55: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_673, 2); slice_673 = None
slice_674: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_55, 3, 0, 9223372036854775807); unsqueeze_55 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_15: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_674, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_675: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_133, 0, 0, 9223372036854775807)
slice_676: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_675, 1, 0, 256); slice_675 = None
slice_677: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_676, 2, 0, 9223372036854775807); slice_676 = None
slice_678: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_677, 3, 0, 257); slice_677 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand_14: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_674, [2, 256, 1, 257]); slice_674 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_14: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_14, 1); expand_14 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_218: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_163, [2, 1, 1024, 513])
transpose_134: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_218, 2, 1); view_218 = None
slice_679: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_134, 0, 0, 9223372036854775807); transpose_134 = None
slice_680: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_679, 1, 0, 256); slice_679 = None
slice_681: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_680, 2, 0, 9223372036854775807); slice_680 = None
slice_682: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_681, 3, 0, 257); slice_681 = None
masked_fill_21: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_682, eq_14, -inf); slice_682 = eq_14 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
slice_683: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_133, 0, 0, 9223372036854775807); transpose_133 = None
slice_684: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_683, 1, -256, 9223372036854775807); slice_683 = None
slice_685: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_684, 2, 0, 9223372036854775807); slice_684 = None
slice_686: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_685, 3, -257, 9223372036854775807); slice_685 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_15: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_15, [2, 256, 1, 257]); flip_15 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_15: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_15, 1); expand_15 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_219: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_163, [2, 1, 1024, 513]); slice_scatter_163 = None
transpose_135: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_219, 2, 1); view_219 = None
slice_687: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_135, 0, 0, 9223372036854775807)
slice_688: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_687, 1, 0, 256)
slice_689: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_688, 2, 0, 9223372036854775807)
slice_scatter_164: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_689, masked_fill_21, 3, 0, 257); slice_689 = masked_fill_21 = None
slice_scatter_165: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_688, slice_scatter_164, 2, 0, 9223372036854775807); slice_688 = slice_scatter_164 = None
slice_scatter_166: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_687, slice_scatter_165, 1, 0, 256); slice_687 = slice_scatter_165 = None
slice_scatter_167: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_135, slice_scatter_166, 0, 0, 9223372036854775807); transpose_135 = slice_scatter_166 = None
transpose_136: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_167, 2, 1); slice_scatter_167 = None
view_220: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_136, [2, 4, 256, 513]); transpose_136 = None
view_221: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_220, [2, 1, 1024, 513])
transpose_137: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_221, 2, 1); view_221 = None
slice_690: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_137, 0, 0, 9223372036854775807); transpose_137 = None
slice_691: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_690, 1, -256, 9223372036854775807); slice_690 = None
slice_692: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_691, 2, 0, 9223372036854775807); slice_691 = None
slice_693: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_692, 3, -257, 9223372036854775807); slice_692 = None
masked_fill_22: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_693, eq_15, -inf); slice_693 = eq_15 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:597, code: attn_scores += diagonal_mask
view_222: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_6, [2, 12, 1024, 513])
transpose_138: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_222, 2, 1); view_222 = None
view_223: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_207, [2, 12, 1024, 513]); view_207 = None
transpose_139: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_223, 2, 1); view_223 = None
slice_694: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_139, 0, 0, 9223372036854775807)
slice_695: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_694, 1, -256, 9223372036854775807)
slice_696: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_695, 2, 0, 9223372036854775807)
slice_scatter_168: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_696, masked_fill_19, 3, -257, 9223372036854775807); slice_696 = masked_fill_19 = None
slice_scatter_169: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_695, slice_scatter_168, 2, 0, 9223372036854775807); slice_695 = slice_scatter_168 = None
slice_scatter_170: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_694, slice_scatter_169, 1, -256, 9223372036854775807); slice_694 = slice_scatter_169 = None
slice_scatter_171: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_139, slice_scatter_170, 0, 0, 9223372036854775807); transpose_139 = slice_scatter_170 = None
transpose_140: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_171, 2, 1); slice_scatter_171 = None
view_224: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_140, [24, 4, 256, 513]); transpose_140 = None
view_225: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_224, [2, 12, 1024, 513]); view_224 = None
transpose_141: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_225, 2, 1); view_225 = None
view_226: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_220, [2, 1, 1024, 513]); view_220 = None
transpose_142: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_226, 2, 1); view_226 = None
slice_697: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_142, 0, 0, 9223372036854775807)
slice_698: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_697, 1, -256, 9223372036854775807)
slice_699: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_698, 2, 0, 9223372036854775807)
slice_scatter_172: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_699, masked_fill_22, 3, -257, 9223372036854775807); slice_699 = masked_fill_22 = None
slice_scatter_173: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_698, slice_scatter_172, 2, 0, 9223372036854775807); slice_698 = slice_scatter_172 = None
slice_scatter_174: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_697, slice_scatter_173, 1, -256, 9223372036854775807); slice_697 = slice_scatter_173 = None
slice_scatter_175: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_142, slice_scatter_174, 0, 0, 9223372036854775807); transpose_142 = slice_scatter_174 = None
transpose_143: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_175, 2, 1); slice_scatter_175 = None
view_227: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_143, [2, 4, 256, 513]); transpose_143 = None
view_228: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_227, [2, 1, 1024, 513]); view_227 = None
transpose_144: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_228, 2, 1); view_228 = None
add_24: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.add.Tensor(transpose_141, transpose_144); transpose_141 = transpose_144 = None
view_229: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_6, [2, 12, 1024, 513]); new_empty_6 = None
transpose_145: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_229, 2, 1); view_229 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:635, code: attn_probs = nn.functional.softmax(
_softmax_3: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten._softmax.default(add_24, -1, False); add_24 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:646, code: attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
slice_700: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(orig_primals_195, 0, 0, 9223372036854775807)
slice_701: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_700, 1, 0, 9223372036854775807); slice_700 = None
unsqueeze_56: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_701, 2); slice_701 = None
unsqueeze_57: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_56, 3); unsqueeze_56 = None
masked_fill_23: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten.masked_fill.Scalar(_softmax_3, unsqueeze_57, 0.0); _softmax_3 = unsqueeze_57 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:655, code: value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_230: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_23, [1024, 2, 12, 64]); add_23 = None
transpose_146: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_230, 0, 1); view_230 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:883, code: chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
transpose_147: f32[2, 12, 1024, 513], [6303744, 513, 6156, 1] = torch.ops.aten.transpose.int(masked_fill_23, 1, 2); masked_fill_23 = None
clone_32: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.clone.default(transpose_147, memory_format = torch.contiguous_format); transpose_147 = None
_unsafe_view_17: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten._unsafe_view.default(clone_32, [24, 4, 256, 513]); clone_32 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:888, code: value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_148: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_146, 1, 2); transpose_146 = None
view_231: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_148, [24, 1024, 64]); transpose_148 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:891, code: padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1)
constant_pad_nd_14: f32[24, 1536, 64], [98304, 64, 1] = torch.ops.aten.constant_pad_nd.default(view_231, [0, 0, 256, 256], -1.0); view_231 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:902, code: chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride)
as_strided_23: f32[24, 4, 768, 64], [98304, 16384, 64, 1] = torch.ops.aten.as_strided.default(constant_pad_nd_14, [24, 4, 768, 64], [98304, 16384, 64, 1]); constant_pad_nd_14 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:755, code: chunked_hidden_states = nn.functional.pad(
constant_pad_nd_15: f32[24, 4, 256, 770], [788480, 197120, 770, 1] = torch.ops.aten.constant_pad_nd.default(_unsafe_view_17, [0, 257], 0.0); _unsafe_view_17 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:758, code: chunked_hidden_states = chunked_hidden_states.view(
view_232: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.view.default(constant_pad_nd_15, [24, 4, -1]); constant_pad_nd_15 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:761, code: chunked_hidden_states = chunked_hidden_states[
slice_702: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(view_232, 0, 0, 9223372036854775807); view_232 = None
slice_703: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_702, 1, 0, 9223372036854775807); slice_702 = None
slice_704: f32[24, 4, 196864], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_703, 2, 0, -256); slice_703 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:764, code: chunked_hidden_states = chunked_hidden_states.view(
view_233: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.view.default(slice_704, [24, 4, 256, 769]); slice_704 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:767, code: chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
slice_705: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(view_233, 0, 0, 9223372036854775807); view_233 = None
slice_706: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_705, 1, 0, 9223372036854775807)
slice_707: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_706, 2, 0, 9223372036854775807); slice_706 = None
slice_708: f32[24, 4, 256, 768], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_707, 3, 0, -1); slice_707 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:906, code: context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
unsqueeze_58: f32[24, 4, 256, 768, 1], [788480, 197120, 769, 1, 0] = torch.ops.aten.unsqueeze.default(slice_708, 4)
permute_51: f32[24, 4, 256, 1, 768], [788480, 197120, 769, 0, 1] = torch.ops.aten.permute.default(unsqueeze_58, [0, 1, 2, 4, 3]); unsqueeze_58 = None
unsqueeze_59: f32[24, 4, 768, 64, 1], [98304, 16384, 64, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_23, 4); as_strided_23 = None
permute_52: f32[24, 4, 1, 64, 768], [98304, 16384, 0, 1, 64] = torch.ops.aten.permute.default(unsqueeze_59, [0, 1, 4, 3, 2]); unsqueeze_59 = None
permute_53: f32[24, 4, 256, 768, 1], [788480, 197120, 769, 1, 0] = torch.ops.aten.permute.default(permute_51, [0, 1, 2, 4, 3]); permute_51 = None
sym_size_27: Sym(24) = torch.ops.aten.sym_size(slice_705, 0); slice_705 = None
# No stacktrace found for following nodes
mul_28: Sym(96) = sym_size_27 * 4
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:906, code: context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
sym_size_28: Sym(768) = torch.ops.aten.sym_size(slice_708, 3); slice_708 = None
view_234: f32[96, 256, 768], [197120, 769, 1] = torch.ops.aten.view.default(permute_53, [mul_28, 256, sym_size_28]); permute_53 = None
permute_54: f32[24, 4, 768, 64, 1], [98304, 16384, 64, 1, 0] = torch.ops.aten.permute.default(permute_52, [0, 1, 4, 3, 2]); permute_52 = None
clone_33: f32[24, 4, 768, 64, 1], [196608, 49152, 64, 1, 1] = torch.ops.aten.clone.default(permute_54, memory_format = torch.contiguous_format); permute_54 = None
_unsafe_view_18: f32[96, 768, 64], [49152, 64, 1] = torch.ops.aten._unsafe_view.default(clone_33, [mul_28, sym_size_28, 64]); clone_33 = mul_28 = sym_size_28 = None
bmm_7: f32[96, 256, 64], [16384, 64, 1] = torch.ops.aten.bmm.default(view_234, _unsafe_view_18); view_234 = _unsafe_view_18 = None
view_235: f32[24, 4, 256, 1, 64], [65536, 16384, 64, 64, 1] = torch.ops.aten.view.default(bmm_7, [sym_size_27, 4, 256, 1, 64]); bmm_7 = None
permute_55: f32[24, 4, 256, 64, 1], [65536, 16384, 64, 1, 64] = torch.ops.aten.permute.default(view_235, [0, 1, 2, 4, 3])
sym_size_29: Sym(4) = torch.ops.aten.sym_size(view_235, 1); view_235 = None
view_236: f32[24, 4, 256, 64], [65536, 16384, 64, 1] = torch.ops.aten.view.default(permute_55, [sym_size_27, sym_size_29, 256, 64]); permute_55 = sym_size_27 = sym_size_29 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:907, code: return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
view_237: f32[2, 12, 1024, 64], [786432, 65536, 64, 1] = torch.ops.aten.view.default(view_236, [2, 12, 1024, 64]); view_236 = None
transpose_149: f32[2, 1024, 12, 64], [786432, 64, 65536, 1] = torch.ops.aten.transpose.int(view_237, 1, 2)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:674, code: attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
transpose_150: f32[1024, 2, 12, 64], [64, 786432, 65536, 1] = torch.ops.aten.transpose.int(transpose_149, 0, 1); transpose_149 = None
clone_34: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.clone.default(transpose_150, memory_format = torch.contiguous_format); transpose_150 = None
_unsafe_view_19: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten._unsafe_view.default(clone_34, [1024, 2, 768]); clone_34 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:703, code: outputs = (attn_output.transpose(0, 1),)
transpose_151: f32[2, 1024, 768], [768, 1536, 1] = torch.ops.aten.transpose.int(_unsafe_view_19, 0, 1)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1116, code: hidden_states = self.dense(hidden_states)
t_21: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_55); orig_primals_55 = None
clone_35: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.clone.default(transpose_151, memory_format = torch.contiguous_format); transpose_151 = None
sym_size_30: Sym(1024) = torch.ops.aten.sym_size(view_237, 2); view_237 = None
# No stacktrace found for following nodes
mul_29: Sym(2048) = 2 * sym_size_30
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1116, code: hidden_states = self.dense(hidden_states)
sym_size_31: Sym(768) = torch.ops.aten.sym_size(_unsafe_view_19, 2); _unsafe_view_19 = None
view_238: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_35, [mul_29, sym_size_31]); clone_35 = mul_29 = sym_size_31 = None
mm_15: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_238, t_21); view_238 = t_21 = None
view_239: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.view.default(mm_15, [2, sym_size_30, 768]); mm_15 = sym_size_30 = None
add_25: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(view_239, orig_primals_56); orig_primals_56 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1118, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
add_26: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(add_25, getitem_15); add_25 = getitem_15 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1118, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
native_layer_norm_6 = torch.ops.aten.native_layer_norm.default(add_26, [768], orig_primals_57, orig_primals_58, 1e-05); add_26 = orig_primals_57 = orig_primals_58 = None
getitem_18: f32[2, 1024, 768], [786432, 768, 1] = native_layer_norm_6[0]
getitem_19: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_6[1]
getitem_20: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_6[2]; native_layer_norm_6 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1182, code: hidden_states = self.dense(hidden_states)
t_22: f32[768, 3072], [1, 768] = torch.ops.aten.t.default(orig_primals_59); orig_primals_59 = None
sym_size_32: Sym(1024) = torch.ops.aten.sym_size(view_239, 1); view_239 = None
# No stacktrace found for following nodes
mul_30: Sym(2048) = 2 * sym_size_32
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1182, code: hidden_states = self.dense(hidden_states)
view_240: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(getitem_18, [mul_30, 768]); mul_30 = None
addmm_6: f32[2048, 3072], [3072, 1] = torch.ops.aten.addmm.default(orig_primals_60, view_240, t_22); orig_primals_60 = view_240 = t_22 = None
view_241: f32[2, 1024, 3072], [3145728, 3072, 1] = torch.ops.aten.view.default(addmm_6, [2, sym_size_32, 3072]); addmm_6 = sym_size_32 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/activations.py:56, code: return self.act(input)
gelu_3: f32[2, 1024, 3072], [3145728, 3072, 1] = torch.ops.aten.gelu.default(view_241)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1196, code: hidden_states = self.dense(hidden_states)
t_23: f32[3072, 768], [1, 3072] = torch.ops.aten.t.default(orig_primals_61); orig_primals_61 = None
sym_size_33: Sym(1024) = torch.ops.aten.sym_size(view_241, 1); view_241 = None
# No stacktrace found for following nodes
mul_31: Sym(2048) = 2 * sym_size_33
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1196, code: hidden_states = self.dense(hidden_states)
view_242: f32[2048, 3072], [3072, 1] = torch.ops.aten.view.default(gelu_3, [mul_31, 3072]); gelu_3 = mul_31 = None
addmm_7: f32[2048, 768], [768, 1] = torch.ops.aten.addmm.default(orig_primals_62, view_242, t_23); orig_primals_62 = view_242 = t_23 = None
view_243: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.view.default(addmm_7, [2, sym_size_33, 768]); addmm_7 = sym_size_33 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1198, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
add_27: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(view_243, getitem_18); getitem_18 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1198, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
native_layer_norm_7 = torch.ops.aten.native_layer_norm.default(add_27, [768], orig_primals_63, orig_primals_64, 1e-05); add_27 = orig_primals_63 = orig_primals_64 = None
getitem_21: f32[2, 1024, 768], [786432, 768, 1] = native_layer_norm_7[0]
getitem_22: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_7[1]
getitem_23: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_7[2]; native_layer_norm_7 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:562, code: hidden_states = hidden_states.transpose(0, 1)
transpose_152: f32[1024, 2, 768], [768, 786432, 1] = torch.ops.aten.transpose.int(getitem_21, 0, 1)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
t_24: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_65); orig_primals_65 = None
clone_36: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_152, memory_format = torch.contiguous_format)
sym_size_34: Sym(1024) = torch.ops.aten.sym_size(view_243, 1); view_243 = None
# No stacktrace found for following nodes
mul_32: Sym(2048) = sym_size_34 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
view_244: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_36, [mul_32, 768]); clone_36 = mul_32 = None
mm_16: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_244, t_24); view_244 = t_24 = None
view_245: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_16, [sym_size_34, 2, 768]); mm_16 = None
add_28: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_245, orig_primals_66); view_245 = orig_primals_66 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
t_25: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_67); orig_primals_67 = None
clone_37: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_152, memory_format = torch.contiguous_format)
# No stacktrace found for following nodes
mul_33: Sym(2048) = sym_size_34 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
view_246: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_37, [mul_33, 768]); clone_37 = mul_33 = None
mm_17: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_246, t_25); view_246 = t_25 = None
view_247: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_17, [sym_size_34, 2, 768]); mm_17 = None
add_29: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_247, orig_primals_68); view_247 = orig_primals_68 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
t_26: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_69); orig_primals_69 = None
clone_38: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_152, memory_format = torch.contiguous_format); transpose_152 = None
# No stacktrace found for following nodes
mul_34: Sym(2048) = sym_size_34 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
view_248: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_38, [mul_34, 768]); clone_38 = mul_34 = None
mm_18: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_248, t_26); view_248 = t_26 = None
view_249: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_18, [sym_size_34, 2, 768]); mm_18 = sym_size_34 = None
add_30: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_249, orig_primals_70); view_249 = orig_primals_70 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:575, code: query_vectors /= math.sqrt(self.head_dim)
div_4: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.div.Tensor(add_28, 8.0); add_28 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:577, code: query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_250: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div_4, [1024, 2, 12, 64])
transpose_153: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_250, 0, 1); view_250 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:578, code: key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_251: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_29, [1024, 2, 12, 64]); add_29 = None
transpose_154: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_251, 0, 1); view_251 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_155: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_153, 1, 2); transpose_153 = None
view_252: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_155, [24, 1024, 64]); transpose_155 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_156: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_154, 1, 2); transpose_154 = None
view_253: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_156, [24, 1024, 64]); transpose_156 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_254: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_252, [24, 2, 512, 64]); view_252 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_24: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_254, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_254 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_255: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_253, [24, 2, 512, 64]); view_253 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_25: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_255, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_255 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze_60: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_24, 4); as_strided_24 = None
permute_56: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze_60, [0, 1, 2, 4, 3]); unsqueeze_60 = None
unsqueeze_61: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_25, 4); as_strided_25 = None
permute_57: f32[24, 3, 1, 512, 64], [64, 393216, 0, 1536, 1] = torch.ops.aten.permute.default(unsqueeze_61, [0, 1, 4, 2, 3]); unsqueeze_61 = None
permute_58: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute_56, [0, 1, 2, 4, 3]); permute_56 = None
view_256: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div_4, [1024, 2, 12, 64]); div_4 = None
transpose_157: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_256, 0, 1); view_256 = None
transpose_158: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_157, 1, 2); transpose_157 = None
view_257: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_158, [24, 1024, 64]); transpose_158 = None
view_258: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_257, [24, 2, 512, 64]); view_257 = None
as_strided_26: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_258, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_258 = None
unsqueeze_62: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_26, 4); as_strided_26 = None
permute_59: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze_62, [0, 1, 2, 4, 3]); unsqueeze_62 = None
permute_60: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute_59, [0, 1, 2, 4, 3]); permute_59 = None
clone_39: f32[24, 3, 512, 64, 1], [98304, 32768, 64, 1, 1] = torch.ops.aten.clone.default(permute_60, memory_format = torch.contiguous_format); permute_60 = None
_unsafe_view_20: f32[72, 512, 64], [32768, 64, 1] = torch.ops.aten._unsafe_view.default(clone_39, [72, 512, 64]); clone_39 = None
permute_61: f32[24, 3, 64, 512, 1], [64, 393216, 1, 1536, 0] = torch.ops.aten.permute.default(permute_57, [0, 1, 4, 3, 2]); permute_57 = None
clone_40: f32[24, 3, 64, 512, 1], [98304, 32768, 512, 1, 1] = torch.ops.aten.clone.default(permute_61, memory_format = torch.contiguous_format); permute_61 = None
_unsafe_view_21: f32[72, 64, 512], [32768, 512, 1] = torch.ops.aten._unsafe_view.default(clone_40, [72, 64, 512]); clone_40 = None
bmm_8: f32[72, 512, 512], [262144, 512, 1] = torch.ops.aten.bmm.default(_unsafe_view_20, _unsafe_view_21); _unsafe_view_20 = _unsafe_view_21 = None
view_259: f32[24, 3, 512, 1, 512], [786432, 262144, 512, 512, 1] = torch.ops.aten.view.default(bmm_8, [24, 3, 512, 1, 512]); bmm_8 = None
permute_62: f32[24, 3, 512, 512, 1], [786432, 262144, 512, 1, 512] = torch.ops.aten.permute.default(view_259, [0, 1, 2, 4, 3]); view_259 = None
view_260: f32[24, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(permute_62, [24, 3, 512, 512]); permute_62 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd_16: f32[24, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_260, [0, 0, 0, 1], 0.0); view_260 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_261: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd_16, [24, 3, 512, 513]); constant_pad_nd_16 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty_8: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_261, [24, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_709: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_261, 0, 0, 9223372036854775807)
slice_710: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_709, 1, 0, 9223372036854775807); slice_709 = None
slice_711: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_710, 2, 0, 256); slice_710 = None
slice_712: f32[24, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_711, 3, 0, 257); slice_711 = None
slice_713: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_8, 0, 0, 9223372036854775807)
slice_714: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_713, 1, 0, -1); slice_713 = None
slice_715: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_714, 2, 0, 9223372036854775807); slice_714 = None
slice_716: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_715, 3, 256, 9223372036854775807); slice_715 = None
slice_717: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_8, 0, 0, 9223372036854775807)
slice_718: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_717, 1, 0, -1); slice_717 = None
slice_719: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_718, 2, 0, 9223372036854775807); slice_718 = None
slice_720: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_719, 3, 256, 9223372036854775807); slice_719 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_721: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_261, 0, 0, 9223372036854775807)
select_72: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_721, 1, -1); slice_721 = None
slice_722: f32[24, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_72, 1, 256, 9223372036854775807); select_72 = None
slice_723: f32[24, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_722, 2, 0, 257); slice_722 = None
slice_724: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_8, 0, 0, 9223372036854775807)
select_73: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_724, 1, -1); slice_724 = None
slice_725: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_73, 1, 0, 9223372036854775807); select_73 = None
slice_726: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_725, 2, 256, 9223372036854775807); slice_725 = None
slice_727: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_8, 0, 0, 9223372036854775807)
slice_728: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_727, 1, 0, -1)
slice_729: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_728, 2, 0, 9223372036854775807)
slice_scatter_176: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_729, slice_712, 3, 256, 9223372036854775807); slice_729 = slice_712 = None
slice_scatter_177: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_728, slice_scatter_176, 2, 0, 9223372036854775807); slice_728 = slice_scatter_176 = None
slice_scatter_178: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_727, slice_scatter_177, 1, 0, -1); slice_727 = slice_scatter_177 = None
slice_scatter_179: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty_8, slice_scatter_178, 0, 0, 9223372036854775807); slice_scatter_178 = None
slice_730: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_179, 0, 0, 9223372036854775807)
select_74: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_730, 1, -1); slice_730 = None
slice_731: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_74, 1, 0, 9223372036854775807); select_74 = None
slice_732: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_731, 2, 256, 9223372036854775807); slice_731 = None
slice_733: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_8, 0, 0, 9223372036854775807)
select_75: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_733, 1, -1); slice_733 = None
slice_734: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_75, 1, 0, 9223372036854775807); select_75 = None
slice_735: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_734, 2, 256, 9223372036854775807); slice_734 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_736: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_261, 0, 0, 9223372036854775807)
slice_737: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_736, 1, 0, 9223372036854775807); slice_736 = None
slice_738: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_737, 2, -257, -1); slice_737 = None
slice_739: f32[24, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_738, 3, 257, 9223372036854775807); slice_738 = None
slice_740: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_8, 0, 0, 9223372036854775807)
slice_741: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_740, 1, 1, 9223372036854775807); slice_740 = None
slice_742: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_741, 2, 0, 9223372036854775807); slice_741 = None
slice_743: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_742, 3, 0, 256); slice_742 = None
slice_744: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_179, 0, 0, 9223372036854775807)
select_76: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_744, 1, -1)
slice_745: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_76, 1, 0, 9223372036854775807)
slice_scatter_180: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_745, slice_723, 2, 256, 9223372036854775807); slice_745 = slice_723 = None
slice_scatter_181: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_76, slice_scatter_180, 1, 0, 9223372036854775807); select_76 = slice_scatter_180 = None
select_scatter_16: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_744, slice_scatter_181, 1, -1); slice_744 = slice_scatter_181 = None
slice_scatter_182: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_179, select_scatter_16, 0, 0, 9223372036854775807); slice_scatter_179 = select_scatter_16 = None
slice_746: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_182, 0, 0, 9223372036854775807)
slice_747: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_746, 1, 1, 9223372036854775807); slice_746 = None
slice_748: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_747, 2, 0, 9223372036854775807); slice_747 = None
slice_749: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_748, 3, 0, 256); slice_748 = None
slice_750: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_8, 0, 0, 9223372036854775807)
slice_751: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_750, 1, 1, 9223372036854775807); slice_750 = None
slice_752: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_751, 2, 0, 9223372036854775807); slice_751 = None
slice_753: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_752, 3, 0, 256); slice_752 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_754: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_261, 0, 0, 9223372036854775807); view_261 = None
select_77: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_754, 1, 0); slice_754 = None
slice_755: f32[24, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_77, 1, 0, 255); select_77 = None
slice_756: f32[24, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_755, 2, -255, 9223372036854775807); slice_755 = None
slice_757: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_8, 0, 0, 9223372036854775807)
select_78: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_757, 1, 0); slice_757 = None
slice_758: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_78, 1, 1, 256); select_78 = None
slice_759: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_758, 2, 1, 256); slice_758 = None
slice_760: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_182, 0, 0, 9223372036854775807)
slice_761: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_760, 1, 1, 9223372036854775807)
slice_762: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_761, 2, 0, 9223372036854775807)
slice_scatter_183: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_762, slice_739, 3, 0, 256); slice_762 = slice_739 = None
slice_scatter_184: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_761, slice_scatter_183, 2, 0, 9223372036854775807); slice_761 = slice_scatter_183 = None
slice_scatter_185: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_760, slice_scatter_184, 1, 1, 9223372036854775807); slice_760 = slice_scatter_184 = None
slice_scatter_186: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_182, slice_scatter_185, 0, 0, 9223372036854775807); slice_scatter_182 = slice_scatter_185 = None
slice_763: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_186, 0, 0, 9223372036854775807)
select_79: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_763, 1, 0); slice_763 = None
slice_764: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_79, 1, 1, 256); select_79 = None
slice_765: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_764, 2, 1, 256); slice_764 = None
slice_766: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_8, 0, 0, 9223372036854775807)
select_80: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_766, 1, 0); slice_766 = None
slice_767: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_80, 1, 1, 256); select_80 = None
slice_768: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_767, 2, 1, 256); slice_767 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_262: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_8, [2, 12, 1024, 513])
transpose_159: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_262, 2, 1); view_262 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_769: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_186, 0, 0, 9223372036854775807)
select_81: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_769, 1, 0)
slice_770: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_81, 1, 1, 256)
slice_scatter_187: f32[24, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_770, slice_756, 2, 1, 256); slice_770 = slice_756 = None
slice_scatter_188: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_81, slice_scatter_187, 1, 1, 256); select_81 = slice_scatter_187 = None
select_scatter_17: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_769, slice_scatter_188, 1, 0); slice_769 = slice_scatter_188 = None
slice_scatter_189: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_186, select_scatter_17, 0, 0, 9223372036854775807); slice_scatter_186 = select_scatter_17 = None
view_263: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_189, [2, 12, 1024, 513])
transpose_160: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_263, 2, 1); view_263 = None
new_ones_12: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_160, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril_8: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones_12); new_ones_12 = None
flip_16: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril_8, [0]); tril_8 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_63: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip_16, 0); flip_16 = None
slice_771: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_63, 1, 0, 9223372036854775807); unsqueeze_63 = None
unsqueeze_64: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_771, 2); slice_771 = None
slice_772: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_64, 3, 0, 9223372036854775807); unsqueeze_64 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_17: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_772, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_773: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_160, 0, 0, 9223372036854775807)
slice_774: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_773, 1, 0, 256); slice_773 = None
slice_775: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_774, 2, 0, 9223372036854775807); slice_774 = None
slice_776: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_775, 3, 0, 257); slice_775 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand_16: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_772, [2, 256, 12, 257]); slice_772 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_16: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand_16, 1); expand_16 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_264: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_189, [2, 12, 1024, 513])
transpose_161: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_264, 2, 1); view_264 = None
slice_777: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_161, 0, 0, 9223372036854775807); transpose_161 = None
slice_778: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_777, 1, 0, 256); slice_777 = None
slice_779: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_778, 2, 0, 9223372036854775807); slice_778 = None
slice_780: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_779, 3, 0, 257); slice_779 = None
masked_fill_24: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_780, eq_16, -inf); slice_780 = eq_16 = None
view_265: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_8, [2, 12, 1024, 513])
transpose_162: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_265, 2, 1); view_265 = None
slice_781: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_162, 0, 0, 9223372036854775807); transpose_162 = None
slice_782: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_781, 1, 0, 256); slice_781 = None
slice_783: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_782, 2, 0, 9223372036854775807); slice_782 = None
slice_784: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_783, 3, 0, 257); slice_783 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
view_266: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_8, [2, 12, 1024, 513])
transpose_163: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_266, 2, 1); view_266 = None
slice_785: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_160, 0, 0, 9223372036854775807); transpose_160 = None
slice_786: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_785, 1, -256, 9223372036854775807); slice_785 = None
slice_787: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_786, 2, 0, 9223372036854775807); slice_786 = None
slice_788: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_787, 3, -257, 9223372036854775807); slice_787 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_17: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_17, [2, 256, 12, 257]); flip_17 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_17: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand_17, 1); expand_17 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_267: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_189, [2, 12, 1024, 513]); slice_scatter_189 = None
transpose_164: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_267, 2, 1); view_267 = None
slice_789: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_164, 0, 0, 9223372036854775807)
slice_790: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_789, 1, 0, 256)
slice_791: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_790, 2, 0, 9223372036854775807)
slice_scatter_190: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_791, masked_fill_24, 3, 0, 257); slice_791 = masked_fill_24 = None
slice_scatter_191: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_790, slice_scatter_190, 2, 0, 9223372036854775807); slice_790 = slice_scatter_190 = None
slice_scatter_192: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_789, slice_scatter_191, 1, 0, 256); slice_789 = slice_scatter_191 = None
slice_scatter_193: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_164, slice_scatter_192, 0, 0, 9223372036854775807); transpose_164 = slice_scatter_192 = None
transpose_165: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_193, 2, 1); slice_scatter_193 = None
view_268: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_165, [24, 4, 256, 513]); transpose_165 = None
view_269: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_268, [2, 12, 1024, 513])
transpose_166: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_269, 2, 1); view_269 = None
slice_792: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_166, 0, 0, 9223372036854775807); transpose_166 = None
slice_793: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_792, 1, -256, 9223372036854775807); slice_792 = None
slice_794: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_793, 2, 0, 9223372036854775807); slice_793 = None
slice_795: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_794, 3, -257, 9223372036854775807); slice_794 = None
masked_fill_25: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_795, eq_17, -inf); slice_795 = eq_17 = None
view_270: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_8, [2, 12, 1024, 513])
transpose_167: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_270, 2, 1); view_270 = None
slice_796: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_167, 0, 0, 9223372036854775807); transpose_167 = None
slice_797: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_796, 1, -256, 9223372036854775807); slice_796 = None
slice_798: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_797, 2, 0, 9223372036854775807); slice_797 = None
slice_799: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_798, 3, -257, 9223372036854775807); slice_798 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
ne_4: b8[2, 1024], [1024, 1] = torch.ops.aten.ne.Scalar(orig_primals_194, 0)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:585, code: remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]
slice_800: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(ne_4, 0, 0, 9223372036854775807); ne_4 = None
slice_801: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_800, 1, 0, 9223372036854775807); slice_800 = None
unsqueeze_65: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_801, 2); slice_801 = None
unsqueeze_66: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_65, 3); unsqueeze_65 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:588, code: float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
_to_copy_4: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten._to_copy.default(unsqueeze_66, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))
masked_fill_26: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.masked_fill.Scalar(_to_copy_4, unsqueeze_66, -10000.0); _to_copy_4 = unsqueeze_66 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:593, code: float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
new_ones_13: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.new_ones.default(masked_fill_26, [2, 1024, 1, 1], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_168: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(new_ones_13, 1, 2); new_ones_13 = None
view_271: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_168, [2, 1024, 1]); transpose_168 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_169: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(masked_fill_26, 1, 2); masked_fill_26 = None
view_272: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_169, [2, 1024, 1]); transpose_169 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_273: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_271, [2, 2, 512, 1]); view_271 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_27: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_273, [2, 3, 512, 1], [1024, 256, 1, 1]); view_273 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_274: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_272, [2, 2, 512, 1]); view_272 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_28: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_274, [2, 3, 512, 1], [1024, 256, 1, 1]); view_274 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze_67: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_27, 4); as_strided_27 = None
permute_63: f32[2, 3, 512, 1, 1], [1024, 256, 1, 0, 1] = torch.ops.aten.permute.default(unsqueeze_67, [0, 1, 2, 4, 3]); unsqueeze_67 = None
unsqueeze_68: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_28, 4); as_strided_28 = None
permute_64: f32[2, 3, 1, 512, 1], [1024, 256, 0, 1, 1] = torch.ops.aten.permute.default(unsqueeze_68, [0, 1, 4, 2, 3]); unsqueeze_68 = None
mul_35: f32[2, 3, 512, 512, 1], [786432, 262144, 512, 1, 1] = torch.ops.aten.mul.Tensor(permute_63, permute_64); permute_63 = permute_64 = None
view_275: f32[2, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(mul_35, [2, 3, 512, 512]); mul_35 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd_17: f32[2, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_275, [0, 0, 0, 1], 0.0); view_275 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_276: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd_17, [2, 3, 512, 513]); constant_pad_nd_17 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty_9: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_276, [2, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_802: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_276, 0, 0, 9223372036854775807)
slice_803: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_802, 1, 0, 9223372036854775807); slice_802 = None
slice_804: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_803, 2, 0, 256); slice_803 = None
slice_805: f32[2, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_804, 3, 0, 257); slice_804 = None
slice_806: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_9, 0, 0, 9223372036854775807)
slice_807: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_806, 1, 0, -1); slice_806 = None
slice_808: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_807, 2, 0, 9223372036854775807); slice_807 = None
slice_809: f32[2, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_808, 3, 256, 9223372036854775807); slice_808 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_810: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_276, 0, 0, 9223372036854775807)
select_82: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_810, 1, -1); slice_810 = None
slice_811: f32[2, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_82, 1, 256, 9223372036854775807); select_82 = None
slice_812: f32[2, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_811, 2, 0, 257); slice_811 = None
slice_813: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_9, 0, 0, 9223372036854775807)
select_83: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_813, 1, -1); slice_813 = None
slice_814: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_83, 1, 0, 9223372036854775807); select_83 = None
slice_815: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_814, 2, 256, 9223372036854775807); slice_814 = None
slice_816: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_9, 0, 0, 9223372036854775807)
slice_817: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_816, 1, 0, -1)
slice_818: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_817, 2, 0, 9223372036854775807)
slice_scatter_194: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_818, slice_805, 3, 256, 9223372036854775807); slice_818 = slice_805 = None
slice_scatter_195: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_817, slice_scatter_194, 2, 0, 9223372036854775807); slice_817 = slice_scatter_194 = None
slice_scatter_196: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_816, slice_scatter_195, 1, 0, -1); slice_816 = slice_scatter_195 = None
slice_scatter_197: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty_9, slice_scatter_196, 0, 0, 9223372036854775807); slice_scatter_196 = None
slice_819: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_197, 0, 0, 9223372036854775807)
select_84: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_819, 1, -1); slice_819 = None
slice_820: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_84, 1, 0, 9223372036854775807); select_84 = None
slice_821: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_820, 2, 256, 9223372036854775807); slice_820 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_822: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_276, 0, 0, 9223372036854775807)
slice_823: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_822, 1, 0, 9223372036854775807); slice_822 = None
slice_824: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_823, 2, -257, -1); slice_823 = None
slice_825: f32[2, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_824, 3, 257, 9223372036854775807); slice_824 = None
slice_826: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_9, 0, 0, 9223372036854775807)
slice_827: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_826, 1, 1, 9223372036854775807); slice_826 = None
slice_828: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_827, 2, 0, 9223372036854775807); slice_827 = None
slice_829: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_828, 3, 0, 256); slice_828 = None
slice_830: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_197, 0, 0, 9223372036854775807)
select_85: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_830, 1, -1)
slice_831: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_85, 1, 0, 9223372036854775807)
slice_scatter_198: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_831, slice_812, 2, 256, 9223372036854775807); slice_831 = slice_812 = None
slice_scatter_199: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_85, slice_scatter_198, 1, 0, 9223372036854775807); select_85 = slice_scatter_198 = None
select_scatter_18: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_830, slice_scatter_199, 1, -1); slice_830 = slice_scatter_199 = None
slice_scatter_200: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_197, select_scatter_18, 0, 0, 9223372036854775807); slice_scatter_197 = select_scatter_18 = None
slice_832: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_200, 0, 0, 9223372036854775807)
slice_833: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_832, 1, 1, 9223372036854775807); slice_832 = None
slice_834: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_833, 2, 0, 9223372036854775807); slice_833 = None
slice_835: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_834, 3, 0, 256); slice_834 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_836: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_276, 0, 0, 9223372036854775807); view_276 = None
select_86: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_836, 1, 0); slice_836 = None
slice_837: f32[2, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_86, 1, 0, 255); select_86 = None
slice_838: f32[2, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_837, 2, -255, 9223372036854775807); slice_837 = None
slice_839: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_9, 0, 0, 9223372036854775807)
select_87: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_839, 1, 0); slice_839 = None
slice_840: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_87, 1, 1, 256); select_87 = None
slice_841: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_840, 2, 1, 256); slice_840 = None
slice_842: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_200, 0, 0, 9223372036854775807)
slice_843: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_842, 1, 1, 9223372036854775807)
slice_844: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_843, 2, 0, 9223372036854775807)
slice_scatter_201: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_844, slice_825, 3, 0, 256); slice_844 = slice_825 = None
slice_scatter_202: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_843, slice_scatter_201, 2, 0, 9223372036854775807); slice_843 = slice_scatter_201 = None
slice_scatter_203: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_842, slice_scatter_202, 1, 1, 9223372036854775807); slice_842 = slice_scatter_202 = None
slice_scatter_204: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_200, slice_scatter_203, 0, 0, 9223372036854775807); slice_scatter_200 = slice_scatter_203 = None
slice_845: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_204, 0, 0, 9223372036854775807)
select_88: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_845, 1, 0); slice_845 = None
slice_846: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_88, 1, 1, 256); select_88 = None
slice_847: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_846, 2, 1, 256); slice_846 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_277: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_9, [2, 1, 1024, 513]); new_empty_9 = None
transpose_170: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_277, 2, 1); view_277 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_848: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_204, 0, 0, 9223372036854775807)
select_89: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_848, 1, 0)
slice_849: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_89, 1, 1, 256)
slice_scatter_205: f32[2, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_849, slice_838, 2, 1, 256); slice_849 = slice_838 = None
slice_scatter_206: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_89, slice_scatter_205, 1, 1, 256); select_89 = slice_scatter_205 = None
select_scatter_19: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_848, slice_scatter_206, 1, 0); slice_848 = slice_scatter_206 = None
slice_scatter_207: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_204, select_scatter_19, 0, 0, 9223372036854775807); slice_scatter_204 = select_scatter_19 = None
view_278: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_207, [2, 1, 1024, 513])
transpose_171: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_278, 2, 1); view_278 = None
new_ones_14: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_171, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril_9: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones_14); new_ones_14 = None
flip_18: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril_9, [0]); tril_9 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_69: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip_18, 0); flip_18 = None
slice_850: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_69, 1, 0, 9223372036854775807); unsqueeze_69 = None
unsqueeze_70: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_850, 2); slice_850 = None
slice_851: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_70, 3, 0, 9223372036854775807); unsqueeze_70 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_19: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_851, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_852: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_171, 0, 0, 9223372036854775807)
slice_853: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_852, 1, 0, 256); slice_852 = None
slice_854: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_853, 2, 0, 9223372036854775807); slice_853 = None
slice_855: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_854, 3, 0, 257); slice_854 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand_18: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_851, [2, 256, 1, 257]); slice_851 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_18: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_18, 1); expand_18 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_279: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_207, [2, 1, 1024, 513])
transpose_172: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_279, 2, 1); view_279 = None
slice_856: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_172, 0, 0, 9223372036854775807); transpose_172 = None
slice_857: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_856, 1, 0, 256); slice_856 = None
slice_858: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_857, 2, 0, 9223372036854775807); slice_857 = None
slice_859: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_858, 3, 0, 257); slice_858 = None
masked_fill_27: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_859, eq_18, -inf); slice_859 = eq_18 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
slice_860: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_171, 0, 0, 9223372036854775807); transpose_171 = None
slice_861: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_860, 1, -256, 9223372036854775807); slice_860 = None
slice_862: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_861, 2, 0, 9223372036854775807); slice_861 = None
slice_863: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_862, 3, -257, 9223372036854775807); slice_862 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_19: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_19, [2, 256, 1, 257]); flip_19 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_19: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_19, 1); expand_19 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_280: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_207, [2, 1, 1024, 513]); slice_scatter_207 = None
transpose_173: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_280, 2, 1); view_280 = None
slice_864: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_173, 0, 0, 9223372036854775807)
slice_865: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_864, 1, 0, 256)
slice_866: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_865, 2, 0, 9223372036854775807)
slice_scatter_208: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_866, masked_fill_27, 3, 0, 257); slice_866 = masked_fill_27 = None
slice_scatter_209: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_865, slice_scatter_208, 2, 0, 9223372036854775807); slice_865 = slice_scatter_208 = None
slice_scatter_210: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_864, slice_scatter_209, 1, 0, 256); slice_864 = slice_scatter_209 = None
slice_scatter_211: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_173, slice_scatter_210, 0, 0, 9223372036854775807); transpose_173 = slice_scatter_210 = None
transpose_174: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_211, 2, 1); slice_scatter_211 = None
view_281: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_174, [2, 4, 256, 513]); transpose_174 = None
view_282: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_281, [2, 1, 1024, 513])
transpose_175: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_282, 2, 1); view_282 = None
slice_867: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_175, 0, 0, 9223372036854775807); transpose_175 = None
slice_868: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_867, 1, -256, 9223372036854775807); slice_867 = None
slice_869: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_868, 2, 0, 9223372036854775807); slice_868 = None
slice_870: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_869, 3, -257, 9223372036854775807); slice_869 = None
masked_fill_28: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_870, eq_19, -inf); slice_870 = eq_19 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:597, code: attn_scores += diagonal_mask
view_283: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_8, [2, 12, 1024, 513])
transpose_176: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_283, 2, 1); view_283 = None
view_284: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_268, [2, 12, 1024, 513]); view_268 = None
transpose_177: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_284, 2, 1); view_284 = None
slice_871: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_177, 0, 0, 9223372036854775807)
slice_872: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_871, 1, -256, 9223372036854775807)
slice_873: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_872, 2, 0, 9223372036854775807)
slice_scatter_212: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_873, masked_fill_25, 3, -257, 9223372036854775807); slice_873 = masked_fill_25 = None
slice_scatter_213: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_872, slice_scatter_212, 2, 0, 9223372036854775807); slice_872 = slice_scatter_212 = None
slice_scatter_214: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_871, slice_scatter_213, 1, -256, 9223372036854775807); slice_871 = slice_scatter_213 = None
slice_scatter_215: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_177, slice_scatter_214, 0, 0, 9223372036854775807); transpose_177 = slice_scatter_214 = None
transpose_178: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_215, 2, 1); slice_scatter_215 = None
view_285: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_178, [24, 4, 256, 513]); transpose_178 = None
view_286: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_285, [2, 12, 1024, 513]); view_285 = None
transpose_179: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_286, 2, 1); view_286 = None
view_287: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_281, [2, 1, 1024, 513]); view_281 = None
transpose_180: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_287, 2, 1); view_287 = None
slice_874: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_180, 0, 0, 9223372036854775807)
slice_875: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_874, 1, -256, 9223372036854775807)
slice_876: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_875, 2, 0, 9223372036854775807)
slice_scatter_216: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_876, masked_fill_28, 3, -257, 9223372036854775807); slice_876 = masked_fill_28 = None
slice_scatter_217: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_875, slice_scatter_216, 2, 0, 9223372036854775807); slice_875 = slice_scatter_216 = None
slice_scatter_218: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_874, slice_scatter_217, 1, -256, 9223372036854775807); slice_874 = slice_scatter_217 = None
slice_scatter_219: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_180, slice_scatter_218, 0, 0, 9223372036854775807); transpose_180 = slice_scatter_218 = None
transpose_181: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_219, 2, 1); slice_scatter_219 = None
view_288: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_181, [2, 4, 256, 513]); transpose_181 = None
view_289: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_288, [2, 1, 1024, 513]); view_288 = None
transpose_182: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_289, 2, 1); view_289 = None
add_31: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.add.Tensor(transpose_179, transpose_182); transpose_179 = transpose_182 = None
view_290: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_8, [2, 12, 1024, 513]); new_empty_8 = None
transpose_183: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_290, 2, 1); view_290 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:635, code: attn_probs = nn.functional.softmax(
_softmax_4: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten._softmax.default(add_31, -1, False); add_31 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:646, code: attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
slice_877: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(orig_primals_195, 0, 0, 9223372036854775807)
slice_878: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_877, 1, 0, 9223372036854775807); slice_877 = None
unsqueeze_71: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_878, 2); slice_878 = None
unsqueeze_72: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_71, 3); unsqueeze_71 = None
masked_fill_29: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten.masked_fill.Scalar(_softmax_4, unsqueeze_72, 0.0); _softmax_4 = unsqueeze_72 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:655, code: value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_291: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_30, [1024, 2, 12, 64]); add_30 = None
transpose_184: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_291, 0, 1); view_291 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:883, code: chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
transpose_185: f32[2, 12, 1024, 513], [6303744, 513, 6156, 1] = torch.ops.aten.transpose.int(masked_fill_29, 1, 2); masked_fill_29 = None
clone_41: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.clone.default(transpose_185, memory_format = torch.contiguous_format); transpose_185 = None
_unsafe_view_22: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten._unsafe_view.default(clone_41, [24, 4, 256, 513]); clone_41 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:888, code: value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_186: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_184, 1, 2); transpose_184 = None
view_292: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_186, [24, 1024, 64]); transpose_186 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:891, code: padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1)
constant_pad_nd_18: f32[24, 1536, 64], [98304, 64, 1] = torch.ops.aten.constant_pad_nd.default(view_292, [0, 0, 256, 256], -1.0); view_292 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:902, code: chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride)
as_strided_29: f32[24, 4, 768, 64], [98304, 16384, 64, 1] = torch.ops.aten.as_strided.default(constant_pad_nd_18, [24, 4, 768, 64], [98304, 16384, 64, 1]); constant_pad_nd_18 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:755, code: chunked_hidden_states = nn.functional.pad(
constant_pad_nd_19: f32[24, 4, 256, 770], [788480, 197120, 770, 1] = torch.ops.aten.constant_pad_nd.default(_unsafe_view_22, [0, 257], 0.0); _unsafe_view_22 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:758, code: chunked_hidden_states = chunked_hidden_states.view(
view_293: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.view.default(constant_pad_nd_19, [24, 4, -1]); constant_pad_nd_19 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:761, code: chunked_hidden_states = chunked_hidden_states[
slice_879: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(view_293, 0, 0, 9223372036854775807); view_293 = None
slice_880: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_879, 1, 0, 9223372036854775807); slice_879 = None
slice_881: f32[24, 4, 196864], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_880, 2, 0, -256); slice_880 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:764, code: chunked_hidden_states = chunked_hidden_states.view(
view_294: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.view.default(slice_881, [24, 4, 256, 769]); slice_881 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:767, code: chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
slice_882: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(view_294, 0, 0, 9223372036854775807); view_294 = None
slice_883: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_882, 1, 0, 9223372036854775807)
slice_884: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_883, 2, 0, 9223372036854775807); slice_883 = None
slice_885: f32[24, 4, 256, 768], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_884, 3, 0, -1); slice_884 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:906, code: context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
unsqueeze_73: f32[24, 4, 256, 768, 1], [788480, 197120, 769, 1, 0] = torch.ops.aten.unsqueeze.default(slice_885, 4)
permute_65: f32[24, 4, 256, 1, 768], [788480, 197120, 769, 0, 1] = torch.ops.aten.permute.default(unsqueeze_73, [0, 1, 2, 4, 3]); unsqueeze_73 = None
unsqueeze_74: f32[24, 4, 768, 64, 1], [98304, 16384, 64, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_29, 4); as_strided_29 = None
permute_66: f32[24, 4, 1, 64, 768], [98304, 16384, 0, 1, 64] = torch.ops.aten.permute.default(unsqueeze_74, [0, 1, 4, 3, 2]); unsqueeze_74 = None
permute_67: f32[24, 4, 256, 768, 1], [788480, 197120, 769, 1, 0] = torch.ops.aten.permute.default(permute_65, [0, 1, 2, 4, 3]); permute_65 = None
sym_size_35: Sym(24) = torch.ops.aten.sym_size(slice_882, 0); slice_882 = None
# No stacktrace found for following nodes
mul_36: Sym(96) = sym_size_35 * 4
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:906, code: context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
sym_size_36: Sym(768) = torch.ops.aten.sym_size(slice_885, 3); slice_885 = None
view_295: f32[96, 256, 768], [197120, 769, 1] = torch.ops.aten.view.default(permute_67, [mul_36, 256, sym_size_36]); permute_67 = None
permute_68: f32[24, 4, 768, 64, 1], [98304, 16384, 64, 1, 0] = torch.ops.aten.permute.default(permute_66, [0, 1, 4, 3, 2]); permute_66 = None
clone_42: f32[24, 4, 768, 64, 1], [196608, 49152, 64, 1, 1] = torch.ops.aten.clone.default(permute_68, memory_format = torch.contiguous_format); permute_68 = None
_unsafe_view_23: f32[96, 768, 64], [49152, 64, 1] = torch.ops.aten._unsafe_view.default(clone_42, [mul_36, sym_size_36, 64]); clone_42 = mul_36 = sym_size_36 = None
bmm_9: f32[96, 256, 64], [16384, 64, 1] = torch.ops.aten.bmm.default(view_295, _unsafe_view_23); view_295 = _unsafe_view_23 = None
view_296: f32[24, 4, 256, 1, 64], [65536, 16384, 64, 64, 1] = torch.ops.aten.view.default(bmm_9, [sym_size_35, 4, 256, 1, 64]); bmm_9 = None
permute_69: f32[24, 4, 256, 64, 1], [65536, 16384, 64, 1, 64] = torch.ops.aten.permute.default(view_296, [0, 1, 2, 4, 3])
sym_size_37: Sym(4) = torch.ops.aten.sym_size(view_296, 1); view_296 = None
view_297: f32[24, 4, 256, 64], [65536, 16384, 64, 1] = torch.ops.aten.view.default(permute_69, [sym_size_35, sym_size_37, 256, 64]); permute_69 = sym_size_35 = sym_size_37 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:907, code: return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
view_298: f32[2, 12, 1024, 64], [786432, 65536, 64, 1] = torch.ops.aten.view.default(view_297, [2, 12, 1024, 64]); view_297 = None
transpose_187: f32[2, 1024, 12, 64], [786432, 64, 65536, 1] = torch.ops.aten.transpose.int(view_298, 1, 2)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:674, code: attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
transpose_188: f32[1024, 2, 12, 64], [64, 786432, 65536, 1] = torch.ops.aten.transpose.int(transpose_187, 0, 1); transpose_187 = None
clone_43: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.clone.default(transpose_188, memory_format = torch.contiguous_format); transpose_188 = None
_unsafe_view_24: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten._unsafe_view.default(clone_43, [1024, 2, 768]); clone_43 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:703, code: outputs = (attn_output.transpose(0, 1),)
transpose_189: f32[2, 1024, 768], [768, 1536, 1] = torch.ops.aten.transpose.int(_unsafe_view_24, 0, 1)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1116, code: hidden_states = self.dense(hidden_states)
t_27: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_71); orig_primals_71 = None
clone_44: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.clone.default(transpose_189, memory_format = torch.contiguous_format); transpose_189 = None
sym_size_38: Sym(1024) = torch.ops.aten.sym_size(view_298, 2); view_298 = None
# No stacktrace found for following nodes
mul_37: Sym(2048) = 2 * sym_size_38
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1116, code: hidden_states = self.dense(hidden_states)
sym_size_39: Sym(768) = torch.ops.aten.sym_size(_unsafe_view_24, 2); _unsafe_view_24 = None
view_299: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_44, [mul_37, sym_size_39]); clone_44 = mul_37 = sym_size_39 = None
mm_19: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_299, t_27); view_299 = t_27 = None
view_300: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.view.default(mm_19, [2, sym_size_38, 768]); mm_19 = sym_size_38 = None
add_32: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(view_300, orig_primals_72); orig_primals_72 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1118, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
add_33: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(add_32, getitem_21); add_32 = getitem_21 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1118, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
native_layer_norm_8 = torch.ops.aten.native_layer_norm.default(add_33, [768], orig_primals_73, orig_primals_74, 1e-05); add_33 = orig_primals_73 = orig_primals_74 = None
getitem_24: f32[2, 1024, 768], [786432, 768, 1] = native_layer_norm_8[0]
getitem_25: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_8[1]
getitem_26: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_8[2]; native_layer_norm_8 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1182, code: hidden_states = self.dense(hidden_states)
t_28: f32[768, 3072], [1, 768] = torch.ops.aten.t.default(orig_primals_75); orig_primals_75 = None
sym_size_40: Sym(1024) = torch.ops.aten.sym_size(view_300, 1); view_300 = None
# No stacktrace found for following nodes
mul_38: Sym(2048) = 2 * sym_size_40
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1182, code: hidden_states = self.dense(hidden_states)
view_301: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(getitem_24, [mul_38, 768]); mul_38 = None
addmm_8: f32[2048, 3072], [3072, 1] = torch.ops.aten.addmm.default(orig_primals_76, view_301, t_28); orig_primals_76 = view_301 = t_28 = None
view_302: f32[2, 1024, 3072], [3145728, 3072, 1] = torch.ops.aten.view.default(addmm_8, [2, sym_size_40, 3072]); addmm_8 = sym_size_40 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/activations.py:56, code: return self.act(input)
gelu_4: f32[2, 1024, 3072], [3145728, 3072, 1] = torch.ops.aten.gelu.default(view_302)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1196, code: hidden_states = self.dense(hidden_states)
t_29: f32[3072, 768], [1, 3072] = torch.ops.aten.t.default(orig_primals_77); orig_primals_77 = None
sym_size_41: Sym(1024) = torch.ops.aten.sym_size(view_302, 1); view_302 = None
# No stacktrace found for following nodes
mul_39: Sym(2048) = 2 * sym_size_41
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1196, code: hidden_states = self.dense(hidden_states)
view_303: f32[2048, 3072], [3072, 1] = torch.ops.aten.view.default(gelu_4, [mul_39, 3072]); gelu_4 = mul_39 = None
addmm_9: f32[2048, 768], [768, 1] = torch.ops.aten.addmm.default(orig_primals_78, view_303, t_29); orig_primals_78 = view_303 = t_29 = None
view_304: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.view.default(addmm_9, [2, sym_size_41, 768]); addmm_9 = sym_size_41 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1198, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
add_34: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(view_304, getitem_24); getitem_24 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1198, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
native_layer_norm_9 = torch.ops.aten.native_layer_norm.default(add_34, [768], orig_primals_79, orig_primals_80, 1e-05); add_34 = orig_primals_79 = orig_primals_80 = None
getitem_27: f32[2, 1024, 768], [786432, 768, 1] = native_layer_norm_9[0]
getitem_28: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_9[1]
getitem_29: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_9[2]; native_layer_norm_9 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:562, code: hidden_states = hidden_states.transpose(0, 1)
transpose_190: f32[1024, 2, 768], [768, 786432, 1] = torch.ops.aten.transpose.int(getitem_27, 0, 1)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
t_30: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_81); orig_primals_81 = None
clone_45: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_190, memory_format = torch.contiguous_format)
sym_size_42: Sym(1024) = torch.ops.aten.sym_size(view_304, 1); view_304 = None
# No stacktrace found for following nodes
mul_40: Sym(2048) = sym_size_42 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
view_305: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_45, [mul_40, 768]); clone_45 = mul_40 = None
mm_20: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_305, t_30); view_305 = t_30 = None
view_306: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_20, [sym_size_42, 2, 768]); mm_20 = None
add_35: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_306, orig_primals_82); view_306 = orig_primals_82 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
t_31: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_83); orig_primals_83 = None
clone_46: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_190, memory_format = torch.contiguous_format)
# No stacktrace found for following nodes
mul_41: Sym(2048) = sym_size_42 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
view_307: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_46, [mul_41, 768]); clone_46 = mul_41 = None
mm_21: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_307, t_31); view_307 = t_31 = None
view_308: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_21, [sym_size_42, 2, 768]); mm_21 = None
add_36: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_308, orig_primals_84); view_308 = orig_primals_84 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
t_32: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_85); orig_primals_85 = None
clone_47: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_190, memory_format = torch.contiguous_format); transpose_190 = None
# No stacktrace found for following nodes
mul_42: Sym(2048) = sym_size_42 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
view_309: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_47, [mul_42, 768]); clone_47 = mul_42 = None
mm_22: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_309, t_32); view_309 = t_32 = None
view_310: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_22, [sym_size_42, 2, 768]); mm_22 = sym_size_42 = None
add_37: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_310, orig_primals_86); view_310 = orig_primals_86 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:575, code: query_vectors /= math.sqrt(self.head_dim)
div_5: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.div.Tensor(add_35, 8.0); add_35 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:577, code: query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_311: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div_5, [1024, 2, 12, 64])
transpose_191: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_311, 0, 1); view_311 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:578, code: key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_312: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_36, [1024, 2, 12, 64]); add_36 = None
transpose_192: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_312, 0, 1); view_312 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_193: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_191, 1, 2); transpose_191 = None
view_313: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_193, [24, 1024, 64]); transpose_193 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_194: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_192, 1, 2); transpose_192 = None
view_314: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_194, [24, 1024, 64]); transpose_194 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_315: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_313, [24, 2, 512, 64]); view_313 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_30: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_315, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_315 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_316: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_314, [24, 2, 512, 64]); view_314 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_31: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_316, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_316 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze_75: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_30, 4); as_strided_30 = None
permute_70: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze_75, [0, 1, 2, 4, 3]); unsqueeze_75 = None
unsqueeze_76: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_31, 4); as_strided_31 = None
permute_71: f32[24, 3, 1, 512, 64], [64, 393216, 0, 1536, 1] = torch.ops.aten.permute.default(unsqueeze_76, [0, 1, 4, 2, 3]); unsqueeze_76 = None
permute_72: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute_70, [0, 1, 2, 4, 3]); permute_70 = None
view_317: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div_5, [1024, 2, 12, 64]); div_5 = None
transpose_195: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_317, 0, 1); view_317 = None
transpose_196: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_195, 1, 2); transpose_195 = None
view_318: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_196, [24, 1024, 64]); transpose_196 = None
view_319: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_318, [24, 2, 512, 64]); view_318 = None
as_strided_32: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_319, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_319 = None
unsqueeze_77: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_32, 4); as_strided_32 = None
permute_73: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze_77, [0, 1, 2, 4, 3]); unsqueeze_77 = None
permute_74: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute_73, [0, 1, 2, 4, 3]); permute_73 = None
clone_48: f32[24, 3, 512, 64, 1], [98304, 32768, 64, 1, 1] = torch.ops.aten.clone.default(permute_74, memory_format = torch.contiguous_format); permute_74 = None
_unsafe_view_25: f32[72, 512, 64], [32768, 64, 1] = torch.ops.aten._unsafe_view.default(clone_48, [72, 512, 64]); clone_48 = None
permute_75: f32[24, 3, 64, 512, 1], [64, 393216, 1, 1536, 0] = torch.ops.aten.permute.default(permute_71, [0, 1, 4, 3, 2]); permute_71 = None
clone_49: f32[24, 3, 64, 512, 1], [98304, 32768, 512, 1, 1] = torch.ops.aten.clone.default(permute_75, memory_format = torch.contiguous_format); permute_75 = None
_unsafe_view_26: f32[72, 64, 512], [32768, 512, 1] = torch.ops.aten._unsafe_view.default(clone_49, [72, 64, 512]); clone_49 = None
bmm_10: f32[72, 512, 512], [262144, 512, 1] = torch.ops.aten.bmm.default(_unsafe_view_25, _unsafe_view_26); _unsafe_view_25 = _unsafe_view_26 = None
view_320: f32[24, 3, 512, 1, 512], [786432, 262144, 512, 512, 1] = torch.ops.aten.view.default(bmm_10, [24, 3, 512, 1, 512]); bmm_10 = None
permute_76: f32[24, 3, 512, 512, 1], [786432, 262144, 512, 1, 512] = torch.ops.aten.permute.default(view_320, [0, 1, 2, 4, 3]); view_320 = None
view_321: f32[24, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(permute_76, [24, 3, 512, 512]); permute_76 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd_20: f32[24, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_321, [0, 0, 0, 1], 0.0); view_321 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_322: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd_20, [24, 3, 512, 513]); constant_pad_nd_20 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty_10: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_322, [24, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_886: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_322, 0, 0, 9223372036854775807)
slice_887: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_886, 1, 0, 9223372036854775807); slice_886 = None
slice_888: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_887, 2, 0, 256); slice_887 = None
slice_889: f32[24, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_888, 3, 0, 257); slice_888 = None
slice_890: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_10, 0, 0, 9223372036854775807)
slice_891: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_890, 1, 0, -1); slice_890 = None
slice_892: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_891, 2, 0, 9223372036854775807); slice_891 = None
slice_893: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_892, 3, 256, 9223372036854775807); slice_892 = None
slice_894: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_10, 0, 0, 9223372036854775807)
slice_895: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_894, 1, 0, -1); slice_894 = None
slice_896: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_895, 2, 0, 9223372036854775807); slice_895 = None
slice_897: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_896, 3, 256, 9223372036854775807); slice_896 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_898: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_322, 0, 0, 9223372036854775807)
select_90: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_898, 1, -1); slice_898 = None
slice_899: f32[24, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_90, 1, 256, 9223372036854775807); select_90 = None
slice_900: f32[24, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_899, 2, 0, 257); slice_899 = None
slice_901: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_10, 0, 0, 9223372036854775807)
select_91: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_901, 1, -1); slice_901 = None
slice_902: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_91, 1, 0, 9223372036854775807); select_91 = None
slice_903: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_902, 2, 256, 9223372036854775807); slice_902 = None
slice_904: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_10, 0, 0, 9223372036854775807)
slice_905: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_904, 1, 0, -1)
slice_906: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_905, 2, 0, 9223372036854775807)
slice_scatter_220: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_906, slice_889, 3, 256, 9223372036854775807); slice_906 = slice_889 = None
slice_scatter_221: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_905, slice_scatter_220, 2, 0, 9223372036854775807); slice_905 = slice_scatter_220 = None
slice_scatter_222: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_904, slice_scatter_221, 1, 0, -1); slice_904 = slice_scatter_221 = None
slice_scatter_223: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty_10, slice_scatter_222, 0, 0, 9223372036854775807); slice_scatter_222 = None
slice_907: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_223, 0, 0, 9223372036854775807)
select_92: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_907, 1, -1); slice_907 = None
slice_908: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_92, 1, 0, 9223372036854775807); select_92 = None
slice_909: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_908, 2, 256, 9223372036854775807); slice_908 = None
slice_910: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_10, 0, 0, 9223372036854775807)
select_93: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_910, 1, -1); slice_910 = None
slice_911: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_93, 1, 0, 9223372036854775807); select_93 = None
slice_912: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_911, 2, 256, 9223372036854775807); slice_911 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_913: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_322, 0, 0, 9223372036854775807)
slice_914: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_913, 1, 0, 9223372036854775807); slice_913 = None
slice_915: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_914, 2, -257, -1); slice_914 = None
slice_916: f32[24, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_915, 3, 257, 9223372036854775807); slice_915 = None
slice_917: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_10, 0, 0, 9223372036854775807)
slice_918: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_917, 1, 1, 9223372036854775807); slice_917 = None
slice_919: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_918, 2, 0, 9223372036854775807); slice_918 = None
slice_920: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_919, 3, 0, 256); slice_919 = None
slice_921: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_223, 0, 0, 9223372036854775807)
select_94: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_921, 1, -1)
slice_922: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_94, 1, 0, 9223372036854775807)
slice_scatter_224: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_922, slice_900, 2, 256, 9223372036854775807); slice_922 = slice_900 = None
slice_scatter_225: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_94, slice_scatter_224, 1, 0, 9223372036854775807); select_94 = slice_scatter_224 = None
select_scatter_20: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_921, slice_scatter_225, 1, -1); slice_921 = slice_scatter_225 = None
slice_scatter_226: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_223, select_scatter_20, 0, 0, 9223372036854775807); slice_scatter_223 = select_scatter_20 = None
slice_923: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_226, 0, 0, 9223372036854775807)
slice_924: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_923, 1, 1, 9223372036854775807); slice_923 = None
slice_925: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_924, 2, 0, 9223372036854775807); slice_924 = None
slice_926: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_925, 3, 0, 256); slice_925 = None
slice_927: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_10, 0, 0, 9223372036854775807)
slice_928: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_927, 1, 1, 9223372036854775807); slice_927 = None
slice_929: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_928, 2, 0, 9223372036854775807); slice_928 = None
slice_930: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_929, 3, 0, 256); slice_929 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_931: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_322, 0, 0, 9223372036854775807); view_322 = None
select_95: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_931, 1, 0); slice_931 = None
slice_932: f32[24, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_95, 1, 0, 255); select_95 = None
slice_933: f32[24, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_932, 2, -255, 9223372036854775807); slice_932 = None
slice_934: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_10, 0, 0, 9223372036854775807)
select_96: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_934, 1, 0); slice_934 = None
slice_935: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_96, 1, 1, 256); select_96 = None
slice_936: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_935, 2, 1, 256); slice_935 = None
slice_937: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_226, 0, 0, 9223372036854775807)
slice_938: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_937, 1, 1, 9223372036854775807)
slice_939: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_938, 2, 0, 9223372036854775807)
slice_scatter_227: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_939, slice_916, 3, 0, 256); slice_939 = slice_916 = None
slice_scatter_228: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_938, slice_scatter_227, 2, 0, 9223372036854775807); slice_938 = slice_scatter_227 = None
slice_scatter_229: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_937, slice_scatter_228, 1, 1, 9223372036854775807); slice_937 = slice_scatter_228 = None
slice_scatter_230: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_226, slice_scatter_229, 0, 0, 9223372036854775807); slice_scatter_226 = slice_scatter_229 = None
slice_940: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_230, 0, 0, 9223372036854775807)
select_97: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_940, 1, 0); slice_940 = None
slice_941: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_97, 1, 1, 256); select_97 = None
slice_942: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_941, 2, 1, 256); slice_941 = None
slice_943: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_10, 0, 0, 9223372036854775807)
select_98: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_943, 1, 0); slice_943 = None
slice_944: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_98, 1, 1, 256); select_98 = None
slice_945: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_944, 2, 1, 256); slice_944 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_323: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_10, [2, 12, 1024, 513])
transpose_197: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_323, 2, 1); view_323 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_946: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_230, 0, 0, 9223372036854775807)
select_99: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_946, 1, 0)
slice_947: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_99, 1, 1, 256)
slice_scatter_231: f32[24, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_947, slice_933, 2, 1, 256); slice_947 = slice_933 = None
slice_scatter_232: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_99, slice_scatter_231, 1, 1, 256); select_99 = slice_scatter_231 = None
select_scatter_21: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_946, slice_scatter_232, 1, 0); slice_946 = slice_scatter_232 = None
slice_scatter_233: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_230, select_scatter_21, 0, 0, 9223372036854775807); slice_scatter_230 = select_scatter_21 = None
view_324: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_233, [2, 12, 1024, 513])
transpose_198: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_324, 2, 1); view_324 = None
new_ones_15: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_198, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril_10: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones_15); new_ones_15 = None
flip_20: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril_10, [0]); tril_10 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_78: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip_20, 0); flip_20 = None
slice_948: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_78, 1, 0, 9223372036854775807); unsqueeze_78 = None
unsqueeze_79: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_948, 2); slice_948 = None
slice_949: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_79, 3, 0, 9223372036854775807); unsqueeze_79 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_21: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_949, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_950: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_198, 0, 0, 9223372036854775807)
slice_951: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_950, 1, 0, 256); slice_950 = None
slice_952: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_951, 2, 0, 9223372036854775807); slice_951 = None
slice_953: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_952, 3, 0, 257); slice_952 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand_20: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_949, [2, 256, 12, 257]); slice_949 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_20: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand_20, 1); expand_20 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_325: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_233, [2, 12, 1024, 513])
transpose_199: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_325, 2, 1); view_325 = None
slice_954: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_199, 0, 0, 9223372036854775807); transpose_199 = None
slice_955: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_954, 1, 0, 256); slice_954 = None
slice_956: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_955, 2, 0, 9223372036854775807); slice_955 = None
slice_957: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_956, 3, 0, 257); slice_956 = None
masked_fill_30: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_957, eq_20, -inf); slice_957 = eq_20 = None
view_326: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_10, [2, 12, 1024, 513])
transpose_200: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_326, 2, 1); view_326 = None
slice_958: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_200, 0, 0, 9223372036854775807); transpose_200 = None
slice_959: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_958, 1, 0, 256); slice_958 = None
slice_960: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_959, 2, 0, 9223372036854775807); slice_959 = None
slice_961: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_960, 3, 0, 257); slice_960 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
view_327: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_10, [2, 12, 1024, 513])
transpose_201: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_327, 2, 1); view_327 = None
slice_962: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_198, 0, 0, 9223372036854775807); transpose_198 = None
slice_963: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_962, 1, -256, 9223372036854775807); slice_962 = None
slice_964: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_963, 2, 0, 9223372036854775807); slice_963 = None
slice_965: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_964, 3, -257, 9223372036854775807); slice_964 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_21: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_21, [2, 256, 12, 257]); flip_21 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_21: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand_21, 1); expand_21 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_328: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_233, [2, 12, 1024, 513]); slice_scatter_233 = None
transpose_202: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_328, 2, 1); view_328 = None
slice_966: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_202, 0, 0, 9223372036854775807)
slice_967: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_966, 1, 0, 256)
slice_968: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_967, 2, 0, 9223372036854775807)
slice_scatter_234: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_968, masked_fill_30, 3, 0, 257); slice_968 = masked_fill_30 = None
slice_scatter_235: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_967, slice_scatter_234, 2, 0, 9223372036854775807); slice_967 = slice_scatter_234 = None
slice_scatter_236: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_966, slice_scatter_235, 1, 0, 256); slice_966 = slice_scatter_235 = None
slice_scatter_237: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_202, slice_scatter_236, 0, 0, 9223372036854775807); transpose_202 = slice_scatter_236 = None
transpose_203: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_237, 2, 1); slice_scatter_237 = None
view_329: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_203, [24, 4, 256, 513]); transpose_203 = None
view_330: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_329, [2, 12, 1024, 513])
transpose_204: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_330, 2, 1); view_330 = None
slice_969: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_204, 0, 0, 9223372036854775807); transpose_204 = None
slice_970: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_969, 1, -256, 9223372036854775807); slice_969 = None
slice_971: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_970, 2, 0, 9223372036854775807); slice_970 = None
slice_972: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_971, 3, -257, 9223372036854775807); slice_971 = None
masked_fill_31: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_972, eq_21, -inf); slice_972 = eq_21 = None
view_331: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_10, [2, 12, 1024, 513])
transpose_205: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_331, 2, 1); view_331 = None
slice_973: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_205, 0, 0, 9223372036854775807); transpose_205 = None
slice_974: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_973, 1, -256, 9223372036854775807); slice_973 = None
slice_975: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_974, 2, 0, 9223372036854775807); slice_974 = None
slice_976: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_975, 3, -257, 9223372036854775807); slice_975 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
ne_5: b8[2, 1024], [1024, 1] = torch.ops.aten.ne.Scalar(orig_primals_194, 0)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:585, code: remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]
slice_977: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(ne_5, 0, 0, 9223372036854775807); ne_5 = None
slice_978: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_977, 1, 0, 9223372036854775807); slice_977 = None
unsqueeze_80: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_978, 2); slice_978 = None
unsqueeze_81: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_80, 3); unsqueeze_80 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:588, code: float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
_to_copy_5: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten._to_copy.default(unsqueeze_81, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))
masked_fill_32: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.masked_fill.Scalar(_to_copy_5, unsqueeze_81, -10000.0); _to_copy_5 = unsqueeze_81 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:593, code: float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
new_ones_16: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.new_ones.default(masked_fill_32, [2, 1024, 1, 1], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_206: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(new_ones_16, 1, 2); new_ones_16 = None
view_332: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_206, [2, 1024, 1]); transpose_206 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_207: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(masked_fill_32, 1, 2); masked_fill_32 = None
view_333: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_207, [2, 1024, 1]); transpose_207 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_334: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_332, [2, 2, 512, 1]); view_332 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_33: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_334, [2, 3, 512, 1], [1024, 256, 1, 1]); view_334 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_335: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_333, [2, 2, 512, 1]); view_333 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_34: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_335, [2, 3, 512, 1], [1024, 256, 1, 1]); view_335 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze_82: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_33, 4); as_strided_33 = None
permute_77: f32[2, 3, 512, 1, 1], [1024, 256, 1, 0, 1] = torch.ops.aten.permute.default(unsqueeze_82, [0, 1, 2, 4, 3]); unsqueeze_82 = None
unsqueeze_83: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_34, 4); as_strided_34 = None
permute_78: f32[2, 3, 1, 512, 1], [1024, 256, 0, 1, 1] = torch.ops.aten.permute.default(unsqueeze_83, [0, 1, 4, 2, 3]); unsqueeze_83 = None
mul_43: f32[2, 3, 512, 512, 1], [786432, 262144, 512, 1, 1] = torch.ops.aten.mul.Tensor(permute_77, permute_78); permute_77 = permute_78 = None
view_336: f32[2, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(mul_43, [2, 3, 512, 512]); mul_43 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd_21: f32[2, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_336, [0, 0, 0, 1], 0.0); view_336 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_337: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd_21, [2, 3, 512, 513]); constant_pad_nd_21 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty_11: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_337, [2, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_979: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_337, 0, 0, 9223372036854775807)
slice_980: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_979, 1, 0, 9223372036854775807); slice_979 = None
slice_981: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_980, 2, 0, 256); slice_980 = None
slice_982: f32[2, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_981, 3, 0, 257); slice_981 = None
slice_983: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_11, 0, 0, 9223372036854775807)
slice_984: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_983, 1, 0, -1); slice_983 = None
slice_985: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_984, 2, 0, 9223372036854775807); slice_984 = None
slice_986: f32[2, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_985, 3, 256, 9223372036854775807); slice_985 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_987: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_337, 0, 0, 9223372036854775807)
select_100: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_987, 1, -1); slice_987 = None
slice_988: f32[2, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_100, 1, 256, 9223372036854775807); select_100 = None
slice_989: f32[2, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_988, 2, 0, 257); slice_988 = None
slice_990: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_11, 0, 0, 9223372036854775807)
select_101: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_990, 1, -1); slice_990 = None
slice_991: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_101, 1, 0, 9223372036854775807); select_101 = None
slice_992: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_991, 2, 256, 9223372036854775807); slice_991 = None
slice_993: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_11, 0, 0, 9223372036854775807)
slice_994: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_993, 1, 0, -1)
slice_995: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_994, 2, 0, 9223372036854775807)
slice_scatter_238: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_995, slice_982, 3, 256, 9223372036854775807); slice_995 = slice_982 = None
slice_scatter_239: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_994, slice_scatter_238, 2, 0, 9223372036854775807); slice_994 = slice_scatter_238 = None
slice_scatter_240: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_993, slice_scatter_239, 1, 0, -1); slice_993 = slice_scatter_239 = None
slice_scatter_241: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty_11, slice_scatter_240, 0, 0, 9223372036854775807); slice_scatter_240 = None
slice_996: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_241, 0, 0, 9223372036854775807)
select_102: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_996, 1, -1); slice_996 = None
slice_997: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_102, 1, 0, 9223372036854775807); select_102 = None
slice_998: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_997, 2, 256, 9223372036854775807); slice_997 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_999: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_337, 0, 0, 9223372036854775807)
slice_1000: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_999, 1, 0, 9223372036854775807); slice_999 = None
slice_1001: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1000, 2, -257, -1); slice_1000 = None
slice_1002: f32[2, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1001, 3, 257, 9223372036854775807); slice_1001 = None
slice_1003: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_11, 0, 0, 9223372036854775807)
slice_1004: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1003, 1, 1, 9223372036854775807); slice_1003 = None
slice_1005: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1004, 2, 0, 9223372036854775807); slice_1004 = None
slice_1006: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1005, 3, 0, 256); slice_1005 = None
slice_1007: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_241, 0, 0, 9223372036854775807)
select_103: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1007, 1, -1)
slice_1008: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_103, 1, 0, 9223372036854775807)
slice_scatter_242: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1008, slice_989, 2, 256, 9223372036854775807); slice_1008 = slice_989 = None
slice_scatter_243: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_103, slice_scatter_242, 1, 0, 9223372036854775807); select_103 = slice_scatter_242 = None
select_scatter_22: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_1007, slice_scatter_243, 1, -1); slice_1007 = slice_scatter_243 = None
slice_scatter_244: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_241, select_scatter_22, 0, 0, 9223372036854775807); slice_scatter_241 = select_scatter_22 = None
slice_1009: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_244, 0, 0, 9223372036854775807)
slice_1010: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1009, 1, 1, 9223372036854775807); slice_1009 = None
slice_1011: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1010, 2, 0, 9223372036854775807); slice_1010 = None
slice_1012: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1011, 3, 0, 256); slice_1011 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_1013: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_337, 0, 0, 9223372036854775807); view_337 = None
select_104: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_1013, 1, 0); slice_1013 = None
slice_1014: f32[2, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_104, 1, 0, 255); select_104 = None
slice_1015: f32[2, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_1014, 2, -255, 9223372036854775807); slice_1014 = None
slice_1016: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_11, 0, 0, 9223372036854775807)
select_105: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1016, 1, 0); slice_1016 = None
slice_1017: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_105, 1, 1, 256); select_105 = None
slice_1018: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1017, 2, 1, 256); slice_1017 = None
slice_1019: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_244, 0, 0, 9223372036854775807)
slice_1020: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1019, 1, 1, 9223372036854775807)
slice_1021: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1020, 2, 0, 9223372036854775807)
slice_scatter_245: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1021, slice_1002, 3, 0, 256); slice_1021 = slice_1002 = None
slice_scatter_246: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1020, slice_scatter_245, 2, 0, 9223372036854775807); slice_1020 = slice_scatter_245 = None
slice_scatter_247: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1019, slice_scatter_246, 1, 1, 9223372036854775807); slice_1019 = slice_scatter_246 = None
slice_scatter_248: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_244, slice_scatter_247, 0, 0, 9223372036854775807); slice_scatter_244 = slice_scatter_247 = None
slice_1022: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_248, 0, 0, 9223372036854775807)
select_106: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1022, 1, 0); slice_1022 = None
slice_1023: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_106, 1, 1, 256); select_106 = None
slice_1024: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1023, 2, 1, 256); slice_1023 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_338: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_11, [2, 1, 1024, 513]); new_empty_11 = None
transpose_208: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_338, 2, 1); view_338 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_1025: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_248, 0, 0, 9223372036854775807)
select_107: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1025, 1, 0)
slice_1026: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_107, 1, 1, 256)
slice_scatter_249: f32[2, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1026, slice_1015, 2, 1, 256); slice_1026 = slice_1015 = None
slice_scatter_250: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_107, slice_scatter_249, 1, 1, 256); select_107 = slice_scatter_249 = None
select_scatter_23: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_1025, slice_scatter_250, 1, 0); slice_1025 = slice_scatter_250 = None
slice_scatter_251: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_248, select_scatter_23, 0, 0, 9223372036854775807); slice_scatter_248 = select_scatter_23 = None
view_339: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_251, [2, 1, 1024, 513])
transpose_209: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_339, 2, 1); view_339 = None
new_ones_17: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_209, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril_11: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones_17); new_ones_17 = None
flip_22: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril_11, [0]); tril_11 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_84: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip_22, 0); flip_22 = None
slice_1027: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_84, 1, 0, 9223372036854775807); unsqueeze_84 = None
unsqueeze_85: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_1027, 2); slice_1027 = None
slice_1028: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_85, 3, 0, 9223372036854775807); unsqueeze_85 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_23: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_1028, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_1029: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_209, 0, 0, 9223372036854775807)
slice_1030: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1029, 1, 0, 256); slice_1029 = None
slice_1031: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1030, 2, 0, 9223372036854775807); slice_1030 = None
slice_1032: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1031, 3, 0, 257); slice_1031 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand_22: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_1028, [2, 256, 1, 257]); slice_1028 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_22: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_22, 1); expand_22 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_340: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_251, [2, 1, 1024, 513])
transpose_210: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_340, 2, 1); view_340 = None
slice_1033: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_210, 0, 0, 9223372036854775807); transpose_210 = None
slice_1034: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1033, 1, 0, 256); slice_1033 = None
slice_1035: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1034, 2, 0, 9223372036854775807); slice_1034 = None
slice_1036: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1035, 3, 0, 257); slice_1035 = None
masked_fill_33: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_1036, eq_22, -inf); slice_1036 = eq_22 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
slice_1037: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_209, 0, 0, 9223372036854775807); transpose_209 = None
slice_1038: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1037, 1, -256, 9223372036854775807); slice_1037 = None
slice_1039: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1038, 2, 0, 9223372036854775807); slice_1038 = None
slice_1040: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1039, 3, -257, 9223372036854775807); slice_1039 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_23: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_23, [2, 256, 1, 257]); flip_23 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_23: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_23, 1); expand_23 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_341: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_251, [2, 1, 1024, 513]); slice_scatter_251 = None
transpose_211: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_341, 2, 1); view_341 = None
slice_1041: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_211, 0, 0, 9223372036854775807)
slice_1042: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1041, 1, 0, 256)
slice_1043: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1042, 2, 0, 9223372036854775807)
slice_scatter_252: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1043, masked_fill_33, 3, 0, 257); slice_1043 = masked_fill_33 = None
slice_scatter_253: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1042, slice_scatter_252, 2, 0, 9223372036854775807); slice_1042 = slice_scatter_252 = None
slice_scatter_254: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_1041, slice_scatter_253, 1, 0, 256); slice_1041 = slice_scatter_253 = None
slice_scatter_255: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_211, slice_scatter_254, 0, 0, 9223372036854775807); transpose_211 = slice_scatter_254 = None
transpose_212: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_255, 2, 1); slice_scatter_255 = None
view_342: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_212, [2, 4, 256, 513]); transpose_212 = None
view_343: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_342, [2, 1, 1024, 513])
transpose_213: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_343, 2, 1); view_343 = None
slice_1044: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_213, 0, 0, 9223372036854775807); transpose_213 = None
slice_1045: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1044, 1, -256, 9223372036854775807); slice_1044 = None
slice_1046: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1045, 2, 0, 9223372036854775807); slice_1045 = None
slice_1047: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1046, 3, -257, 9223372036854775807); slice_1046 = None
masked_fill_34: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_1047, eq_23, -inf); slice_1047 = eq_23 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:597, code: attn_scores += diagonal_mask
view_344: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_10, [2, 12, 1024, 513])
transpose_214: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_344, 2, 1); view_344 = None
view_345: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_329, [2, 12, 1024, 513]); view_329 = None
transpose_215: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_345, 2, 1); view_345 = None
slice_1048: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_215, 0, 0, 9223372036854775807)
slice_1049: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1048, 1, -256, 9223372036854775807)
slice_1050: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1049, 2, 0, 9223372036854775807)
slice_scatter_256: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1050, masked_fill_31, 3, -257, 9223372036854775807); slice_1050 = masked_fill_31 = None
slice_scatter_257: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1049, slice_scatter_256, 2, 0, 9223372036854775807); slice_1049 = slice_scatter_256 = None
slice_scatter_258: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_1048, slice_scatter_257, 1, -256, 9223372036854775807); slice_1048 = slice_scatter_257 = None
slice_scatter_259: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_215, slice_scatter_258, 0, 0, 9223372036854775807); transpose_215 = slice_scatter_258 = None
transpose_216: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_259, 2, 1); slice_scatter_259 = None
view_346: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_216, [24, 4, 256, 513]); transpose_216 = None
view_347: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_346, [2, 12, 1024, 513]); view_346 = None
transpose_217: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_347, 2, 1); view_347 = None
view_348: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_342, [2, 1, 1024, 513]); view_342 = None
transpose_218: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_348, 2, 1); view_348 = None
slice_1051: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_218, 0, 0, 9223372036854775807)
slice_1052: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1051, 1, -256, 9223372036854775807)
slice_1053: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1052, 2, 0, 9223372036854775807)
slice_scatter_260: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1053, masked_fill_34, 3, -257, 9223372036854775807); slice_1053 = masked_fill_34 = None
slice_scatter_261: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1052, slice_scatter_260, 2, 0, 9223372036854775807); slice_1052 = slice_scatter_260 = None
slice_scatter_262: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_1051, slice_scatter_261, 1, -256, 9223372036854775807); slice_1051 = slice_scatter_261 = None
slice_scatter_263: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_218, slice_scatter_262, 0, 0, 9223372036854775807); transpose_218 = slice_scatter_262 = None
transpose_219: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_263, 2, 1); slice_scatter_263 = None
view_349: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_219, [2, 4, 256, 513]); transpose_219 = None
view_350: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_349, [2, 1, 1024, 513]); view_349 = None
transpose_220: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_350, 2, 1); view_350 = None
add_38: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.add.Tensor(transpose_217, transpose_220); transpose_217 = transpose_220 = None
view_351: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_10, [2, 12, 1024, 513]); new_empty_10 = None
transpose_221: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_351, 2, 1); view_351 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:635, code: attn_probs = nn.functional.softmax(
_softmax_5: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten._softmax.default(add_38, -1, False); add_38 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:646, code: attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
slice_1054: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(orig_primals_195, 0, 0, 9223372036854775807)
slice_1055: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_1054, 1, 0, 9223372036854775807); slice_1054 = None
unsqueeze_86: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_1055, 2); slice_1055 = None
unsqueeze_87: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_86, 3); unsqueeze_86 = None
masked_fill_35: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten.masked_fill.Scalar(_softmax_5, unsqueeze_87, 0.0); _softmax_5 = unsqueeze_87 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:655, code: value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_352: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_37, [1024, 2, 12, 64]); add_37 = None
transpose_222: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_352, 0, 1); view_352 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:883, code: chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
transpose_223: f32[2, 12, 1024, 513], [6303744, 513, 6156, 1] = torch.ops.aten.transpose.int(masked_fill_35, 1, 2); masked_fill_35 = None
clone_50: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.clone.default(transpose_223, memory_format = torch.contiguous_format); transpose_223 = None
_unsafe_view_27: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten._unsafe_view.default(clone_50, [24, 4, 256, 513]); clone_50 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:888, code: value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_224: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_222, 1, 2); transpose_222 = None
view_353: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_224, [24, 1024, 64]); transpose_224 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:891, code: padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1)
constant_pad_nd_22: f32[24, 1536, 64], [98304, 64, 1] = torch.ops.aten.constant_pad_nd.default(view_353, [0, 0, 256, 256], -1.0); view_353 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:902, code: chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride)
as_strided_35: f32[24, 4, 768, 64], [98304, 16384, 64, 1] = torch.ops.aten.as_strided.default(constant_pad_nd_22, [24, 4, 768, 64], [98304, 16384, 64, 1]); constant_pad_nd_22 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:755, code: chunked_hidden_states = nn.functional.pad(
constant_pad_nd_23: f32[24, 4, 256, 770], [788480, 197120, 770, 1] = torch.ops.aten.constant_pad_nd.default(_unsafe_view_27, [0, 257], 0.0); _unsafe_view_27 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:758, code: chunked_hidden_states = chunked_hidden_states.view(
view_354: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.view.default(constant_pad_nd_23, [24, 4, -1]); constant_pad_nd_23 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:761, code: chunked_hidden_states = chunked_hidden_states[
slice_1056: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(view_354, 0, 0, 9223372036854775807); view_354 = None
slice_1057: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_1056, 1, 0, 9223372036854775807); slice_1056 = None
slice_1058: f32[24, 4, 196864], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_1057, 2, 0, -256); slice_1057 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:764, code: chunked_hidden_states = chunked_hidden_states.view(
view_355: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.view.default(slice_1058, [24, 4, 256, 769]); slice_1058 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:767, code: chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
slice_1059: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(view_355, 0, 0, 9223372036854775807); view_355 = None
slice_1060: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_1059, 1, 0, 9223372036854775807)
slice_1061: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_1060, 2, 0, 9223372036854775807); slice_1060 = None
slice_1062: f32[24, 4, 256, 768], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_1061, 3, 0, -1); slice_1061 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:906, code: context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
unsqueeze_88: f32[24, 4, 256, 768, 1], [788480, 197120, 769, 1, 0] = torch.ops.aten.unsqueeze.default(slice_1062, 4)
permute_79: f32[24, 4, 256, 1, 768], [788480, 197120, 769, 0, 1] = torch.ops.aten.permute.default(unsqueeze_88, [0, 1, 2, 4, 3]); unsqueeze_88 = None
unsqueeze_89: f32[24, 4, 768, 64, 1], [98304, 16384, 64, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_35, 4); as_strided_35 = None
permute_80: f32[24, 4, 1, 64, 768], [98304, 16384, 0, 1, 64] = torch.ops.aten.permute.default(unsqueeze_89, [0, 1, 4, 3, 2]); unsqueeze_89 = None
permute_81: f32[24, 4, 256, 768, 1], [788480, 197120, 769, 1, 0] = torch.ops.aten.permute.default(permute_79, [0, 1, 2, 4, 3]); permute_79 = None
sym_size_43: Sym(24) = torch.ops.aten.sym_size(slice_1059, 0); slice_1059 = None
# No stacktrace found for following nodes
mul_44: Sym(96) = sym_size_43 * 4
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:906, code: context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
sym_size_44: Sym(768) = torch.ops.aten.sym_size(slice_1062, 3); slice_1062 = None
view_356: f32[96, 256, 768], [197120, 769, 1] = torch.ops.aten.view.default(permute_81, [mul_44, 256, sym_size_44]); permute_81 = None
permute_82: f32[24, 4, 768, 64, 1], [98304, 16384, 64, 1, 0] = torch.ops.aten.permute.default(permute_80, [0, 1, 4, 3, 2]); permute_80 = None
clone_51: f32[24, 4, 768, 64, 1], [196608, 49152, 64, 1, 1] = torch.ops.aten.clone.default(permute_82, memory_format = torch.contiguous_format); permute_82 = None
_unsafe_view_28: f32[96, 768, 64], [49152, 64, 1] = torch.ops.aten._unsafe_view.default(clone_51, [mul_44, sym_size_44, 64]); clone_51 = mul_44 = sym_size_44 = None
bmm_11: f32[96, 256, 64], [16384, 64, 1] = torch.ops.aten.bmm.default(view_356, _unsafe_view_28); view_356 = _unsafe_view_28 = None
view_357: f32[24, 4, 256, 1, 64], [65536, 16384, 64, 64, 1] = torch.ops.aten.view.default(bmm_11, [sym_size_43, 4, 256, 1, 64]); bmm_11 = None
permute_83: f32[24, 4, 256, 64, 1], [65536, 16384, 64, 1, 64] = torch.ops.aten.permute.default(view_357, [0, 1, 2, 4, 3])
sym_size_45: Sym(4) = torch.ops.aten.sym_size(view_357, 1); view_357 = None
view_358: f32[24, 4, 256, 64], [65536, 16384, 64, 1] = torch.ops.aten.view.default(permute_83, [sym_size_43, sym_size_45, 256, 64]); permute_83 = sym_size_43 = sym_size_45 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:907, code: return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
view_359: f32[2, 12, 1024, 64], [786432, 65536, 64, 1] = torch.ops.aten.view.default(view_358, [2, 12, 1024, 64]); view_358 = None
transpose_225: f32[2, 1024, 12, 64], [786432, 64, 65536, 1] = torch.ops.aten.transpose.int(view_359, 1, 2)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:674, code: attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
transpose_226: f32[1024, 2, 12, 64], [64, 786432, 65536, 1] = torch.ops.aten.transpose.int(transpose_225, 0, 1); transpose_225 = None
clone_52: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.clone.default(transpose_226, memory_format = torch.contiguous_format); transpose_226 = None
_unsafe_view_29: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten._unsafe_view.default(clone_52, [1024, 2, 768]); clone_52 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:703, code: outputs = (attn_output.transpose(0, 1),)
transpose_227: f32[2, 1024, 768], [768, 1536, 1] = torch.ops.aten.transpose.int(_unsafe_view_29, 0, 1)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1116, code: hidden_states = self.dense(hidden_states)
t_33: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_87); orig_primals_87 = None
clone_53: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.clone.default(transpose_227, memory_format = torch.contiguous_format); transpose_227 = None
sym_size_46: Sym(1024) = torch.ops.aten.sym_size(view_359, 2); view_359 = None
# No stacktrace found for following nodes
mul_45: Sym(2048) = 2 * sym_size_46
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1116, code: hidden_states = self.dense(hidden_states)
sym_size_47: Sym(768) = torch.ops.aten.sym_size(_unsafe_view_29, 2); _unsafe_view_29 = None
view_360: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_53, [mul_45, sym_size_47]); clone_53 = mul_45 = sym_size_47 = None
mm_23: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_360, t_33); view_360 = t_33 = None
view_361: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.view.default(mm_23, [2, sym_size_46, 768]); mm_23 = sym_size_46 = None
add_39: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(view_361, orig_primals_88); orig_primals_88 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1118, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
add_40: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(add_39, getitem_27); add_39 = getitem_27 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1118, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
native_layer_norm_10 = torch.ops.aten.native_layer_norm.default(add_40, [768], orig_primals_89, orig_primals_90, 1e-05); add_40 = orig_primals_89 = orig_primals_90 = None
getitem_30: f32[2, 1024, 768], [786432, 768, 1] = native_layer_norm_10[0]
getitem_31: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_10[1]
getitem_32: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_10[2]; native_layer_norm_10 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1182, code: hidden_states = self.dense(hidden_states)
t_34: f32[768, 3072], [1, 768] = torch.ops.aten.t.default(orig_primals_91); orig_primals_91 = None
sym_size_48: Sym(1024) = torch.ops.aten.sym_size(view_361, 1); view_361 = None
# No stacktrace found for following nodes
mul_46: Sym(2048) = 2 * sym_size_48
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1182, code: hidden_states = self.dense(hidden_states)
view_362: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(getitem_30, [mul_46, 768]); mul_46 = None
addmm_10: f32[2048, 3072], [3072, 1] = torch.ops.aten.addmm.default(orig_primals_92, view_362, t_34); orig_primals_92 = view_362 = t_34 = None
view_363: f32[2, 1024, 3072], [3145728, 3072, 1] = torch.ops.aten.view.default(addmm_10, [2, sym_size_48, 3072]); addmm_10 = sym_size_48 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/activations.py:56, code: return self.act(input)
gelu_5: f32[2, 1024, 3072], [3145728, 3072, 1] = torch.ops.aten.gelu.default(view_363)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1196, code: hidden_states = self.dense(hidden_states)
t_35: f32[3072, 768], [1, 3072] = torch.ops.aten.t.default(orig_primals_93); orig_primals_93 = None
sym_size_49: Sym(1024) = torch.ops.aten.sym_size(view_363, 1); view_363 = None
# No stacktrace found for following nodes
mul_47: Sym(2048) = 2 * sym_size_49
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1196, code: hidden_states = self.dense(hidden_states)
view_364: f32[2048, 3072], [3072, 1] = torch.ops.aten.view.default(gelu_5, [mul_47, 3072]); gelu_5 = mul_47 = None
addmm_11: f32[2048, 768], [768, 1] = torch.ops.aten.addmm.default(orig_primals_94, view_364, t_35); orig_primals_94 = view_364 = t_35 = None
view_365: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.view.default(addmm_11, [2, sym_size_49, 768]); addmm_11 = sym_size_49 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1198, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
add_41: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(view_365, getitem_30); getitem_30 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1198, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
native_layer_norm_11 = torch.ops.aten.native_layer_norm.default(add_41, [768], orig_primals_95, orig_primals_96, 1e-05); add_41 = orig_primals_95 = orig_primals_96 = None
getitem_33: f32[2, 1024, 768], [786432, 768, 1] = native_layer_norm_11[0]
getitem_34: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_11[1]
getitem_35: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_11[2]; native_layer_norm_11 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:562, code: hidden_states = hidden_states.transpose(0, 1)
transpose_228: f32[1024, 2, 768], [768, 786432, 1] = torch.ops.aten.transpose.int(getitem_33, 0, 1)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
t_36: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_97); orig_primals_97 = None
clone_54: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_228, memory_format = torch.contiguous_format)
sym_size_50: Sym(1024) = torch.ops.aten.sym_size(view_365, 1); view_365 = None
# No stacktrace found for following nodes
mul_48: Sym(2048) = sym_size_50 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
view_366: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_54, [mul_48, 768]); clone_54 = mul_48 = None
mm_24: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_366, t_36); view_366 = t_36 = None
view_367: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_24, [sym_size_50, 2, 768]); mm_24 = None
add_42: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_367, orig_primals_98); view_367 = orig_primals_98 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
t_37: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_99); orig_primals_99 = None
clone_55: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_228, memory_format = torch.contiguous_format)
# No stacktrace found for following nodes
mul_49: Sym(2048) = sym_size_50 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
view_368: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_55, [mul_49, 768]); clone_55 = mul_49 = None
mm_25: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_368, t_37); view_368 = t_37 = None
view_369: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_25, [sym_size_50, 2, 768]); mm_25 = None
add_43: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_369, orig_primals_100); view_369 = orig_primals_100 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
t_38: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_101); orig_primals_101 = None
clone_56: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_228, memory_format = torch.contiguous_format); transpose_228 = None
# No stacktrace found for following nodes
mul_50: Sym(2048) = sym_size_50 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
view_370: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_56, [mul_50, 768]); clone_56 = mul_50 = None
mm_26: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_370, t_38); view_370 = t_38 = None
view_371: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_26, [sym_size_50, 2, 768]); mm_26 = sym_size_50 = None
add_44: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_371, orig_primals_102); view_371 = orig_primals_102 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:575, code: query_vectors /= math.sqrt(self.head_dim)
div_6: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.div.Tensor(add_42, 8.0); add_42 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:577, code: query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_372: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div_6, [1024, 2, 12, 64])
transpose_229: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_372, 0, 1); view_372 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:578, code: key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_373: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_43, [1024, 2, 12, 64]); add_43 = None
transpose_230: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_373, 0, 1); view_373 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_231: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_229, 1, 2); transpose_229 = None
view_374: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_231, [24, 1024, 64]); transpose_231 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_232: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_230, 1, 2); transpose_230 = None
view_375: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_232, [24, 1024, 64]); transpose_232 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_376: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_374, [24, 2, 512, 64]); view_374 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_36: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_376, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_376 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_377: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_375, [24, 2, 512, 64]); view_375 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_37: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_377, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_377 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze_90: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_36, 4); as_strided_36 = None
permute_84: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze_90, [0, 1, 2, 4, 3]); unsqueeze_90 = None
unsqueeze_91: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_37, 4); as_strided_37 = None
permute_85: f32[24, 3, 1, 512, 64], [64, 393216, 0, 1536, 1] = torch.ops.aten.permute.default(unsqueeze_91, [0, 1, 4, 2, 3]); unsqueeze_91 = None
permute_86: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute_84, [0, 1, 2, 4, 3]); permute_84 = None
view_378: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div_6, [1024, 2, 12, 64]); div_6 = None
transpose_233: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_378, 0, 1); view_378 = None
transpose_234: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_233, 1, 2); transpose_233 = None
view_379: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_234, [24, 1024, 64]); transpose_234 = None
view_380: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_379, [24, 2, 512, 64]); view_379 = None
as_strided_38: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_380, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_380 = None
unsqueeze_92: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_38, 4); as_strided_38 = None
permute_87: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze_92, [0, 1, 2, 4, 3]); unsqueeze_92 = None
permute_88: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute_87, [0, 1, 2, 4, 3]); permute_87 = None
clone_57: f32[24, 3, 512, 64, 1], [98304, 32768, 64, 1, 1] = torch.ops.aten.clone.default(permute_88, memory_format = torch.contiguous_format); permute_88 = None
_unsafe_view_30: f32[72, 512, 64], [32768, 64, 1] = torch.ops.aten._unsafe_view.default(clone_57, [72, 512, 64]); clone_57 = None
permute_89: f32[24, 3, 64, 512, 1], [64, 393216, 1, 1536, 0] = torch.ops.aten.permute.default(permute_85, [0, 1, 4, 3, 2]); permute_85 = None
clone_58: f32[24, 3, 64, 512, 1], [98304, 32768, 512, 1, 1] = torch.ops.aten.clone.default(permute_89, memory_format = torch.contiguous_format); permute_89 = None
_unsafe_view_31: f32[72, 64, 512], [32768, 512, 1] = torch.ops.aten._unsafe_view.default(clone_58, [72, 64, 512]); clone_58 = None
bmm_12: f32[72, 512, 512], [262144, 512, 1] = torch.ops.aten.bmm.default(_unsafe_view_30, _unsafe_view_31); _unsafe_view_30 = _unsafe_view_31 = None
view_381: f32[24, 3, 512, 1, 512], [786432, 262144, 512, 512, 1] = torch.ops.aten.view.default(bmm_12, [24, 3, 512, 1, 512]); bmm_12 = None
permute_90: f32[24, 3, 512, 512, 1], [786432, 262144, 512, 1, 512] = torch.ops.aten.permute.default(view_381, [0, 1, 2, 4, 3]); view_381 = None
view_382: f32[24, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(permute_90, [24, 3, 512, 512]); permute_90 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd_24: f32[24, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_382, [0, 0, 0, 1], 0.0); view_382 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_383: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd_24, [24, 3, 512, 513]); constant_pad_nd_24 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty_12: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_383, [24, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_1063: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_383, 0, 0, 9223372036854775807)
slice_1064: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1063, 1, 0, 9223372036854775807); slice_1063 = None
slice_1065: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1064, 2, 0, 256); slice_1064 = None
slice_1066: f32[24, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1065, 3, 0, 257); slice_1065 = None
slice_1067: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_12, 0, 0, 9223372036854775807)
slice_1068: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1067, 1, 0, -1); slice_1067 = None
slice_1069: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1068, 2, 0, 9223372036854775807); slice_1068 = None
slice_1070: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1069, 3, 256, 9223372036854775807); slice_1069 = None
slice_1071: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_12, 0, 0, 9223372036854775807)
slice_1072: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1071, 1, 0, -1); slice_1071 = None
slice_1073: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1072, 2, 0, 9223372036854775807); slice_1072 = None
slice_1074: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1073, 3, 256, 9223372036854775807); slice_1073 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_1075: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_383, 0, 0, 9223372036854775807)
select_108: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_1075, 1, -1); slice_1075 = None
slice_1076: f32[24, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_108, 1, 256, 9223372036854775807); select_108 = None
slice_1077: f32[24, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_1076, 2, 0, 257); slice_1076 = None
slice_1078: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_12, 0, 0, 9223372036854775807)
select_109: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1078, 1, -1); slice_1078 = None
slice_1079: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_109, 1, 0, 9223372036854775807); select_109 = None
slice_1080: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1079, 2, 256, 9223372036854775807); slice_1079 = None
slice_1081: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_12, 0, 0, 9223372036854775807)
slice_1082: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1081, 1, 0, -1)
slice_1083: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1082, 2, 0, 9223372036854775807)
slice_scatter_264: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1083, slice_1066, 3, 256, 9223372036854775807); slice_1083 = slice_1066 = None
slice_scatter_265: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1082, slice_scatter_264, 2, 0, 9223372036854775807); slice_1082 = slice_scatter_264 = None
slice_scatter_266: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1081, slice_scatter_265, 1, 0, -1); slice_1081 = slice_scatter_265 = None
slice_scatter_267: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty_12, slice_scatter_266, 0, 0, 9223372036854775807); slice_scatter_266 = None
slice_1084: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_267, 0, 0, 9223372036854775807)
select_110: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1084, 1, -1); slice_1084 = None
slice_1085: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_110, 1, 0, 9223372036854775807); select_110 = None
slice_1086: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1085, 2, 256, 9223372036854775807); slice_1085 = None
slice_1087: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_12, 0, 0, 9223372036854775807)
select_111: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1087, 1, -1); slice_1087 = None
slice_1088: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_111, 1, 0, 9223372036854775807); select_111 = None
slice_1089: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1088, 2, 256, 9223372036854775807); slice_1088 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_1090: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_383, 0, 0, 9223372036854775807)
slice_1091: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1090, 1, 0, 9223372036854775807); slice_1090 = None
slice_1092: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1091, 2, -257, -1); slice_1091 = None
slice_1093: f32[24, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1092, 3, 257, 9223372036854775807); slice_1092 = None
slice_1094: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_12, 0, 0, 9223372036854775807)
slice_1095: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1094, 1, 1, 9223372036854775807); slice_1094 = None
slice_1096: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1095, 2, 0, 9223372036854775807); slice_1095 = None
slice_1097: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1096, 3, 0, 256); slice_1096 = None
slice_1098: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_267, 0, 0, 9223372036854775807)
select_112: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1098, 1, -1)
slice_1099: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_112, 1, 0, 9223372036854775807)
slice_scatter_268: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1099, slice_1077, 2, 256, 9223372036854775807); slice_1099 = slice_1077 = None
slice_scatter_269: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_112, slice_scatter_268, 1, 0, 9223372036854775807); select_112 = slice_scatter_268 = None
select_scatter_24: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_1098, slice_scatter_269, 1, -1); slice_1098 = slice_scatter_269 = None
slice_scatter_270: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_267, select_scatter_24, 0, 0, 9223372036854775807); slice_scatter_267 = select_scatter_24 = None
slice_1100: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_270, 0, 0, 9223372036854775807)
slice_1101: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1100, 1, 1, 9223372036854775807); slice_1100 = None
slice_1102: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1101, 2, 0, 9223372036854775807); slice_1101 = None
slice_1103: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1102, 3, 0, 256); slice_1102 = None
slice_1104: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_12, 0, 0, 9223372036854775807)
slice_1105: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1104, 1, 1, 9223372036854775807); slice_1104 = None
slice_1106: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1105, 2, 0, 9223372036854775807); slice_1105 = None
slice_1107: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1106, 3, 0, 256); slice_1106 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_1108: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_383, 0, 0, 9223372036854775807); view_383 = None
select_113: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_1108, 1, 0); slice_1108 = None
slice_1109: f32[24, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_113, 1, 0, 255); select_113 = None
slice_1110: f32[24, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_1109, 2, -255, 9223372036854775807); slice_1109 = None
slice_1111: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_12, 0, 0, 9223372036854775807)
select_114: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1111, 1, 0); slice_1111 = None
slice_1112: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_114, 1, 1, 256); select_114 = None
slice_1113: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1112, 2, 1, 256); slice_1112 = None
slice_1114: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_270, 0, 0, 9223372036854775807)
slice_1115: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1114, 1, 1, 9223372036854775807)
slice_1116: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1115, 2, 0, 9223372036854775807)
slice_scatter_271: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1116, slice_1093, 3, 0, 256); slice_1116 = slice_1093 = None
slice_scatter_272: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1115, slice_scatter_271, 2, 0, 9223372036854775807); slice_1115 = slice_scatter_271 = None
slice_scatter_273: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1114, slice_scatter_272, 1, 1, 9223372036854775807); slice_1114 = slice_scatter_272 = None
slice_scatter_274: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_270, slice_scatter_273, 0, 0, 9223372036854775807); slice_scatter_270 = slice_scatter_273 = None
slice_1117: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_274, 0, 0, 9223372036854775807)
select_115: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1117, 1, 0); slice_1117 = None
slice_1118: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_115, 1, 1, 256); select_115 = None
slice_1119: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1118, 2, 1, 256); slice_1118 = None
slice_1120: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_12, 0, 0, 9223372036854775807)
select_116: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1120, 1, 0); slice_1120 = None
slice_1121: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_116, 1, 1, 256); select_116 = None
slice_1122: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1121, 2, 1, 256); slice_1121 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_384: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_12, [2, 12, 1024, 513])
transpose_235: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_384, 2, 1); view_384 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_1123: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_274, 0, 0, 9223372036854775807)
select_117: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1123, 1, 0)
slice_1124: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_117, 1, 1, 256)
slice_scatter_275: f32[24, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1124, slice_1110, 2, 1, 256); slice_1124 = slice_1110 = None
slice_scatter_276: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_117, slice_scatter_275, 1, 1, 256); select_117 = slice_scatter_275 = None
select_scatter_25: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_1123, slice_scatter_276, 1, 0); slice_1123 = slice_scatter_276 = None
slice_scatter_277: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_274, select_scatter_25, 0, 0, 9223372036854775807); slice_scatter_274 = select_scatter_25 = None
view_385: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_277, [2, 12, 1024, 513])
transpose_236: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_385, 2, 1); view_385 = None
new_ones_18: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_236, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril_12: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones_18); new_ones_18 = None
flip_24: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril_12, [0]); tril_12 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_93: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip_24, 0); flip_24 = None
slice_1125: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_93, 1, 0, 9223372036854775807); unsqueeze_93 = None
unsqueeze_94: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_1125, 2); slice_1125 = None
slice_1126: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_94, 3, 0, 9223372036854775807); unsqueeze_94 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_25: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_1126, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_1127: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_236, 0, 0, 9223372036854775807)
slice_1128: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1127, 1, 0, 256); slice_1127 = None
slice_1129: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1128, 2, 0, 9223372036854775807); slice_1128 = None
slice_1130: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1129, 3, 0, 257); slice_1129 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand_24: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_1126, [2, 256, 12, 257]); slice_1126 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_24: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand_24, 1); expand_24 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_386: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_277, [2, 12, 1024, 513])
transpose_237: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_386, 2, 1); view_386 = None
slice_1131: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_237, 0, 0, 9223372036854775807); transpose_237 = None
slice_1132: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1131, 1, 0, 256); slice_1131 = None
slice_1133: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1132, 2, 0, 9223372036854775807); slice_1132 = None
slice_1134: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1133, 3, 0, 257); slice_1133 = None
masked_fill_36: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_1134, eq_24, -inf); slice_1134 = eq_24 = None
view_387: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_12, [2, 12, 1024, 513])
transpose_238: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_387, 2, 1); view_387 = None
slice_1135: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_238, 0, 0, 9223372036854775807); transpose_238 = None
slice_1136: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1135, 1, 0, 256); slice_1135 = None
slice_1137: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1136, 2, 0, 9223372036854775807); slice_1136 = None
slice_1138: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1137, 3, 0, 257); slice_1137 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
view_388: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_12, [2, 12, 1024, 513])
transpose_239: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_388, 2, 1); view_388 = None
slice_1139: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_236, 0, 0, 9223372036854775807); transpose_236 = None
slice_1140: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1139, 1, -256, 9223372036854775807); slice_1139 = None
slice_1141: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1140, 2, 0, 9223372036854775807); slice_1140 = None
slice_1142: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1141, 3, -257, 9223372036854775807); slice_1141 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_25: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_25, [2, 256, 12, 257]); flip_25 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_25: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand_25, 1); expand_25 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_389: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_277, [2, 12, 1024, 513]); slice_scatter_277 = None
transpose_240: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_389, 2, 1); view_389 = None
slice_1143: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_240, 0, 0, 9223372036854775807)
slice_1144: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1143, 1, 0, 256)
slice_1145: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1144, 2, 0, 9223372036854775807)
slice_scatter_278: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1145, masked_fill_36, 3, 0, 257); slice_1145 = masked_fill_36 = None
slice_scatter_279: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1144, slice_scatter_278, 2, 0, 9223372036854775807); slice_1144 = slice_scatter_278 = None
slice_scatter_280: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_1143, slice_scatter_279, 1, 0, 256); slice_1143 = slice_scatter_279 = None
slice_scatter_281: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_240, slice_scatter_280, 0, 0, 9223372036854775807); transpose_240 = slice_scatter_280 = None
transpose_241: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_281, 2, 1); slice_scatter_281 = None
view_390: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_241, [24, 4, 256, 513]); transpose_241 = None
view_391: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_390, [2, 12, 1024, 513])
transpose_242: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_391, 2, 1); view_391 = None
slice_1146: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_242, 0, 0, 9223372036854775807); transpose_242 = None
slice_1147: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1146, 1, -256, 9223372036854775807); slice_1146 = None
slice_1148: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1147, 2, 0, 9223372036854775807); slice_1147 = None
slice_1149: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1148, 3, -257, 9223372036854775807); slice_1148 = None
masked_fill_37: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_1149, eq_25, -inf); slice_1149 = eq_25 = None
view_392: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_12, [2, 12, 1024, 513])
transpose_243: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_392, 2, 1); view_392 = None
slice_1150: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_243, 0, 0, 9223372036854775807); transpose_243 = None
slice_1151: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1150, 1, -256, 9223372036854775807); slice_1150 = None
slice_1152: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1151, 2, 0, 9223372036854775807); slice_1151 = None
slice_1153: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1152, 3, -257, 9223372036854775807); slice_1152 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
ne_6: b8[2, 1024], [1024, 1] = torch.ops.aten.ne.Scalar(orig_primals_194, 0)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:585, code: remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]
slice_1154: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(ne_6, 0, 0, 9223372036854775807); ne_6 = None
slice_1155: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_1154, 1, 0, 9223372036854775807); slice_1154 = None
unsqueeze_95: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_1155, 2); slice_1155 = None
unsqueeze_96: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_95, 3); unsqueeze_95 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:588, code: float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
_to_copy_6: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten._to_copy.default(unsqueeze_96, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))
masked_fill_38: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.masked_fill.Scalar(_to_copy_6, unsqueeze_96, -10000.0); _to_copy_6 = unsqueeze_96 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:593, code: float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
new_ones_19: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.new_ones.default(masked_fill_38, [2, 1024, 1, 1], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_244: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(new_ones_19, 1, 2); new_ones_19 = None
view_393: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_244, [2, 1024, 1]); transpose_244 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_245: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(masked_fill_38, 1, 2); masked_fill_38 = None
view_394: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_245, [2, 1024, 1]); transpose_245 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_395: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_393, [2, 2, 512, 1]); view_393 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_39: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_395, [2, 3, 512, 1], [1024, 256, 1, 1]); view_395 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_396: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_394, [2, 2, 512, 1]); view_394 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_40: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_396, [2, 3, 512, 1], [1024, 256, 1, 1]); view_396 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze_97: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_39, 4); as_strided_39 = None
permute_91: f32[2, 3, 512, 1, 1], [1024, 256, 1, 0, 1] = torch.ops.aten.permute.default(unsqueeze_97, [0, 1, 2, 4, 3]); unsqueeze_97 = None
unsqueeze_98: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_40, 4); as_strided_40 = None
permute_92: f32[2, 3, 1, 512, 1], [1024, 256, 0, 1, 1] = torch.ops.aten.permute.default(unsqueeze_98, [0, 1, 4, 2, 3]); unsqueeze_98 = None
mul_51: f32[2, 3, 512, 512, 1], [786432, 262144, 512, 1, 1] = torch.ops.aten.mul.Tensor(permute_91, permute_92); permute_91 = permute_92 = None
view_397: f32[2, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(mul_51, [2, 3, 512, 512]); mul_51 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd_25: f32[2, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_397, [0, 0, 0, 1], 0.0); view_397 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_398: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd_25, [2, 3, 512, 513]); constant_pad_nd_25 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty_13: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_398, [2, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_1156: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_398, 0, 0, 9223372036854775807)
slice_1157: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1156, 1, 0, 9223372036854775807); slice_1156 = None
slice_1158: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1157, 2, 0, 256); slice_1157 = None
slice_1159: f32[2, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1158, 3, 0, 257); slice_1158 = None
slice_1160: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_13, 0, 0, 9223372036854775807)
slice_1161: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1160, 1, 0, -1); slice_1160 = None
slice_1162: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1161, 2, 0, 9223372036854775807); slice_1161 = None
slice_1163: f32[2, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1162, 3, 256, 9223372036854775807); slice_1162 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_1164: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_398, 0, 0, 9223372036854775807)
select_118: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_1164, 1, -1); slice_1164 = None
slice_1165: f32[2, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_118, 1, 256, 9223372036854775807); select_118 = None
slice_1166: f32[2, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_1165, 2, 0, 257); slice_1165 = None
slice_1167: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_13, 0, 0, 9223372036854775807)
select_119: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1167, 1, -1); slice_1167 = None
slice_1168: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_119, 1, 0, 9223372036854775807); select_119 = None
slice_1169: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1168, 2, 256, 9223372036854775807); slice_1168 = None
slice_1170: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_13, 0, 0, 9223372036854775807)
slice_1171: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1170, 1, 0, -1)
slice_1172: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1171, 2, 0, 9223372036854775807)
slice_scatter_282: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1172, slice_1159, 3, 256, 9223372036854775807); slice_1172 = slice_1159 = None
slice_scatter_283: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1171, slice_scatter_282, 2, 0, 9223372036854775807); slice_1171 = slice_scatter_282 = None
slice_scatter_284: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1170, slice_scatter_283, 1, 0, -1); slice_1170 = slice_scatter_283 = None
slice_scatter_285: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty_13, slice_scatter_284, 0, 0, 9223372036854775807); slice_scatter_284 = None
slice_1173: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_285, 0, 0, 9223372036854775807)
select_120: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1173, 1, -1); slice_1173 = None
slice_1174: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_120, 1, 0, 9223372036854775807); select_120 = None
slice_1175: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1174, 2, 256, 9223372036854775807); slice_1174 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_1176: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_398, 0, 0, 9223372036854775807)
slice_1177: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1176, 1, 0, 9223372036854775807); slice_1176 = None
slice_1178: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1177, 2, -257, -1); slice_1177 = None
slice_1179: f32[2, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1178, 3, 257, 9223372036854775807); slice_1178 = None
slice_1180: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_13, 0, 0, 9223372036854775807)
slice_1181: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1180, 1, 1, 9223372036854775807); slice_1180 = None
slice_1182: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1181, 2, 0, 9223372036854775807); slice_1181 = None
slice_1183: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1182, 3, 0, 256); slice_1182 = None
slice_1184: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_285, 0, 0, 9223372036854775807)
select_121: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1184, 1, -1)
slice_1185: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_121, 1, 0, 9223372036854775807)
slice_scatter_286: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1185, slice_1166, 2, 256, 9223372036854775807); slice_1185 = slice_1166 = None
slice_scatter_287: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_121, slice_scatter_286, 1, 0, 9223372036854775807); select_121 = slice_scatter_286 = None
select_scatter_26: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_1184, slice_scatter_287, 1, -1); slice_1184 = slice_scatter_287 = None
slice_scatter_288: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_285, select_scatter_26, 0, 0, 9223372036854775807); slice_scatter_285 = select_scatter_26 = None
slice_1186: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_288, 0, 0, 9223372036854775807)
slice_1187: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1186, 1, 1, 9223372036854775807); slice_1186 = None
slice_1188: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1187, 2, 0, 9223372036854775807); slice_1187 = None
slice_1189: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1188, 3, 0, 256); slice_1188 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_1190: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_398, 0, 0, 9223372036854775807); view_398 = None
select_122: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_1190, 1, 0); slice_1190 = None
slice_1191: f32[2, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_122, 1, 0, 255); select_122 = None
slice_1192: f32[2, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_1191, 2, -255, 9223372036854775807); slice_1191 = None
slice_1193: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_13, 0, 0, 9223372036854775807)
select_123: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1193, 1, 0); slice_1193 = None
slice_1194: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_123, 1, 1, 256); select_123 = None
slice_1195: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1194, 2, 1, 256); slice_1194 = None
slice_1196: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_288, 0, 0, 9223372036854775807)
slice_1197: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1196, 1, 1, 9223372036854775807)
slice_1198: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1197, 2, 0, 9223372036854775807)
slice_scatter_289: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1198, slice_1179, 3, 0, 256); slice_1198 = slice_1179 = None
slice_scatter_290: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1197, slice_scatter_289, 2, 0, 9223372036854775807); slice_1197 = slice_scatter_289 = None
slice_scatter_291: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1196, slice_scatter_290, 1, 1, 9223372036854775807); slice_1196 = slice_scatter_290 = None
slice_scatter_292: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_288, slice_scatter_291, 0, 0, 9223372036854775807); slice_scatter_288 = slice_scatter_291 = None
slice_1199: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_292, 0, 0, 9223372036854775807)
select_124: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1199, 1, 0); slice_1199 = None
slice_1200: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_124, 1, 1, 256); select_124 = None
slice_1201: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1200, 2, 1, 256); slice_1200 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_399: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_13, [2, 1, 1024, 513]); new_empty_13 = None
transpose_246: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_399, 2, 1); view_399 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_1202: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_292, 0, 0, 9223372036854775807)
select_125: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1202, 1, 0)
slice_1203: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_125, 1, 1, 256)
slice_scatter_293: f32[2, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1203, slice_1192, 2, 1, 256); slice_1203 = slice_1192 = None
slice_scatter_294: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_125, slice_scatter_293, 1, 1, 256); select_125 = slice_scatter_293 = None
select_scatter_27: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_1202, slice_scatter_294, 1, 0); slice_1202 = slice_scatter_294 = None
slice_scatter_295: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_292, select_scatter_27, 0, 0, 9223372036854775807); slice_scatter_292 = select_scatter_27 = None
view_400: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_295, [2, 1, 1024, 513])
transpose_247: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_400, 2, 1); view_400 = None
new_ones_20: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_247, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril_13: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones_20); new_ones_20 = None
flip_26: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril_13, [0]); tril_13 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_99: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip_26, 0); flip_26 = None
slice_1204: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_99, 1, 0, 9223372036854775807); unsqueeze_99 = None
unsqueeze_100: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_1204, 2); slice_1204 = None
slice_1205: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_100, 3, 0, 9223372036854775807); unsqueeze_100 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_27: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_1205, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_1206: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_247, 0, 0, 9223372036854775807)
slice_1207: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1206, 1, 0, 256); slice_1206 = None
slice_1208: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1207, 2, 0, 9223372036854775807); slice_1207 = None
slice_1209: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1208, 3, 0, 257); slice_1208 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand_26: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_1205, [2, 256, 1, 257]); slice_1205 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_26: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_26, 1); expand_26 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_401: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_295, [2, 1, 1024, 513])
transpose_248: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_401, 2, 1); view_401 = None
slice_1210: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_248, 0, 0, 9223372036854775807); transpose_248 = None
slice_1211: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1210, 1, 0, 256); slice_1210 = None
slice_1212: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1211, 2, 0, 9223372036854775807); slice_1211 = None
slice_1213: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1212, 3, 0, 257); slice_1212 = None
masked_fill_39: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_1213, eq_26, -inf); slice_1213 = eq_26 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
slice_1214: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_247, 0, 0, 9223372036854775807); transpose_247 = None
slice_1215: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1214, 1, -256, 9223372036854775807); slice_1214 = None
slice_1216: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1215, 2, 0, 9223372036854775807); slice_1215 = None
slice_1217: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1216, 3, -257, 9223372036854775807); slice_1216 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_27: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_27, [2, 256, 1, 257]); flip_27 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_27: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_27, 1); expand_27 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_402: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_295, [2, 1, 1024, 513]); slice_scatter_295 = None
transpose_249: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_402, 2, 1); view_402 = None
slice_1218: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_249, 0, 0, 9223372036854775807)
slice_1219: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1218, 1, 0, 256)
slice_1220: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1219, 2, 0, 9223372036854775807)
slice_scatter_296: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1220, masked_fill_39, 3, 0, 257); slice_1220 = masked_fill_39 = None
slice_scatter_297: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1219, slice_scatter_296, 2, 0, 9223372036854775807); slice_1219 = slice_scatter_296 = None
slice_scatter_298: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_1218, slice_scatter_297, 1, 0, 256); slice_1218 = slice_scatter_297 = None
slice_scatter_299: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_249, slice_scatter_298, 0, 0, 9223372036854775807); transpose_249 = slice_scatter_298 = None
transpose_250: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_299, 2, 1); slice_scatter_299 = None
view_403: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_250, [2, 4, 256, 513]); transpose_250 = None
view_404: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_403, [2, 1, 1024, 513])
transpose_251: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_404, 2, 1); view_404 = None
slice_1221: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_251, 0, 0, 9223372036854775807); transpose_251 = None
slice_1222: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1221, 1, -256, 9223372036854775807); slice_1221 = None
slice_1223: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1222, 2, 0, 9223372036854775807); slice_1222 = None
slice_1224: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1223, 3, -257, 9223372036854775807); slice_1223 = None
masked_fill_40: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_1224, eq_27, -inf); slice_1224 = eq_27 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:597, code: attn_scores += diagonal_mask
view_405: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_12, [2, 12, 1024, 513])
transpose_252: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_405, 2, 1); view_405 = None
view_406: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_390, [2, 12, 1024, 513]); view_390 = None
transpose_253: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_406, 2, 1); view_406 = None
slice_1225: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_253, 0, 0, 9223372036854775807)
slice_1226: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1225, 1, -256, 9223372036854775807)
slice_1227: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1226, 2, 0, 9223372036854775807)
slice_scatter_300: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1227, masked_fill_37, 3, -257, 9223372036854775807); slice_1227 = masked_fill_37 = None
slice_scatter_301: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1226, slice_scatter_300, 2, 0, 9223372036854775807); slice_1226 = slice_scatter_300 = None
slice_scatter_302: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_1225, slice_scatter_301, 1, -256, 9223372036854775807); slice_1225 = slice_scatter_301 = None
slice_scatter_303: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_253, slice_scatter_302, 0, 0, 9223372036854775807); transpose_253 = slice_scatter_302 = None
transpose_254: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_303, 2, 1); slice_scatter_303 = None
view_407: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_254, [24, 4, 256, 513]); transpose_254 = None
view_408: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_407, [2, 12, 1024, 513]); view_407 = None
transpose_255: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_408, 2, 1); view_408 = None
view_409: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_403, [2, 1, 1024, 513]); view_403 = None
transpose_256: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_409, 2, 1); view_409 = None
slice_1228: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_256, 0, 0, 9223372036854775807)
slice_1229: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1228, 1, -256, 9223372036854775807)
slice_1230: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1229, 2, 0, 9223372036854775807)
slice_scatter_304: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1230, masked_fill_40, 3, -257, 9223372036854775807); slice_1230 = masked_fill_40 = None
slice_scatter_305: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1229, slice_scatter_304, 2, 0, 9223372036854775807); slice_1229 = slice_scatter_304 = None
slice_scatter_306: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_1228, slice_scatter_305, 1, -256, 9223372036854775807); slice_1228 = slice_scatter_305 = None
slice_scatter_307: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_256, slice_scatter_306, 0, 0, 9223372036854775807); transpose_256 = slice_scatter_306 = None
transpose_257: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_307, 2, 1); slice_scatter_307 = None
view_410: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_257, [2, 4, 256, 513]); transpose_257 = None
view_411: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_410, [2, 1, 1024, 513]); view_410 = None
transpose_258: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_411, 2, 1); view_411 = None
add_45: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.add.Tensor(transpose_255, transpose_258); transpose_255 = transpose_258 = None
view_412: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_12, [2, 12, 1024, 513]); new_empty_12 = None
transpose_259: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_412, 2, 1); view_412 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:635, code: attn_probs = nn.functional.softmax(
_softmax_6: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten._softmax.default(add_45, -1, False); add_45 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:646, code: attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
slice_1231: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(orig_primals_195, 0, 0, 9223372036854775807)
slice_1232: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_1231, 1, 0, 9223372036854775807); slice_1231 = None
unsqueeze_101: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_1232, 2); slice_1232 = None
unsqueeze_102: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_101, 3); unsqueeze_101 = None
masked_fill_41: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten.masked_fill.Scalar(_softmax_6, unsqueeze_102, 0.0); _softmax_6 = unsqueeze_102 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:655, code: value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_413: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_44, [1024, 2, 12, 64]); add_44 = None
transpose_260: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_413, 0, 1); view_413 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:883, code: chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
transpose_261: f32[2, 12, 1024, 513], [6303744, 513, 6156, 1] = torch.ops.aten.transpose.int(masked_fill_41, 1, 2); masked_fill_41 = None
clone_59: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.clone.default(transpose_261, memory_format = torch.contiguous_format); transpose_261 = None
_unsafe_view_32: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten._unsafe_view.default(clone_59, [24, 4, 256, 513]); clone_59 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:888, code: value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_262: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_260, 1, 2); transpose_260 = None
view_414: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_262, [24, 1024, 64]); transpose_262 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:891, code: padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1)
constant_pad_nd_26: f32[24, 1536, 64], [98304, 64, 1] = torch.ops.aten.constant_pad_nd.default(view_414, [0, 0, 256, 256], -1.0); view_414 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:902, code: chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride)
as_strided_41: f32[24, 4, 768, 64], [98304, 16384, 64, 1] = torch.ops.aten.as_strided.default(constant_pad_nd_26, [24, 4, 768, 64], [98304, 16384, 64, 1]); constant_pad_nd_26 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:755, code: chunked_hidden_states = nn.functional.pad(
constant_pad_nd_27: f32[24, 4, 256, 770], [788480, 197120, 770, 1] = torch.ops.aten.constant_pad_nd.default(_unsafe_view_32, [0, 257], 0.0); _unsafe_view_32 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:758, code: chunked_hidden_states = chunked_hidden_states.view(
view_415: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.view.default(constant_pad_nd_27, [24, 4, -1]); constant_pad_nd_27 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:761, code: chunked_hidden_states = chunked_hidden_states[
slice_1233: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(view_415, 0, 0, 9223372036854775807); view_415 = None
slice_1234: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_1233, 1, 0, 9223372036854775807); slice_1233 = None
slice_1235: f32[24, 4, 196864], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_1234, 2, 0, -256); slice_1234 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:764, code: chunked_hidden_states = chunked_hidden_states.view(
view_416: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.view.default(slice_1235, [24, 4, 256, 769]); slice_1235 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:767, code: chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
slice_1236: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(view_416, 0, 0, 9223372036854775807); view_416 = None
slice_1237: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_1236, 1, 0, 9223372036854775807)
slice_1238: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_1237, 2, 0, 9223372036854775807); slice_1237 = None
slice_1239: f32[24, 4, 256, 768], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_1238, 3, 0, -1); slice_1238 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:906, code: context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
unsqueeze_103: f32[24, 4, 256, 768, 1], [788480, 197120, 769, 1, 0] = torch.ops.aten.unsqueeze.default(slice_1239, 4)
permute_93: f32[24, 4, 256, 1, 768], [788480, 197120, 769, 0, 1] = torch.ops.aten.permute.default(unsqueeze_103, [0, 1, 2, 4, 3]); unsqueeze_103 = None
unsqueeze_104: f32[24, 4, 768, 64, 1], [98304, 16384, 64, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_41, 4); as_strided_41 = None
permute_94: f32[24, 4, 1, 64, 768], [98304, 16384, 0, 1, 64] = torch.ops.aten.permute.default(unsqueeze_104, [0, 1, 4, 3, 2]); unsqueeze_104 = None
permute_95: f32[24, 4, 256, 768, 1], [788480, 197120, 769, 1, 0] = torch.ops.aten.permute.default(permute_93, [0, 1, 2, 4, 3]); permute_93 = None
sym_size_51: Sym(24) = torch.ops.aten.sym_size(slice_1236, 0); slice_1236 = None
# No stacktrace found for following nodes
mul_52: Sym(96) = sym_size_51 * 4
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:906, code: context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
sym_size_52: Sym(768) = torch.ops.aten.sym_size(slice_1239, 3); slice_1239 = None
view_417: f32[96, 256, 768], [197120, 769, 1] = torch.ops.aten.view.default(permute_95, [mul_52, 256, sym_size_52]); permute_95 = None
permute_96: f32[24, 4, 768, 64, 1], [98304, 16384, 64, 1, 0] = torch.ops.aten.permute.default(permute_94, [0, 1, 4, 3, 2]); permute_94 = None
clone_60: f32[24, 4, 768, 64, 1], [196608, 49152, 64, 1, 1] = torch.ops.aten.clone.default(permute_96, memory_format = torch.contiguous_format); permute_96 = None
_unsafe_view_33: f32[96, 768, 64], [49152, 64, 1] = torch.ops.aten._unsafe_view.default(clone_60, [mul_52, sym_size_52, 64]); clone_60 = mul_52 = sym_size_52 = None
bmm_13: f32[96, 256, 64], [16384, 64, 1] = torch.ops.aten.bmm.default(view_417, _unsafe_view_33); view_417 = _unsafe_view_33 = None
view_418: f32[24, 4, 256, 1, 64], [65536, 16384, 64, 64, 1] = torch.ops.aten.view.default(bmm_13, [sym_size_51, 4, 256, 1, 64]); bmm_13 = None
permute_97: f32[24, 4, 256, 64, 1], [65536, 16384, 64, 1, 64] = torch.ops.aten.permute.default(view_418, [0, 1, 2, 4, 3])
sym_size_53: Sym(4) = torch.ops.aten.sym_size(view_418, 1); view_418 = None
view_419: f32[24, 4, 256, 64], [65536, 16384, 64, 1] = torch.ops.aten.view.default(permute_97, [sym_size_51, sym_size_53, 256, 64]); permute_97 = sym_size_51 = sym_size_53 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:907, code: return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
view_420: f32[2, 12, 1024, 64], [786432, 65536, 64, 1] = torch.ops.aten.view.default(view_419, [2, 12, 1024, 64]); view_419 = None
transpose_263: f32[2, 1024, 12, 64], [786432, 64, 65536, 1] = torch.ops.aten.transpose.int(view_420, 1, 2)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:674, code: attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
transpose_264: f32[1024, 2, 12, 64], [64, 786432, 65536, 1] = torch.ops.aten.transpose.int(transpose_263, 0, 1); transpose_263 = None
clone_61: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.clone.default(transpose_264, memory_format = torch.contiguous_format); transpose_264 = None
_unsafe_view_34: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten._unsafe_view.default(clone_61, [1024, 2, 768]); clone_61 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:703, code: outputs = (attn_output.transpose(0, 1),)
transpose_265: f32[2, 1024, 768], [768, 1536, 1] = torch.ops.aten.transpose.int(_unsafe_view_34, 0, 1)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1116, code: hidden_states = self.dense(hidden_states)
t_39: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_103); orig_primals_103 = None
clone_62: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.clone.default(transpose_265, memory_format = torch.contiguous_format); transpose_265 = None
sym_size_54: Sym(1024) = torch.ops.aten.sym_size(view_420, 2); view_420 = None
# No stacktrace found for following nodes
mul_53: Sym(2048) = 2 * sym_size_54
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1116, code: hidden_states = self.dense(hidden_states)
sym_size_55: Sym(768) = torch.ops.aten.sym_size(_unsafe_view_34, 2); _unsafe_view_34 = None
view_421: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_62, [mul_53, sym_size_55]); clone_62 = mul_53 = sym_size_55 = None
mm_27: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_421, t_39); view_421 = t_39 = None
view_422: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.view.default(mm_27, [2, sym_size_54, 768]); mm_27 = sym_size_54 = None
add_46: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(view_422, orig_primals_104); orig_primals_104 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1118, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
add_47: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(add_46, getitem_33); add_46 = getitem_33 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1118, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
native_layer_norm_12 = torch.ops.aten.native_layer_norm.default(add_47, [768], orig_primals_105, orig_primals_106, 1e-05); add_47 = orig_primals_105 = orig_primals_106 = None
getitem_36: f32[2, 1024, 768], [786432, 768, 1] = native_layer_norm_12[0]
getitem_37: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_12[1]
getitem_38: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_12[2]; native_layer_norm_12 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1182, code: hidden_states = self.dense(hidden_states)
t_40: f32[768, 3072], [1, 768] = torch.ops.aten.t.default(orig_primals_107); orig_primals_107 = None
sym_size_56: Sym(1024) = torch.ops.aten.sym_size(view_422, 1); view_422 = None
# No stacktrace found for following nodes
mul_54: Sym(2048) = 2 * sym_size_56
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1182, code: hidden_states = self.dense(hidden_states)
view_423: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(getitem_36, [mul_54, 768]); mul_54 = None
addmm_12: f32[2048, 3072], [3072, 1] = torch.ops.aten.addmm.default(orig_primals_108, view_423, t_40); orig_primals_108 = view_423 = t_40 = None
view_424: f32[2, 1024, 3072], [3145728, 3072, 1] = torch.ops.aten.view.default(addmm_12, [2, sym_size_56, 3072]); addmm_12 = sym_size_56 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/activations.py:56, code: return self.act(input)
gelu_6: f32[2, 1024, 3072], [3145728, 3072, 1] = torch.ops.aten.gelu.default(view_424)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1196, code: hidden_states = self.dense(hidden_states)
t_41: f32[3072, 768], [1, 3072] = torch.ops.aten.t.default(orig_primals_109); orig_primals_109 = None
sym_size_57: Sym(1024) = torch.ops.aten.sym_size(view_424, 1); view_424 = None
# No stacktrace found for following nodes
mul_55: Sym(2048) = 2 * sym_size_57
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1196, code: hidden_states = self.dense(hidden_states)
view_425: f32[2048, 3072], [3072, 1] = torch.ops.aten.view.default(gelu_6, [mul_55, 3072]); gelu_6 = mul_55 = None
addmm_13: f32[2048, 768], [768, 1] = torch.ops.aten.addmm.default(orig_primals_110, view_425, t_41); orig_primals_110 = view_425 = t_41 = None
view_426: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.view.default(addmm_13, [2, sym_size_57, 768]); addmm_13 = sym_size_57 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1198, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
add_48: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(view_426, getitem_36); getitem_36 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1198, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
native_layer_norm_13 = torch.ops.aten.native_layer_norm.default(add_48, [768], orig_primals_111, orig_primals_112, 1e-05); add_48 = orig_primals_111 = orig_primals_112 = None
getitem_39: f32[2, 1024, 768], [786432, 768, 1] = native_layer_norm_13[0]
getitem_40: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_13[1]
getitem_41: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_13[2]; native_layer_norm_13 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:562, code: hidden_states = hidden_states.transpose(0, 1)
transpose_266: f32[1024, 2, 768], [768, 786432, 1] = torch.ops.aten.transpose.int(getitem_39, 0, 1)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
t_42: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_113); orig_primals_113 = None
clone_63: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_266, memory_format = torch.contiguous_format)
sym_size_58: Sym(1024) = torch.ops.aten.sym_size(view_426, 1); view_426 = None
# No stacktrace found for following nodes
mul_56: Sym(2048) = sym_size_58 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
view_427: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_63, [mul_56, 768]); clone_63 = mul_56 = None
mm_28: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_427, t_42); view_427 = t_42 = None
view_428: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_28, [sym_size_58, 2, 768]); mm_28 = None
add_49: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_428, orig_primals_114); view_428 = orig_primals_114 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
t_43: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_115); orig_primals_115 = None
clone_64: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_266, memory_format = torch.contiguous_format)
# No stacktrace found for following nodes
mul_57: Sym(2048) = sym_size_58 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
view_429: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_64, [mul_57, 768]); clone_64 = mul_57 = None
mm_29: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_429, t_43); view_429 = t_43 = None
view_430: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_29, [sym_size_58, 2, 768]); mm_29 = None
add_50: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_430, orig_primals_116); view_430 = orig_primals_116 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
t_44: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_117); orig_primals_117 = None
clone_65: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_266, memory_format = torch.contiguous_format); transpose_266 = None
# No stacktrace found for following nodes
mul_58: Sym(2048) = sym_size_58 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
view_431: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_65, [mul_58, 768]); clone_65 = mul_58 = None
mm_30: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_431, t_44); view_431 = t_44 = None
view_432: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_30, [sym_size_58, 2, 768]); mm_30 = sym_size_58 = None
add_51: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_432, orig_primals_118); view_432 = orig_primals_118 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:575, code: query_vectors /= math.sqrt(self.head_dim)
div_7: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.div.Tensor(add_49, 8.0); add_49 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:577, code: query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_433: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div_7, [1024, 2, 12, 64])
transpose_267: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_433, 0, 1); view_433 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:578, code: key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_434: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_50, [1024, 2, 12, 64]); add_50 = None
transpose_268: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_434, 0, 1); view_434 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_269: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_267, 1, 2); transpose_267 = None
view_435: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_269, [24, 1024, 64]); transpose_269 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_270: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_268, 1, 2); transpose_268 = None
view_436: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_270, [24, 1024, 64]); transpose_270 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_437: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_435, [24, 2, 512, 64]); view_435 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_42: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_437, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_437 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_438: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_436, [24, 2, 512, 64]); view_436 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_43: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_438, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_438 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze_105: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_42, 4); as_strided_42 = None
permute_98: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze_105, [0, 1, 2, 4, 3]); unsqueeze_105 = None
unsqueeze_106: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_43, 4); as_strided_43 = None
permute_99: f32[24, 3, 1, 512, 64], [64, 393216, 0, 1536, 1] = torch.ops.aten.permute.default(unsqueeze_106, [0, 1, 4, 2, 3]); unsqueeze_106 = None
permute_100: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute_98, [0, 1, 2, 4, 3]); permute_98 = None
view_439: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div_7, [1024, 2, 12, 64]); div_7 = None
transpose_271: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_439, 0, 1); view_439 = None
transpose_272: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_271, 1, 2); transpose_271 = None
view_440: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_272, [24, 1024, 64]); transpose_272 = None
view_441: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_440, [24, 2, 512, 64]); view_440 = None
as_strided_44: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_441, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_441 = None
unsqueeze_107: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_44, 4); as_strided_44 = None
permute_101: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze_107, [0, 1, 2, 4, 3]); unsqueeze_107 = None
permute_102: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute_101, [0, 1, 2, 4, 3]); permute_101 = None
clone_66: f32[24, 3, 512, 64, 1], [98304, 32768, 64, 1, 1] = torch.ops.aten.clone.default(permute_102, memory_format = torch.contiguous_format); permute_102 = None
_unsafe_view_35: f32[72, 512, 64], [32768, 64, 1] = torch.ops.aten._unsafe_view.default(clone_66, [72, 512, 64]); clone_66 = None
permute_103: f32[24, 3, 64, 512, 1], [64, 393216, 1, 1536, 0] = torch.ops.aten.permute.default(permute_99, [0, 1, 4, 3, 2]); permute_99 = None
clone_67: f32[24, 3, 64, 512, 1], [98304, 32768, 512, 1, 1] = torch.ops.aten.clone.default(permute_103, memory_format = torch.contiguous_format); permute_103 = None
_unsafe_view_36: f32[72, 64, 512], [32768, 512, 1] = torch.ops.aten._unsafe_view.default(clone_67, [72, 64, 512]); clone_67 = None
bmm_14: f32[72, 512, 512], [262144, 512, 1] = torch.ops.aten.bmm.default(_unsafe_view_35, _unsafe_view_36); _unsafe_view_35 = _unsafe_view_36 = None
view_442: f32[24, 3, 512, 1, 512], [786432, 262144, 512, 512, 1] = torch.ops.aten.view.default(bmm_14, [24, 3, 512, 1, 512]); bmm_14 = None
permute_104: f32[24, 3, 512, 512, 1], [786432, 262144, 512, 1, 512] = torch.ops.aten.permute.default(view_442, [0, 1, 2, 4, 3]); view_442 = None
view_443: f32[24, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(permute_104, [24, 3, 512, 512]); permute_104 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd_28: f32[24, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_443, [0, 0, 0, 1], 0.0); view_443 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_444: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd_28, [24, 3, 512, 513]); constant_pad_nd_28 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty_14: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_444, [24, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_1240: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_444, 0, 0, 9223372036854775807)
slice_1241: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1240, 1, 0, 9223372036854775807); slice_1240 = None
slice_1242: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1241, 2, 0, 256); slice_1241 = None
slice_1243: f32[24, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1242, 3, 0, 257); slice_1242 = None
slice_1244: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_14, 0, 0, 9223372036854775807)
slice_1245: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1244, 1, 0, -1); slice_1244 = None
slice_1246: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1245, 2, 0, 9223372036854775807); slice_1245 = None
slice_1247: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1246, 3, 256, 9223372036854775807); slice_1246 = None
slice_1248: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_14, 0, 0, 9223372036854775807)
slice_1249: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1248, 1, 0, -1); slice_1248 = None
slice_1250: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1249, 2, 0, 9223372036854775807); slice_1249 = None
slice_1251: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1250, 3, 256, 9223372036854775807); slice_1250 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_1252: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_444, 0, 0, 9223372036854775807)
select_126: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_1252, 1, -1); slice_1252 = None
slice_1253: f32[24, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_126, 1, 256, 9223372036854775807); select_126 = None
slice_1254: f32[24, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_1253, 2, 0, 257); slice_1253 = None
slice_1255: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_14, 0, 0, 9223372036854775807)
select_127: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1255, 1, -1); slice_1255 = None
slice_1256: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_127, 1, 0, 9223372036854775807); select_127 = None
slice_1257: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1256, 2, 256, 9223372036854775807); slice_1256 = None
slice_1258: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_14, 0, 0, 9223372036854775807)
slice_1259: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1258, 1, 0, -1)
slice_1260: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1259, 2, 0, 9223372036854775807)
slice_scatter_308: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1260, slice_1243, 3, 256, 9223372036854775807); slice_1260 = slice_1243 = None
slice_scatter_309: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1259, slice_scatter_308, 2, 0, 9223372036854775807); slice_1259 = slice_scatter_308 = None
slice_scatter_310: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1258, slice_scatter_309, 1, 0, -1); slice_1258 = slice_scatter_309 = None
slice_scatter_311: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty_14, slice_scatter_310, 0, 0, 9223372036854775807); slice_scatter_310 = None
slice_1261: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_311, 0, 0, 9223372036854775807)
select_128: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1261, 1, -1); slice_1261 = None
slice_1262: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_128, 1, 0, 9223372036854775807); select_128 = None
slice_1263: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1262, 2, 256, 9223372036854775807); slice_1262 = None
slice_1264: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_14, 0, 0, 9223372036854775807)
select_129: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1264, 1, -1); slice_1264 = None
slice_1265: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_129, 1, 0, 9223372036854775807); select_129 = None
slice_1266: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1265, 2, 256, 9223372036854775807); slice_1265 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_1267: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_444, 0, 0, 9223372036854775807)
slice_1268: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1267, 1, 0, 9223372036854775807); slice_1267 = None
slice_1269: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1268, 2, -257, -1); slice_1268 = None
slice_1270: f32[24, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1269, 3, 257, 9223372036854775807); slice_1269 = None
slice_1271: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_14, 0, 0, 9223372036854775807)
slice_1272: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1271, 1, 1, 9223372036854775807); slice_1271 = None
slice_1273: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1272, 2, 0, 9223372036854775807); slice_1272 = None
slice_1274: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1273, 3, 0, 256); slice_1273 = None
slice_1275: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_311, 0, 0, 9223372036854775807)
select_130: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1275, 1, -1)
slice_1276: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_130, 1, 0, 9223372036854775807)
slice_scatter_312: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1276, slice_1254, 2, 256, 9223372036854775807); slice_1276 = slice_1254 = None
slice_scatter_313: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_130, slice_scatter_312, 1, 0, 9223372036854775807); select_130 = slice_scatter_312 = None
select_scatter_28: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_1275, slice_scatter_313, 1, -1); slice_1275 = slice_scatter_313 = None
slice_scatter_314: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_311, select_scatter_28, 0, 0, 9223372036854775807); slice_scatter_311 = select_scatter_28 = None
slice_1277: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_314, 0, 0, 9223372036854775807)
slice_1278: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1277, 1, 1, 9223372036854775807); slice_1277 = None
slice_1279: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1278, 2, 0, 9223372036854775807); slice_1278 = None
slice_1280: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1279, 3, 0, 256); slice_1279 = None
slice_1281: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_14, 0, 0, 9223372036854775807)
slice_1282: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1281, 1, 1, 9223372036854775807); slice_1281 = None
slice_1283: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1282, 2, 0, 9223372036854775807); slice_1282 = None
slice_1284: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1283, 3, 0, 256); slice_1283 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_1285: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_444, 0, 0, 9223372036854775807); view_444 = None
select_131: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_1285, 1, 0); slice_1285 = None
slice_1286: f32[24, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_131, 1, 0, 255); select_131 = None
slice_1287: f32[24, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_1286, 2, -255, 9223372036854775807); slice_1286 = None
slice_1288: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_14, 0, 0, 9223372036854775807)
select_132: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1288, 1, 0); slice_1288 = None
slice_1289: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_132, 1, 1, 256); select_132 = None
slice_1290: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1289, 2, 1, 256); slice_1289 = None
slice_1291: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_314, 0, 0, 9223372036854775807)
slice_1292: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1291, 1, 1, 9223372036854775807)
slice_1293: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1292, 2, 0, 9223372036854775807)
slice_scatter_315: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1293, slice_1270, 3, 0, 256); slice_1293 = slice_1270 = None
slice_scatter_316: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1292, slice_scatter_315, 2, 0, 9223372036854775807); slice_1292 = slice_scatter_315 = None
slice_scatter_317: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1291, slice_scatter_316, 1, 1, 9223372036854775807); slice_1291 = slice_scatter_316 = None
slice_scatter_318: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_314, slice_scatter_317, 0, 0, 9223372036854775807); slice_scatter_314 = slice_scatter_317 = None
slice_1294: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_318, 0, 0, 9223372036854775807)
select_133: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1294, 1, 0); slice_1294 = None
slice_1295: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_133, 1, 1, 256); select_133 = None
slice_1296: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1295, 2, 1, 256); slice_1295 = None
slice_1297: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_14, 0, 0, 9223372036854775807)
select_134: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1297, 1, 0); slice_1297 = None
slice_1298: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_134, 1, 1, 256); select_134 = None
slice_1299: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1298, 2, 1, 256); slice_1298 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_445: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_14, [2, 12, 1024, 513])
transpose_273: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_445, 2, 1); view_445 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_1300: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_318, 0, 0, 9223372036854775807)
select_135: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1300, 1, 0)
slice_1301: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_135, 1, 1, 256)
slice_scatter_319: f32[24, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1301, slice_1287, 2, 1, 256); slice_1301 = slice_1287 = None
slice_scatter_320: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_135, slice_scatter_319, 1, 1, 256); select_135 = slice_scatter_319 = None
select_scatter_29: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_1300, slice_scatter_320, 1, 0); slice_1300 = slice_scatter_320 = None
slice_scatter_321: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_318, select_scatter_29, 0, 0, 9223372036854775807); slice_scatter_318 = select_scatter_29 = None
view_446: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_321, [2, 12, 1024, 513])
transpose_274: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_446, 2, 1); view_446 = None
new_ones_21: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_274, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril_14: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones_21); new_ones_21 = None
flip_28: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril_14, [0]); tril_14 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_108: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip_28, 0); flip_28 = None
slice_1302: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_108, 1, 0, 9223372036854775807); unsqueeze_108 = None
unsqueeze_109: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_1302, 2); slice_1302 = None
slice_1303: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_109, 3, 0, 9223372036854775807); unsqueeze_109 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_29: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_1303, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_1304: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_274, 0, 0, 9223372036854775807)
slice_1305: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1304, 1, 0, 256); slice_1304 = None
slice_1306: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1305, 2, 0, 9223372036854775807); slice_1305 = None
slice_1307: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1306, 3, 0, 257); slice_1306 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand_28: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_1303, [2, 256, 12, 257]); slice_1303 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_28: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand_28, 1); expand_28 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_447: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_321, [2, 12, 1024, 513])
transpose_275: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_447, 2, 1); view_447 = None
slice_1308: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_275, 0, 0, 9223372036854775807); transpose_275 = None
slice_1309: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1308, 1, 0, 256); slice_1308 = None
slice_1310: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1309, 2, 0, 9223372036854775807); slice_1309 = None
slice_1311: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1310, 3, 0, 257); slice_1310 = None
masked_fill_42: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_1311, eq_28, -inf); slice_1311 = eq_28 = None
view_448: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_14, [2, 12, 1024, 513])
transpose_276: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_448, 2, 1); view_448 = None
slice_1312: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_276, 0, 0, 9223372036854775807); transpose_276 = None
slice_1313: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1312, 1, 0, 256); slice_1312 = None
slice_1314: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1313, 2, 0, 9223372036854775807); slice_1313 = None
slice_1315: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1314, 3, 0, 257); slice_1314 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
view_449: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_14, [2, 12, 1024, 513])
transpose_277: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_449, 2, 1); view_449 = None
slice_1316: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_274, 0, 0, 9223372036854775807); transpose_274 = None
slice_1317: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1316, 1, -256, 9223372036854775807); slice_1316 = None
slice_1318: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1317, 2, 0, 9223372036854775807); slice_1317 = None
slice_1319: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1318, 3, -257, 9223372036854775807); slice_1318 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_29: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_29, [2, 256, 12, 257]); flip_29 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_29: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand_29, 1); expand_29 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_450: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_321, [2, 12, 1024, 513]); slice_scatter_321 = None
transpose_278: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_450, 2, 1); view_450 = None
slice_1320: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_278, 0, 0, 9223372036854775807)
slice_1321: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1320, 1, 0, 256)
slice_1322: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1321, 2, 0, 9223372036854775807)
slice_scatter_322: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1322, masked_fill_42, 3, 0, 257); slice_1322 = masked_fill_42 = None
slice_scatter_323: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1321, slice_scatter_322, 2, 0, 9223372036854775807); slice_1321 = slice_scatter_322 = None
slice_scatter_324: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_1320, slice_scatter_323, 1, 0, 256); slice_1320 = slice_scatter_323 = None
slice_scatter_325: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_278, slice_scatter_324, 0, 0, 9223372036854775807); transpose_278 = slice_scatter_324 = None
transpose_279: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_325, 2, 1); slice_scatter_325 = None
view_451: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_279, [24, 4, 256, 513]); transpose_279 = None
view_452: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_451, [2, 12, 1024, 513])
transpose_280: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_452, 2, 1); view_452 = None
slice_1323: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_280, 0, 0, 9223372036854775807); transpose_280 = None
slice_1324: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1323, 1, -256, 9223372036854775807); slice_1323 = None
slice_1325: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1324, 2, 0, 9223372036854775807); slice_1324 = None
slice_1326: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1325, 3, -257, 9223372036854775807); slice_1325 = None
masked_fill_43: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_1326, eq_29, -inf); slice_1326 = eq_29 = None
view_453: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_14, [2, 12, 1024, 513])
transpose_281: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_453, 2, 1); view_453 = None
slice_1327: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_281, 0, 0, 9223372036854775807); transpose_281 = None
slice_1328: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1327, 1, -256, 9223372036854775807); slice_1327 = None
slice_1329: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1328, 2, 0, 9223372036854775807); slice_1328 = None
slice_1330: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1329, 3, -257, 9223372036854775807); slice_1329 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
ne_7: b8[2, 1024], [1024, 1] = torch.ops.aten.ne.Scalar(orig_primals_194, 0)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:585, code: remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]
slice_1331: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(ne_7, 0, 0, 9223372036854775807); ne_7 = None
slice_1332: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_1331, 1, 0, 9223372036854775807); slice_1331 = None
unsqueeze_110: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_1332, 2); slice_1332 = None
unsqueeze_111: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_110, 3); unsqueeze_110 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:588, code: float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
_to_copy_7: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten._to_copy.default(unsqueeze_111, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))
masked_fill_44: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.masked_fill.Scalar(_to_copy_7, unsqueeze_111, -10000.0); _to_copy_7 = unsqueeze_111 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:593, code: float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
new_ones_22: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.new_ones.default(masked_fill_44, [2, 1024, 1, 1], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_282: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(new_ones_22, 1, 2); new_ones_22 = None
view_454: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_282, [2, 1024, 1]); transpose_282 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_283: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(masked_fill_44, 1, 2); masked_fill_44 = None
view_455: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_283, [2, 1024, 1]); transpose_283 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_456: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_454, [2, 2, 512, 1]); view_454 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_45: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_456, [2, 3, 512, 1], [1024, 256, 1, 1]); view_456 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_457: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_455, [2, 2, 512, 1]); view_455 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_46: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_457, [2, 3, 512, 1], [1024, 256, 1, 1]); view_457 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze_112: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_45, 4); as_strided_45 = None
permute_105: f32[2, 3, 512, 1, 1], [1024, 256, 1, 0, 1] = torch.ops.aten.permute.default(unsqueeze_112, [0, 1, 2, 4, 3]); unsqueeze_112 = None
unsqueeze_113: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_46, 4); as_strided_46 = None
permute_106: f32[2, 3, 1, 512, 1], [1024, 256, 0, 1, 1] = torch.ops.aten.permute.default(unsqueeze_113, [0, 1, 4, 2, 3]); unsqueeze_113 = None
mul_59: f32[2, 3, 512, 512, 1], [786432, 262144, 512, 1, 1] = torch.ops.aten.mul.Tensor(permute_105, permute_106); permute_105 = permute_106 = None
view_458: f32[2, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(mul_59, [2, 3, 512, 512]); mul_59 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd_29: f32[2, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_458, [0, 0, 0, 1], 0.0); view_458 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_459: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd_29, [2, 3, 512, 513]); constant_pad_nd_29 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty_15: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_459, [2, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_1333: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_459, 0, 0, 9223372036854775807)
slice_1334: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1333, 1, 0, 9223372036854775807); slice_1333 = None
slice_1335: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1334, 2, 0, 256); slice_1334 = None
slice_1336: f32[2, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1335, 3, 0, 257); slice_1335 = None
slice_1337: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_15, 0, 0, 9223372036854775807)
slice_1338: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1337, 1, 0, -1); slice_1337 = None
slice_1339: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1338, 2, 0, 9223372036854775807); slice_1338 = None
slice_1340: f32[2, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1339, 3, 256, 9223372036854775807); slice_1339 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_1341: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_459, 0, 0, 9223372036854775807)
select_136: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_1341, 1, -1); slice_1341 = None
slice_1342: f32[2, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_136, 1, 256, 9223372036854775807); select_136 = None
slice_1343: f32[2, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_1342, 2, 0, 257); slice_1342 = None
slice_1344: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_15, 0, 0, 9223372036854775807)
select_137: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1344, 1, -1); slice_1344 = None
slice_1345: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_137, 1, 0, 9223372036854775807); select_137 = None
slice_1346: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1345, 2, 256, 9223372036854775807); slice_1345 = None
slice_1347: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_15, 0, 0, 9223372036854775807)
slice_1348: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1347, 1, 0, -1)
slice_1349: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1348, 2, 0, 9223372036854775807)
slice_scatter_326: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1349, slice_1336, 3, 256, 9223372036854775807); slice_1349 = slice_1336 = None
slice_scatter_327: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1348, slice_scatter_326, 2, 0, 9223372036854775807); slice_1348 = slice_scatter_326 = None
slice_scatter_328: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1347, slice_scatter_327, 1, 0, -1); slice_1347 = slice_scatter_327 = None
slice_scatter_329: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty_15, slice_scatter_328, 0, 0, 9223372036854775807); slice_scatter_328 = None
slice_1350: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_329, 0, 0, 9223372036854775807)
select_138: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1350, 1, -1); slice_1350 = None
slice_1351: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_138, 1, 0, 9223372036854775807); select_138 = None
slice_1352: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1351, 2, 256, 9223372036854775807); slice_1351 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_1353: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_459, 0, 0, 9223372036854775807)
slice_1354: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1353, 1, 0, 9223372036854775807); slice_1353 = None
slice_1355: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1354, 2, -257, -1); slice_1354 = None
slice_1356: f32[2, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1355, 3, 257, 9223372036854775807); slice_1355 = None
slice_1357: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_15, 0, 0, 9223372036854775807)
slice_1358: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1357, 1, 1, 9223372036854775807); slice_1357 = None
slice_1359: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1358, 2, 0, 9223372036854775807); slice_1358 = None
slice_1360: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1359, 3, 0, 256); slice_1359 = None
slice_1361: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_329, 0, 0, 9223372036854775807)
select_139: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1361, 1, -1)
slice_1362: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_139, 1, 0, 9223372036854775807)
slice_scatter_330: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1362, slice_1343, 2, 256, 9223372036854775807); slice_1362 = slice_1343 = None
slice_scatter_331: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_139, slice_scatter_330, 1, 0, 9223372036854775807); select_139 = slice_scatter_330 = None
select_scatter_30: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_1361, slice_scatter_331, 1, -1); slice_1361 = slice_scatter_331 = None
slice_scatter_332: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_329, select_scatter_30, 0, 0, 9223372036854775807); slice_scatter_329 = select_scatter_30 = None
slice_1363: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_332, 0, 0, 9223372036854775807)
slice_1364: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1363, 1, 1, 9223372036854775807); slice_1363 = None
slice_1365: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1364, 2, 0, 9223372036854775807); slice_1364 = None
slice_1366: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1365, 3, 0, 256); slice_1365 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_1367: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_459, 0, 0, 9223372036854775807); view_459 = None
select_140: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_1367, 1, 0); slice_1367 = None
slice_1368: f32[2, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_140, 1, 0, 255); select_140 = None
slice_1369: f32[2, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_1368, 2, -255, 9223372036854775807); slice_1368 = None
slice_1370: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_15, 0, 0, 9223372036854775807)
select_141: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1370, 1, 0); slice_1370 = None
slice_1371: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_141, 1, 1, 256); select_141 = None
slice_1372: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1371, 2, 1, 256); slice_1371 = None
slice_1373: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_332, 0, 0, 9223372036854775807)
slice_1374: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1373, 1, 1, 9223372036854775807)
slice_1375: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1374, 2, 0, 9223372036854775807)
slice_scatter_333: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1375, slice_1356, 3, 0, 256); slice_1375 = slice_1356 = None
slice_scatter_334: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1374, slice_scatter_333, 2, 0, 9223372036854775807); slice_1374 = slice_scatter_333 = None
slice_scatter_335: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1373, slice_scatter_334, 1, 1, 9223372036854775807); slice_1373 = slice_scatter_334 = None
slice_scatter_336: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_332, slice_scatter_335, 0, 0, 9223372036854775807); slice_scatter_332 = slice_scatter_335 = None
slice_1376: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_336, 0, 0, 9223372036854775807)
select_142: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1376, 1, 0); slice_1376 = None
slice_1377: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_142, 1, 1, 256); select_142 = None
slice_1378: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1377, 2, 1, 256); slice_1377 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_460: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_15, [2, 1, 1024, 513]); new_empty_15 = None
transpose_284: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_460, 2, 1); view_460 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_1379: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_336, 0, 0, 9223372036854775807)
select_143: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1379, 1, 0)
slice_1380: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_143, 1, 1, 256)
slice_scatter_337: f32[2, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1380, slice_1369, 2, 1, 256); slice_1380 = slice_1369 = None
slice_scatter_338: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_143, slice_scatter_337, 1, 1, 256); select_143 = slice_scatter_337 = None
select_scatter_31: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_1379, slice_scatter_338, 1, 0); slice_1379 = slice_scatter_338 = None
slice_scatter_339: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_336, select_scatter_31, 0, 0, 9223372036854775807); slice_scatter_336 = select_scatter_31 = None
view_461: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_339, [2, 1, 1024, 513])
transpose_285: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_461, 2, 1); view_461 = None
new_ones_23: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_285, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril_15: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones_23); new_ones_23 = None
flip_30: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril_15, [0]); tril_15 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_114: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip_30, 0); flip_30 = None
slice_1381: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_114, 1, 0, 9223372036854775807); unsqueeze_114 = None
unsqueeze_115: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_1381, 2); slice_1381 = None
slice_1382: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_115, 3, 0, 9223372036854775807); unsqueeze_115 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_31: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_1382, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_1383: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_285, 0, 0, 9223372036854775807)
slice_1384: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1383, 1, 0, 256); slice_1383 = None
slice_1385: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1384, 2, 0, 9223372036854775807); slice_1384 = None
slice_1386: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1385, 3, 0, 257); slice_1385 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand_30: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_1382, [2, 256, 1, 257]); slice_1382 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_30: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_30, 1); expand_30 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_462: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_339, [2, 1, 1024, 513])
transpose_286: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_462, 2, 1); view_462 = None
slice_1387: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_286, 0, 0, 9223372036854775807); transpose_286 = None
slice_1388: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1387, 1, 0, 256); slice_1387 = None
slice_1389: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1388, 2, 0, 9223372036854775807); slice_1388 = None
slice_1390: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1389, 3, 0, 257); slice_1389 = None
masked_fill_45: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_1390, eq_30, -inf); slice_1390 = eq_30 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
slice_1391: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_285, 0, 0, 9223372036854775807); transpose_285 = None
slice_1392: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1391, 1, -256, 9223372036854775807); slice_1391 = None
slice_1393: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1392, 2, 0, 9223372036854775807); slice_1392 = None
slice_1394: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1393, 3, -257, 9223372036854775807); slice_1393 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_31: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_31, [2, 256, 1, 257]); flip_31 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_31: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_31, 1); expand_31 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_463: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_339, [2, 1, 1024, 513]); slice_scatter_339 = None
transpose_287: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_463, 2, 1); view_463 = None
slice_1395: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_287, 0, 0, 9223372036854775807)
slice_1396: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1395, 1, 0, 256)
slice_1397: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1396, 2, 0, 9223372036854775807)
slice_scatter_340: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1397, masked_fill_45, 3, 0, 257); slice_1397 = masked_fill_45 = None
slice_scatter_341: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1396, slice_scatter_340, 2, 0, 9223372036854775807); slice_1396 = slice_scatter_340 = None
slice_scatter_342: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_1395, slice_scatter_341, 1, 0, 256); slice_1395 = slice_scatter_341 = None
slice_scatter_343: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_287, slice_scatter_342, 0, 0, 9223372036854775807); transpose_287 = slice_scatter_342 = None
transpose_288: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_343, 2, 1); slice_scatter_343 = None
view_464: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_288, [2, 4, 256, 513]); transpose_288 = None
view_465: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_464, [2, 1, 1024, 513])
transpose_289: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_465, 2, 1); view_465 = None
slice_1398: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_289, 0, 0, 9223372036854775807); transpose_289 = None
slice_1399: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1398, 1, -256, 9223372036854775807); slice_1398 = None
slice_1400: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1399, 2, 0, 9223372036854775807); slice_1399 = None
slice_1401: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1400, 3, -257, 9223372036854775807); slice_1400 = None
masked_fill_46: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_1401, eq_31, -inf); slice_1401 = eq_31 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:597, code: attn_scores += diagonal_mask
view_466: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_14, [2, 12, 1024, 513])
transpose_290: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_466, 2, 1); view_466 = None
view_467: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_451, [2, 12, 1024, 513]); view_451 = None
transpose_291: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_467, 2, 1); view_467 = None
slice_1402: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_291, 0, 0, 9223372036854775807)
slice_1403: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1402, 1, -256, 9223372036854775807)
slice_1404: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1403, 2, 0, 9223372036854775807)
slice_scatter_344: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1404, masked_fill_43, 3, -257, 9223372036854775807); slice_1404 = masked_fill_43 = None
slice_scatter_345: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1403, slice_scatter_344, 2, 0, 9223372036854775807); slice_1403 = slice_scatter_344 = None
slice_scatter_346: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_1402, slice_scatter_345, 1, -256, 9223372036854775807); slice_1402 = slice_scatter_345 = None
slice_scatter_347: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_291, slice_scatter_346, 0, 0, 9223372036854775807); transpose_291 = slice_scatter_346 = None
transpose_292: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_347, 2, 1); slice_scatter_347 = None
view_468: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_292, [24, 4, 256, 513]); transpose_292 = None
view_469: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_468, [2, 12, 1024, 513]); view_468 = None
transpose_293: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_469, 2, 1); view_469 = None
view_470: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_464, [2, 1, 1024, 513]); view_464 = None
transpose_294: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_470, 2, 1); view_470 = None
slice_1405: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_294, 0, 0, 9223372036854775807)
slice_1406: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1405, 1, -256, 9223372036854775807)
slice_1407: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1406, 2, 0, 9223372036854775807)
slice_scatter_348: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1407, masked_fill_46, 3, -257, 9223372036854775807); slice_1407 = masked_fill_46 = None
slice_scatter_349: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1406, slice_scatter_348, 2, 0, 9223372036854775807); slice_1406 = slice_scatter_348 = None
slice_scatter_350: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_1405, slice_scatter_349, 1, -256, 9223372036854775807); slice_1405 = slice_scatter_349 = None
slice_scatter_351: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_294, slice_scatter_350, 0, 0, 9223372036854775807); transpose_294 = slice_scatter_350 = None
transpose_295: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_351, 2, 1); slice_scatter_351 = None
view_471: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_295, [2, 4, 256, 513]); transpose_295 = None
view_472: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_471, [2, 1, 1024, 513]); view_471 = None
transpose_296: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_472, 2, 1); view_472 = None
add_52: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.add.Tensor(transpose_293, transpose_296); transpose_293 = transpose_296 = None
view_473: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_14, [2, 12, 1024, 513]); new_empty_14 = None
transpose_297: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_473, 2, 1); view_473 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:635, code: attn_probs = nn.functional.softmax(
_softmax_7: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten._softmax.default(add_52, -1, False); add_52 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:646, code: attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
slice_1408: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(orig_primals_195, 0, 0, 9223372036854775807)
slice_1409: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_1408, 1, 0, 9223372036854775807); slice_1408 = None
unsqueeze_116: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_1409, 2); slice_1409 = None
unsqueeze_117: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_116, 3); unsqueeze_116 = None
masked_fill_47: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten.masked_fill.Scalar(_softmax_7, unsqueeze_117, 0.0); _softmax_7 = unsqueeze_117 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:655, code: value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_474: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_51, [1024, 2, 12, 64]); add_51 = None
transpose_298: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_474, 0, 1); view_474 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:883, code: chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
transpose_299: f32[2, 12, 1024, 513], [6303744, 513, 6156, 1] = torch.ops.aten.transpose.int(masked_fill_47, 1, 2); masked_fill_47 = None
clone_68: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.clone.default(transpose_299, memory_format = torch.contiguous_format); transpose_299 = None
_unsafe_view_37: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten._unsafe_view.default(clone_68, [24, 4, 256, 513]); clone_68 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:888, code: value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_300: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_298, 1, 2); transpose_298 = None
view_475: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_300, [24, 1024, 64]); transpose_300 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:891, code: padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1)
constant_pad_nd_30: f32[24, 1536, 64], [98304, 64, 1] = torch.ops.aten.constant_pad_nd.default(view_475, [0, 0, 256, 256], -1.0); view_475 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:902, code: chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride)
as_strided_47: f32[24, 4, 768, 64], [98304, 16384, 64, 1] = torch.ops.aten.as_strided.default(constant_pad_nd_30, [24, 4, 768, 64], [98304, 16384, 64, 1]); constant_pad_nd_30 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:755, code: chunked_hidden_states = nn.functional.pad(
constant_pad_nd_31: f32[24, 4, 256, 770], [788480, 197120, 770, 1] = torch.ops.aten.constant_pad_nd.default(_unsafe_view_37, [0, 257], 0.0); _unsafe_view_37 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:758, code: chunked_hidden_states = chunked_hidden_states.view(
view_476: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.view.default(constant_pad_nd_31, [24, 4, -1]); constant_pad_nd_31 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:761, code: chunked_hidden_states = chunked_hidden_states[
slice_1410: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(view_476, 0, 0, 9223372036854775807); view_476 = None
slice_1411: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_1410, 1, 0, 9223372036854775807); slice_1410 = None
slice_1412: f32[24, 4, 196864], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_1411, 2, 0, -256); slice_1411 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:764, code: chunked_hidden_states = chunked_hidden_states.view(
view_477: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.view.default(slice_1412, [24, 4, 256, 769]); slice_1412 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:767, code: chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
slice_1413: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(view_477, 0, 0, 9223372036854775807); view_477 = None
slice_1414: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_1413, 1, 0, 9223372036854775807)
slice_1415: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_1414, 2, 0, 9223372036854775807); slice_1414 = None
slice_1416: f32[24, 4, 256, 768], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_1415, 3, 0, -1); slice_1415 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:906, code: context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
unsqueeze_118: f32[24, 4, 256, 768, 1], [788480, 197120, 769, 1, 0] = torch.ops.aten.unsqueeze.default(slice_1416, 4)
permute_107: f32[24, 4, 256, 1, 768], [788480, 197120, 769, 0, 1] = torch.ops.aten.permute.default(unsqueeze_118, [0, 1, 2, 4, 3]); unsqueeze_118 = None
unsqueeze_119: f32[24, 4, 768, 64, 1], [98304, 16384, 64, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_47, 4); as_strided_47 = None
permute_108: f32[24, 4, 1, 64, 768], [98304, 16384, 0, 1, 64] = torch.ops.aten.permute.default(unsqueeze_119, [0, 1, 4, 3, 2]); unsqueeze_119 = None
permute_109: f32[24, 4, 256, 768, 1], [788480, 197120, 769, 1, 0] = torch.ops.aten.permute.default(permute_107, [0, 1, 2, 4, 3]); permute_107 = None
sym_size_59: Sym(24) = torch.ops.aten.sym_size(slice_1413, 0); slice_1413 = None
# No stacktrace found for following nodes
mul_60: Sym(96) = sym_size_59 * 4
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:906, code: context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
sym_size_60: Sym(768) = torch.ops.aten.sym_size(slice_1416, 3); slice_1416 = None
view_478: f32[96, 256, 768], [197120, 769, 1] = torch.ops.aten.view.default(permute_109, [mul_60, 256, sym_size_60]); permute_109 = None
permute_110: f32[24, 4, 768, 64, 1], [98304, 16384, 64, 1, 0] = torch.ops.aten.permute.default(permute_108, [0, 1, 4, 3, 2]); permute_108 = None
clone_69: f32[24, 4, 768, 64, 1], [196608, 49152, 64, 1, 1] = torch.ops.aten.clone.default(permute_110, memory_format = torch.contiguous_format); permute_110 = None
_unsafe_view_38: f32[96, 768, 64], [49152, 64, 1] = torch.ops.aten._unsafe_view.default(clone_69, [mul_60, sym_size_60, 64]); clone_69 = mul_60 = sym_size_60 = None
bmm_15: f32[96, 256, 64], [16384, 64, 1] = torch.ops.aten.bmm.default(view_478, _unsafe_view_38); view_478 = _unsafe_view_38 = None
view_479: f32[24, 4, 256, 1, 64], [65536, 16384, 64, 64, 1] = torch.ops.aten.view.default(bmm_15, [sym_size_59, 4, 256, 1, 64]); bmm_15 = None
permute_111: f32[24, 4, 256, 64, 1], [65536, 16384, 64, 1, 64] = torch.ops.aten.permute.default(view_479, [0, 1, 2, 4, 3])
sym_size_61: Sym(4) = torch.ops.aten.sym_size(view_479, 1); view_479 = None
view_480: f32[24, 4, 256, 64], [65536, 16384, 64, 1] = torch.ops.aten.view.default(permute_111, [sym_size_59, sym_size_61, 256, 64]); permute_111 = sym_size_59 = sym_size_61 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:907, code: return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
view_481: f32[2, 12, 1024, 64], [786432, 65536, 64, 1] = torch.ops.aten.view.default(view_480, [2, 12, 1024, 64]); view_480 = None
transpose_301: f32[2, 1024, 12, 64], [786432, 64, 65536, 1] = torch.ops.aten.transpose.int(view_481, 1, 2)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:674, code: attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
transpose_302: f32[1024, 2, 12, 64], [64, 786432, 65536, 1] = torch.ops.aten.transpose.int(transpose_301, 0, 1); transpose_301 = None
clone_70: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.clone.default(transpose_302, memory_format = torch.contiguous_format); transpose_302 = None
_unsafe_view_39: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten._unsafe_view.default(clone_70, [1024, 2, 768]); clone_70 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:703, code: outputs = (attn_output.transpose(0, 1),)
transpose_303: f32[2, 1024, 768], [768, 1536, 1] = torch.ops.aten.transpose.int(_unsafe_view_39, 0, 1)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1116, code: hidden_states = self.dense(hidden_states)
t_45: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_119); orig_primals_119 = None
clone_71: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.clone.default(transpose_303, memory_format = torch.contiguous_format); transpose_303 = None
sym_size_62: Sym(1024) = torch.ops.aten.sym_size(view_481, 2); view_481 = None
# No stacktrace found for following nodes
mul_61: Sym(2048) = 2 * sym_size_62
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1116, code: hidden_states = self.dense(hidden_states)
sym_size_63: Sym(768) = torch.ops.aten.sym_size(_unsafe_view_39, 2); _unsafe_view_39 = None
view_482: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_71, [mul_61, sym_size_63]); clone_71 = mul_61 = sym_size_63 = None
mm_31: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_482, t_45); view_482 = t_45 = None
view_483: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.view.default(mm_31, [2, sym_size_62, 768]); mm_31 = sym_size_62 = None
add_53: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(view_483, orig_primals_120); orig_primals_120 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1118, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
add_54: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(add_53, getitem_39); add_53 = getitem_39 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1118, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
native_layer_norm_14 = torch.ops.aten.native_layer_norm.default(add_54, [768], orig_primals_121, orig_primals_122, 1e-05); add_54 = orig_primals_121 = orig_primals_122 = None
getitem_42: f32[2, 1024, 768], [786432, 768, 1] = native_layer_norm_14[0]
getitem_43: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_14[1]
getitem_44: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_14[2]; native_layer_norm_14 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1182, code: hidden_states = self.dense(hidden_states)
t_46: f32[768, 3072], [1, 768] = torch.ops.aten.t.default(orig_primals_123); orig_primals_123 = None
sym_size_64: Sym(1024) = torch.ops.aten.sym_size(view_483, 1); view_483 = None
# No stacktrace found for following nodes
mul_62: Sym(2048) = 2 * sym_size_64
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1182, code: hidden_states = self.dense(hidden_states)
view_484: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(getitem_42, [mul_62, 768]); mul_62 = None
addmm_14: f32[2048, 3072], [3072, 1] = torch.ops.aten.addmm.default(orig_primals_124, view_484, t_46); orig_primals_124 = view_484 = t_46 = None
view_485: f32[2, 1024, 3072], [3145728, 3072, 1] = torch.ops.aten.view.default(addmm_14, [2, sym_size_64, 3072]); addmm_14 = sym_size_64 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/activations.py:56, code: return self.act(input)
gelu_7: f32[2, 1024, 3072], [3145728, 3072, 1] = torch.ops.aten.gelu.default(view_485)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1196, code: hidden_states = self.dense(hidden_states)
t_47: f32[3072, 768], [1, 3072] = torch.ops.aten.t.default(orig_primals_125); orig_primals_125 = None
sym_size_65: Sym(1024) = torch.ops.aten.sym_size(view_485, 1); view_485 = None
# No stacktrace found for following nodes
mul_63: Sym(2048) = 2 * sym_size_65
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1196, code: hidden_states = self.dense(hidden_states)
view_486: f32[2048, 3072], [3072, 1] = torch.ops.aten.view.default(gelu_7, [mul_63, 3072]); gelu_7 = mul_63 = None
addmm_15: f32[2048, 768], [768, 1] = torch.ops.aten.addmm.default(orig_primals_126, view_486, t_47); orig_primals_126 = view_486 = t_47 = None
view_487: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.view.default(addmm_15, [2, sym_size_65, 768]); addmm_15 = sym_size_65 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1198, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
add_55: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(view_487, getitem_42); getitem_42 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1198, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
native_layer_norm_15 = torch.ops.aten.native_layer_norm.default(add_55, [768], orig_primals_127, orig_primals_128, 1e-05); add_55 = orig_primals_127 = orig_primals_128 = None
getitem_45: f32[2, 1024, 768], [786432, 768, 1] = native_layer_norm_15[0]
getitem_46: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_15[1]
getitem_47: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_15[2]; native_layer_norm_15 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:562, code: hidden_states = hidden_states.transpose(0, 1)
transpose_304: f32[1024, 2, 768], [768, 786432, 1] = torch.ops.aten.transpose.int(getitem_45, 0, 1)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
t_48: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_129); orig_primals_129 = None
clone_72: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_304, memory_format = torch.contiguous_format)
sym_size_66: Sym(1024) = torch.ops.aten.sym_size(view_487, 1); view_487 = None
# No stacktrace found for following nodes
mul_64: Sym(2048) = sym_size_66 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
view_488: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_72, [mul_64, 768]); clone_72 = mul_64 = None
mm_32: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_488, t_48); view_488 = t_48 = None
view_489: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_32, [sym_size_66, 2, 768]); mm_32 = None
add_56: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_489, orig_primals_130); view_489 = orig_primals_130 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
t_49: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_131); orig_primals_131 = None
clone_73: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_304, memory_format = torch.contiguous_format)
# No stacktrace found for following nodes
mul_65: Sym(2048) = sym_size_66 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
view_490: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_73, [mul_65, 768]); clone_73 = mul_65 = None
mm_33: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_490, t_49); view_490 = t_49 = None
view_491: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_33, [sym_size_66, 2, 768]); mm_33 = None
add_57: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_491, orig_primals_132); view_491 = orig_primals_132 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
t_50: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_133); orig_primals_133 = None
clone_74: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_304, memory_format = torch.contiguous_format); transpose_304 = None
# No stacktrace found for following nodes
mul_66: Sym(2048) = sym_size_66 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
view_492: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_74, [mul_66, 768]); clone_74 = mul_66 = None
mm_34: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_492, t_50); view_492 = t_50 = None
view_493: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_34, [sym_size_66, 2, 768]); mm_34 = sym_size_66 = None
add_58: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_493, orig_primals_134); view_493 = orig_primals_134 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:575, code: query_vectors /= math.sqrt(self.head_dim)
div_8: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.div.Tensor(add_56, 8.0); add_56 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:577, code: query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_494: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div_8, [1024, 2, 12, 64])
transpose_305: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_494, 0, 1); view_494 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:578, code: key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_495: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_57, [1024, 2, 12, 64]); add_57 = None
transpose_306: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_495, 0, 1); view_495 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_307: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_305, 1, 2); transpose_305 = None
view_496: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_307, [24, 1024, 64]); transpose_307 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_308: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_306, 1, 2); transpose_306 = None
view_497: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_308, [24, 1024, 64]); transpose_308 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_498: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_496, [24, 2, 512, 64]); view_496 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_48: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_498, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_498 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_499: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_497, [24, 2, 512, 64]); view_497 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_49: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_499, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_499 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze_120: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_48, 4); as_strided_48 = None
permute_112: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze_120, [0, 1, 2, 4, 3]); unsqueeze_120 = None
unsqueeze_121: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_49, 4); as_strided_49 = None
permute_113: f32[24, 3, 1, 512, 64], [64, 393216, 0, 1536, 1] = torch.ops.aten.permute.default(unsqueeze_121, [0, 1, 4, 2, 3]); unsqueeze_121 = None
permute_114: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute_112, [0, 1, 2, 4, 3]); permute_112 = None
view_500: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div_8, [1024, 2, 12, 64]); div_8 = None
transpose_309: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_500, 0, 1); view_500 = None
transpose_310: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_309, 1, 2); transpose_309 = None
view_501: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_310, [24, 1024, 64]); transpose_310 = None
view_502: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_501, [24, 2, 512, 64]); view_501 = None
as_strided_50: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_502, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_502 = None
unsqueeze_122: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_50, 4); as_strided_50 = None
permute_115: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze_122, [0, 1, 2, 4, 3]); unsqueeze_122 = None
permute_116: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute_115, [0, 1, 2, 4, 3]); permute_115 = None
clone_75: f32[24, 3, 512, 64, 1], [98304, 32768, 64, 1, 1] = torch.ops.aten.clone.default(permute_116, memory_format = torch.contiguous_format); permute_116 = None
_unsafe_view_40: f32[72, 512, 64], [32768, 64, 1] = torch.ops.aten._unsafe_view.default(clone_75, [72, 512, 64]); clone_75 = None
permute_117: f32[24, 3, 64, 512, 1], [64, 393216, 1, 1536, 0] = torch.ops.aten.permute.default(permute_113, [0, 1, 4, 3, 2]); permute_113 = None
clone_76: f32[24, 3, 64, 512, 1], [98304, 32768, 512, 1, 1] = torch.ops.aten.clone.default(permute_117, memory_format = torch.contiguous_format); permute_117 = None
_unsafe_view_41: f32[72, 64, 512], [32768, 512, 1] = torch.ops.aten._unsafe_view.default(clone_76, [72, 64, 512]); clone_76 = None
bmm_16: f32[72, 512, 512], [262144, 512, 1] = torch.ops.aten.bmm.default(_unsafe_view_40, _unsafe_view_41); _unsafe_view_40 = _unsafe_view_41 = None
view_503: f32[24, 3, 512, 1, 512], [786432, 262144, 512, 512, 1] = torch.ops.aten.view.default(bmm_16, [24, 3, 512, 1, 512]); bmm_16 = None
permute_118: f32[24, 3, 512, 512, 1], [786432, 262144, 512, 1, 512] = torch.ops.aten.permute.default(view_503, [0, 1, 2, 4, 3]); view_503 = None
view_504: f32[24, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(permute_118, [24, 3, 512, 512]); permute_118 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd_32: f32[24, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_504, [0, 0, 0, 1], 0.0); view_504 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_505: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd_32, [24, 3, 512, 513]); constant_pad_nd_32 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty_16: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_505, [24, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_1417: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_505, 0, 0, 9223372036854775807)
slice_1418: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1417, 1, 0, 9223372036854775807); slice_1417 = None
slice_1419: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1418, 2, 0, 256); slice_1418 = None
slice_1420: f32[24, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1419, 3, 0, 257); slice_1419 = None
slice_1421: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_16, 0, 0, 9223372036854775807)
slice_1422: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1421, 1, 0, -1); slice_1421 = None
slice_1423: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1422, 2, 0, 9223372036854775807); slice_1422 = None
slice_1424: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1423, 3, 256, 9223372036854775807); slice_1423 = None
slice_1425: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_16, 0, 0, 9223372036854775807)
slice_1426: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1425, 1, 0, -1); slice_1425 = None
slice_1427: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1426, 2, 0, 9223372036854775807); slice_1426 = None
slice_1428: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1427, 3, 256, 9223372036854775807); slice_1427 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_1429: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_505, 0, 0, 9223372036854775807)
select_144: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_1429, 1, -1); slice_1429 = None
slice_1430: f32[24, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_144, 1, 256, 9223372036854775807); select_144 = None
slice_1431: f32[24, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_1430, 2, 0, 257); slice_1430 = None
slice_1432: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_16, 0, 0, 9223372036854775807)
select_145: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1432, 1, -1); slice_1432 = None
slice_1433: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_145, 1, 0, 9223372036854775807); select_145 = None
slice_1434: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1433, 2, 256, 9223372036854775807); slice_1433 = None
slice_1435: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_16, 0, 0, 9223372036854775807)
slice_1436: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1435, 1, 0, -1)
slice_1437: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1436, 2, 0, 9223372036854775807)
slice_scatter_352: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1437, slice_1420, 3, 256, 9223372036854775807); slice_1437 = slice_1420 = None
slice_scatter_353: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1436, slice_scatter_352, 2, 0, 9223372036854775807); slice_1436 = slice_scatter_352 = None
slice_scatter_354: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1435, slice_scatter_353, 1, 0, -1); slice_1435 = slice_scatter_353 = None
slice_scatter_355: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty_16, slice_scatter_354, 0, 0, 9223372036854775807); slice_scatter_354 = None
slice_1438: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_355, 0, 0, 9223372036854775807)
select_146: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1438, 1, -1); slice_1438 = None
slice_1439: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_146, 1, 0, 9223372036854775807); select_146 = None
slice_1440: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1439, 2, 256, 9223372036854775807); slice_1439 = None
slice_1441: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_16, 0, 0, 9223372036854775807)
select_147: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1441, 1, -1); slice_1441 = None
slice_1442: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_147, 1, 0, 9223372036854775807); select_147 = None
slice_1443: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1442, 2, 256, 9223372036854775807); slice_1442 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_1444: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_505, 0, 0, 9223372036854775807)
slice_1445: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1444, 1, 0, 9223372036854775807); slice_1444 = None
slice_1446: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1445, 2, -257, -1); slice_1445 = None
slice_1447: f32[24, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1446, 3, 257, 9223372036854775807); slice_1446 = None
slice_1448: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_16, 0, 0, 9223372036854775807)
slice_1449: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1448, 1, 1, 9223372036854775807); slice_1448 = None
slice_1450: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1449, 2, 0, 9223372036854775807); slice_1449 = None
slice_1451: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1450, 3, 0, 256); slice_1450 = None
slice_1452: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_355, 0, 0, 9223372036854775807)
select_148: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1452, 1, -1)
slice_1453: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_148, 1, 0, 9223372036854775807)
slice_scatter_356: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1453, slice_1431, 2, 256, 9223372036854775807); slice_1453 = slice_1431 = None
slice_scatter_357: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_148, slice_scatter_356, 1, 0, 9223372036854775807); select_148 = slice_scatter_356 = None
select_scatter_32: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_1452, slice_scatter_357, 1, -1); slice_1452 = slice_scatter_357 = None
slice_scatter_358: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_355, select_scatter_32, 0, 0, 9223372036854775807); slice_scatter_355 = select_scatter_32 = None
slice_1454: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_358, 0, 0, 9223372036854775807)
slice_1455: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1454, 1, 1, 9223372036854775807); slice_1454 = None
slice_1456: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1455, 2, 0, 9223372036854775807); slice_1455 = None
slice_1457: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1456, 3, 0, 256); slice_1456 = None
slice_1458: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_16, 0, 0, 9223372036854775807)
slice_1459: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1458, 1, 1, 9223372036854775807); slice_1458 = None
slice_1460: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1459, 2, 0, 9223372036854775807); slice_1459 = None
slice_1461: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1460, 3, 0, 256); slice_1460 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_1462: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_505, 0, 0, 9223372036854775807); view_505 = None
select_149: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_1462, 1, 0); slice_1462 = None
slice_1463: f32[24, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_149, 1, 0, 255); select_149 = None
slice_1464: f32[24, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_1463, 2, -255, 9223372036854775807); slice_1463 = None
slice_1465: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_16, 0, 0, 9223372036854775807)
select_150: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1465, 1, 0); slice_1465 = None
slice_1466: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_150, 1, 1, 256); select_150 = None
slice_1467: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1466, 2, 1, 256); slice_1466 = None
slice_1468: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_358, 0, 0, 9223372036854775807)
slice_1469: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1468, 1, 1, 9223372036854775807)
slice_1470: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1469, 2, 0, 9223372036854775807)
slice_scatter_359: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1470, slice_1447, 3, 0, 256); slice_1470 = slice_1447 = None
slice_scatter_360: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1469, slice_scatter_359, 2, 0, 9223372036854775807); slice_1469 = slice_scatter_359 = None
slice_scatter_361: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1468, slice_scatter_360, 1, 1, 9223372036854775807); slice_1468 = slice_scatter_360 = None
slice_scatter_362: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_358, slice_scatter_361, 0, 0, 9223372036854775807); slice_scatter_358 = slice_scatter_361 = None
slice_1471: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_362, 0, 0, 9223372036854775807)
select_151: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1471, 1, 0); slice_1471 = None
slice_1472: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_151, 1, 1, 256); select_151 = None
slice_1473: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1472, 2, 1, 256); slice_1472 = None
slice_1474: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_16, 0, 0, 9223372036854775807)
select_152: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1474, 1, 0); slice_1474 = None
slice_1475: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_152, 1, 1, 256); select_152 = None
slice_1476: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1475, 2, 1, 256); slice_1475 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_506: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_16, [2, 12, 1024, 513])
transpose_311: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_506, 2, 1); view_506 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_1477: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_362, 0, 0, 9223372036854775807)
select_153: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1477, 1, 0)
slice_1478: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_153, 1, 1, 256)
slice_scatter_363: f32[24, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1478, slice_1464, 2, 1, 256); slice_1478 = slice_1464 = None
slice_scatter_364: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_153, slice_scatter_363, 1, 1, 256); select_153 = slice_scatter_363 = None
select_scatter_33: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_1477, slice_scatter_364, 1, 0); slice_1477 = slice_scatter_364 = None
slice_scatter_365: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_362, select_scatter_33, 0, 0, 9223372036854775807); slice_scatter_362 = select_scatter_33 = None
view_507: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_365, [2, 12, 1024, 513])
transpose_312: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_507, 2, 1); view_507 = None
new_ones_24: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_312, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril_16: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones_24); new_ones_24 = None
flip_32: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril_16, [0]); tril_16 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_123: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip_32, 0); flip_32 = None
slice_1479: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_123, 1, 0, 9223372036854775807); unsqueeze_123 = None
unsqueeze_124: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_1479, 2); slice_1479 = None
slice_1480: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_124, 3, 0, 9223372036854775807); unsqueeze_124 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_33: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_1480, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_1481: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_312, 0, 0, 9223372036854775807)
slice_1482: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1481, 1, 0, 256); slice_1481 = None
slice_1483: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1482, 2, 0, 9223372036854775807); slice_1482 = None
slice_1484: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1483, 3, 0, 257); slice_1483 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand_32: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_1480, [2, 256, 12, 257]); slice_1480 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_32: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand_32, 1); expand_32 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_508: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_365, [2, 12, 1024, 513])
transpose_313: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_508, 2, 1); view_508 = None
slice_1485: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_313, 0, 0, 9223372036854775807); transpose_313 = None
slice_1486: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1485, 1, 0, 256); slice_1485 = None
slice_1487: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1486, 2, 0, 9223372036854775807); slice_1486 = None
slice_1488: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1487, 3, 0, 257); slice_1487 = None
masked_fill_48: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_1488, eq_32, -inf); slice_1488 = eq_32 = None
view_509: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_16, [2, 12, 1024, 513])
transpose_314: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_509, 2, 1); view_509 = None
slice_1489: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_314, 0, 0, 9223372036854775807); transpose_314 = None
slice_1490: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1489, 1, 0, 256); slice_1489 = None
slice_1491: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1490, 2, 0, 9223372036854775807); slice_1490 = None
slice_1492: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1491, 3, 0, 257); slice_1491 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
view_510: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_16, [2, 12, 1024, 513])
transpose_315: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_510, 2, 1); view_510 = None
slice_1493: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_312, 0, 0, 9223372036854775807); transpose_312 = None
slice_1494: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1493, 1, -256, 9223372036854775807); slice_1493 = None
slice_1495: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1494, 2, 0, 9223372036854775807); slice_1494 = None
slice_1496: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1495, 3, -257, 9223372036854775807); slice_1495 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_33: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_33, [2, 256, 12, 257]); flip_33 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_33: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand_33, 1); expand_33 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_511: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_365, [2, 12, 1024, 513]); slice_scatter_365 = None
transpose_316: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_511, 2, 1); view_511 = None
slice_1497: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_316, 0, 0, 9223372036854775807)
slice_1498: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1497, 1, 0, 256)
slice_1499: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1498, 2, 0, 9223372036854775807)
slice_scatter_366: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1499, masked_fill_48, 3, 0, 257); slice_1499 = masked_fill_48 = None
slice_scatter_367: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1498, slice_scatter_366, 2, 0, 9223372036854775807); slice_1498 = slice_scatter_366 = None
slice_scatter_368: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_1497, slice_scatter_367, 1, 0, 256); slice_1497 = slice_scatter_367 = None
slice_scatter_369: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_316, slice_scatter_368, 0, 0, 9223372036854775807); transpose_316 = slice_scatter_368 = None
transpose_317: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_369, 2, 1); slice_scatter_369 = None
view_512: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_317, [24, 4, 256, 513]); transpose_317 = None
view_513: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_512, [2, 12, 1024, 513])
transpose_318: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_513, 2, 1); view_513 = None
slice_1500: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_318, 0, 0, 9223372036854775807); transpose_318 = None
slice_1501: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1500, 1, -256, 9223372036854775807); slice_1500 = None
slice_1502: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1501, 2, 0, 9223372036854775807); slice_1501 = None
slice_1503: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1502, 3, -257, 9223372036854775807); slice_1502 = None
masked_fill_49: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_1503, eq_33, -inf); slice_1503 = eq_33 = None
view_514: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_16, [2, 12, 1024, 513])
transpose_319: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_514, 2, 1); view_514 = None
slice_1504: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_319, 0, 0, 9223372036854775807); transpose_319 = None
slice_1505: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1504, 1, -256, 9223372036854775807); slice_1504 = None
slice_1506: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1505, 2, 0, 9223372036854775807); slice_1505 = None
slice_1507: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1506, 3, -257, 9223372036854775807); slice_1506 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
ne_8: b8[2, 1024], [1024, 1] = torch.ops.aten.ne.Scalar(orig_primals_194, 0)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:585, code: remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]
slice_1508: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(ne_8, 0, 0, 9223372036854775807); ne_8 = None
slice_1509: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_1508, 1, 0, 9223372036854775807); slice_1508 = None
unsqueeze_125: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_1509, 2); slice_1509 = None
unsqueeze_126: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_125, 3); unsqueeze_125 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:588, code: float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
_to_copy_8: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten._to_copy.default(unsqueeze_126, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))
masked_fill_50: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.masked_fill.Scalar(_to_copy_8, unsqueeze_126, -10000.0); _to_copy_8 = unsqueeze_126 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:593, code: float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
new_ones_25: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.new_ones.default(masked_fill_50, [2, 1024, 1, 1], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_320: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(new_ones_25, 1, 2); new_ones_25 = None
view_515: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_320, [2, 1024, 1]); transpose_320 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_321: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(masked_fill_50, 1, 2); masked_fill_50 = None
view_516: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_321, [2, 1024, 1]); transpose_321 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_517: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_515, [2, 2, 512, 1]); view_515 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_51: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_517, [2, 3, 512, 1], [1024, 256, 1, 1]); view_517 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_518: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_516, [2, 2, 512, 1]); view_516 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_52: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_518, [2, 3, 512, 1], [1024, 256, 1, 1]); view_518 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze_127: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_51, 4); as_strided_51 = None
permute_119: f32[2, 3, 512, 1, 1], [1024, 256, 1, 0, 1] = torch.ops.aten.permute.default(unsqueeze_127, [0, 1, 2, 4, 3]); unsqueeze_127 = None
unsqueeze_128: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_52, 4); as_strided_52 = None
permute_120: f32[2, 3, 1, 512, 1], [1024, 256, 0, 1, 1] = torch.ops.aten.permute.default(unsqueeze_128, [0, 1, 4, 2, 3]); unsqueeze_128 = None
mul_67: f32[2, 3, 512, 512, 1], [786432, 262144, 512, 1, 1] = torch.ops.aten.mul.Tensor(permute_119, permute_120); permute_119 = permute_120 = None
view_519: f32[2, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(mul_67, [2, 3, 512, 512]); mul_67 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd_33: f32[2, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_519, [0, 0, 0, 1], 0.0); view_519 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_520: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd_33, [2, 3, 512, 513]); constant_pad_nd_33 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty_17: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_520, [2, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_1510: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_520, 0, 0, 9223372036854775807)
slice_1511: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1510, 1, 0, 9223372036854775807); slice_1510 = None
slice_1512: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1511, 2, 0, 256); slice_1511 = None
slice_1513: f32[2, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1512, 3, 0, 257); slice_1512 = None
slice_1514: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_17, 0, 0, 9223372036854775807)
slice_1515: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1514, 1, 0, -1); slice_1514 = None
slice_1516: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1515, 2, 0, 9223372036854775807); slice_1515 = None
slice_1517: f32[2, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1516, 3, 256, 9223372036854775807); slice_1516 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_1518: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_520, 0, 0, 9223372036854775807)
select_154: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_1518, 1, -1); slice_1518 = None
slice_1519: f32[2, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_154, 1, 256, 9223372036854775807); select_154 = None
slice_1520: f32[2, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_1519, 2, 0, 257); slice_1519 = None
slice_1521: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_17, 0, 0, 9223372036854775807)
select_155: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1521, 1, -1); slice_1521 = None
slice_1522: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_155, 1, 0, 9223372036854775807); select_155 = None
slice_1523: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1522, 2, 256, 9223372036854775807); slice_1522 = None
slice_1524: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_17, 0, 0, 9223372036854775807)
slice_1525: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1524, 1, 0, -1)
slice_1526: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1525, 2, 0, 9223372036854775807)
slice_scatter_370: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1526, slice_1513, 3, 256, 9223372036854775807); slice_1526 = slice_1513 = None
slice_scatter_371: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1525, slice_scatter_370, 2, 0, 9223372036854775807); slice_1525 = slice_scatter_370 = None
slice_scatter_372: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1524, slice_scatter_371, 1, 0, -1); slice_1524 = slice_scatter_371 = None
slice_scatter_373: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty_17, slice_scatter_372, 0, 0, 9223372036854775807); slice_scatter_372 = None
slice_1527: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_373, 0, 0, 9223372036854775807)
select_156: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1527, 1, -1); slice_1527 = None
slice_1528: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_156, 1, 0, 9223372036854775807); select_156 = None
slice_1529: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1528, 2, 256, 9223372036854775807); slice_1528 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_1530: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_520, 0, 0, 9223372036854775807)
slice_1531: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1530, 1, 0, 9223372036854775807); slice_1530 = None
slice_1532: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1531, 2, -257, -1); slice_1531 = None
slice_1533: f32[2, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1532, 3, 257, 9223372036854775807); slice_1532 = None
slice_1534: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_17, 0, 0, 9223372036854775807)
slice_1535: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1534, 1, 1, 9223372036854775807); slice_1534 = None
slice_1536: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1535, 2, 0, 9223372036854775807); slice_1535 = None
slice_1537: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1536, 3, 0, 256); slice_1536 = None
slice_1538: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_373, 0, 0, 9223372036854775807)
select_157: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1538, 1, -1)
slice_1539: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_157, 1, 0, 9223372036854775807)
slice_scatter_374: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1539, slice_1520, 2, 256, 9223372036854775807); slice_1539 = slice_1520 = None
slice_scatter_375: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_157, slice_scatter_374, 1, 0, 9223372036854775807); select_157 = slice_scatter_374 = None
select_scatter_34: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_1538, slice_scatter_375, 1, -1); slice_1538 = slice_scatter_375 = None
slice_scatter_376: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_373, select_scatter_34, 0, 0, 9223372036854775807); slice_scatter_373 = select_scatter_34 = None
slice_1540: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_376, 0, 0, 9223372036854775807)
slice_1541: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1540, 1, 1, 9223372036854775807); slice_1540 = None
slice_1542: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1541, 2, 0, 9223372036854775807); slice_1541 = None
slice_1543: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1542, 3, 0, 256); slice_1542 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_1544: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_520, 0, 0, 9223372036854775807); view_520 = None
select_158: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_1544, 1, 0); slice_1544 = None
slice_1545: f32[2, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_158, 1, 0, 255); select_158 = None
slice_1546: f32[2, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_1545, 2, -255, 9223372036854775807); slice_1545 = None
slice_1547: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_17, 0, 0, 9223372036854775807)
select_159: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1547, 1, 0); slice_1547 = None
slice_1548: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_159, 1, 1, 256); select_159 = None
slice_1549: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1548, 2, 1, 256); slice_1548 = None
slice_1550: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_376, 0, 0, 9223372036854775807)
slice_1551: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1550, 1, 1, 9223372036854775807)
slice_1552: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1551, 2, 0, 9223372036854775807)
slice_scatter_377: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1552, slice_1533, 3, 0, 256); slice_1552 = slice_1533 = None
slice_scatter_378: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1551, slice_scatter_377, 2, 0, 9223372036854775807); slice_1551 = slice_scatter_377 = None
slice_scatter_379: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1550, slice_scatter_378, 1, 1, 9223372036854775807); slice_1550 = slice_scatter_378 = None
slice_scatter_380: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_376, slice_scatter_379, 0, 0, 9223372036854775807); slice_scatter_376 = slice_scatter_379 = None
slice_1553: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_380, 0, 0, 9223372036854775807)
select_160: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1553, 1, 0); slice_1553 = None
slice_1554: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_160, 1, 1, 256); select_160 = None
slice_1555: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1554, 2, 1, 256); slice_1554 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_521: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_17, [2, 1, 1024, 513]); new_empty_17 = None
transpose_322: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_521, 2, 1); view_521 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_1556: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_380, 0, 0, 9223372036854775807)
select_161: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1556, 1, 0)
slice_1557: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_161, 1, 1, 256)
slice_scatter_381: f32[2, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1557, slice_1546, 2, 1, 256); slice_1557 = slice_1546 = None
slice_scatter_382: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_161, slice_scatter_381, 1, 1, 256); select_161 = slice_scatter_381 = None
select_scatter_35: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_1556, slice_scatter_382, 1, 0); slice_1556 = slice_scatter_382 = None
slice_scatter_383: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_380, select_scatter_35, 0, 0, 9223372036854775807); slice_scatter_380 = select_scatter_35 = None
view_522: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_383, [2, 1, 1024, 513])
transpose_323: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_522, 2, 1); view_522 = None
new_ones_26: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_323, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril_17: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones_26); new_ones_26 = None
flip_34: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril_17, [0]); tril_17 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_129: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip_34, 0); flip_34 = None
slice_1558: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_129, 1, 0, 9223372036854775807); unsqueeze_129 = None
unsqueeze_130: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_1558, 2); slice_1558 = None
slice_1559: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_130, 3, 0, 9223372036854775807); unsqueeze_130 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_35: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_1559, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_1560: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_323, 0, 0, 9223372036854775807)
slice_1561: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1560, 1, 0, 256); slice_1560 = None
slice_1562: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1561, 2, 0, 9223372036854775807); slice_1561 = None
slice_1563: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1562, 3, 0, 257); slice_1562 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand_34: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_1559, [2, 256, 1, 257]); slice_1559 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_34: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_34, 1); expand_34 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_523: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_383, [2, 1, 1024, 513])
transpose_324: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_523, 2, 1); view_523 = None
slice_1564: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_324, 0, 0, 9223372036854775807); transpose_324 = None
slice_1565: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1564, 1, 0, 256); slice_1564 = None
slice_1566: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1565, 2, 0, 9223372036854775807); slice_1565 = None
slice_1567: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1566, 3, 0, 257); slice_1566 = None
masked_fill_51: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_1567, eq_34, -inf); slice_1567 = eq_34 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
slice_1568: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_323, 0, 0, 9223372036854775807); transpose_323 = None
slice_1569: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1568, 1, -256, 9223372036854775807); slice_1568 = None
slice_1570: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1569, 2, 0, 9223372036854775807); slice_1569 = None
slice_1571: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1570, 3, -257, 9223372036854775807); slice_1570 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_35: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_35, [2, 256, 1, 257]); flip_35 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_35: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_35, 1); expand_35 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_524: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_383, [2, 1, 1024, 513]); slice_scatter_383 = None
transpose_325: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_524, 2, 1); view_524 = None
slice_1572: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_325, 0, 0, 9223372036854775807)
slice_1573: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1572, 1, 0, 256)
slice_1574: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1573, 2, 0, 9223372036854775807)
slice_scatter_384: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1574, masked_fill_51, 3, 0, 257); slice_1574 = masked_fill_51 = None
slice_scatter_385: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1573, slice_scatter_384, 2, 0, 9223372036854775807); slice_1573 = slice_scatter_384 = None
slice_scatter_386: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_1572, slice_scatter_385, 1, 0, 256); slice_1572 = slice_scatter_385 = None
slice_scatter_387: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_325, slice_scatter_386, 0, 0, 9223372036854775807); transpose_325 = slice_scatter_386 = None
transpose_326: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_387, 2, 1); slice_scatter_387 = None
view_525: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_326, [2, 4, 256, 513]); transpose_326 = None
view_526: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_525, [2, 1, 1024, 513])
transpose_327: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_526, 2, 1); view_526 = None
slice_1575: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_327, 0, 0, 9223372036854775807); transpose_327 = None
slice_1576: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1575, 1, -256, 9223372036854775807); slice_1575 = None
slice_1577: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1576, 2, 0, 9223372036854775807); slice_1576 = None
slice_1578: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1577, 3, -257, 9223372036854775807); slice_1577 = None
masked_fill_52: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_1578, eq_35, -inf); slice_1578 = eq_35 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:597, code: attn_scores += diagonal_mask
view_527: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_16, [2, 12, 1024, 513])
transpose_328: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_527, 2, 1); view_527 = None
view_528: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_512, [2, 12, 1024, 513]); view_512 = None
transpose_329: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_528, 2, 1); view_528 = None
slice_1579: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_329, 0, 0, 9223372036854775807)
slice_1580: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1579, 1, -256, 9223372036854775807)
slice_1581: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1580, 2, 0, 9223372036854775807)
slice_scatter_388: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1581, masked_fill_49, 3, -257, 9223372036854775807); slice_1581 = masked_fill_49 = None
slice_scatter_389: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1580, slice_scatter_388, 2, 0, 9223372036854775807); slice_1580 = slice_scatter_388 = None
slice_scatter_390: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_1579, slice_scatter_389, 1, -256, 9223372036854775807); slice_1579 = slice_scatter_389 = None
slice_scatter_391: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_329, slice_scatter_390, 0, 0, 9223372036854775807); transpose_329 = slice_scatter_390 = None
transpose_330: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_391, 2, 1); slice_scatter_391 = None
view_529: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_330, [24, 4, 256, 513]); transpose_330 = None
view_530: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_529, [2, 12, 1024, 513]); view_529 = None
transpose_331: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_530, 2, 1); view_530 = None
view_531: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_525, [2, 1, 1024, 513]); view_525 = None
transpose_332: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_531, 2, 1); view_531 = None
slice_1582: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_332, 0, 0, 9223372036854775807)
slice_1583: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1582, 1, -256, 9223372036854775807)
slice_1584: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1583, 2, 0, 9223372036854775807)
slice_scatter_392: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1584, masked_fill_52, 3, -257, 9223372036854775807); slice_1584 = masked_fill_52 = None
slice_scatter_393: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1583, slice_scatter_392, 2, 0, 9223372036854775807); slice_1583 = slice_scatter_392 = None
slice_scatter_394: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_1582, slice_scatter_393, 1, -256, 9223372036854775807); slice_1582 = slice_scatter_393 = None
slice_scatter_395: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_332, slice_scatter_394, 0, 0, 9223372036854775807); transpose_332 = slice_scatter_394 = None
transpose_333: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_395, 2, 1); slice_scatter_395 = None
view_532: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_333, [2, 4, 256, 513]); transpose_333 = None
view_533: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_532, [2, 1, 1024, 513]); view_532 = None
transpose_334: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_533, 2, 1); view_533 = None
add_59: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.add.Tensor(transpose_331, transpose_334); transpose_331 = transpose_334 = None
view_534: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_16, [2, 12, 1024, 513]); new_empty_16 = None
transpose_335: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_534, 2, 1); view_534 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:635, code: attn_probs = nn.functional.softmax(
_softmax_8: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten._softmax.default(add_59, -1, False); add_59 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:646, code: attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
slice_1585: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(orig_primals_195, 0, 0, 9223372036854775807)
slice_1586: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_1585, 1, 0, 9223372036854775807); slice_1585 = None
unsqueeze_131: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_1586, 2); slice_1586 = None
unsqueeze_132: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_131, 3); unsqueeze_131 = None
masked_fill_53: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten.masked_fill.Scalar(_softmax_8, unsqueeze_132, 0.0); _softmax_8 = unsqueeze_132 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:655, code: value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_535: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_58, [1024, 2, 12, 64]); add_58 = None
transpose_336: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_535, 0, 1); view_535 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:883, code: chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
transpose_337: f32[2, 12, 1024, 513], [6303744, 513, 6156, 1] = torch.ops.aten.transpose.int(masked_fill_53, 1, 2); masked_fill_53 = None
clone_77: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.clone.default(transpose_337, memory_format = torch.contiguous_format); transpose_337 = None
_unsafe_view_42: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten._unsafe_view.default(clone_77, [24, 4, 256, 513]); clone_77 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:888, code: value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_338: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_336, 1, 2); transpose_336 = None
view_536: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_338, [24, 1024, 64]); transpose_338 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:891, code: padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1)
constant_pad_nd_34: f32[24, 1536, 64], [98304, 64, 1] = torch.ops.aten.constant_pad_nd.default(view_536, [0, 0, 256, 256], -1.0); view_536 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:902, code: chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride)
as_strided_53: f32[24, 4, 768, 64], [98304, 16384, 64, 1] = torch.ops.aten.as_strided.default(constant_pad_nd_34, [24, 4, 768, 64], [98304, 16384, 64, 1]); constant_pad_nd_34 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:755, code: chunked_hidden_states = nn.functional.pad(
constant_pad_nd_35: f32[24, 4, 256, 770], [788480, 197120, 770, 1] = torch.ops.aten.constant_pad_nd.default(_unsafe_view_42, [0, 257], 0.0); _unsafe_view_42 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:758, code: chunked_hidden_states = chunked_hidden_states.view(
view_537: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.view.default(constant_pad_nd_35, [24, 4, -1]); constant_pad_nd_35 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:761, code: chunked_hidden_states = chunked_hidden_states[
slice_1587: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(view_537, 0, 0, 9223372036854775807); view_537 = None
slice_1588: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_1587, 1, 0, 9223372036854775807); slice_1587 = None
slice_1589: f32[24, 4, 196864], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_1588, 2, 0, -256); slice_1588 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:764, code: chunked_hidden_states = chunked_hidden_states.view(
view_538: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.view.default(slice_1589, [24, 4, 256, 769]); slice_1589 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:767, code: chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
slice_1590: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(view_538, 0, 0, 9223372036854775807); view_538 = None
slice_1591: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_1590, 1, 0, 9223372036854775807)
slice_1592: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_1591, 2, 0, 9223372036854775807); slice_1591 = None
slice_1593: f32[24, 4, 256, 768], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_1592, 3, 0, -1); slice_1592 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:906, code: context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
unsqueeze_133: f32[24, 4, 256, 768, 1], [788480, 197120, 769, 1, 0] = torch.ops.aten.unsqueeze.default(slice_1593, 4)
permute_121: f32[24, 4, 256, 1, 768], [788480, 197120, 769, 0, 1] = torch.ops.aten.permute.default(unsqueeze_133, [0, 1, 2, 4, 3]); unsqueeze_133 = None
unsqueeze_134: f32[24, 4, 768, 64, 1], [98304, 16384, 64, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_53, 4); as_strided_53 = None
permute_122: f32[24, 4, 1, 64, 768], [98304, 16384, 0, 1, 64] = torch.ops.aten.permute.default(unsqueeze_134, [0, 1, 4, 3, 2]); unsqueeze_134 = None
permute_123: f32[24, 4, 256, 768, 1], [788480, 197120, 769, 1, 0] = torch.ops.aten.permute.default(permute_121, [0, 1, 2, 4, 3]); permute_121 = None
sym_size_67: Sym(24) = torch.ops.aten.sym_size(slice_1590, 0); slice_1590 = None
# No stacktrace found for following nodes
mul_68: Sym(96) = sym_size_67 * 4
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:906, code: context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
sym_size_68: Sym(768) = torch.ops.aten.sym_size(slice_1593, 3); slice_1593 = None
view_539: f32[96, 256, 768], [197120, 769, 1] = torch.ops.aten.view.default(permute_123, [mul_68, 256, sym_size_68]); permute_123 = None
permute_124: f32[24, 4, 768, 64, 1], [98304, 16384, 64, 1, 0] = torch.ops.aten.permute.default(permute_122, [0, 1, 4, 3, 2]); permute_122 = None
clone_78: f32[24, 4, 768, 64, 1], [196608, 49152, 64, 1, 1] = torch.ops.aten.clone.default(permute_124, memory_format = torch.contiguous_format); permute_124 = None
_unsafe_view_43: f32[96, 768, 64], [49152, 64, 1] = torch.ops.aten._unsafe_view.default(clone_78, [mul_68, sym_size_68, 64]); clone_78 = mul_68 = sym_size_68 = None
bmm_17: f32[96, 256, 64], [16384, 64, 1] = torch.ops.aten.bmm.default(view_539, _unsafe_view_43); view_539 = _unsafe_view_43 = None
view_540: f32[24, 4, 256, 1, 64], [65536, 16384, 64, 64, 1] = torch.ops.aten.view.default(bmm_17, [sym_size_67, 4, 256, 1, 64]); bmm_17 = None
permute_125: f32[24, 4, 256, 64, 1], [65536, 16384, 64, 1, 64] = torch.ops.aten.permute.default(view_540, [0, 1, 2, 4, 3])
sym_size_69: Sym(4) = torch.ops.aten.sym_size(view_540, 1); view_540 = None
view_541: f32[24, 4, 256, 64], [65536, 16384, 64, 1] = torch.ops.aten.view.default(permute_125, [sym_size_67, sym_size_69, 256, 64]); permute_125 = sym_size_67 = sym_size_69 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:907, code: return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
view_542: f32[2, 12, 1024, 64], [786432, 65536, 64, 1] = torch.ops.aten.view.default(view_541, [2, 12, 1024, 64]); view_541 = None
transpose_339: f32[2, 1024, 12, 64], [786432, 64, 65536, 1] = torch.ops.aten.transpose.int(view_542, 1, 2)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:674, code: attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
transpose_340: f32[1024, 2, 12, 64], [64, 786432, 65536, 1] = torch.ops.aten.transpose.int(transpose_339, 0, 1); transpose_339 = None
clone_79: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.clone.default(transpose_340, memory_format = torch.contiguous_format); transpose_340 = None
_unsafe_view_44: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten._unsafe_view.default(clone_79, [1024, 2, 768]); clone_79 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:703, code: outputs = (attn_output.transpose(0, 1),)
transpose_341: f32[2, 1024, 768], [768, 1536, 1] = torch.ops.aten.transpose.int(_unsafe_view_44, 0, 1)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1116, code: hidden_states = self.dense(hidden_states)
t_51: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_135); orig_primals_135 = None
clone_80: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.clone.default(transpose_341, memory_format = torch.contiguous_format); transpose_341 = None
sym_size_70: Sym(1024) = torch.ops.aten.sym_size(view_542, 2); view_542 = None
# No stacktrace found for following nodes
mul_69: Sym(2048) = 2 * sym_size_70
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1116, code: hidden_states = self.dense(hidden_states)
sym_size_71: Sym(768) = torch.ops.aten.sym_size(_unsafe_view_44, 2); _unsafe_view_44 = None
view_543: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_80, [mul_69, sym_size_71]); clone_80 = mul_69 = sym_size_71 = None
mm_35: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_543, t_51); view_543 = t_51 = None
view_544: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.view.default(mm_35, [2, sym_size_70, 768]); mm_35 = sym_size_70 = None
add_60: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(view_544, orig_primals_136); orig_primals_136 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1118, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
add_61: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(add_60, getitem_45); add_60 = getitem_45 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1118, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
native_layer_norm_16 = torch.ops.aten.native_layer_norm.default(add_61, [768], orig_primals_137, orig_primals_138, 1e-05); add_61 = orig_primals_137 = orig_primals_138 = None
getitem_48: f32[2, 1024, 768], [786432, 768, 1] = native_layer_norm_16[0]
getitem_49: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_16[1]
getitem_50: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_16[2]; native_layer_norm_16 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1182, code: hidden_states = self.dense(hidden_states)
t_52: f32[768, 3072], [1, 768] = torch.ops.aten.t.default(orig_primals_139); orig_primals_139 = None
sym_size_72: Sym(1024) = torch.ops.aten.sym_size(view_544, 1); view_544 = None
# No stacktrace found for following nodes
mul_70: Sym(2048) = 2 * sym_size_72
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1182, code: hidden_states = self.dense(hidden_states)
view_545: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(getitem_48, [mul_70, 768]); mul_70 = None
addmm_16: f32[2048, 3072], [3072, 1] = torch.ops.aten.addmm.default(orig_primals_140, view_545, t_52); orig_primals_140 = view_545 = t_52 = None
view_546: f32[2, 1024, 3072], [3145728, 3072, 1] = torch.ops.aten.view.default(addmm_16, [2, sym_size_72, 3072]); addmm_16 = sym_size_72 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/activations.py:56, code: return self.act(input)
gelu_8: f32[2, 1024, 3072], [3145728, 3072, 1] = torch.ops.aten.gelu.default(view_546)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1196, code: hidden_states = self.dense(hidden_states)
t_53: f32[3072, 768], [1, 3072] = torch.ops.aten.t.default(orig_primals_141); orig_primals_141 = None
sym_size_73: Sym(1024) = torch.ops.aten.sym_size(view_546, 1); view_546 = None
# No stacktrace found for following nodes
mul_71: Sym(2048) = 2 * sym_size_73
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1196, code: hidden_states = self.dense(hidden_states)
view_547: f32[2048, 3072], [3072, 1] = torch.ops.aten.view.default(gelu_8, [mul_71, 3072]); gelu_8 = mul_71 = None
addmm_17: f32[2048, 768], [768, 1] = torch.ops.aten.addmm.default(orig_primals_142, view_547, t_53); orig_primals_142 = view_547 = t_53 = None
view_548: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.view.default(addmm_17, [2, sym_size_73, 768]); addmm_17 = sym_size_73 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1198, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
add_62: f32[2, 1024, 768], [786432, 768, 1] = torch.ops.aten.add.Tensor(view_548, getitem_48); getitem_48 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1198, code: hidden_states = self.LayerNorm(hidden_states + input_tensor)
native_layer_norm_17 = torch.ops.aten.native_layer_norm.default(add_62, [768], orig_primals_143, orig_primals_144, 1e-05); add_62 = orig_primals_143 = orig_primals_144 = None
getitem_51: f32[2, 1024, 768], [786432, 768, 1] = native_layer_norm_17[0]
getitem_52: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_17[1]
getitem_53: f32[2, 1024, 1], [1024, 1, 1] = native_layer_norm_17[2]; native_layer_norm_17 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:562, code: hidden_states = hidden_states.transpose(0, 1)
transpose_342: f32[1024, 2, 768], [768, 786432, 1] = torch.ops.aten.transpose.int(getitem_51, 0, 1)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
t_54: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_145); orig_primals_145 = None
clone_81: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_342, memory_format = torch.contiguous_format)
sym_size_74: Sym(1024) = torch.ops.aten.sym_size(view_548, 1); view_548 = None
# No stacktrace found for following nodes
mul_72: Sym(2048) = sym_size_74 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:565, code: query_vectors = self.query(hidden_states)
view_549: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_81, [mul_72, 768]); clone_81 = mul_72 = None
mm_36: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_549, t_54); view_549 = t_54 = None
view_550: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_36, [sym_size_74, 2, 768]); mm_36 = None
add_63: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_550, orig_primals_146); view_550 = orig_primals_146 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
t_55: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_147); orig_primals_147 = None
clone_82: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_342, memory_format = torch.contiguous_format)
# No stacktrace found for following nodes
mul_73: Sym(2048) = sym_size_74 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:566, code: key_vectors = self.key(hidden_states)
view_551: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_82, [mul_73, 768]); clone_82 = mul_73 = None
mm_37: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_551, t_55); view_551 = t_55 = None
view_552: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_37, [sym_size_74, 2, 768]); mm_37 = None
add_64: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_552, orig_primals_148); view_552 = orig_primals_148 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
t_56: f32[768, 768], [1, 768] = torch.ops.aten.t.default(orig_primals_149); orig_primals_149 = None
clone_83: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.clone.default(transpose_342, memory_format = torch.contiguous_format); transpose_342 = None
# No stacktrace found for following nodes
mul_74: Sym(2048) = sym_size_74 * 2
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:567, code: value_vectors = self.value(hidden_states)
view_553: f32[2048, 768], [768, 1] = torch.ops.aten.view.default(clone_83, [mul_74, 768]); clone_83 = mul_74 = None
mm_38: f32[2048, 768], [768, 1] = torch.ops.aten.mm.default(view_553, t_56); view_553 = t_56 = None
view_554: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.view.default(mm_38, [sym_size_74, 2, 768]); mm_38 = sym_size_74 = None
add_65: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.add.Tensor(view_554, orig_primals_150); view_554 = orig_primals_150 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:575, code: query_vectors /= math.sqrt(self.head_dim)
div_9: f32[1024, 2, 768], [1536, 768, 1] = torch.ops.aten.div.Tensor(add_63, 8.0); add_63 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:577, code: query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_555: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div_9, [1024, 2, 12, 64])
transpose_343: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_555, 0, 1); view_555 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:578, code: key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_556: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_64, [1024, 2, 12, 64]); add_64 = None
transpose_344: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_556, 0, 1); view_556 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_345: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_343, 1, 2); transpose_343 = None
view_557: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_345, [24, 1024, 64]); transpose_345 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_346: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_344, 1, 2); transpose_344 = None
view_558: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_346, [24, 1024, 64]); transpose_346 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_559: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_557, [24, 2, 512, 64]); view_557 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_54: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_559, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_559 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_560: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_558, [24, 2, 512, 64]); view_558 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_55: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_560, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_560 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze_135: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_54, 4); as_strided_54 = None
permute_126: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze_135, [0, 1, 2, 4, 3]); unsqueeze_135 = None
unsqueeze_136: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_55, 4); as_strided_55 = None
permute_127: f32[24, 3, 1, 512, 64], [64, 393216, 0, 1536, 1] = torch.ops.aten.permute.default(unsqueeze_136, [0, 1, 4, 2, 3]); unsqueeze_136 = None
permute_128: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute_126, [0, 1, 2, 4, 3]); permute_126 = None
view_561: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(div_9, [1024, 2, 12, 64]); div_9 = None
transpose_347: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_561, 0, 1); view_561 = None
transpose_348: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_347, 1, 2); transpose_347 = None
view_562: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_348, [24, 1024, 64]); transpose_348 = None
view_563: f32[24, 2, 512, 64], [64, 786432, 1536, 1] = torch.ops.aten.view.default(view_562, [24, 2, 512, 64]); view_562 = None
as_strided_56: f32[24, 3, 512, 64], [64, 393216, 1536, 1] = torch.ops.aten.as_strided.default(view_563, [24, 3, 512, 64], [64, 393216, 1536, 1]); view_563 = None
unsqueeze_137: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_56, 4); as_strided_56 = None
permute_129: f32[24, 3, 512, 1, 64], [64, 393216, 1536, 0, 1] = torch.ops.aten.permute.default(unsqueeze_137, [0, 1, 2, 4, 3]); unsqueeze_137 = None
permute_130: f32[24, 3, 512, 64, 1], [64, 393216, 1536, 1, 0] = torch.ops.aten.permute.default(permute_129, [0, 1, 2, 4, 3]); permute_129 = None
clone_84: f32[24, 3, 512, 64, 1], [98304, 32768, 64, 1, 1] = torch.ops.aten.clone.default(permute_130, memory_format = torch.contiguous_format); permute_130 = None
_unsafe_view_45: f32[72, 512, 64], [32768, 64, 1] = torch.ops.aten._unsafe_view.default(clone_84, [72, 512, 64]); clone_84 = None
permute_131: f32[24, 3, 64, 512, 1], [64, 393216, 1, 1536, 0] = torch.ops.aten.permute.default(permute_127, [0, 1, 4, 3, 2]); permute_127 = None
clone_85: f32[24, 3, 64, 512, 1], [98304, 32768, 512, 1, 1] = torch.ops.aten.clone.default(permute_131, memory_format = torch.contiguous_format); permute_131 = None
_unsafe_view_46: f32[72, 64, 512], [32768, 512, 1] = torch.ops.aten._unsafe_view.default(clone_85, [72, 64, 512]); clone_85 = None
bmm_18: f32[72, 512, 512], [262144, 512, 1] = torch.ops.aten.bmm.default(_unsafe_view_45, _unsafe_view_46); _unsafe_view_45 = _unsafe_view_46 = None
view_564: f32[24, 3, 512, 1, 512], [786432, 262144, 512, 512, 1] = torch.ops.aten.view.default(bmm_18, [24, 3, 512, 1, 512]); bmm_18 = None
permute_132: f32[24, 3, 512, 512, 1], [786432, 262144, 512, 1, 512] = torch.ops.aten.permute.default(view_564, [0, 1, 2, 4, 3]); view_564 = None
view_565: f32[24, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(permute_132, [24, 3, 512, 512]); permute_132 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd_36: f32[24, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_565, [0, 0, 0, 1], 0.0); view_565 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_566: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd_36, [24, 3, 512, 513]); constant_pad_nd_36 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty_18: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_566, [24, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_1594: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_566, 0, 0, 9223372036854775807)
slice_1595: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1594, 1, 0, 9223372036854775807); slice_1594 = None
slice_1596: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1595, 2, 0, 256); slice_1595 = None
slice_1597: f32[24, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1596, 3, 0, 257); slice_1596 = None
slice_1598: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_18, 0, 0, 9223372036854775807)
slice_1599: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1598, 1, 0, -1); slice_1598 = None
slice_1600: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1599, 2, 0, 9223372036854775807); slice_1599 = None
slice_1601: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1600, 3, 256, 9223372036854775807); slice_1600 = None
slice_1602: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_18, 0, 0, 9223372036854775807)
slice_1603: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1602, 1, 0, -1); slice_1602 = None
slice_1604: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1603, 2, 0, 9223372036854775807); slice_1603 = None
slice_1605: f32[24, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1604, 3, 256, 9223372036854775807); slice_1604 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_1606: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_566, 0, 0, 9223372036854775807)
select_162: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_1606, 1, -1); slice_1606 = None
slice_1607: f32[24, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_162, 1, 256, 9223372036854775807); select_162 = None
slice_1608: f32[24, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_1607, 2, 0, 257); slice_1607 = None
slice_1609: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_18, 0, 0, 9223372036854775807)
select_163: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1609, 1, -1); slice_1609 = None
slice_1610: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_163, 1, 0, 9223372036854775807); select_163 = None
slice_1611: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1610, 2, 256, 9223372036854775807); slice_1610 = None
slice_1612: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_18, 0, 0, 9223372036854775807)
slice_1613: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1612, 1, 0, -1)
slice_1614: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1613, 2, 0, 9223372036854775807)
slice_scatter_396: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1614, slice_1597, 3, 256, 9223372036854775807); slice_1614 = slice_1597 = None
slice_scatter_397: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1613, slice_scatter_396, 2, 0, 9223372036854775807); slice_1613 = slice_scatter_396 = None
slice_scatter_398: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1612, slice_scatter_397, 1, 0, -1); slice_1612 = slice_scatter_397 = None
slice_scatter_399: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty_18, slice_scatter_398, 0, 0, 9223372036854775807); slice_scatter_398 = None
slice_1615: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_399, 0, 0, 9223372036854775807)
select_164: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1615, 1, -1); slice_1615 = None
slice_1616: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_164, 1, 0, 9223372036854775807); select_164 = None
slice_1617: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1616, 2, 256, 9223372036854775807); slice_1616 = None
slice_1618: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_18, 0, 0, 9223372036854775807)
select_165: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1618, 1, -1); slice_1618 = None
slice_1619: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_165, 1, 0, 9223372036854775807); select_165 = None
slice_1620: f32[24, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1619, 2, 256, 9223372036854775807); slice_1619 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_1621: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_566, 0, 0, 9223372036854775807)
slice_1622: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1621, 1, 0, 9223372036854775807); slice_1621 = None
slice_1623: f32[24, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1622, 2, -257, -1); slice_1622 = None
slice_1624: f32[24, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1623, 3, 257, 9223372036854775807); slice_1623 = None
slice_1625: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_18, 0, 0, 9223372036854775807)
slice_1626: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1625, 1, 1, 9223372036854775807); slice_1625 = None
slice_1627: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1626, 2, 0, 9223372036854775807); slice_1626 = None
slice_1628: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1627, 3, 0, 256); slice_1627 = None
slice_1629: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_399, 0, 0, 9223372036854775807)
select_166: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1629, 1, -1)
slice_1630: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_166, 1, 0, 9223372036854775807)
slice_scatter_400: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1630, slice_1608, 2, 256, 9223372036854775807); slice_1630 = slice_1608 = None
slice_scatter_401: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_166, slice_scatter_400, 1, 0, 9223372036854775807); select_166 = slice_scatter_400 = None
select_scatter_36: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_1629, slice_scatter_401, 1, -1); slice_1629 = slice_scatter_401 = None
slice_scatter_402: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_399, select_scatter_36, 0, 0, 9223372036854775807); slice_scatter_399 = select_scatter_36 = None
slice_1631: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_402, 0, 0, 9223372036854775807)
slice_1632: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1631, 1, 1, 9223372036854775807); slice_1631 = None
slice_1633: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1632, 2, 0, 9223372036854775807); slice_1632 = None
slice_1634: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1633, 3, 0, 256); slice_1633 = None
slice_1635: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_18, 0, 0, 9223372036854775807)
slice_1636: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1635, 1, 1, 9223372036854775807); slice_1635 = None
slice_1637: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1636, 2, 0, 9223372036854775807); slice_1636 = None
slice_1638: f32[24, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1637, 3, 0, 256); slice_1637 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_1639: f32[24, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_566, 0, 0, 9223372036854775807); view_566 = None
select_167: f32[24, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_1639, 1, 0); slice_1639 = None
slice_1640: f32[24, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_167, 1, 0, 255); select_167 = None
slice_1641: f32[24, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_1640, 2, -255, 9223372036854775807); slice_1640 = None
slice_1642: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_18, 0, 0, 9223372036854775807)
select_168: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1642, 1, 0); slice_1642 = None
slice_1643: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_168, 1, 1, 256); select_168 = None
slice_1644: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1643, 2, 1, 256); slice_1643 = None
slice_1645: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_402, 0, 0, 9223372036854775807)
slice_1646: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1645, 1, 1, 9223372036854775807)
slice_1647: f32[24, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1646, 2, 0, 9223372036854775807)
slice_scatter_403: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1647, slice_1624, 3, 0, 256); slice_1647 = slice_1624 = None
slice_scatter_404: f32[24, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1646, slice_scatter_403, 2, 0, 9223372036854775807); slice_1646 = slice_scatter_403 = None
slice_scatter_405: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1645, slice_scatter_404, 1, 1, 9223372036854775807); slice_1645 = slice_scatter_404 = None
slice_scatter_406: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_402, slice_scatter_405, 0, 0, 9223372036854775807); slice_scatter_402 = slice_scatter_405 = None
slice_1648: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_406, 0, 0, 9223372036854775807)
select_169: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1648, 1, 0); slice_1648 = None
slice_1649: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_169, 1, 1, 256); select_169 = None
slice_1650: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1649, 2, 1, 256); slice_1649 = None
slice_1651: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_18, 0, 0, 9223372036854775807)
select_170: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1651, 1, 0); slice_1651 = None
slice_1652: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_170, 1, 1, 256); select_170 = None
slice_1653: f32[24, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1652, 2, 1, 256); slice_1652 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_567: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_18, [2, 12, 1024, 513])
transpose_349: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_567, 2, 1); view_567 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_1654: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_406, 0, 0, 9223372036854775807)
select_171: f32[24, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1654, 1, 0)
slice_1655: f32[24, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_171, 1, 1, 256)
slice_scatter_407: f32[24, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1655, slice_1641, 2, 1, 256); slice_1655 = slice_1641 = None
slice_scatter_408: f32[24, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_171, slice_scatter_407, 1, 1, 256); select_171 = slice_scatter_407 = None
select_scatter_37: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_1654, slice_scatter_408, 1, 0); slice_1654 = slice_scatter_408 = None
slice_scatter_409: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_406, select_scatter_37, 0, 0, 9223372036854775807); slice_scatter_406 = select_scatter_37 = None
view_568: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_409, [2, 12, 1024, 513])
transpose_350: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_568, 2, 1); view_568 = None
new_ones_27: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_350, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril_18: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones_27); new_ones_27 = None
flip_36: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril_18, [0]); tril_18 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_138: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip_36, 0); flip_36 = None
slice_1656: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_138, 1, 0, 9223372036854775807); unsqueeze_138 = None
unsqueeze_139: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_1656, 2); slice_1656 = None
slice_1657: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_139, 3, 0, 9223372036854775807); unsqueeze_139 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_37: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_1657, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_1658: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_350, 0, 0, 9223372036854775807)
slice_1659: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1658, 1, 0, 256); slice_1658 = None
slice_1660: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1659, 2, 0, 9223372036854775807); slice_1659 = None
slice_1661: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1660, 3, 0, 257); slice_1660 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand_36: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_1657, [2, 256, 12, 257]); slice_1657 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_36: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand_36, 1); expand_36 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_569: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_409, [2, 12, 1024, 513])
transpose_351: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_569, 2, 1); view_569 = None
slice_1662: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_351, 0, 0, 9223372036854775807); transpose_351 = None
slice_1663: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1662, 1, 0, 256); slice_1662 = None
slice_1664: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1663, 2, 0, 9223372036854775807); slice_1663 = None
slice_1665: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1664, 3, 0, 257); slice_1664 = None
masked_fill_54: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_1665, eq_36, -inf); slice_1665 = eq_36 = None
view_570: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_18, [2, 12, 1024, 513])
transpose_352: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_570, 2, 1); view_570 = None
slice_1666: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_352, 0, 0, 9223372036854775807); transpose_352 = None
slice_1667: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1666, 1, 0, 256); slice_1666 = None
slice_1668: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1667, 2, 0, 9223372036854775807); slice_1667 = None
slice_1669: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1668, 3, 0, 257); slice_1668 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
view_571: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_18, [2, 12, 1024, 513])
transpose_353: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_571, 2, 1); view_571 = None
slice_1670: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_350, 0, 0, 9223372036854775807); transpose_350 = None
slice_1671: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1670, 1, -256, 9223372036854775807); slice_1670 = None
slice_1672: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1671, 2, 0, 9223372036854775807); slice_1671 = None
slice_1673: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1672, 3, -257, 9223372036854775807); slice_1672 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_37: f32[2, 256, 12, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_37, [2, 256, 12, 257]); flip_37 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_37: b8[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.eq.Scalar(expand_37, 1); expand_37 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_572: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_409, [2, 12, 1024, 513]); slice_scatter_409 = None
transpose_354: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_572, 2, 1); view_572 = None
slice_1674: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_354, 0, 0, 9223372036854775807)
slice_1675: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1674, 1, 0, 256)
slice_1676: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1675, 2, 0, 9223372036854775807)
slice_scatter_410: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1676, masked_fill_54, 3, 0, 257); slice_1676 = masked_fill_54 = None
slice_scatter_411: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1675, slice_scatter_410, 2, 0, 9223372036854775807); slice_1675 = slice_scatter_410 = None
slice_scatter_412: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_1674, slice_scatter_411, 1, 0, 256); slice_1674 = slice_scatter_411 = None
slice_scatter_413: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_354, slice_scatter_412, 0, 0, 9223372036854775807); transpose_354 = slice_scatter_412 = None
transpose_355: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_413, 2, 1); slice_scatter_413 = None
view_573: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_355, [24, 4, 256, 513]); transpose_355 = None
view_574: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_573, [2, 12, 1024, 513])
transpose_356: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_574, 2, 1); view_574 = None
slice_1677: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_356, 0, 0, 9223372036854775807); transpose_356 = None
slice_1678: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1677, 1, -256, 9223372036854775807); slice_1677 = None
slice_1679: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1678, 2, 0, 9223372036854775807); slice_1678 = None
slice_1680: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1679, 3, -257, 9223372036854775807); slice_1679 = None
masked_fill_55: f32[2, 256, 12, 257], [789504, 3084, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_1680, eq_37, -inf); slice_1680 = eq_37 = None
view_575: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_18, [2, 12, 1024, 513])
transpose_357: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_575, 2, 1); view_575 = None
slice_1681: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_357, 0, 0, 9223372036854775807); transpose_357 = None
slice_1682: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1681, 1, -256, 9223372036854775807); slice_1681 = None
slice_1683: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1682, 2, 0, 9223372036854775807); slice_1682 = None
slice_1684: f32[2, 256, 12, 257], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1683, 3, -257, 9223372036854775807); slice_1683 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
ne_9: b8[2, 1024], [1024, 1] = torch.ops.aten.ne.Scalar(orig_primals_194, 0)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:585, code: remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]
slice_1685: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(ne_9, 0, 0, 9223372036854775807); ne_9 = None
slice_1686: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_1685, 1, 0, 9223372036854775807); slice_1685 = None
unsqueeze_140: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_1686, 2); slice_1686 = None
unsqueeze_141: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_140, 3); unsqueeze_140 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:588, code: float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
_to_copy_9: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten._to_copy.default(unsqueeze_141, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))
masked_fill_56: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.masked_fill.Scalar(_to_copy_9, unsqueeze_141, -10000.0); _to_copy_9 = unsqueeze_141 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:593, code: float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
new_ones_28: f32[2, 1024, 1, 1], [1024, 1, 1, 1] = torch.ops.aten.new_ones.default(masked_fill_56, [2, 1024, 1, 1], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:817, code: query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_358: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(new_ones_28, 1, 2); new_ones_28 = None
view_576: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_358, [2, 1024, 1]); transpose_358 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:818, code: key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_359: f32[2, 1, 1024, 1], [1024, 1, 1, 1] = torch.ops.aten.transpose.int(masked_fill_56, 1, 2); masked_fill_56 = None
view_577: f32[2, 1024, 1], [1024, 1, 1] = torch.ops.aten.view.default(transpose_359, [2, 1024, 1]); transpose_359 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_578: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_576, [2, 2, 512, 1]); view_576 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_57: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_578, [2, 3, 512, 1], [1024, 256, 1, 1]); view_578 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:775, code: hidden_states = hidden_states.view(
view_579: f32[2, 2, 512, 1], [1024, 512, 1, 1] = torch.ops.aten.view.default(view_577, [2, 2, 512, 1]); view_577 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:788, code: return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
as_strided_58: f32[2, 3, 512, 1], [1024, 256, 1, 1] = torch.ops.aten.as_strided.default(view_579, [2, 3, 512, 1], [1024, 256, 1, 1]); view_579 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:827, code: diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
unsqueeze_142: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_57, 4); as_strided_57 = None
permute_133: f32[2, 3, 512, 1, 1], [1024, 256, 1, 0, 1] = torch.ops.aten.permute.default(unsqueeze_142, [0, 1, 2, 4, 3]); unsqueeze_142 = None
unsqueeze_143: f32[2, 3, 512, 1, 1], [1024, 256, 1, 1, 0] = torch.ops.aten.unsqueeze.default(as_strided_58, 4); as_strided_58 = None
permute_134: f32[2, 3, 1, 512, 1], [1024, 256, 0, 1, 1] = torch.ops.aten.permute.default(unsqueeze_143, [0, 1, 4, 2, 3]); unsqueeze_143 = None
mul_75: f32[2, 3, 512, 512, 1], [786432, 262144, 512, 1, 1] = torch.ops.aten.mul.Tensor(permute_133, permute_134); permute_133 = permute_134 = None
view_580: f32[2, 3, 512, 512], [786432, 262144, 512, 1] = torch.ops.aten.view.default(mul_75, [2, 3, 512, 512]); mul_75 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:713, code: hidden_states_padded = nn.functional.pad(
constant_pad_nd_37: f32[2, 3, 513, 512], [787968, 262656, 512, 1] = torch.ops.aten.constant_pad_nd.default(view_580, [0, 0, 0, 1], 0.0); view_580 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:716, code: hidden_states_padded = hidden_states_padded.view(
view_581: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.view.default(constant_pad_nd_37, [2, 3, 512, 513]); constant_pad_nd_37 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:839, code: diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
new_empty_19: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.new_empty.default(view_581, [2, 4, 256, 513], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:845, code: diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_1687: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_581, 0, 0, 9223372036854775807)
slice_1688: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1687, 1, 0, 9223372036854775807); slice_1687 = None
slice_1689: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1688, 2, 0, 256); slice_1688 = None
slice_1690: f32[2, 3, 256, 257], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1689, 3, 0, 257); slice_1689 = None
slice_1691: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_19, 0, 0, 9223372036854775807)
slice_1692: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1691, 1, 0, -1); slice_1691 = None
slice_1693: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1692, 2, 0, 9223372036854775807); slice_1692 = None
slice_1694: f32[2, 3, 256, 257], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1693, 3, 256, 9223372036854775807); slice_1693 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:848, code: diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
slice_1695: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_581, 0, 0, 9223372036854775807)
select_172: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_1695, 1, -1); slice_1695 = None
slice_1696: f32[2, 256, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_172, 1, 256, 9223372036854775807); select_172 = None
slice_1697: f32[2, 256, 257], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_1696, 2, 0, 257); slice_1696 = None
slice_1698: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_19, 0, 0, 9223372036854775807)
select_173: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1698, 1, -1); slice_1698 = None
slice_1699: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_173, 1, 0, 9223372036854775807); select_173 = None
slice_1700: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1699, 2, 256, 9223372036854775807); slice_1699 = None
slice_1701: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_19, 0, 0, 9223372036854775807)
slice_1702: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1701, 1, 0, -1)
slice_1703: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1702, 2, 0, 9223372036854775807)
slice_scatter_414: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1703, slice_1690, 3, 256, 9223372036854775807); slice_1703 = slice_1690 = None
slice_scatter_415: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1702, slice_scatter_414, 2, 0, 9223372036854775807); slice_1702 = slice_scatter_414 = None
slice_scatter_416: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1701, slice_scatter_415, 1, 0, -1); slice_1701 = slice_scatter_415 = None
slice_scatter_417: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(new_empty_19, slice_scatter_416, 0, 0, 9223372036854775807); slice_scatter_416 = None
slice_1704: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_417, 0, 0, 9223372036854775807)
select_174: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1704, 1, -1); slice_1704 = None
slice_1705: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_174, 1, 0, 9223372036854775807); select_174 = None
slice_1706: f32[2, 256, 257], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1705, 2, 256, 9223372036854775807); slice_1705 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:852, code: diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
slice_1707: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_581, 0, 0, 9223372036854775807)
slice_1708: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1707, 1, 0, 9223372036854775807); slice_1707 = None
slice_1709: f32[2, 3, 256, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1708, 2, -257, -1); slice_1708 = None
slice_1710: f32[2, 3, 256, 256], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(slice_1709, 3, 257, 9223372036854775807); slice_1709 = None
slice_1711: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_19, 0, 0, 9223372036854775807)
slice_1712: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1711, 1, 1, 9223372036854775807); slice_1711 = None
slice_1713: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1712, 2, 0, 9223372036854775807); slice_1712 = None
slice_1714: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1713, 3, 0, 256); slice_1713 = None
slice_1715: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_417, 0, 0, 9223372036854775807)
select_175: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1715, 1, -1)
slice_1716: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_175, 1, 0, 9223372036854775807)
slice_scatter_418: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1716, slice_1697, 2, 256, 9223372036854775807); slice_1716 = slice_1697 = None
slice_scatter_419: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_175, slice_scatter_418, 1, 0, 9223372036854775807); select_175 = slice_scatter_418 = None
select_scatter_38: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_1715, slice_scatter_419, 1, -1); slice_1715 = slice_scatter_419 = None
slice_scatter_420: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_417, select_scatter_38, 0, 0, 9223372036854775807); slice_scatter_417 = select_scatter_38 = None
slice_1717: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_420, 0, 0, 9223372036854775807)
slice_1718: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1717, 1, 1, 9223372036854775807); slice_1717 = None
slice_1719: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1718, 2, 0, 9223372036854775807); slice_1718 = None
slice_1720: f32[2, 3, 256, 256], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1719, 3, 0, 256); slice_1719 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:856, code: diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
slice_1721: f32[2, 3, 512, 513], [787968, 262656, 513, 1] = torch.ops.aten.slice.Tensor(view_581, 0, 0, 9223372036854775807); view_581 = None
select_176: f32[2, 512, 513], [787968, 513, 1] = torch.ops.aten.select.int(slice_1721, 1, 0); slice_1721 = None
slice_1722: f32[2, 255, 513], [787968, 513, 1] = torch.ops.aten.slice.Tensor(select_176, 1, 0, 255); select_176 = None
slice_1723: f32[2, 255, 255], [787968, 513, 1] = torch.ops.aten.slice.Tensor(slice_1722, 2, -255, 9223372036854775807); slice_1722 = None
slice_1724: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(new_empty_19, 0, 0, 9223372036854775807)
select_177: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1724, 1, 0); slice_1724 = None
slice_1725: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_177, 1, 1, 256); select_177 = None
slice_1726: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1725, 2, 1, 256); slice_1725 = None
slice_1727: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_420, 0, 0, 9223372036854775807)
slice_1728: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1727, 1, 1, 9223372036854775807)
slice_1729: f32[2, 3, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_1728, 2, 0, 9223372036854775807)
slice_scatter_421: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1729, slice_1710, 3, 0, 256); slice_1729 = slice_1710 = None
slice_scatter_422: f32[2, 3, 256, 513], [393984, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1728, slice_scatter_421, 2, 0, 9223372036854775807); slice_1728 = slice_scatter_421 = None
slice_scatter_423: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1727, slice_scatter_422, 1, 1, 9223372036854775807); slice_1727 = slice_scatter_422 = None
slice_scatter_424: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_420, slice_scatter_423, 0, 0, 9223372036854775807); slice_scatter_420 = slice_scatter_423 = None
slice_1730: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_424, 0, 0, 9223372036854775807)
select_178: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1730, 1, 0); slice_1730 = None
slice_1731: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_178, 1, 1, 256); select_178 = None
slice_1732: f32[2, 255, 255], [525312, 513, 1] = torch.ops.aten.slice.Tensor(slice_1731, 2, 1, 256); slice_1731 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:861, code: diagonal_attention_scores = diagonal_attention_scores.view(
view_582: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_19, [2, 1, 1024, 513]); new_empty_19 = None
transpose_360: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_582, 2, 1); view_582 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:792, code: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
slice_1733: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice.Tensor(slice_scatter_424, 0, 0, 9223372036854775807)
select_179: f32[2, 256, 513], [525312, 513, 1] = torch.ops.aten.select.int(slice_1733, 1, 0)
slice_1734: f32[2, 255, 513], [525312, 513, 1] = torch.ops.aten.slice.Tensor(select_179, 1, 1, 256)
slice_scatter_425: f32[2, 255, 513], [130815, 513, 1] = torch.ops.aten.slice_scatter.default(slice_1734, slice_1723, 2, 1, 256); slice_1734 = slice_1723 = None
slice_scatter_426: f32[2, 256, 513], [131328, 513, 1] = torch.ops.aten.slice_scatter.default(select_179, slice_scatter_425, 1, 1, 256); select_179 = slice_scatter_425 = None
select_scatter_39: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.select_scatter.default(slice_1733, slice_scatter_426, 1, 0); slice_1733 = slice_scatter_426 = None
slice_scatter_427: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.slice_scatter.default(slice_scatter_424, select_scatter_39, 0, 0, 9223372036854775807); slice_scatter_424 = select_scatter_39 = None
view_583: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_427, [2, 1, 1024, 513])
transpose_361: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_583, 2, 1); view_583 = None
new_ones_29: f32[256, 257], [257, 1] = torch.ops.aten.new_ones.default(transpose_361, [256, 257], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
tril_19: f32[256, 257], [257, 1] = torch.ops.aten.tril.default(new_ones_29); new_ones_29 = None
flip_38: f32[256, 257], [257, 1] = torch.ops.aten.flip.default(tril_19, [0]); tril_19 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:793, code: beginning_mask = beginning_mask_2d[None, :, None, :]
unsqueeze_144: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.unsqueeze.default(flip_38, 0); flip_38 = None
slice_1735: f32[1, 256, 257], [0, 257, 1] = torch.ops.aten.slice.Tensor(unsqueeze_144, 1, 0, 9223372036854775807); unsqueeze_144 = None
unsqueeze_145: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.unsqueeze.default(slice_1735, 2); slice_1735 = None
slice_1736: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.slice.Tensor(unsqueeze_145, 3, 0, 9223372036854775807); unsqueeze_145 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:794, code: ending_mask = beginning_mask.flip(dims=(1, 3))
flip_39: f32[1, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.flip.default(slice_1736, [1, 3])
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:795, code: beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
slice_1737: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_361, 0, 0, 9223372036854775807)
slice_1738: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1737, 1, 0, 256); slice_1737 = None
slice_1739: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1738, 2, 0, 9223372036854775807); slice_1738 = None
slice_1740: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1739, 3, 0, 257); slice_1739 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:796, code: beginning_mask = beginning_mask.expand(beginning_input.size())
expand_38: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(slice_1736, [2, 256, 1, 257]); slice_1736 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_38: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_38, 1); expand_38 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:797, code: beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_584: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_427, [2, 1, 1024, 513])
transpose_362: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_584, 2, 1); view_584 = None
slice_1741: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_362, 0, 0, 9223372036854775807); transpose_362 = None
slice_1742: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1741, 1, 0, 256); slice_1741 = None
slice_1743: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1742, 2, 0, 9223372036854775807); slice_1742 = None
slice_1744: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1743, 3, 0, 257); slice_1743 = None
masked_fill_57: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_1744, eq_38, -inf); slice_1744 = eq_38 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:798, code: ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
slice_1745: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_361, 0, 0, 9223372036854775807); transpose_361 = None
slice_1746: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1745, 1, -256, 9223372036854775807); slice_1745 = None
slice_1747: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1746, 2, 0, 9223372036854775807); slice_1746 = None
slice_1748: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1747, 3, -257, 9223372036854775807); slice_1747 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:799, code: ending_mask = ending_mask.expand(ending_input.size())
expand_39: f32[2, 256, 1, 257], [0, 257, 0, 1] = torch.ops.aten.expand.default(flip_39, [2, 256, 1, 257]); flip_39 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:1297, code: layer_outputs = layer_module(
eq_39: b8[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.eq.Scalar(expand_39, 1); expand_39 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:800, code: ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
view_585: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(slice_scatter_427, [2, 1, 1024, 513]); slice_scatter_427 = None
transpose_363: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_585, 2, 1); view_585 = None
slice_1749: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_363, 0, 0, 9223372036854775807)
slice_1750: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1749, 1, 0, 256)
slice_1751: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1750, 2, 0, 9223372036854775807)
slice_scatter_428: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1751, masked_fill_57, 3, 0, 257); slice_1751 = masked_fill_57 = None
slice_scatter_429: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1750, slice_scatter_428, 2, 0, 9223372036854775807); slice_1750 = slice_scatter_428 = None
slice_scatter_430: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_1749, slice_scatter_429, 1, 0, 256); slice_1749 = slice_scatter_429 = None
slice_scatter_431: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_363, slice_scatter_430, 0, 0, 9223372036854775807); transpose_363 = slice_scatter_430 = None
transpose_364: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_431, 2, 1); slice_scatter_431 = None
view_586: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_364, [2, 4, 256, 513]); transpose_364 = None
view_587: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_586, [2, 1, 1024, 513])
transpose_365: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_587, 2, 1); view_587 = None
slice_1752: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_365, 0, 0, 9223372036854775807); transpose_365 = None
slice_1753: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1752, 1, -256, 9223372036854775807); slice_1752 = None
slice_1754: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1753, 2, 0, 9223372036854775807); slice_1753 = None
slice_1755: f32[2, 256, 1, 257], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1754, 3, -257, 9223372036854775807); slice_1754 = None
masked_fill_58: f32[2, 256, 1, 257], [65792, 257, 257, 1] = torch.ops.aten.masked_fill.Scalar(slice_1755, eq_39, -inf); slice_1755 = eq_39 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:597, code: attn_scores += diagonal_mask
view_588: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_18, [2, 12, 1024, 513])
transpose_366: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_588, 2, 1); view_588 = None
view_589: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_573, [2, 12, 1024, 513]); view_573 = None
transpose_367: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_589, 2, 1); view_589 = None
slice_1756: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_367, 0, 0, 9223372036854775807)
slice_1757: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1756, 1, -256, 9223372036854775807)
slice_1758: f32[2, 256, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1757, 2, 0, 9223372036854775807)
slice_scatter_432: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1758, masked_fill_55, 3, -257, 9223372036854775807); slice_1758 = masked_fill_55 = None
slice_scatter_433: f32[2, 256, 12, 513], [1575936, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1757, slice_scatter_432, 2, 0, 9223372036854775807); slice_1757 = slice_scatter_432 = None
slice_scatter_434: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_1756, slice_scatter_433, 1, -256, 9223372036854775807); slice_1756 = slice_scatter_433 = None
slice_scatter_435: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_367, slice_scatter_434, 0, 0, 9223372036854775807); transpose_367 = slice_scatter_434 = None
transpose_368: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_435, 2, 1); slice_scatter_435 = None
view_590: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_368, [24, 4, 256, 513]); transpose_368 = None
view_591: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(view_590, [2, 12, 1024, 513]); view_590 = None
transpose_369: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_591, 2, 1); view_591 = None
view_592: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_586, [2, 1, 1024, 513]); view_586 = None
transpose_370: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_592, 2, 1); view_592 = None
slice_1759: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(transpose_370, 0, 0, 9223372036854775807)
slice_1760: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1759, 1, -256, 9223372036854775807)
slice_1761: f32[2, 256, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice.Tensor(slice_1760, 2, 0, 9223372036854775807)
slice_scatter_436: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1761, masked_fill_58, 3, -257, 9223372036854775807); slice_1761 = masked_fill_58 = None
slice_scatter_437: f32[2, 256, 1, 513], [131328, 513, 131328, 1] = torch.ops.aten.slice_scatter.default(slice_1760, slice_scatter_436, 2, 0, 9223372036854775807); slice_1760 = slice_scatter_436 = None
slice_scatter_438: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(slice_1759, slice_scatter_437, 1, -256, 9223372036854775807); slice_1759 = slice_scatter_437 = None
slice_scatter_439: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.slice_scatter.default(transpose_370, slice_scatter_438, 0, 0, 9223372036854775807); transpose_370 = slice_scatter_438 = None
transpose_371: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.transpose.int(slice_scatter_439, 2, 1); slice_scatter_439 = None
view_593: f32[2, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten.view.default(transpose_371, [2, 4, 256, 513]); transpose_371 = None
view_594: f32[2, 1, 1024, 513], [525312, 525312, 513, 1] = torch.ops.aten.view.default(view_593, [2, 1, 1024, 513]); view_593 = None
transpose_372: f32[2, 1024, 1, 513], [525312, 513, 525312, 1] = torch.ops.aten.transpose.int(view_594, 2, 1); view_594 = None
add_66: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.add.Tensor(transpose_369, transpose_372); transpose_369 = transpose_372 = None
view_595: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.view.default(new_empty_18, [2, 12, 1024, 513]); new_empty_18 = None
transpose_373: f32[2, 1024, 12, 513], [6303744, 513, 525312, 1] = torch.ops.aten.transpose.int(view_595, 2, 1); view_595 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:635, code: attn_probs = nn.functional.softmax(
_softmax_9: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten._softmax.default(add_66, -1, False); add_66 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:646, code: attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
slice_1762: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(orig_primals_195, 0, 0, 9223372036854775807)
slice_1763: b8[2, 1024], [1024, 1] = torch.ops.aten.slice.Tensor(slice_1762, 1, 0, 9223372036854775807); slice_1762 = None
unsqueeze_146: b8[2, 1024, 1], [1024, 1, 0] = torch.ops.aten.unsqueeze.default(slice_1763, 2); slice_1763 = None
unsqueeze_147: b8[2, 1024, 1, 1], [1024, 1, 0, 0] = torch.ops.aten.unsqueeze.default(unsqueeze_146, 3); unsqueeze_146 = None
masked_fill_59: f32[2, 1024, 12, 513], [6303744, 6156, 513, 1] = torch.ops.aten.masked_fill.Scalar(_softmax_9, unsqueeze_147, 0.0); _softmax_9 = unsqueeze_147 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:655, code: value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
view_596: f32[1024, 2, 12, 64], [1536, 768, 64, 1] = torch.ops.aten.view.default(add_65, [1024, 2, 12, 64]); add_65 = None
transpose_374: f32[2, 1024, 12, 64], [768, 1536, 64, 1] = torch.ops.aten.transpose.int(view_596, 0, 1); view_596 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:883, code: chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
transpose_375: f32[2, 12, 1024, 513], [6303744, 513, 6156, 1] = torch.ops.aten.transpose.int(masked_fill_59, 1, 2); masked_fill_59 = None
clone_86: f32[2, 12, 1024, 513], [6303744, 525312, 513, 1] = torch.ops.aten.clone.default(transpose_375, memory_format = torch.contiguous_format); transpose_375 = None
_unsafe_view_47: f32[24, 4, 256, 513], [525312, 131328, 513, 1] = torch.ops.aten._unsafe_view.default(clone_86, [24, 4, 256, 513]); clone_86 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:888, code: value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
transpose_376: f32[2, 12, 1024, 64], [768, 64, 1536, 1] = torch.ops.aten.transpose.int(transpose_374, 1, 2); transpose_374 = None
view_597: f32[24, 1024, 64], [64, 1536, 1] = torch.ops.aten.view.default(transpose_376, [24, 1024, 64]); transpose_376 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:891, code: padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1)
constant_pad_nd_38: f32[24, 1536, 64], [98304, 64, 1] = torch.ops.aten.constant_pad_nd.default(view_597, [0, 0, 256, 256], -1.0); view_597 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:902, code: chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride)
as_strided_59: f32[24, 4, 768, 64], [98304, 16384, 64, 1] = torch.ops.aten.as_strided.default(constant_pad_nd_38, [24, 4, 768, 64], [98304, 16384, 64, 1]); constant_pad_nd_38 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:755, code: chunked_hidden_states = nn.functional.pad(
constant_pad_nd_39: f32[24, 4, 256, 770], [788480, 197120, 770, 1] = torch.ops.aten.constant_pad_nd.default(_unsafe_view_47, [0, 257], 0.0); _unsafe_view_47 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:758, code: chunked_hidden_states = chunked_hidden_states.view(
view_598: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.view.default(constant_pad_nd_39, [24, 4, -1]); constant_pad_nd_39 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:761, code: chunked_hidden_states = chunked_hidden_states[
slice_1764: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(view_598, 0, 0, 9223372036854775807); view_598 = None
slice_1765: f32[24, 4, 197120], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_1764, 1, 0, 9223372036854775807); slice_1764 = None
slice_1766: f32[24, 4, 196864], [788480, 197120, 1] = torch.ops.aten.slice.Tensor(slice_1765, 2, 0, -256); slice_1765 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:764, code: chunked_hidden_states = chunked_hidden_states.view(
view_599: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.view.default(slice_1766, [24, 4, 256, 769]); slice_1766 = None
# File: /home/ezyang/local/pytorch-tmp-env/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py:767, code: chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
slice_1767: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(view_599, 0, 0, 9223372036854775807); view_599 = None
slice_1768: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_1767, 1, 0, 9223372036854775807)
slice_1769: f32[24, 4, 256, 769], [788480, 197120, 769, 1] = torch.ops.aten.slice.Tensor(slice_1768, 2, 0, 9223372036854775807); slice_1768 = None
slice_1770: f32[24, 4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment