Skip to content

Instantly share code, notes, and snippets.

@zilunpeng
zilunpeng / use_ray_to_get_pred.py
Created March 23, 2021 22:22
Get all predictions using Ray. Code below is part of the distributed inference notebook (https://git.io/JYeQQ).
predictions = ray.get(prediction_futures)
@zilunpeng
zilunpeng / ray_process_data_sample.py
Created March 23, 2021 22:21
Process every data sample using Ray. Code below is part of the distributed inference notebook (https://git.io/JYeQQ).
prediction_futures, ground_truths = [], []
for i, batch in enumerate(data_loader):
prediction_future = remote_process_batch_element.remote(batch, model_id, decoder_id, target_dict)
prediction_futures.append(prediction_future)
ground_truths.append(batch[2][0])
@zilunpeng
zilunpeng / share_model_decoder.py
Created March 23, 2021 22:19
Put model and decoder to shared memory. Code below is part of the distributed inference notebook (https://git.io/JYeQQ).
model_id = ray.put(model)
decoder_id = ray.put(decoder)
@zilunpeng
zilunpeng / remote_process_data_sample.py
Created March 23, 2021 22:18
Define the Ray remote method for processing data sample. Code below is part of the distributed inference notebook (https://git.io/JYeQQ).
@ray.remote
def remote_process_data_sample(batch, model, generator, target_dict):
result = process_data_sample(batch, model, generator, target_dict)
return result
@zilunpeng
zilunpeng / import_init_ray.py
Created March 23, 2021 22:16
Import and initialize Ray. Code below is part of the distributed inference notebook (https://git.io/JYeQQ).
import ray
ray.init()
@zilunpeng
zilunpeng / call_viterbi_decode.py
Created March 23, 2021 22:13
Make calls to the C++ method for Viterbi decoding. Code below is part of utils.py (https://git.io/JYeHy).
def decode(self, emissions):
B, T, N = emissions.size()
hypos = list()
if self.asg_transitions is None:
transitions = torch.FloatTensor(N, N).zero_()
else:
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
viterbi_path = torch.IntTensor(B, T)
@zilunpeng
zilunpeng / import_wav2letter.py
Created March 23, 2021 22:08
Import from wav2letter. Code below is part of utils.py (https://git.io/JYeHy).
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
@zilunpeng
zilunpeng / call_wav2vec2_decoder.py
Created March 23, 2021 22:03
Get the decoder output. Code below is part of the wav2vec 2.0 inference notebook (https://git.io/JYeKX).
decoder_out = decoder.decode(emissions)
@zilunpeng
zilunpeng / quantize_wav2vec2.py
Created March 23, 2021 21:31
Quantize wav2vec 2.0. Code below is part of quantized wav2vec 2.0 demo notebook (https://git.io/JYe1o).
quantized_model = torch.quantization.quantize_dynamic(pt_wav2vec2, {torch.nn.Linear}, dtype=torch.qint8, inplace=True)
quantized_model.prepare_for_inference_after_quantization()
@zilunpeng
zilunpeng / prepare_quantized_wav2vec2_for_inf.py
Created March 23, 2021 21:28
Prepare wav2vec 2.0 for inference after quantization. Code is part of wav2vec2.py (https://git.io/JYe1Y).
def prepare_for_inference_after_quantization(self):
dequantizer = torch.nn.quantized.DeQuantize()
for trans_layer in self.encoder.layers:
trans_layer.self_attn.q_proj_bias = trans_layer.self_attn.q_proj.bias()
trans_layer.self_attn.k_proj_bias = trans_layer.self_attn.k_proj.bias()
trans_layer.self_attn.v_proj_bias = trans_layer.self_attn.v_proj.bias()
trans_layer.self_attn.in_proj_bias = torch.cat((trans_layer.self_attn.q_proj_bias, trans_layer.self_attn.k_proj_bias, trans_layer.self_attn.v_proj_bias))
trans_layer.self_attn.out_proj_bias = trans_layer.self_attn.out_proj.bias()
trans_layer.self_attn.out_proj_weight = dequantizer(trans_layer.self_attn.out_proj.weight())
trans_layer.self_attn.q_proj_weight = dequantizer(trans_layer.self_attn.q_proj.weight())