Skip to content

Instantly share code, notes, and snippets.

@hadifar
Created April 25, 2022 15:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hadifar/7b89bc435279829fb7923066a8a63869 to your computer and use it in GitHub Desktop.
Save hadifar/7b89bc435279829fb7923066a8a63869 to your computer and use it in GitHub Desktop.
A simple example for SVM
import argparse
import os
import numpy as np
from joblib import dump, load
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import accuracy_score
from sklearn.pipeline import Pipeline
from sklearn.svm import SVC
def load_data(args):
x_train = ['svm is all you need !!!', 'test is test', 'this is another test', 'this is test 3',
'svm is more than you need']
y_train = [0, 1, 1, 1, 0]
x_test = ['svm is good ?']
y_test = [0]
return x_train, x_test, y_train, y_test
def do_inference(args):
voting_classifier = load(args.save_path)
inps = ['appli onlin medicar access comput right'] # gold class 1
predictions = voting_classifier.predict_proba(inps)
predictions = [int(np.argmax(x)) for x in predictions]
print(predictions)
def train_classifier(args):
x_train, x_valid, y_train, y_valid = load_data(args)
print('loading data finished...')
# hyper_param_search()
# check trained file
if os.path.isfile(args.save_path):
print('load existing stat classifier')
pipeline = load(args.save_path)
else:
pipeline = Pipeline([
("tfidf", TfidfVectorizer(ngram_range=(1, 2))),
('svc', SVC(kernel='linear', C=1, probability=True))], verbose=True)
pipeline.fit(x_train, y_train)
dump(pipeline, args.save_path)
predictions = pipeline.predict_proba(x_valid)
predictions = [int(np.argmax(x)) for x in predictions]
print('ACC ' + str(accuracy_score(y_valid, predictions)))
def main(args):
# if args.model_name.find('stat') != -1:
train_classifier(args)
do_inference(args)
# do_inference(args)
# elif args.model_name.find('bert') != -1:
# do_train_and_eval(args)
# else:
# raise Exception('cls type not found...')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--debug', type=str,
default=1)
parser.add_argument('--save_path', type=str, default='voting_domain_classifier.joblib',
help='two type of classifier: neural-based (e.g., bert) and statistical based (e.g., svm)')
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment