Skip to content

Instantly share code, notes, and snippets.

@jmp84
Created September 16, 2020 06:11
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jmp84/b98dfc5788b530d594220a552e4fdb27 to your computer and use it in GitHub Desktop.
Save jmp84/b98dfc5788b530d594220a552e4fdb27 to your computer and use it in GitHub Desktop.
TorchScript MT model
import argparse
import logging
import torch
from fairseq.checkpoint_utils import load_model_ensemble_and_task
from fairseq.sequence_generator import SequenceGenerator
def get_args():
parser = argparse.ArgumentParser(
description="Script to convert a PyTorch model to a TorchScript model."
)
parser.add_argument(
"--input-model", type=str, required=True, help="Path to the PyTorch model."
)
parser.add_argument(
"--output-model",
type=str,
required=True,
help="Path to the output TorchScript model.",
)
parser.add_argument(
"--beam-size",
type=int,
required=True,
help="Beam size for the sequence generator.",
)
parser.add_argument("--quantize", action="store_true", help="Apply quantization.")
return parser.parse_args()
def main():
args = get_args()
logging.info("Loading model...")
model = load_model_ensemble_and_task([args.input_model])[0][0]
model.eval()
logging.info("Model loaded.")
model_dict = model.decoder.dictionary
generator = SequenceGenerator([model], model_dict, beam_size=args.beam_size)
if args.quantize:
generator = torch.quantization.quantize_dynamic(
generator, {torch.nn.Linear}, dtype=torch.qint8, inplace=True
)
logging.info("TorchScripting...")
scripted_generator = torch.jit.script(generator)
logging.info("Saving TorchScript model...")
scripted_generator.save(args.output_model)
logging.info("Done!")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment