Skip to content

Instantly share code, notes, and snippets.

@Navjotbians
Created May 22, 2021 20:11
Show Gist options
  • Save Navjotbians/56f6ae962b23a52ea9680d439925dc42 to your computer and use it in GitHub Desktop.
Save Navjotbians/56f6ae962b23a52ea9680d439925dc42 to your computer and use it in GitHub Desktop.
train fuction
from sklearn.multiclass import OneVsRestClassifier
from sklearn import model_selection
from sklearn.metrics import accuracy_score, classification_report, f1_score, roc_auc_score
from sklearn.metrics import multilabel_confusion_matrix
### OneVsRestClassifier
def train_model(classifier,X, y, max_feature = 1000, embedding= 'bow' ):
#Train-test split
print("... Performing train test split")
X_train, X_test, y_train, y_test = model_selection.train_test_split(X,y,
test_size=0.25,random_state=42)
## Features extraction with word embedding
print("... Extracting features")
Xv_train, Xv_test, vectorizer = get_embeddings(X_train, X_test,
max_feature = max_feature , embedding_type= embedding)
# train the model
print('... Training {} model'.format(classifier.__class__.__name__))
clf = OneVsRestClassifier(classifier)
clf.fit(Xv_train, y_train)
# compute the test accuracy
print("... Computing accuracy")
prediction = clf.predict(Xv_test)
## Accuracy score
score = (accuracy_score(y_test, prediction))
type2_score = j_score(y_test, prediction)
f1_s = f1_score(y_test, prediction,average='macro')
roc_auc = roc_auc_score(y_test, prediction)
confusion_matrix = multilabel_confusion_matrix(y_test, prediction)
score_sumry = [score, type2_score, f1_s, roc_auc]
print('\n')
print("Model evaluation")
print("------")
print(print_score(prediction,y_test, classifier))
print('Accuracy is {}'.format(score))
print("ROC_AUC - {}".format(roc_auc))
print("------")
print("Multilabel confusion matrix \n {}".format(confusion_matrix))
return clf, vectorizer, score_sumry
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment