-
-
Save moradology/e26faef47eced56f0a01ae47e4aff558 to your computer and use it in GitHub Desktop.
A simple test of GDAL functionality based on a raster with two classes: 'urban' and 'other'
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.example.geotrellis | |
import org.apache.spark.SparkConf | |
import org.apache.spark.sql.SparkSession | |
import geotrellis.raster._ | |
import geotrellis.raster.gdal.GDALRasterSource | |
import geotrellis.vector._ | |
import geotrellis.raster.summary._ | |
import geotrellis.raster.summary.polygonal._ | |
import geotrellis.raster.gdal.GDALRasterSource | |
// spark-submit --class com.example.geotrellis.ExampleSparkGDAL example-spark-gdal.jar s3://your-bucket/your-raster.tif | |
object ExampleSparkGDAL extends App { | |
// Initialize SparkSession for EMR Serverless | |
val conf = new SparkConf() | |
.setAppName("GT GDAL Spark Example") | |
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") | |
.set("spark.kryo.registrator", "geotrellis.spark.store.kryo.KryoRegistrator") | |
.set("spark.debug.maxToStringFields", "255") | |
val spark = SparkSession.builder.config(conf).getOrCreate() | |
// e.g. s3://geotrellis-test-non-public/lebanon-inferred-lulc-2013.tif | |
val url = args(0) | |
println(f"THE PROVIDED URL IS... ${url}") | |
val rasterSource = GDALRasterSource(url) | |
println(rasterSource) | |
val totalExtent = rasterSource.extent | |
println(totalExtent) | |
val subExtents = { | |
val colSize = totalExtent.width / 5 | |
val rowSize = totalExtent.height / 5 | |
for { | |
col <- 0 until 5 | |
row <- 0 until 5 | |
} yield { | |
println(f"COL ${col}, ROW ${row}") | |
Extent( | |
xmin = totalExtent.xmin + col * colSize, | |
ymin = totalExtent.ymin + row * rowSize, | |
xmax = totalExtent.xmin + (col + 1) * colSize, | |
ymax = totalExtent.ymin + (row + 1) * rowSize | |
) | |
} | |
} | |
println("Beginning distributed portion of task on executors") | |
println(f"SubExtents: ${subExtents}") | |
def countClasses(extent: Extent): (Long, Long) = { | |
rasterSource.read(extent) match { | |
case Some(raster) => | |
val sbtile = raster.tile.band(0) | |
val histogram: Histogram[Int] = sbtile.histogram | |
val urbanCount = histogram.itemCount(2) // Urban/built up | |
val otherCount = histogram.itemCount(1) // Other | |
(urbanCount, otherCount) | |
case None => | |
(0L, 0L) | |
} | |
} | |
def getCombinedCounts() = { | |
val rs = rasterSource | |
// Parallelize the extents using Spark | |
val extentsRDD = spark.sparkContext.parallelize(subExtents) | |
val countRDD = extentsRDD.map({ extent => | |
rs.read(extent) match { | |
case Some(raster) => | |
val sbtile = raster.tile.band(0) | |
val histogram: Histogram[Int] = sbtile.histogram | |
val urbanCount = histogram.itemCount(2) // Urban/built up | |
val otherCount = histogram.itemCount(1) // Other | |
(urbanCount, otherCount) | |
case None => | |
(0L, 0L) | |
} | |
}) | |
countRDD.reduce({ (a, b) => | |
(a._1 + b._1, a._2 + b._2) | |
}) | |
} | |
val combinedCounts = getCombinedCounts() | |
println(s"Combined Urban area count: ${combinedCounts._1}") | |
println(s"Combined Other area count: ${combinedCounts._2}") | |
// Stop the Spark session | |
spark.stop() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment