Skip to content

Instantly share code, notes, and snippets.

@zilunpeng
Created March 23, 2021 21:28
Show Gist options
  • Save zilunpeng/7aa09e38dafa816ec2754339bca7628d to your computer and use it in GitHub Desktop.
Save zilunpeng/7aa09e38dafa816ec2754339bca7628d to your computer and use it in GitHub Desktop.
Prepare wav2vec 2.0 for inference after quantization. Code is part of wav2vec2.py (https://git.io/JYe1Y).
def prepare_for_inference_after_quantization(self):
dequantizer = torch.nn.quantized.DeQuantize()
for trans_layer in self.encoder.layers:
trans_layer.self_attn.q_proj_bias = trans_layer.self_attn.q_proj.bias()
trans_layer.self_attn.k_proj_bias = trans_layer.self_attn.k_proj.bias()
trans_layer.self_attn.v_proj_bias = trans_layer.self_attn.v_proj.bias()
trans_layer.self_attn.in_proj_bias = torch.cat((trans_layer.self_attn.q_proj_bias, trans_layer.self_attn.k_proj_bias, trans_layer.self_attn.v_proj_bias))
trans_layer.self_attn.out_proj_bias = trans_layer.self_attn.out_proj.bias()
trans_layer.self_attn.out_proj_weight = dequantizer(trans_layer.self_attn.out_proj.weight())
trans_layer.self_attn.q_proj_weight = dequantizer(trans_layer.self_attn.q_proj.weight())
trans_layer.self_attn.k_proj_weight = dequantizer(trans_layer.self_attn.k_proj.weight())
trans_layer.self_attn.v_proj_weight = dequantizer(trans_layer.self_attn.v_proj.weight())
return
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment