Skip to content

Instantly share code, notes, and snippets.

@tosen1990
Created January 26, 2021 05:22
Show Gist options
  • Save tosen1990/2209f546a7a12e139bf14e5ac5e2c24b to your computer and use it in GitHub Desktop.
Save tosen1990/2209f546a7a12e139bf14e5ac5e2c24b to your computer and use it in GitHub Desktop.
use new tml to generate tiff
package com.dcits.spark.classify
import geotrellis.layer.{Boundable, Bounds, FloatingLayoutScheme, KeyBounds, LayoutDefinition, Metadata, SpaceTimeKey, SpatialComponent, SpatialKey, TemporalKey, TileLayerMetadata}
import geotrellis.proj4.{CRS, LatLng}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Encoders, Row, SparkSession}
import org.locationtech.rasterframes.WithBKryoMethods
import org.locationtech.rasterframes._
import org.locationtech.rasterframes.util._
import geotrellis.raster.{MultibandTile, Raster, RasterSource}
import org.locationtech.rasterframes.ml.{NoDataFilter, TileExploder}
import geotrellis.raster._
import geotrellis.spark._
import geotrellis.raster.io.geotiff.MultibandGeoTiff
import geotrellis.raster.resample.{NearestNeighbor, ResampleMethod}
import geotrellis.spark.tiling.Tiler
import geotrellis.spark.{ContextRDD, MultibandTileLayerRDD, withCollectMetadataMethods}
import geotrellis.store.hadoop.SerializableConfiguration
import geotrellis.vector.{Extent, ProjectedExtent}
import org.apache.spark.ml.{Pipeline, linalg}
import org.apache.spark.ml.classification.{DecisionTreeClassifier, RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions.{array, col, lit, udf}
import org.locationtech.rasterframes.datasource.raster.DataFrameReaderHasRasterSourceFormat
import org.apache.spark.ml.linalg.DenseVector
import org.apache.log4j.Logger
import org.apache.log4j.Level
object ClassificationTest2 {
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.WARN)
Logger.getLogger("akka").setLevel(Level.WARN)
implicit val spark = SparkSession.builder()
.master("local[*]")
.appName(getClass.getName)
.withKryoSerialization
.getOrCreate()
.withRasterFrames
import spark.implicits._
/**
* args prama e.g.
* /Users/ethan/work/data/L2a10m4326/xiangfu202009/202009_crop.tif /Users/ethan/work/data/L2a10m4326/xiangfu202009/xiangfu_9_classified.tif /Users/ethan/work/data/L2a10m4326/zds/roi/xiangfuqu_2020_9.geojson 1 100 256 100
*
*/
val (inputPath: String, outputPath: String, geojsonPath: String, bandIndex: String, treenum: Int, tilesize: Int, partitionnum: Int)
= (args(0), args(1), args(2), args(3), args(4).toInt, args(5).toInt, args(6).toInt)
val imageName: Array[String] = inputPath.split(",")
.map(_.split("\\/").last).map(_.replace(".tif", ""))
val bandIndexInt: Array[Int] = bandIndex.split(",").map(_.toInt - 1)
val cat =
s"""
|${imageName.mkString(",")}
|${inputPath}
""".stripMargin.trim
val df_init: DataFrame = spark.read.raster
.withTileDimensions(tilesize, tilesize)
.withBandIndexes(bandIndexInt: _*)
.fromCSV(cat).load()
.repartition(partitionnum)
df_init.printSchema()
val tileCols: Seq[Column] = df_init.tileColumns
val tileColName = tileCols.head.columnName
val tileColsName: Array[String] = tileCols.map(_.columnName).toArray
val df = df_init.withColumn("crs", rf_crs(col(tileColName)))
.withColumn("extent", rf_extent(col(tileColName)))
val crs: CRS = df.select(rf_crs(col("crs"))).distinct().first()
val targetCol = "label"
val df_tmp = df.select(col("extent"), col(tileColName + ".tile")).as[(Extent, Tile)]
val (_, tlm) = df_tmp
.map { case (ext, tile) => (ProjectedExtent(ext, crs), tile) }
.rdd.collectMetadata[SpatialKey](FloatingLayoutScheme(tilesize, tilesize))
val gb = tlm.layout.gridBoundsFor(tlm.extent)
val totalDim = Dimensions(gb.width, gb.height)
val totalCols = totalDim.cols.toInt
val totalRows = totalDim.rows.toInt
import org.locationtech.rasterframes.datasource.geojson._
val jsonDF: DataFrame = spark.read.geojson.load(geojsonPath)
val class_num = jsonDF.select("CLASS_ID").distinct().count.toInt
val label_df: DataFrame = jsonDF
.select($"CLASS_ID", st_reproject($"geometry", crs, crs).alias("geometry"))
.hint("broadcast")
val df_joined = df.join(label_df, st_intersects(st_geometry($"extent"), $"geometry"))
.withColumn("dims", rf_dimensions(col(tileColName)))
val df_labeled: DataFrame = df_joined.withColumn(
"label",
rf_rasterize($"geometry", st_geometry($"extent"), $"CLASS_ID", $"dims.cols", $"dims.rows")
)
df_labeled.printSchema()
val tmp = df_labeled.filter(rf_tile_sum($"label") > 0).cache()
val exploder = new TileExploder()
val noDataFilter = new NoDataFilter().setInputCols(tileColsName :+ targetCol)
val assembler = new VectorAssembler()
.setInputCols(tileColsName)
.setOutputCol("features")
val classifier = new RandomForestClassifier()
.setLabelCol(targetCol)
.setFeaturesCol(assembler.getOutputCol)
.setNumTrees(treenum)
val paramGrid = new ParamGridBuilder()
.addGrid(classifier.maxDepth, Array(5, 10, 15))
.build()
val cv: CrossValidator = new CrossValidator()
.setEstimator(classifier)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(5)
val pipeline = new Pipeline()
.setStages(Array(exploder, noDataFilter, assembler, classifier))
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol(targetCol)
.setPredictionCol("prediction")
.setMetricName("accuracy")
val model = pipeline.fit(tmp)
val rfModel = model.stages(3).asInstanceOf[RandomForestClassificationModel]
val importance: linalg.Vector = rfModel.featureImportances
println("featureImportances: ")
importance.toArray.map(println _)
println("Learned classification forest model:\n" + rfModel.toDebugString)
val prediction_df = model.transform(tmp)
val accuracy = evaluator.evaluate(prediction_df)
println("accuracy = " + accuracy)
val cnf_mtrx: Dataset[Row] = prediction_df.groupBy("prediction")
.pivot("label")
.count()
.sort("prediction")
cnf_mtrx.show(10, false)
prediction_df.groupBy($"prediction" as "class").count().show
prediction_df.show(10)
val scored = model.transform(df)
val vectorToColumn = udf { (x: DenseVector, index: Int) => x(index) * 100 }
val col_tmp: Seq[Column] = 1 until class_num map {
n =>
rf_assemble_tile(
$"column_index", $"row_index", vectorToColumn($"probability", lit(n)),
tlm.tileCols, tlm.tileRows, ByteConstantNoDataCellType
).alias("bd_" + (n + 1))
}
val col_total: Seq[Column] = 1 to class_num map {
n => col("bd_" + n)
}
val retiled: DataFrame = scored.groupBy($"crs", $"extent").agg(
rf_assemble_tile(
$"column_index", $"row_index", vectorToColumn($"probability", lit(0)),
tlm.tileCols, tlm.tileRows, ByteConstantNoDataCellType
).alias("bd_1"),
col_tmp: _*
)
retiled.printSchema()
val rf: RasterFrameLayer = retiled.toLayer(tlm)
rf.printSchema()
val noDataTile: Tile = ArrayTile.alloc(ByteConstantNoDataCellType, 256, 256).fill(byteNODATA).interpretAs(ByteConstantNoDataCellType)
// don't use toMultibandRaster,otherwise it throws null pointer exception. Just extract one band to generate new tiff for fast reproduce.
val rr: RDD[(SpatialKey, MultibandTile)] = rf.select(rf.spatialKeyColumn, array(col_total: _*)).as[(SpatialKey, Array[Tile])]
.rdd
.filter {
case (_: SpatialKey, b) if b(0) == null ⇒ false // remove any null Tiles
case _ ⇒ true
}
.map { case (sk, tiles) ⇒
(sk, MultibandTile(tiles))
}
val newLayout = LayoutDefinition(tlm.extent, TileLayout(1, 1, totalCols, totalRows))
val cellType = rr.first()._2.cellType
val newLayerMetadata: TileLayerMetadata[SpatialKey] =
tlm.copy(layout = newLayout, bounds = Bounds(SpatialKey(0, 0), SpatialKey(0, 0)), cellType = cellType)
val trans = tlm.mapTransform
val newLayer = rr
.map {
case (key, tile) ⇒
(ProjectedExtent(trans(key), tlm.crs), tile)
}
.tileToLayout(newLayerMetadata, Tiler.Options(NearestNeighbor))
val stitchedTile: Raster[MultibandTile] = newLayer.stitch()
val croppedTile = stitchedTile.crop(totalCols, totalRows)
import org.apache.hadoop.fs._
import geotrellis.spark.store.hadoop._
// Need this to write local files from spark
spark.sparkContext.hadoopConfiguration.set("fs.defaultFS", "hdfs://10.0.25.227:8020")
val hconf = SerializableConfiguration(spark.sparkContext.hadoopConfiguration)
MultibandGeoTiff(croppedTile.tile, newLayerMetadata.extent, newLayerMetadata.crs).write(new Path(outputPath), hconf.value)
// MultibandGeoTiff(croppedTile.tile, tlm.extent, tlm.crs).write(outputPath)
spark.stop
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment