Skip to content

Instantly share code, notes, and snippets.

@Dref360
Created February 10, 2024 16:53
Show Gist options
  • Save Dref360/89d101fb296ee12ef24d1b923fa02d0b to your computer and use it in GitHub Desktop.
Save Dref360/89d101fb296ee12ef24d1b923fa02d0b to your computer and use it in GitHub Desktop.
Example of uncertainty estimation using Baal on Speech Recognition
# Wav2Vec in Baal
from datasets import load_dataset
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, TrainingArguments
from baal.active.heuristics import BALD
from baal.bayesian.dropout import patch_module
from baal.transformers_trainer_wrapper import BaalTransformersTrainer
# load model and tokenizer
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
# load dummy dataset and read soundfiles
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
# Preprocess the audio and set format to torch.
ds_processed = (
ds.map(
lambda u: {k: v[0] for k, v in processor(u["audio"]["array"], return_tensors="pt", padding="longest").items()})
.remove_columns(ds.column_names)
.with_format("torch"))
def uncertainty_estimation(ds_processed):
patched_model = patch_module(model) # Replace dropout layers
wrapper = BaalTransformersTrainer(model=patched_model, args=TrainingArguments('/tmp', per_device_eval_batch_size=1))
predictions_generator = wrapper.predict_on_dataset_generator(ds_processed,
iterations=20) # 20 MC-Dropout iterations.
first_pred = next(predictions_generator) # WARNING: Shape is [Batch Size, Num Classes, Num Tokens, Num Iteration]
uncertainty = BALD(reduction='mean').get_uncertainties_generator(predictions_generator)
return uncertainty
uncertainty_estimation(ds_processed.select([1, 2, 3, 4, 5]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment