Skip to content

Instantly share code, notes, and snippets.

@titu1994
Created April 28, 2021 01:41
Show Gist options
  • Save titu1994/273416f3d4104d237ae477a1e7526837 to your computer and use it in GitHub Desktop.
Save titu1994/273416f3d4104d237ae477a1e7526837 to your computer and use it in GitHub Desktop.
Finetuning recipe for Citrinet models
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# Preparing the Tokenizer
Use the `create_tokenizer.py` script in order to prepare the tokenizer.
# Launch the fine tuning script
HYDRA_FULL_ERROR=1 python finetune_model.py \
--config-path="configs/" \
--config-name="stt_en_citrinet_512" \
+model_pretrained_name="stt_en_citrinet_512" \
+freeze_encoder=false \
model.tokenizer.dir="<DIRECTORY TO TOKENIZER (not the full path to .model file, just the directory)>" \
model.tokenizer.type="bpe" \
model.train_ds.manifest_filepath="<PATH TO TRAIN MANIFEST>" \
model.train_ds.batch_size=32 \
+model.train_ds.num_workers=8 \
+model.train_ds.pin_memory=true \
model.validation_ds.manifest_filepath=["<PATH TO DEV SET>","<PATH TO TEST SET>"] \
model.validation_ds.batch_size=8 \
+model.validation_ds.num_workers=8 \
+model.validation_ds.pin_memory=true \
model.spec_augment.freq_masks=0 \
model.spec_augment.time_masks=0 \
model.optim.lr=0.01 \
model.optim.name='novograd' \
model.optim.betas=[0.8,0.25] \
model.optim.weight_decay=0.001 \
model.optim.sched.warmup_steps=1000 \
model.optim.sched.min_lr=0.00001 \
trainer.gpus=-1 \
trainer.accelerator='ddp' \
trainer.max_epochs=100 \
trainer.check_val_every_n_epoch=1 \
trainer.precision=32 \
trainer.sync_batchnorm=false \
trainer.benchmark=false \
exp_manager.resume_if_exists=false \
exp_manager.resume_ignore_no_checkpoint=false
"""
import torch
import torch.nn as nn
import pytorch_lightning as pl
from omegaconf import OmegaConf, open_dict
from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
def enable_bn_se(m):
if type(m) == nn.BatchNorm1d:
m.train()
for param in m.parameters():
param.requires_grad_(True)
if 'SqueezeExcite' in type(m).__name__:
m.train()
for param in m.parameters():
param.requires_grad_(True)
@hydra_runner(config_path="configs/", config_name="stt_en_citrinet_512")
def main(cfg):
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
with open_dict(cfg):
model_name = cfg.pop('model_pretrained_name')
freeze_encoder = cfg.pop('freeze_encoder', False)
if 'stt_en_citrinet' not in model_name:
raise ValueError("`model_pretrained_name` must be a Citrinet model - `stt_en_citrinet_XYZ`,"
"where XYZ can be {256, 512, 1024}")
# Load pretrained checkpoint
checkpoint = EncDecCTCModelBPE.from_pretrained(
model_name, map_location=torch.device('cpu')
) # type: EncDecCTCModelBPE
# Preserve the models decoder weights
decoder_ckpt_copy = checkpoint.decoder.state_dict()
# Load finetuning model
asr_model = EncDecCTCModelBPE(cfg=cfg.model, trainer=trainer)
# Load up weights (partially / fully)
# this allows decoder weights to be loaded if same shape as original citrinet (1024 subword encodings)
asr_model.load_state_dict(checkpoint.state_dict(), strict=False)
# Insert preserved model weights if shapes match
if decoder_ckpt_copy['decoder_layers.0.weight'].shape == asr_model.decoder.decoder_layers[0].weight.shape:
asr_model.decoder.load_state_dict(decoder_ckpt_copy)
logging.info("\n")
logging.info("Decoder shapes matched - restored weights from pretrained model")
logging.info("\n")
# release checkpoint memory
del checkpoint
# If freezing the encoder, unfreeze the batch norm and the squeeze and excite blocks
# for transfer learning
if freeze_encoder:
asr_model.encoder.freeze()
asr_model.encoder.apply(enable_bn_se)
# Train model
trainer.fit(asr_model)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment