Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yogeshpv/a1047449045655331a5ae2e2814074aa to your computer and use it in GitHub Desktop.
Save yogeshpv/a1047449045655331a5ae2e2814074aa to your computer and use it in GitHub Desktop.
Spark MLlib Twitter Quickstart
// Load tweets.
import scala.util.control.Breaks._
import scala.collection.JavaConversions._
import twitter4j.{Twitter,Query,TwitterFactory}
val twitter = TwitterFactory.getSingleton
val query = new twitter4j.Query("lang:en")
query.setCount(100)
query.setSince("2016-01-13")
query.setUntil("2016-01-24")
def getMaxId(tweets:java.util.List[twitter4j.Status]) = {
tweets.map(_.getId).max
}
def getMinId(tweets:java.util.List[twitter4j.Status]) = {
tweets.map(_.getId).min
}
def getTweetsAfterId(lastId:Long) = {
query.setSinceId(lastId)
twitter.search(query).getTweets()
}
def getTweetsBeforeId(firstId:Long) = {
query.setMaxId(firstId - 1)
twitter.search(query).getTweets()
}
var maxId = 0L
var minId = Long.MinValue
val tweetMaxCount = 10000
val tweetList = new java.util.ArrayList[twitter4j.Status]
while (tweetList.size < tweetMaxCount ) {
try {
val tweets = getTweetsBeforeId(minId)
println("tweetList.size=" + tweetList.size)
tweetList.addAll(tweets)
if (tweets.size != 0) {
minId = getMinId(tweets)
} else {
break
}
}
catch {
case e:Throwable => { println(e.getMessage); break }
}
}
// Build model.
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.VectorIndexer
import org.apache.spark.ml.regression.{RandomForestRegressionModel,
RandomForestRegressor}
def safeLog(x:Double) = {
math.log(x + 1)
}
val data = sc.parallelize(tweetList).
filter(_.getRetweetCount > 10).
map(t =>
LabeledPoint(
safeLog(t.getRetweetCount),
Vectors.dense(
safeLog(t.getUser.getFollowersCount),
safeLog(t.getMediaEntities.size),
safeLog(t.getUserMentionEntities.size),
safeLog(t.getHashtagEntities.size),
safeLog(t.getText.length)))).toDF
val featureIndexer = new VectorIndexer().
setInputCol("features").
setOutputCol("indexedFeatures").
setMaxCategories(4).
fit(data)
// Split data into training and test sets.
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
// Train RandomForest model.
val rf = new RandomForestRegressor().
setLabelCol("label").
setFeaturesCol("indexedFeatures")
// Chain indexer and forest in a Pipeline
val pipeline = new Pipeline().setStages(Array(featureIndexer, rf))
// Train model and run indexer.
val model = pipeline.fit(trainingData)
// Make predictions.
val predictions = model.transform(testData)
// Display some example rows.
predictions.select("prediction", "label", "features").show(5)
// Select (prediction, true label) and compute test error
val evaluator = new RegressionEvaluator().
setLabelCol("label").
setPredictionCol("prediction").
setMetricName("rmse")
val rmse = evaluator.evaluate(predictions)
println("Root Mean Squared Error (RMSE) on test data = " + rmse)
// Print out feature importances.
val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel]
rfModel.numFeatures
rfModel.featureImportances
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment