Last active
December 3, 2021 12:54
-
-
Save arijitx/b853c8bc29c3935e26f69343b32ddb4e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import Speech2Text2Processor, SpeechEncoderDecoderModel | |
from torch.utils.data import Dataset | |
from transformers import AutoFeatureExtractor, AutoTokenizer, SpeechEncoderDecoderModel | |
from torch.utils.data import DataLoader | |
from transformers import TrainingArguments, Trainer | |
import librosa | |
import argparse | |
def process_audio_file(file): | |
data, sr = librosa.load(file) | |
if sr != 16000: | |
data = librosa.resample(data, sr, 16000) | |
input_values = feature_extractor(data, return_tensors="pt", sampling_rate=16000).input_values | |
return input_values | |
class S2TDataset(Dataset): | |
def __init__(self, data_flder): | |
self.script = open(data_flder+"/script/script.txt",encoding='utf8') | |
self.wav_fns = [] | |
self.text = [] | |
for line in self.script: | |
if len(self.text) > 500: | |
break | |
line = line.strip() | |
wfn, text = line.split("\t") | |
self.wav_fns.append(data_flder+"/wavs/"+wfn+".wav") | |
self.text.append(text) | |
def __len__(self): | |
return len(self.wav_fns) | |
def __getitem__(self, idx): | |
speech_input = process_audio_file(self.wav_fns[idx]).squeeze() | |
text = self.text[idx] | |
with tokenizer.as_target_tokenizer(): | |
labels = tokenizer(text, return_tensors="pt") | |
tgt = labels['input_ids'].squeeze() | |
return {'input_values': speech_input, 'input_ids':tgt} | |
def collate_fn(batch): | |
speech = [{"input_values": feature["input_values"]} for feature in batch] | |
labels = [{"input_ids": feature["input_ids"]} for feature in batch] | |
batch = feature_extractor.pad(speech, padding=True, return_tensors='pt') | |
labels = tokenizer.pad(labels, padding=True, return_tensors='pt') | |
labels = labels["input_ids"].masked_fill(labels.attention_mask.ne(1), -100) | |
batch['labels'] = labels | |
return batch | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# parser.add_argument("-v",'--vocab',default='vocab.json') | |
parser.add_argument("-d",'--data',default='bin') | |
parser.add_argument("-m",'--model',default="facebook/wav2vec2-xls-r-300m-21-to-en") | |
parser.add_argument("-o",'--outdir',default="outdir") | |
# parser.add_argument("-b",'--batch_size',type=int,default=8) | |
# parser.add_argument("-e",'--epoch',type=int,default=10) | |
args = parser.parse_args() | |
model_name = args.model | |
print("Loading Model ...") | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name, padding_value=0.0, sampling_rate=16000) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, tgt_lang="en_XX", padding_value=-100, sampling_rate=16000) | |
model = SpeechEncoderDecoderModel.from_pretrained(model_name) | |
print('Loaded Model ...') | |
# model.freeze_feature_extractor() | |
# for param in model.encoder.parameters(): | |
# param.requires_grad = False | |
dset = S2TDataset(args.data) | |
print('Created dataset ..') | |
training_args = TrainingArguments( | |
output_dir=args.outdir, | |
group_by_length=True, | |
per_device_train_batch_size=1, | |
gradient_accumulation_steps=4, | |
# evaluation_strategy="steps", | |
num_train_epochs=30, | |
gradient_checkpointing=False, | |
fp16=False, | |
save_steps=400, | |
eval_steps=400, | |
logging_steps=10, | |
learning_rate=3e-4, | |
warmup_steps=500, | |
save_total_limit=2 | |
) | |
trainer = Trainer( | |
model=model, | |
data_collator=collate_fn, | |
args=training_args, | |
train_dataset=dset, | |
tokenizer=tokenizer | |
) | |
print('Starting Training') | |
trainer.train() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment