Skip to content

Instantly share code, notes, and snippets.

@arijitx
Last active December 3, 2021 12:54
Show Gist options
  • Save arijitx/b853c8bc29c3935e26f69343b32ddb4e to your computer and use it in GitHub Desktop.
Save arijitx/b853c8bc29c3935e26f69343b32ddb4e to your computer and use it in GitHub Desktop.
import torch
from transformers import Speech2Text2Processor, SpeechEncoderDecoderModel
from torch.utils.data import Dataset
from transformers import AutoFeatureExtractor, AutoTokenizer, SpeechEncoderDecoderModel
from torch.utils.data import DataLoader
from transformers import TrainingArguments, Trainer
import librosa
import argparse
def process_audio_file(file):
data, sr = librosa.load(file)
if sr != 16000:
data = librosa.resample(data, sr, 16000)
input_values = feature_extractor(data, return_tensors="pt", sampling_rate=16000).input_values
return input_values
class S2TDataset(Dataset):
def __init__(self, data_flder):
self.script = open(data_flder+"/script/script.txt",encoding='utf8')
self.wav_fns = []
self.text = []
for line in self.script:
if len(self.text) > 500:
break
line = line.strip()
wfn, text = line.split("\t")
self.wav_fns.append(data_flder+"/wavs/"+wfn+".wav")
self.text.append(text)
def __len__(self):
return len(self.wav_fns)
def __getitem__(self, idx):
speech_input = process_audio_file(self.wav_fns[idx]).squeeze()
text = self.text[idx]
with tokenizer.as_target_tokenizer():
labels = tokenizer(text, return_tensors="pt")
tgt = labels['input_ids'].squeeze()
return {'input_values': speech_input, 'input_ids':tgt}
def collate_fn(batch):
speech = [{"input_values": feature["input_values"]} for feature in batch]
labels = [{"input_ids": feature["input_ids"]} for feature in batch]
batch = feature_extractor.pad(speech, padding=True, return_tensors='pt')
labels = tokenizer.pad(labels, padding=True, return_tensors='pt')
labels = labels["input_ids"].masked_fill(labels.attention_mask.ne(1), -100)
batch['labels'] = labels
return batch
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# parser.add_argument("-v",'--vocab',default='vocab.json')
parser.add_argument("-d",'--data',default='bin')
parser.add_argument("-m",'--model',default="facebook/wav2vec2-xls-r-300m-21-to-en")
parser.add_argument("-o",'--outdir',default="outdir")
# parser.add_argument("-b",'--batch_size',type=int,default=8)
# parser.add_argument("-e",'--epoch',type=int,default=10)
args = parser.parse_args()
model_name = args.model
print("Loading Model ...")
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name, padding_value=0.0, sampling_rate=16000)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, tgt_lang="en_XX", padding_value=-100, sampling_rate=16000)
model = SpeechEncoderDecoderModel.from_pretrained(model_name)
print('Loaded Model ...')
# model.freeze_feature_extractor()
# for param in model.encoder.parameters():
# param.requires_grad = False
dset = S2TDataset(args.data)
print('Created dataset ..')
training_args = TrainingArguments(
output_dir=args.outdir,
group_by_length=True,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
# evaluation_strategy="steps",
num_train_epochs=30,
gradient_checkpointing=False,
fp16=False,
save_steps=400,
eval_steps=400,
logging_steps=10,
learning_rate=3e-4,
warmup_steps=500,
save_total_limit=2
)
trainer = Trainer(
model=model,
data_collator=collate_fn,
args=training_args,
train_dataset=dset,
tokenizer=tokenizer
)
print('Starting Training')
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment