Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Last active March 24, 2022 03:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesr66a/48679d669d531e3ab2b0127627aac71a to your computer and use it in GitHub Desktop.
Save jamesr66a/48679d669d531e3ab2b0127627aac71a to your computer and use it in GitHub Desktop.
diff --git a/test/t5_test.py b/test/t5_test.py
index e6b36b8..09c2f13 100644
--- a/test/t5_test.py
+++ b/test/t5_test.py
@@ -33,5 +33,5 @@ print(t5_pipe.split_gm.graph)
t5_input = torch.zeros(bs, seq_length, dtype=torch.long).random_(t5.config.vocab_size)
decoder_input_ids = torch.zeros(bs, seq_length, dtype=torch.long).random_(t5.config.vocab_size)
-t5_output = t5(input_ids=t5_input, decoder_attention_mask=None, decoder_input_ids=decoder_input_ids)
-t5_pipe_output = t5_pipe(input_ids=t5_input, decoder_attention_mask=None, decoder_input_ids=decoder_input_ids)
+t5_output = t5(**t5.dummy_inputs)
+t5_pipe_output = t5_pipe(**t5.dummy_inputs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment