Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Created May 13, 2022 22:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesr66a/684518df606135a9316abf6e82449c8e to your computer and use it in GitHub Desktop.
Save jamesr66a/684518df606135a9316abf6e82449c8e to your computer and use it in GitHub Desktop.
###### test.py #####
import torch
from transformers import (
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
T5Config,
T5ForConditionalGeneration,
DataCollatorForSeq2Seq,
AutoTokenizer,
set_seed,
)
import transformers.utils.fx as fx
config = T5Config.from_pretrained('small.json')
model = T5ForConditionalGeneration(config)
BS, T = 5, 5
input_ids = torch.zeros(BS, T, dtype=torch.long)
decoder_input_ids = torch.zeros(BS, T, dtype=torch.long)
model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
import inspect
input_names = ['input_ids', 'decoder_input_ids', 'attention_mask', 'labels']
sig = inspect.signature(model.forward)
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
tracer = fx.HFTracer()
tracer.trace(model, concrete_args=concrete_args)
import pickle
with open('file.pkl', 'wb') as f:
pickle.dump(model, f)
with open('file.pkl', 'rb') as f:
loaded = pickle.load(f)
##### small.json #####
{
"architectures": [
"T5WithLMHeadModel"
],
"pad_token_id": 0,
"decoder_start_token_id": 0,
"bos_token_id": 1,
"eos_token_id": 2,
"d_ff": 1024,
"d_kv": 128,
"d_model": 512,
"num_layers": 1,
"num_decoder_layers": 1,
"num_heads": 8,
"feed_forward_proj": "relu",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"n_positions": 768,
"dropout_rate": 0.1,
"output_past": true,
"relative_attention_num_buckets": 32,
"transformers_version": "4.16.2",
"use_cache": true,
"vocab_size": 1000
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment