Skip to content

Instantly share code, notes, and snippets.

@azarnyx
Last active May 10, 2021 15:36
Show Gist options
  • Save azarnyx/5930f5a778f9c4aefc34811cd21ed739 to your computer and use it in GitHub Desktop.
Save azarnyx/5930f5a778f9c4aefc34811cd21ed739 to your computer and use it in GitHub Desktop.
from transformers import AutoTokenizer, AutoModel, TFAutoModel
MODEL = "cardiffnlp/twitter-roberta-base"
TOKENIZER_EMB = AutoTokenizer.from_pretrained(MODEL)
MODEL_EMB = AutoModel.from_pretrained(MODEL)
def preprocess(text):
new_text = []
for t in text.split(" "):
t = '@user' if t.startswith('@') and len(t) > 1 else t
t = 'http' if t.startswith('http') else t
new_text.append(t)
return " ".join(new_text)
def get_embedding(text):
text = preprocess(text)
encoded_input = TOKENIZER_EMB(text, return_tensors='pt')
features = MODEL_EMB(**encoded_input)
features = features[0].detach().cpu().numpy()
features_mean = np.mean(features[0], axis=0)
return features_mean
initial_df = pd.read_csv("train_data_cleaning.csv", index_col=[0])
embed_df = initial_df.text.apply(get_embedding)
embed_df = pd.DataFrame(embed_df.to_list(), index= embed_df.index)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment