Skip to content

Instantly share code, notes, and snippets.

@Dref360
Last active October 30, 2022 18:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Dref360/448d1d72e0f6f050b154cdb5a1ad909e to your computer and use it in GitHub Desktop.
Save Dref360/448d1d72e0f6f050b154cdb5a1ad909e to your computer and use it in GitHub Desktop.
How to run Baal with HuggingFace on Label Studio

Baal with HuggingFace on Label Studio

Instructions to run Label Studio with Bayesian active learning on Text Classification.

Documentation Github

Environment:

  • export LABEL_STUDIO_HOSTNAME=http://localhost:8080
  • export LABEL_STUDIO_ML_BACKEND_V2=True
  • export API_KEY=${YOUR_API_KEY}

Dependencies:

  • pip install baal[nlp]

How to:

  • Run label-studio-ml init my_ml_backend --script label_studio_baal_hf.py --force
  • Run label-studio-ml start my_ml_backend
  • Run label-studio start my-annotation-project --init --ml-backend http://localhost:9090

In the Settings, do not forget to checkbox all boxes:

and to use active learning, order by Predictions score:

Have fun!

import json
import os
from uuid import uuid4
import numpy as np
import requests
import torch
from baal.active.dataset.nlp_datasets import HuggingFaceDatasets
from baal.active.heuristics import BALD
from baal.bayesian.dropout import patch_module
from baal.transformers_trainer_wrapper import BaalTransformersTrainer
from datasets import Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from label_studio_ml.model import LabelStudioMLBase
from label_studio_ml.utils import DATA_UNDEFINED_NAME, get_env
HOSTNAME = get_env('HOSTNAME', 'http://localhost:8080')
API_KEY = get_env('API_KEY')
LABEL_STUDIO_ML_BACKEND_V2 = get_env('LABEL_STUDIO_ML_BACKEND_V2', default=False)
print('=> LABEL STUDIO HOSTNAME = ', HOSTNAME)
if not API_KEY:
raise EnvironmentError('=> WARNING! API_KEY is not set')
if not LABEL_STUDIO_ML_BACKEND_V2:
raise EnvironmentError('=> WARNING! LABEL_STUDIO_ML_BACKEND_V2 is not set to true!')
BASE_MODEL = 'distilbert-base-uncased'
class SimpleTextClassifier(LabelStudioMLBase):
def __init__(self, **kwargs):
# don't forget to initialize base class...
super(SimpleTextClassifier, self).__init__(**kwargs)
# then collect all keys from config which will be used to extract data from task and to form prediction
# Parsed label config contains only one output of <Choices> type
assert len(self.parsed_label_config) == 1, self.parsed_label_config
self.from_name, self.info = list(self.parsed_label_config.items())[0]
assert self.info['type'] == 'Choices'
self.num_classes = len(self.info['labels'])
# the model has only one textual input
assert len(self.info['to_name']) == 1
assert len(self.info['inputs']) == 1
assert self.info['inputs'][0]['type'] == 'Text'
self.to_name = self.info['to_name'][0]
self.value = self.info['inputs'][0]['value']
if not self.train_output:
# This is an array of <Choice> labels
self.labels = self.info['labels']
self.reset_model()
print('Initialized with from_name={from_name}, to_name={to_name}, labels={labels}'.format(
from_name=self.from_name, to_name=self.to_name, labels=str(self.labels)
))
else:
self.reset_model()
# otherwise load the model from the latest training results
self.trainer.load_state_dict(torch.load(self.train_output['model_file']))
# and use the labels from training outputs
self.labels = self.train_output['labels']
print('Loaded from train output with from_name={from_name}, to_name={to_name}, labels={labels}'.format(
from_name=self.from_name, to_name=self.to_name, labels=str(self.labels)
))
def reset_model(self):
use_cuda = torch.cuda.is_available()
self.model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path=BASE_MODEL,
num_labels=self.num_classes)
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=BASE_MODEL)
self.model = patch_module(self.model)
if use_cuda:
self.model.cuda()
self.trainer = BaalTransformersTrainer(model=self.model)
def make_dataset(self, texts, labels):
dataset = Dataset.from_dict({
'text': texts,
'label': labels
})
return HuggingFaceDatasets(dataset, self.tokenizer, target_key='label', input_key='text', max_seq_len=128, )
def predict(self, tasks, **kwargs):
# collect input texts
input_texts = []
for task in tasks:
input_text = task['data'].get(self.value) or task['data'].get(DATA_UNDEFINED_NAME)
input_texts.append(input_text)
# get model predictions
probabilities = self.trainer.predict_on_dataset(self.make_dataset(input_texts, [0] * len(input_texts)), 20)
uncertainties = BALD().get_uncertainties(probabilities).tolist()
predictions = probabilities.mean(-1)
predicted_label_indices = np.argmax(predictions, axis=1).tolist()
predictions = []
for idx, score in zip(predicted_label_indices, uncertainties):
predicted_label = self.labels[idx]
# prediction result for the single task
result = [{
'from_name': self.from_name,
'to_name': self.to_name,
'type': 'choices',
'value': {'choices': [predicted_label]}
}]
# expand predictions with their scores for all tasks
predictions.append({'result': result, 'score': score})
return predictions
def _get_annotated_dataset(self, project_id):
"""Just for demo purposes: retrieve annotated data from Label Studio API"""
download_url = f'{HOSTNAME.rstrip("/")}/api/projects/{project_id}/export'
response = requests.get(download_url, headers={'Authorization': f'Token {API_KEY}'})
if response.status_code != 200:
raise Exception(f"Can't load task data using {download_url}, "
f"response status_code = {response.status_code}")
return json.loads(response.content)
def fit(self, annotations, workdir=None, **kwargs):
# check if training is from web hook
if kwargs.get('data'):
project_id = kwargs['data']['project']['id']
tasks = self._get_annotated_dataset(project_id)
# ML training without web hook
else:
tasks = annotations
input_texts = []
output_labels, output_labels_idx = [], []
label2idx = {l: i for i, l in enumerate(self.labels)}
for task in tasks:
if not task.get('annotations'):
continue
annotation = task['annotations'][0]
# get input text from task data
if annotation.get('skipped') or annotation.get('was_cancelled'):
continue
input_text = task['data'].get(self.value) or task['data'].get(DATA_UNDEFINED_NAME)
input_texts.append(input_text)
# get an annotation
output_label = annotation['result'][0]['value']['choices'][0]
output_labels.append(output_label)
output_label_idx = label2idx[output_label]
output_labels_idx.append(output_label_idx)
new_labels = set(output_labels)
if len(new_labels) != len(self.labels):
self.labels = list(sorted(new_labels))
print('Label set has been changed:' + str(self.labels))
label2idx = {l: i for i, l in enumerate(self.labels)}
output_labels_idx = [label2idx[label] for label in output_labels]
# train the model
print(f'Start training on {len(input_texts)} samples')
self.reset_model()
self.trainer.train_dataset = self.make_dataset(input_texts, output_labels_idx)
self.trainer.train()
# save output resources
workdir = workdir or os.getenv('MODEL_DIR')
model_name = str(uuid4())[:8]
if workdir:
model_file = os.path.join(workdir, f'{model_name}.pkl')
else:
model_file = f'{model_name}.pkl'
print(f'Save model to {model_file}')
torch.save(self.model.state_dict(), model_file)
train_output = {
'labels': self.labels,
'model_file': model_file
}
return train_output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment