Skip to content

Instantly share code, notes, and snippets.

@moradology
Created January 31, 2024 20:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save moradology/e26faef47eced56f0a01ae47e4aff558 to your computer and use it in GitHub Desktop.
Save moradology/e26faef47eced56f0a01ae47e4aff558 to your computer and use it in GitHub Desktop.
A simple test of GDAL functionality based on a raster with two classes: 'urban' and 'other'
package com.example.geotrellis
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import geotrellis.raster._
import geotrellis.raster.gdal.GDALRasterSource
import geotrellis.vector._
import geotrellis.raster.summary._
import geotrellis.raster.summary.polygonal._
import geotrellis.raster.gdal.GDALRasterSource
// spark-submit --class com.example.geotrellis.ExampleSparkGDAL example-spark-gdal.jar s3://your-bucket/your-raster.tif
object ExampleSparkGDAL extends App {
// Initialize SparkSession for EMR Serverless
val conf = new SparkConf()
.setAppName("GT GDAL Spark Example")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.set("spark.kryo.registrator", "geotrellis.spark.store.kryo.KryoRegistrator")
.set("spark.debug.maxToStringFields", "255")
val spark = SparkSession.builder.config(conf).getOrCreate()
// e.g. s3://geotrellis-test-non-public/lebanon-inferred-lulc-2013.tif
val url = args(0)
println(f"THE PROVIDED URL IS... ${url}")
val rasterSource = GDALRasterSource(url)
println(rasterSource)
val totalExtent = rasterSource.extent
println(totalExtent)
val subExtents = {
val colSize = totalExtent.width / 5
val rowSize = totalExtent.height / 5
for {
col <- 0 until 5
row <- 0 until 5
} yield {
println(f"COL ${col}, ROW ${row}")
Extent(
xmin = totalExtent.xmin + col * colSize,
ymin = totalExtent.ymin + row * rowSize,
xmax = totalExtent.xmin + (col + 1) * colSize,
ymax = totalExtent.ymin + (row + 1) * rowSize
)
}
}
println("Beginning distributed portion of task on executors")
println(f"SubExtents: ${subExtents}")
def countClasses(extent: Extent): (Long, Long) = {
rasterSource.read(extent) match {
case Some(raster) =>
val sbtile = raster.tile.band(0)
val histogram: Histogram[Int] = sbtile.histogram
val urbanCount = histogram.itemCount(2) // Urban/built up
val otherCount = histogram.itemCount(1) // Other
(urbanCount, otherCount)
case None =>
(0L, 0L)
}
}
def getCombinedCounts() = {
val rs = rasterSource
// Parallelize the extents using Spark
val extentsRDD = spark.sparkContext.parallelize(subExtents)
val countRDD = extentsRDD.map({ extent =>
rs.read(extent) match {
case Some(raster) =>
val sbtile = raster.tile.band(0)
val histogram: Histogram[Int] = sbtile.histogram
val urbanCount = histogram.itemCount(2) // Urban/built up
val otherCount = histogram.itemCount(1) // Other
(urbanCount, otherCount)
case None =>
(0L, 0L)
}
})
countRDD.reduce({ (a, b) =>
(a._1 + b._1, a._2 + b._2)
})
}
val combinedCounts = getCombinedCounts()
println(s"Combined Urban area count: ${combinedCounts._1}")
println(s"Combined Other area count: ${combinedCounts._2}")
// Stop the Spark session
spark.stop()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment