|
import json |
|
import os |
|
from uuid import uuid4 |
|
|
|
import numpy as np |
|
import requests |
|
import torch |
|
from baal.active.dataset.nlp_datasets import HuggingFaceDatasets |
|
from baal.active.heuristics import BALD |
|
from baal.bayesian.dropout import patch_module |
|
from baal.transformers_trainer_wrapper import BaalTransformersTrainer |
|
from datasets import Dataset |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
from label_studio_ml.model import LabelStudioMLBase |
|
from label_studio_ml.utils import DATA_UNDEFINED_NAME, get_env |
|
|
|
HOSTNAME = get_env('HOSTNAME', 'http://localhost:8080') |
|
API_KEY = get_env('API_KEY') |
|
LABEL_STUDIO_ML_BACKEND_V2 = get_env('LABEL_STUDIO_ML_BACKEND_V2', default=False) |
|
|
|
|
|
print('=> LABEL STUDIO HOSTNAME = ', HOSTNAME) |
|
if not API_KEY: |
|
raise EnvironmentError('=> WARNING! API_KEY is not set') |
|
|
|
if not LABEL_STUDIO_ML_BACKEND_V2: |
|
raise EnvironmentError('=> WARNING! LABEL_STUDIO_ML_BACKEND_V2 is not set to true!') |
|
|
|
|
|
BASE_MODEL = 'distilbert-base-uncased' |
|
|
|
|
|
class SimpleTextClassifier(LabelStudioMLBase): |
|
|
|
def __init__(self, **kwargs): |
|
# don't forget to initialize base class... |
|
super(SimpleTextClassifier, self).__init__(**kwargs) |
|
|
|
# then collect all keys from config which will be used to extract data from task and to form prediction |
|
# Parsed label config contains only one output of <Choices> type |
|
assert len(self.parsed_label_config) == 1, self.parsed_label_config |
|
self.from_name, self.info = list(self.parsed_label_config.items())[0] |
|
assert self.info['type'] == 'Choices' |
|
self.num_classes = len(self.info['labels']) |
|
|
|
# the model has only one textual input |
|
assert len(self.info['to_name']) == 1 |
|
assert len(self.info['inputs']) == 1 |
|
assert self.info['inputs'][0]['type'] == 'Text' |
|
self.to_name = self.info['to_name'][0] |
|
self.value = self.info['inputs'][0]['value'] |
|
|
|
if not self.train_output: |
|
# This is an array of <Choice> labels |
|
self.labels = self.info['labels'] |
|
self.reset_model() |
|
|
|
print('Initialized with from_name={from_name}, to_name={to_name}, labels={labels}'.format( |
|
from_name=self.from_name, to_name=self.to_name, labels=str(self.labels) |
|
)) |
|
else: |
|
self.reset_model() |
|
# otherwise load the model from the latest training results |
|
self.trainer.load_state_dict(torch.load(self.train_output['model_file'])) |
|
# and use the labels from training outputs |
|
self.labels = self.train_output['labels'] |
|
print('Loaded from train output with from_name={from_name}, to_name={to_name}, labels={labels}'.format( |
|
from_name=self.from_name, to_name=self.to_name, labels=str(self.labels) |
|
)) |
|
|
|
def reset_model(self): |
|
use_cuda = torch.cuda.is_available() |
|
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path=BASE_MODEL, |
|
num_labels=self.num_classes) |
|
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=BASE_MODEL) |
|
self.model = patch_module(self.model) |
|
if use_cuda: |
|
self.model.cuda() |
|
|
|
self.trainer = BaalTransformersTrainer(model=self.model) |
|
|
|
def make_dataset(self, texts, labels): |
|
dataset = Dataset.from_dict({ |
|
'text': texts, |
|
'label': labels |
|
}) |
|
return HuggingFaceDatasets(dataset, self.tokenizer, target_key='label', input_key='text', max_seq_len=128, ) |
|
|
|
def predict(self, tasks, **kwargs): |
|
# collect input texts |
|
input_texts = [] |
|
for task in tasks: |
|
input_text = task['data'].get(self.value) or task['data'].get(DATA_UNDEFINED_NAME) |
|
input_texts.append(input_text) |
|
|
|
# get model predictions |
|
probabilities = self.trainer.predict_on_dataset(self.make_dataset(input_texts, [0] * len(input_texts)), 20) |
|
uncertainties = BALD().get_uncertainties(probabilities).tolist() |
|
predictions = probabilities.mean(-1) |
|
predicted_label_indices = np.argmax(predictions, axis=1).tolist() |
|
predictions = [] |
|
for idx, score in zip(predicted_label_indices, uncertainties): |
|
predicted_label = self.labels[idx] |
|
# prediction result for the single task |
|
result = [{ |
|
'from_name': self.from_name, |
|
'to_name': self.to_name, |
|
'type': 'choices', |
|
'value': {'choices': [predicted_label]} |
|
}] |
|
|
|
# expand predictions with their scores for all tasks |
|
predictions.append({'result': result, 'score': score}) |
|
|
|
return predictions |
|
|
|
def _get_annotated_dataset(self, project_id): |
|
"""Just for demo purposes: retrieve annotated data from Label Studio API""" |
|
download_url = f'{HOSTNAME.rstrip("/")}/api/projects/{project_id}/export' |
|
response = requests.get(download_url, headers={'Authorization': f'Token {API_KEY}'}) |
|
if response.status_code != 200: |
|
raise Exception(f"Can't load task data using {download_url}, " |
|
f"response status_code = {response.status_code}") |
|
return json.loads(response.content) |
|
|
|
def fit(self, annotations, workdir=None, **kwargs): |
|
# check if training is from web hook |
|
if kwargs.get('data'): |
|
project_id = kwargs['data']['project']['id'] |
|
tasks = self._get_annotated_dataset(project_id) |
|
# ML training without web hook |
|
else: |
|
tasks = annotations |
|
|
|
input_texts = [] |
|
output_labels, output_labels_idx = [], [] |
|
label2idx = {l: i for i, l in enumerate(self.labels)} |
|
|
|
for task in tasks: |
|
if not task.get('annotations'): |
|
continue |
|
annotation = task['annotations'][0] |
|
# get input text from task data |
|
if annotation.get('skipped') or annotation.get('was_cancelled'): |
|
continue |
|
|
|
input_text = task['data'].get(self.value) or task['data'].get(DATA_UNDEFINED_NAME) |
|
input_texts.append(input_text) |
|
|
|
# get an annotation |
|
output_label = annotation['result'][0]['value']['choices'][0] |
|
output_labels.append(output_label) |
|
output_label_idx = label2idx[output_label] |
|
output_labels_idx.append(output_label_idx) |
|
|
|
new_labels = set(output_labels) |
|
if len(new_labels) != len(self.labels): |
|
self.labels = list(sorted(new_labels)) |
|
print('Label set has been changed:' + str(self.labels)) |
|
label2idx = {l: i for i, l in enumerate(self.labels)} |
|
output_labels_idx = [label2idx[label] for label in output_labels] |
|
|
|
# train the model |
|
print(f'Start training on {len(input_texts)} samples') |
|
self.reset_model() |
|
self.trainer.train_dataset = self.make_dataset(input_texts, output_labels_idx) |
|
self.trainer.train() |
|
|
|
# save output resources |
|
workdir = workdir or os.getenv('MODEL_DIR') |
|
model_name = str(uuid4())[:8] |
|
if workdir: |
|
model_file = os.path.join(workdir, f'{model_name}.pkl') |
|
else: |
|
model_file = f'{model_name}.pkl' |
|
print(f'Save model to {model_file}') |
|
torch.save(self.model.state_dict(), model_file) |
|
|
|
train_output = { |
|
'labels': self.labels, |
|
'model_file': model_file |
|
} |
|
return train_output |