Skip to content

Instantly share code, notes, and snippets.

@dvsrepo
Last active June 2, 2021 11:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dvsrepo/09e8de122d4cae2a3af985fea156e42e to your computer and use it in GitHub Desktop.
Save dvsrepo/09e8de122d4cae2a3af985fea156e42e to your computer and use it in GitHub Desktop.
from datasets import Dataset
import rubrix as rb
# load rubrix dataset
df = rb.load('unlabelled_dataset_zeroshot')
# inputs can be dicts to support multifield classifiers, we just use the text here.
df['text'] = df.inputs.transform(lambda r: r['text'])
# we flatten the annotations and create a dict for turning labels into numeric ids
df['labels'] = df.annotation.transform(lambda r: r[0])
label2id = {label:id for id,label in enumerate(set(df.labels.values))}
# create 🤗 dataset from pandas with labels as numeric ids
dataset = Dataset.from_pandas(df[['text', 'labels']])
dataset = dataset.map(lambda example: {'labels': label2id[example['labels']]})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment