Created
April 18, 2018 14:51
-
-
Save orico/1f58046ae28e6c2141fdf723a4f41eeb to your computer and use it in GitHub Desktop.
Al-TrainModel
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TrainModel: | |
def __init__(self, model_object): | |
self.accuracies = [] | |
self.model_object = model_object() | |
def print_model_type(self): | |
print (self.model_object.model_type) | |
# we train normally and get probabilities for the validation set. i.e., we use the probabilities to select the most uncertain samples | |
def train(self, X_train, y_train, X_val, X_test, c_weight): | |
print ('Train set:', X_train.shape, 'y:', y_train.shape) | |
print ('Val set:', X_val.shape) | |
print ('Test set:', X_test.shape) | |
t0 = time.time() | |
(X_train, X_val, X_test, self.val_y_predicted, | |
self.test_y_predicted) = \ | |
self.model_object.fit_predict(X_train, y_train, X_val, X_test, c_weight) | |
self.run_time = time.time() - t0 | |
return (X_train, X_val, X_test) # we return them in case we use PCA, with all the other algorithms, this is not needed. | |
# we want accuracy only for the test set | |
def get_test_accuracy(self, i, y_test): | |
classif_rate = np.mean(self.test_y_predicted.ravel() == y_test.ravel()) * 100 | |
self.accuracies.append(classif_rate) | |
print('--------------------------------') | |
print('Iteration:',i) | |
print('--------------------------------') | |
print('y-test set:',y_test.shape) | |
print('Example run in %.3f s' % self.run_time,'\n') | |
print("Accuracy rate for %f " % (classif_rate)) | |
print("Classification report for classifier %s:\n%s\n" % (self.model_object.classifier, metrics.classification_report(y_test, self.test_y_predicted))) | |
print("Confusion matrix:\n%s" % metrics.confusion_matrix(y_test, self.test_y_predicted)) | |
print('--------------------------------') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment