Skip to content

Instantly share code, notes, and snippets.

@jamesonthecrow
Created November 11, 2018 03:37
Show Gist options
  • Save jamesonthecrow/d14f5ee946f6981a3d60170cff18ecbd to your computer and use it in GitHub Desktop.
Save jamesonthecrow/d14f5ee946f6981a3d60170cff18ecbd to your computer and use it in GitHub Desktop.
Train a text classification model with CreateML to suggest subreddits based on a proposed title.
import CreateML
import Foundation
// Load our data into an MLDataTable object.
let dataFilename = "PATH/TO/data.json"
let data = try MLDataTable(contentsOf: URL(fileURLWithPath: dataFilename))
print(data.description)
/*
Columns:
label string
text string
Rows: 26985
Data:
+----------------+----------------+
| label | text |
+----------------+----------------+
| MachineLearn...| Realtime mu...|
| MachineLearn...| Keras Imple...|
| MachineLearn...| Generative ...|
| MachineLearn...| Landing the...|
| MachineLearn...| If you had ...|
| MachineLearn...| Realtime Ma...|
| MachineLearn...| Dedicated t...|
| MachineLearn...| StarGAN Uni...|
| MachineLearn...| Deep Image ...|
| MachineLearn...| Overview of...|
+----------------+----------------+
[26985 rows x 2 columns]
*/
// Split the dataset into two parts, training and testing.
// We make sure to hold out some data for testing so we can
// identify overfitting.
let (trainingData, testingData) = data.randomSplit(by: 0.8, seed: 5)
// Train the model itself.
let subredditClassifier = try MLTextClassifier(trainingData: trainingData,
textColumn: "text",
labelColumn: "label")
// Training accuracy as a percentage
let trainingAccuracy = (1.0 - subredditClassifier.trainingMetrics.classificationError) * 100
print("Training Accuracy: \(trainingAccuracy)")
/*
Training Accuracy: 99.37782530501143
This is really high and suggests the model is over fitting.
*/
// Evaluate the model on the testing data we kept secret from th emodel.
let evaluationMetrics = subredditClassifier.evaluation(on: testingData)
// Evaluation accuracy as a percentage
let evaluationAccuracy = (1.0 - evaluationMetrics.classificationError) * 100
print("Evaluation Accuracy: \(evaluationAccuracy)")
/*
Evaluation Accuracy: 63.9894419306184
Much lower than our training accuracy, but not bad considering there are 28
potential subreddits to choose from.
*/
// Test the model on a single example
let title = "Saw this good boy at the park today with TensorFlow."
let predictedSubreddit = try subredditClassifier.prediction(from: title)
print("Suggested subreddit: r/\(predictedSubreddit)")
/*
Suggested subreddit: r/aww
*/
// Add some metadata
let metadata = MLModelMetadata(author: "Jameson Toole",
shortDescription: "Predict which subreddit a post should go in based on a title.",
version: "1.0")
// Save the model
try subredditClassifier.write(to: URL(fileURLWithPath: "PATH/TO/subredditClassifier.mlmodel"),
metadata: metadata)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment