Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Last active March 24, 2022 03:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesr66a/d53b08f2e6313c3c5d09bfac1ea42199 to your computer and use it in GitHub Desktop.
Save jamesr66a/d53b08f2e6313c3c5d09bfac1ea42199 to your computer and use it in GitHub Desktop.
diff --git a/test/t5_test.py b/test/t5_test.py
index e6b36b8..bc7edde 100644
--- a/test/t5_test.py
+++ b/test/t5_test.py
@@ -15,7 +15,7 @@ for i in range(t5.config.num_layers):
for i in range(t5.config.num_decoder_layers):
annotate_split_points(t5, {f'decoder.block.{i}': PipeSplitWrapper.SplitPoint.BEGINNING})
-input_names = t5.dummy_inputs.keys()
+input_names = t5.dummy_inputs.keys() - ['decoder_attention_mask']
sig = inspect.signature(t5.forward)
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
@@ -33,5 +33,5 @@ print(t5_pipe.split_gm.graph)
t5_input = torch.zeros(bs, seq_length, dtype=torch.long).random_(t5.config.vocab_size)
decoder_input_ids = torch.zeros(bs, seq_length, dtype=torch.long).random_(t5.config.vocab_size)
-t5_output = t5(input_ids=t5_input, decoder_attention_mask=None, decoder_input_ids=decoder_input_ids)
-t5_pipe_output = t5_pipe(input_ids=t5_input, decoder_attention_mask=None, decoder_input_ids=decoder_input_ids)
+t5_output = t5(input_ids=t5_input, decoder_input_ids=decoder_input_ids)
+t5_pipe_output = t5_pipe(input_ids=t5_input, decoder_input_ids=decoder_input_ids)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment