Skip to content

Instantly share code, notes, and snippets.

@lepotatoguy
Created April 25, 2023 05:41
Show Gist options
  • Save lepotatoguy/a137a748b5a35d3b14ade0c563ab6e65 to your computer and use it in GitHub Desktop.
Save lepotatoguy/a137a748b5a35d3b14ade0c563ab6e65 to your computer and use it in GitHub Desktop.
AssertionError Traceback (most recent call last)
<ipython-input-2-8a242b2868b1> in <cell line: 28>()
26
27 # inpaint image
---> 28 inpainted_tensor = model(image_tensor, mask)
29
30 # convert tensors to numpy arrays and show images
9 frames
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
<ipython-input-1-f52b433f2bef> in forward(self, x, mask)
86
87 def forward(self, x, mask):
---> 88 encoded_x = self.encoder(x)
89 batch_size, channels, height, width = encoded_x.size()
90 mask = F.interpolate(mask, size=(height, width), mode='bilinear', align_corners=False)
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
<ipython-input-1-f52b433f2bef> in forward(self, x)
39 x = self.conv(x)
40 for block in self.blocks:
---> 41 x = block(x)
42 return x
43
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
<ipython-input-1-f52b433f2bef> in forward(self, x)
18 def forward(self, x):
19 x_norm = self.norm1(x)
---> 20 attn_output, _ = self.attn(x_norm, x_norm, x_norm)
21 x = x + attn_output
22 x_norm = self.norm2(x)
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/activation.py in forward(self, query, key, value, key_padding_mask, need_weights, attn_mask, average_attn_weights, is_causal)
1187 is_causal=is_causal)
1188 else:
-> 1189 attn_output, attn_output_weights = F.multi_head_attention_forward(
1190 query, key, value, self.embed_dim, self.num_heads,
1191 self.in_proj_weight, self.in_proj_bias,
/usr/local/lib/python3.9/dist-packages/torch/nn/functional.py in multi_head_attention_forward(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v, average_attn_weights, is_causal)
5138 )
5139
-> 5140 is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
5141
5142 # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
/usr/local/lib/python3.9/dist-packages/torch/nn/functional.py in _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
4975 (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}")
4976 else:
-> 4977 raise AssertionError(
4978 f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")
4979
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment