Skip to content

Instantly share code, notes, and snippets.

@azarnyx
Created May 1, 2021 16:24
Show Gist options
  • Save azarnyx/689800160b81d2869a9ab38d226e0d77 to your computer and use it in GitHub Desktop.
Save azarnyx/689800160b81d2869a9ab38d226e0d77 to your computer and use it in GitHub Desktop.
import os
import json
import numpy as np
from joblib import load
import argparse
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from joblib import dump, load
from pathlib import Path
from transformers import AutoTokenizer, AutoModel, TFAutoModel
PATH_TO_DATA = s3://{insert your path to training data}/train_data_cleaning.csv
# parameters for HuggingFace library
MODEL = "cardiffnlp/twitter-roberta-base"
TOKENIZER_EMB = AutoTokenizer.from_pretrained(MODEL)
MODEL_EMB = AutoModel.from_pretrained(MODEL)
# functions for HuggingFace library
def preprocess(text):
new_text = []
for t in text.split(" "):
t = '@user' if t.startswith('@') and len(t) > 1 else t
t = 'http' if t.startswith('http') else t
new_text.append(t)
return " ".join(new_text)
def get_embedding(text):
text = preprocess(text)
encoded_input = TOKENIZER_EMB(text, return_tensors='pt')
features = MODEL_EMB(**encoded_input)
features = features[0].detach().cpu().numpy()
features_mean = np.mean(features[0], axis=0)
return features_mean
if __name__ =='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model-dir', type=str, default=os.environ.get('SM_MODEL_DIR'))
args, _ = parser.parse_known_args()
# Derive embedings
initial_df = pd.read_csv(PATH_TO_DATA, index_col=[0])
embed_df = initial_df.text.apply(get_embedding)
embed_df = pd.DataFrame(embed_df.to_list(), index= embed_df.index)
# Train model
X = embed_df
y = initial_df.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=40)
classif = RandomForestClassifier(n_estimators=80, max_depth=8)
classif.fit(X_train, y_train)
dump(classif, os.path.join(args.model_dir, 'sklearnclf.joblib'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment