Skip to content

Instantly share code, notes, and snippets.

@aplz
Created September 5, 2018 16:15
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save aplz/537826ee96c9097a405f4c20c6164e5d to your computer and use it in GitHub Desktop.
Save aplz/537826ee96c9097a405f4c20c6164e5d to your computer and use it in GitHub Desktop.
sklearn cross-validation for fasttext
import argparse
import os
import fasttext
from sklearn.base import BaseEstimator
from sklearn.metrics import f1_score
from sklearn.model_selection import cross_val_score, StratifiedKFold
def read_data(data_dir):
"""
Import data from a directory. Each of the child directories is assumed to be the label for the files contained in it.
:param data_dir: the path to a data directory.
:return: texts, labels: the list of (unprocessed) text files and the list of associated labels.
"""
print("Converting files from %s to fasttext format." % data_dir)
texts = []
labels = []
for root_directory, child_directories, files in os.walk(data_dir):
for child_directory in sorted(child_directories):
label = child_directory
for sub_root, sub_child_directories, actual_files in os.walk(os.path.join(root_directory, child_directory)):
for filename in actual_files:
file_path = os.path.join(root_directory, child_directory, filename)
with open(file_path, 'r') as text_file:
text = text_file.read()
text = text.lower().replace("\n", " ")
if text is None:
print("Could not extract text from %s!" % file_path)
continue
texts.append(text)
# append the label in the format as expected by fasttext
labels.append("__label__%s" % label)
print("Imported %s text files." % len(texts))
return texts, labels
class FasttextEstimator(BaseEstimator):
def __init__(self, model_dir):
self.model_dir = model_dir
self.model = None
def fit(self, features, labels):
"""
Train fasttext on the given features and labels.
:param features: a list of documents.
:param labels: the list of labels associated with the list of features.
"""
store_file(os.path.join(self.model_dir, "train.txt"), features, labels)
fasttext.supervised(os.path.join(self.model_dir, "train.txt"), os.path.join(self.model_dir, "cv_model"),
thread=1, minn=0, maxn=10, bucket=1)
self.model = fasttext.load_model(os.path.join(self.model_dir, "cv_model.bin"), encoding='utf-8')
return self
def score(self, features, labels):
"""
Compute the macro-f1 score for the predictions on the given features.
:param features: a list of documents.
:param labels: the list of labels associated with the list of features.
:return: f1_score: the macro-f1 score for the predictions on the given features.
"""
predicted_labels = []
predictions = self.model.predict_proba(features)
for prediction in predictions:
predicted_label = prediction[0][0]
predicted_labels.append(predicted_label)
return f1_score(labels, predicted_labels, average="macro")
def store_file(output_file, features, labels):
"""Write the training data in fasttext format to disk.
:param output_file: the name of the output file.
:param features: the features, a list of strings.
:param labels: the labels associated with features.
"""
with open(output_file, 'w') as f:
for i in range(0, len(features)):
f.write("%s %s\n" % (labels[i], features[i]))
if __name__ == '__main__':
ap = argparse.ArgumentParser()
ap.add_argument("--directory", help="path to the directory containing training data",
default="/home/user/fasttext_data/")
ap.add_argument("--model_dir", help="path to the model directory", default="/home/user/temp/")
args = vars(ap.parse_args())
texts, labels = read_data(data_dir=args["directory"])
estimator = FasttextEstimator(model_dir=args["model_dir"])
score = cross_val_score(estimator, texts, labels, cv=StratifiedKFold(n_splits=5), verbose=5)
print(score)
@harunyasar
Copy link

@aplz Could you share your folder structure with example data sets?

@alpha999999999999
Copy link

thanks, it's very helpful with me ^0^

@karthik-ir
Copy link

Awesome!

@tommp4
Copy link

tommp4 commented Dec 17, 2020

This is awesome thanks! 🥇

@hafiz031
Copy link

hafiz031 commented Sep 7, 2021

At line #54 self.model = fasttext.load_model(os.path.join(self.model_dir, "cv_model.bin"), encoding='utf-8'), that means, here the model file is once saved and therefore being loaded from the disc. Will it be consistent if we pass n_jobs > 1 (or n_jobs = -1) in GridSearchCV? I mean, if more than one grid search combination is tried out simultaneously then, will not there be a possibility to replace one model by another? From my experience, I have faced this issue, and the FastText was showing error like: `Model file has wrong file format!' at some point.

Possibly, we need to keep it on memory or have to use unique filename for each search of grid search. And also have to remove all these (temporary) model files which have been created during grid search and save only the best one.

@aplz
Copy link
Author

aplz commented Sep 7, 2021

@hafiz031

At line #54 self.model = fasttext.load_model(os.path.join(self.model_dir, "cv_model.bin"), encoding='utf-8'), that means, here the model file is once saved and therefore being loaded from the disc. Will it be consistent if we pass n_jobs > 1? I mean, if more than one grid search combination is tried out simultaneously then, will not there be a possibility to replace one model by another?

There is no n_jobs parameter in the official documentation but you probably refer to parallelization. No, unfortunately, this setup does currently not support parallelization. But the serialization step would most likely need to be adjusted, yes.

@hafiz031
Copy link

hafiz031 commented Sep 7, 2021

@aplz thanks. I understood. I am actually looking for an implementation which can be used with sklearn.model_selection.GridSearchCV which supports this parallelization. I implemented fit() and predict() btw. for classification and in my implementation I loaded the model inside predict() method. But anyways, I got this error and probably I need to change it as aforementioned.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment