Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@j14159
Created September 17, 2014 18:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save j14159/404f1dc86aeafff53a12 to your computer and use it in GitHub Desktop.
Save j14159/404f1dc86aeafff53a12 to your computer and use it in GitHub Desktop.
Updated S3N RDD
/*
* A more recent version of my S3N RDD. This exists because I needed
* a reliable way to distribute the fetching of S3 data using instance
* credentials as well as a simple way to filter out the inputs that
* I didn't want in the RDD.
*
* This version is more eager than the last one and also provides a
* simple RDD that allows you to tag each line with information about
* its partition/source.
*
* Use at your own risk.
*/
import com.amazonaws.auth.InstanceProfileCredentialsProvider
import com.amazonaws.services.s3.AmazonS3Client
import com.amazonaws.services.s3.model.{ GetObjectRequest, ObjectListing }
import java.io.{ BufferedReader, InputStreamReader }
import org.apache.spark.{ Partition, SparkContext, TaskContext }
import org.apache.spark.rdd.RDD
import scala.reflect.ClassTag
import scala.collection.JavaConverters._
case class S3NPartition(idx: Int, bucket: String, path: String) extends Partition {
def index = idx
}
/**
* The base implementation of S3NRDD that <b>requires</b> you to be
* using instance credentials because I dislike the idea of keys
* and secrets floating in source repositories, etc.
*/
abstract class BaseS3NRDD[T: ClassTag](sc: SparkContext) extends RDD[T](sc, Nil) {
val bucket: String
val files: Seq[String]
/**
* I've abstracted this portion as in some cases I needed to
* tag the items in the RDD with their provenance.
* See [[TaggingS3NRDD]]
*/
def partitionIterator(p: S3NPartition, seq: Seq[String]): Iterator[T]
protected def instanceCreds() = new InstanceProfileCredentialsProvider().getCredentials
override def getPartitions: Array[Partition] =
files.zipWithIndex.map { case (fn, i) => S3NPartition(i, bucket, fn) }.toArray
override def compute(split: Partition, context: TaskContext): Iterator[T] = split match {
case p @ S3NPartition(_, bucket, path) =>
val client = new AmazonS3Client(instanceCreds())
val obj = client.getObject(new GetObjectRequest(bucket, path))
val br = new BufferedReader(new InputStreamReader(obj.getObjectContent()))
val lines = eagerBufferReader(br, Nil)
br.close()
obj.close()
partitionIterator(p, lines)
}
/**
* Preserving the ordering of items in the underlying file is
* likely not strictly necessary but I need to re-evaluate
* against my current work.
*/
@scala.annotation.tailrec
private def eagerBufferReader(br: BufferedReader, memo: List[String]): Seq[String] =
br.readLine match {
case null => memo.reverse
case l => eagerBufferReader(br, l :: memo)
}
}
/**
* Directly construct and use, roughly equivalent to SparkContext.textFile calls but give this
* a list/sequence of files you want to load. This currently makes 1 Partition per file and
* once constructed, just use it like any other RDD.
*
* Example below will construct a RDD from all files starting with "some-files/file-" in the
* S3 bucket "my-bucket":
*
* new S3RDD(yourSparkContext, "my-bucket", new S3NListing("my-bukkit").list("some-files/file-"))
*/
class S3NRDD(sc: SparkContext, val bucket: String, val files: Seq[String]) extends BaseS3NRDD[String](sc) {
def partitionIterator(p: S3NPartition, seq: Seq[String]) = seq.toIterator
}
/**
* Similar to S3NRDD but tags each line of a file with the output of the supplied
* tagging function.
*/
class TaggingS3NRDD(sc: SparkContext, val bucket: String, val files: Seq[String], tagF: S3NPartition => String) extends BaseS3NRDD[(String, String)](sc) {
def partitionIterator(p: S3NPartition, seq: Seq[String]) =
seq.map(s => tagF(p) -> s).toIterator
}
/**
* Simple helper to find files within the given bucket.
*/
class S3NListing(bucket: String) {
private def instanceCreds() = new InstanceProfileCredentialsProvider().getCredentials
lazy val client = new AmazonS3Client(instanceCreds)
/**
* List files behind a given prefix, e.g. "" for all, "my-folder",
* "my-folder/files-that-start-like-this", etc. Will eagerly fetch
* all truncated results.
*/
def list(folder: String) = recursiveListing(folder, None, Nil)
@scala.annotation.tailrec
private def recursiveListing(folder: String, prev: Option[ObjectListing], memo: List[Seq[String]]): List[String] = prev match {
case None =>
val listing = client.listObjects(bucket, folder)
val keys = listing.getObjectSummaries.asScala.map(_.getKey)
if (listing.isTruncated)
recursiveListing(folder, Some(listing), keys :: memo)
else
keys.toList
case Some(lastListing) =>
val listing = client.listNextBatchOfObjects(lastListing)
val keys = listing.getObjectSummaries.asScala.map(_.getKey())
if(listing.isTruncated)
recursiveListing(folder, Some(listing), keys :: memo)
else
(keys :: memo).flatten
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment