Created
October 26, 2020 15:38
-
-
Save naiborhujosua/d5dcb89fbf1f0f6bb3bb56a5d5774f62 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Split the data into training and test sets | |
(training, test) = msd.randomSplit([0.8, 0.2]) | |
#Building 5 folds within the training set. | |
train1, train2, train3, train4, train5 = training.randomSplit([0.2, 0.2, 0.2, 0.2, 0.2], seed = 1) | |
fold1 = train2.union(train3).union(train4).union(train5) | |
fold2 = train3.union(train4).union(train5).union(train1) | |
fold3 = train4.union(train5).union(train1).union(train2) | |
fold4 = train5.union(train1).union(train2).union(train3) | |
fold5 = train1.union(train2).union(train3).union(train4) | |
foldlist = [(fold1, train1), (fold2, train2), (fold3, train3), (fold4, train4), (fold5, train5)] | |
# Empty list to fill with ROEMs from each model | |
ROEMS = [] | |
# Loops through all models and all folds | |
for model in model_list: | |
for ft_pair in foldlist: | |
# Fits model to fold within training data | |
fitted_model = model.fit(ft_pair[0]) | |
# Generates predictions using fitted_model on respective CV test data | |
predictions = fitted_model.transform(ft_pair[1]) | |
# Generates and prints a ROEM metric CV test data | |
r = ROEM(predictions) | |
print ("ROEM: ", r) | |
# Fits model to all of training data and generates preds for test data | |
v_fitted_model = model.fit(training) | |
v_predictions = v_fitted_model.transform(test) | |
v_ROEM = ROEM(v_predictions) | |
# Adds validation ROEM to ROEM list | |
ROEMS.append(v_ROEM) | |
print ("Validation ROEM: ", v_ROEM) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Now that we have several ALS models, each with a different set of hyperparameter values, we can train them on a training portion of the msd dataset using cross validation, and then run them on a test set of data and evaluate how well each one performs using the ROEM function