Skip to content

Instantly share code, notes, and snippets.

@zilunpeng
zilunpeng / init_student_wav2vec2.py
Created March 23, 2021 21:25
Initialize the student model by taking alternating layers. Code below is part of student_wav2vec2.py (https://git.io/JYeXX)
step = num_trans_layer_student_init_model // num_trans_layer_student_model student_init_model_selected_transformer_layers = [i for i in range(0, num_trans_layer_student_init_model, step)]
student_model_trans_layer_prefix = "encoder.layers."
student_model_transformer_layers = [i for i in range(num_trans_layer_student_model)]
for student_layer_i, init_layer_i in zip(student_model_transformer_layers, student_init_model_selected_transformer_layers):
for transformer_part in transformer_parts:
layer_name = student_model_trans_layer_prefix + str(student_layer_i) + transformer_part
param = student_init_model_state[student_init_model_trans_layer_prefix + str(init_layer_i) + transformer_part]
student_model_state[layer_name].copy_(param)
@zilunpeng
zilunpeng / set_kd_opt_scheduler.py
Created March 23, 2021 21:22
Set optimizer and learning rate scheduler. Code below is part of the knowledge distillation toolkit (https://git.io/JYePf).
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
def lr_lambda(current_epoch):
if current_epoch < self.num_lr_warm_up_epoch:
return float(current_epoch+1) / float(max(1, self.num_lr_warm_up_epoch))
else:
return max( 0.0, float(self.max_epoch - current_epoch) / float(max(1, self.max_epoch - self.num_lr_warm_up_epoch)))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
@zilunpeng
zilunpeng / calc_feat_pen.py
Created March 23, 2021 21:20
Calculate the feature penalty. Code below is part of the knowledge distillation toolkit (https://git.io/JYePf).
features_pen = features.float().pow(2).mean()
@zilunpeng
zilunpeng / calc_kd_loss.py
Created March 23, 2021 21:18
Calculate the knowledge distillation loss. Code below is part of the knowledge distillation toolkit (https://git.io/JYePf).
torch.nn.functional.kl_div(student_log_prob, teacher_prob, reduction='batchmean') * (self.temperature**2)
@zilunpeng
zilunpeng / get_student_wav2vec2_log_prob.py
Created March 23, 2021 21:18
Get log probability of student model. Code below is part of the knowledge distillation toolkit (https://git.io/JYePf).
student_net_output = self.student_model(*batch)
student_log_prob = student_net_output["log_prob"]
@zilunpeng
zilunpeng / get_teacher_wav2vec2_prob.py
Created March 23, 2021 21:16
Get teacher model's probability distribution. Code below is part of the knowledge distillation toolkit (https://git.io/JYePf).
with torch.no_grad():
teacher_net_output = self.teacher_model(*batch)
teacher_prob = teacher_net_output["prob"]
@zilunpeng
zilunpeng / set_teacher_wav2vec2.py
Created March 23, 2021 21:15
Set the teacher model to evaluation mode. Code below is part of the knowledge distillation toolkit (https://git.io/JYePf).
self.teacher_model.eval()
@zilunpeng
zilunpeng / get_wav2vec2_decoder_output.py
Created March 23, 2021 21:08
Get output from decoder. Code below is part of the wav2vec 2.0 inference notebook (https://git.io/JYeKX).
decoder_out = decoder.decode(emissions)
@zilunpeng
zilunpeng / get_wav2vec2_output.py
Created March 23, 2021 21:07
Get the output from wav2vec 2.0. Code below is part of the wav2vec 2.0 inference notebook (https://git.io/JYeKX).
encoder_out = model(**encoder_input)
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
emissions = emissions.transpose(0, 1).float().cpu().contiguous()
@zilunpeng
zilunpeng / create_dev_clean_data_loader.py
Created March 23, 2021 21:05
Create data loader. Code below is part of the wav2vec 2.0 inference notebook (https://git.io/JYeKX).
dev_clean_librispeech_data = torchaudio.datasets.LIBRISPEECH(data_path, url='dev-clean', download=False)
data_loader = torch.utils.data.DataLoader(dev_clean_librispeech_data, batch_size=1, shuffle=False)