Skip to content

Instantly share code, notes, and snippets.

@dyerrington
Last active July 8, 2022 17:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dyerrington/c8b183f76cd932b6e8c2fa1ad98f7c9e to your computer and use it in GitHub Desktop.
Save dyerrington/c8b183f76cd932b6e8c2fa1ad98f7c9e to your computer and use it in GitHub Desktop.

Installation

It's recommended that you install the requirements for these t5 models in a new environment since they are known to conflict with common Python package requirements in the scientific Python Stack.

Conda ENV Setup

Create

conda create -n nlp-t5

Activate

conda activate nlp-t5

Packages

conda install python==3.8
conda install -c conda-forge python-multipart pandas scikit-learn transformers protobuf

At least on an M1 Macbookpro, conda managed to get these dependencies right. However, you may need to brew install a few system-level packages and/or deal with gcc or xcode modules if you don't have some of the lower-level prerequisites installed.

References

Original Notebook https://github.com/patil-suraj/exploring-T5/blob/master/t5_fine_tuning.ipynb

Dataset https://github.com/dair-ai/emotion_dataset

Whitepaper https://arxiv.org/pdf/1910.10683.pdf

import pandas as pd
from transformers import AutoTokenizer, AutoModelWithLMHead
tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-emotion")
model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-emotion")
def get_emotion(text):
input_ids = tokenizer.encode(text + '</s>', return_tensors='pt')
output = model.generate(input_ids=input_ids,
max_length=2)
dec = [tokenizer.decode(ids) for ids in output]
label = dec[0]
return label
data = [(1, "I never drink the tap water."), (2, "This customer service is like an enraged wombat."), (3, "Sometimes it's lonely taking a survey.")]
df = pd.DataFrame(data, columns = ["id", "text"])
df['T5_prediction'] = df['text'].map(lambda text: get_emotion(text))
# Print results to console or Jupyter cell (no need to use print function)
print(df.head())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment