Created
November 11, 2018 03:37
-
-
Save jamesonthecrow/d14f5ee946f6981a3d60170cff18ecbd to your computer and use it in GitHub Desktop.
Train a text classification model with CreateML to suggest subreddits based on a proposed title.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import CreateML | |
import Foundation | |
// Load our data into an MLDataTable object. | |
let dataFilename = "PATH/TO/data.json" | |
let data = try MLDataTable(contentsOf: URL(fileURLWithPath: dataFilename)) | |
print(data.description) | |
/* | |
Columns: | |
label string | |
text string | |
Rows: 26985 | |
Data: | |
+----------------+----------------+ | |
| label | text | | |
+----------------+----------------+ | |
| MachineLearn...| Realtime mu...| | |
| MachineLearn...| Keras Imple...| | |
| MachineLearn...| Generative ...| | |
| MachineLearn...| Landing the...| | |
| MachineLearn...| If you had ...| | |
| MachineLearn...| Realtime Ma...| | |
| MachineLearn...| Dedicated t...| | |
| MachineLearn...| StarGAN Uni...| | |
| MachineLearn...| Deep Image ...| | |
| MachineLearn...| Overview of...| | |
+----------------+----------------+ | |
[26985 rows x 2 columns] | |
*/ | |
// Split the dataset into two parts, training and testing. | |
// We make sure to hold out some data for testing so we can | |
// identify overfitting. | |
let (trainingData, testingData) = data.randomSplit(by: 0.8, seed: 5) | |
// Train the model itself. | |
let subredditClassifier = try MLTextClassifier(trainingData: trainingData, | |
textColumn: "text", | |
labelColumn: "label") | |
// Training accuracy as a percentage | |
let trainingAccuracy = (1.0 - subredditClassifier.trainingMetrics.classificationError) * 100 | |
print("Training Accuracy: \(trainingAccuracy)") | |
/* | |
Training Accuracy: 99.37782530501143 | |
This is really high and suggests the model is over fitting. | |
*/ | |
// Evaluate the model on the testing data we kept secret from th emodel. | |
let evaluationMetrics = subredditClassifier.evaluation(on: testingData) | |
// Evaluation accuracy as a percentage | |
let evaluationAccuracy = (1.0 - evaluationMetrics.classificationError) * 100 | |
print("Evaluation Accuracy: \(evaluationAccuracy)") | |
/* | |
Evaluation Accuracy: 63.9894419306184 | |
Much lower than our training accuracy, but not bad considering there are 28 | |
potential subreddits to choose from. | |
*/ | |
// Test the model on a single example | |
let title = "Saw this good boy at the park today with TensorFlow." | |
let predictedSubreddit = try subredditClassifier.prediction(from: title) | |
print("Suggested subreddit: r/\(predictedSubreddit)") | |
/* | |
Suggested subreddit: r/aww | |
*/ | |
// Add some metadata | |
let metadata = MLModelMetadata(author: "Jameson Toole", | |
shortDescription: "Predict which subreddit a post should go in based on a title.", | |
version: "1.0") | |
// Save the model | |
try subredditClassifier.write(to: URL(fileURLWithPath: "PATH/TO/subredditClassifier.mlmodel"), | |
metadata: metadata) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment