Skip to content

Instantly share code, notes, and snippets.

@fauxneticien
Created July 30, 2023 20:12
Show Gist options
  • Save fauxneticien/e43986a6ea5a06f9b79becd23aec9685 to your computer and use it in GitHub Desktop.
Save fauxneticien/e43986a6ea5a06f9b79becd23aec9685 to your computer and use it in GitHub Desktop.
Import HF model to torchaudio
import torch
from transformers import Wav2Vec2Model
from torchaudio.models.wav2vec2.utils import import_huggingface_model
hf_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-300m")
ta_model = import_huggingface_model(hf_model)
torch.save(ta_model.state_dict(), "tmp/wav2vec2-300m_xls-r.pt")
## Usage
#
# import torchaudio
#
# new_model = torchaudio.models.wav2vec2_xlsr_300m(aux_num_out=29)
# new_model.load_state_dict(torch.load("wav2vec2-300m_mms.pt"), strict=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment