-
-
Save tosen1990/2209f546a7a12e139bf14e5ac5e2c24b to your computer and use it in GitHub Desktop.
use new tml to generate tiff
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.dcits.spark.classify | |
import geotrellis.layer.{Boundable, Bounds, FloatingLayoutScheme, KeyBounds, LayoutDefinition, Metadata, SpaceTimeKey, SpatialComponent, SpatialKey, TemporalKey, TileLayerMetadata} | |
import geotrellis.proj4.{CRS, LatLng} | |
import org.apache.spark.sql.{Column, DataFrame, Dataset, Encoders, Row, SparkSession} | |
import org.locationtech.rasterframes.WithBKryoMethods | |
import org.locationtech.rasterframes._ | |
import org.locationtech.rasterframes.util._ | |
import geotrellis.raster.{MultibandTile, Raster, RasterSource} | |
import org.locationtech.rasterframes.ml.{NoDataFilter, TileExploder} | |
import geotrellis.raster._ | |
import geotrellis.spark._ | |
import geotrellis.raster.io.geotiff.MultibandGeoTiff | |
import geotrellis.raster.resample.{NearestNeighbor, ResampleMethod} | |
import geotrellis.spark.tiling.Tiler | |
import geotrellis.spark.{ContextRDD, MultibandTileLayerRDD, withCollectMetadataMethods} | |
import geotrellis.store.hadoop.SerializableConfiguration | |
import geotrellis.vector.{Extent, ProjectedExtent} | |
import org.apache.spark.ml.{Pipeline, linalg} | |
import org.apache.spark.ml.classification.{DecisionTreeClassifier, RandomForestClassificationModel, RandomForestClassifier} | |
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator | |
import org.apache.spark.ml.feature.VectorAssembler | |
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} | |
import org.apache.spark.rdd.RDD | |
import org.apache.spark.sql.functions.{array, col, lit, udf} | |
import org.locationtech.rasterframes.datasource.raster.DataFrameReaderHasRasterSourceFormat | |
import org.apache.spark.ml.linalg.DenseVector | |
import org.apache.log4j.Logger | |
import org.apache.log4j.Level | |
object ClassificationTest2 { | |
def main(args: Array[String]): Unit = { | |
Logger.getLogger("org").setLevel(Level.WARN) | |
Logger.getLogger("akka").setLevel(Level.WARN) | |
implicit val spark = SparkSession.builder() | |
.master("local[*]") | |
.appName(getClass.getName) | |
.withKryoSerialization | |
.getOrCreate() | |
.withRasterFrames | |
import spark.implicits._ | |
/** | |
* args prama e.g. | |
* /Users/ethan/work/data/L2a10m4326/xiangfu202009/202009_crop.tif /Users/ethan/work/data/L2a10m4326/xiangfu202009/xiangfu_9_classified.tif /Users/ethan/work/data/L2a10m4326/zds/roi/xiangfuqu_2020_9.geojson 1 100 256 100 | |
* | |
*/ | |
val (inputPath: String, outputPath: String, geojsonPath: String, bandIndex: String, treenum: Int, tilesize: Int, partitionnum: Int) | |
= (args(0), args(1), args(2), args(3), args(4).toInt, args(5).toInt, args(6).toInt) | |
val imageName: Array[String] = inputPath.split(",") | |
.map(_.split("\\/").last).map(_.replace(".tif", "")) | |
val bandIndexInt: Array[Int] = bandIndex.split(",").map(_.toInt - 1) | |
val cat = | |
s""" | |
|${imageName.mkString(",")} | |
|${inputPath} | |
""".stripMargin.trim | |
val df_init: DataFrame = spark.read.raster | |
.withTileDimensions(tilesize, tilesize) | |
.withBandIndexes(bandIndexInt: _*) | |
.fromCSV(cat).load() | |
.repartition(partitionnum) | |
df_init.printSchema() | |
val tileCols: Seq[Column] = df_init.tileColumns | |
val tileColName = tileCols.head.columnName | |
val tileColsName: Array[String] = tileCols.map(_.columnName).toArray | |
val df = df_init.withColumn("crs", rf_crs(col(tileColName))) | |
.withColumn("extent", rf_extent(col(tileColName))) | |
val crs: CRS = df.select(rf_crs(col("crs"))).distinct().first() | |
val targetCol = "label" | |
val df_tmp = df.select(col("extent"), col(tileColName + ".tile")).as[(Extent, Tile)] | |
val (_, tlm) = df_tmp | |
.map { case (ext, tile) => (ProjectedExtent(ext, crs), tile) } | |
.rdd.collectMetadata[SpatialKey](FloatingLayoutScheme(tilesize, tilesize)) | |
val gb = tlm.layout.gridBoundsFor(tlm.extent) | |
val totalDim = Dimensions(gb.width, gb.height) | |
val totalCols = totalDim.cols.toInt | |
val totalRows = totalDim.rows.toInt | |
import org.locationtech.rasterframes.datasource.geojson._ | |
val jsonDF: DataFrame = spark.read.geojson.load(geojsonPath) | |
val class_num = jsonDF.select("CLASS_ID").distinct().count.toInt | |
val label_df: DataFrame = jsonDF | |
.select($"CLASS_ID", st_reproject($"geometry", crs, crs).alias("geometry")) | |
.hint("broadcast") | |
val df_joined = df.join(label_df, st_intersects(st_geometry($"extent"), $"geometry")) | |
.withColumn("dims", rf_dimensions(col(tileColName))) | |
val df_labeled: DataFrame = df_joined.withColumn( | |
"label", | |
rf_rasterize($"geometry", st_geometry($"extent"), $"CLASS_ID", $"dims.cols", $"dims.rows") | |
) | |
df_labeled.printSchema() | |
val tmp = df_labeled.filter(rf_tile_sum($"label") > 0).cache() | |
val exploder = new TileExploder() | |
val noDataFilter = new NoDataFilter().setInputCols(tileColsName :+ targetCol) | |
val assembler = new VectorAssembler() | |
.setInputCols(tileColsName) | |
.setOutputCol("features") | |
val classifier = new RandomForestClassifier() | |
.setLabelCol(targetCol) | |
.setFeaturesCol(assembler.getOutputCol) | |
.setNumTrees(treenum) | |
val paramGrid = new ParamGridBuilder() | |
.addGrid(classifier.maxDepth, Array(5, 10, 15)) | |
.build() | |
val cv: CrossValidator = new CrossValidator() | |
.setEstimator(classifier) | |
.setEstimatorParamMaps(paramGrid) | |
.setNumFolds(5) | |
val pipeline = new Pipeline() | |
.setStages(Array(exploder, noDataFilter, assembler, classifier)) | |
val evaluator = new MulticlassClassificationEvaluator() | |
.setLabelCol(targetCol) | |
.setPredictionCol("prediction") | |
.setMetricName("accuracy") | |
val model = pipeline.fit(tmp) | |
val rfModel = model.stages(3).asInstanceOf[RandomForestClassificationModel] | |
val importance: linalg.Vector = rfModel.featureImportances | |
println("featureImportances: ") | |
importance.toArray.map(println _) | |
println("Learned classification forest model:\n" + rfModel.toDebugString) | |
val prediction_df = model.transform(tmp) | |
val accuracy = evaluator.evaluate(prediction_df) | |
println("accuracy = " + accuracy) | |
val cnf_mtrx: Dataset[Row] = prediction_df.groupBy("prediction") | |
.pivot("label") | |
.count() | |
.sort("prediction") | |
cnf_mtrx.show(10, false) | |
prediction_df.groupBy($"prediction" as "class").count().show | |
prediction_df.show(10) | |
val scored = model.transform(df) | |
val vectorToColumn = udf { (x: DenseVector, index: Int) => x(index) * 100 } | |
val col_tmp: Seq[Column] = 1 until class_num map { | |
n => | |
rf_assemble_tile( | |
$"column_index", $"row_index", vectorToColumn($"probability", lit(n)), | |
tlm.tileCols, tlm.tileRows, ByteConstantNoDataCellType | |
).alias("bd_" + (n + 1)) | |
} | |
val col_total: Seq[Column] = 1 to class_num map { | |
n => col("bd_" + n) | |
} | |
val retiled: DataFrame = scored.groupBy($"crs", $"extent").agg( | |
rf_assemble_tile( | |
$"column_index", $"row_index", vectorToColumn($"probability", lit(0)), | |
tlm.tileCols, tlm.tileRows, ByteConstantNoDataCellType | |
).alias("bd_1"), | |
col_tmp: _* | |
) | |
retiled.printSchema() | |
val rf: RasterFrameLayer = retiled.toLayer(tlm) | |
rf.printSchema() | |
val noDataTile: Tile = ArrayTile.alloc(ByteConstantNoDataCellType, 256, 256).fill(byteNODATA).interpretAs(ByteConstantNoDataCellType) | |
// don't use toMultibandRaster,otherwise it throws null pointer exception. Just extract one band to generate new tiff for fast reproduce. | |
val rr: RDD[(SpatialKey, MultibandTile)] = rf.select(rf.spatialKeyColumn, array(col_total: _*)).as[(SpatialKey, Array[Tile])] | |
.rdd | |
.filter { | |
case (_: SpatialKey, b) if b(0) == null ⇒ false // remove any null Tiles | |
case _ ⇒ true | |
} | |
.map { case (sk, tiles) ⇒ | |
(sk, MultibandTile(tiles)) | |
} | |
val newLayout = LayoutDefinition(tlm.extent, TileLayout(1, 1, totalCols, totalRows)) | |
val cellType = rr.first()._2.cellType | |
val newLayerMetadata: TileLayerMetadata[SpatialKey] = | |
tlm.copy(layout = newLayout, bounds = Bounds(SpatialKey(0, 0), SpatialKey(0, 0)), cellType = cellType) | |
val trans = tlm.mapTransform | |
val newLayer = rr | |
.map { | |
case (key, tile) ⇒ | |
(ProjectedExtent(trans(key), tlm.crs), tile) | |
} | |
.tileToLayout(newLayerMetadata, Tiler.Options(NearestNeighbor)) | |
val stitchedTile: Raster[MultibandTile] = newLayer.stitch() | |
val croppedTile = stitchedTile.crop(totalCols, totalRows) | |
import org.apache.hadoop.fs._ | |
import geotrellis.spark.store.hadoop._ | |
// Need this to write local files from spark | |
spark.sparkContext.hadoopConfiguration.set("fs.defaultFS", "hdfs://10.0.25.227:8020") | |
val hconf = SerializableConfiguration(spark.sparkContext.hadoopConfiguration) | |
MultibandGeoTiff(croppedTile.tile, newLayerMetadata.extent, newLayerMetadata.crs).write(new Path(outputPath), hconf.value) | |
// MultibandGeoTiff(croppedTile.tile, tlm.extent, tlm.crs).write(outputPath) | |
spark.stop | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment