Skip to content

Commit

Permalink
added function RST_AsFormat() to change raster format / driver in-situ
Browse files Browse the repository at this point in the history
  • Loading branch information
sllynn committed Nov 12, 2024
1 parent 4aa2f10 commit a479ab7
Show file tree
Hide file tree
Showing 14 changed files with 377 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package com.databricks.labs.mosaic.core.raster.operator.RasterTranslate

import com.databricks.labs.mosaic.core.raster.api.GDAL
import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL
import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALTranslate
import com.databricks.labs.mosaic.utils.PathUtils

object TranslateFormat {

/**
* Converts the data type of a raster's bands
*
* @param raster
* The raster to update.
* @param newFormat
* The new format of the raster.
* @return
* A MosaicRasterGDAL object.
*/
def update(
raster: MosaicRasterGDAL,
newFormat: String
): MosaicRasterGDAL = {

val outOptions = raster.getWriteOptions.copy(format = newFormat, extension = GDAL.getExtension(newFormat))
val resultFileName = PathUtils.createTmpFilePath(outOptions.extension)

val result = GDALTranslate.executeTranslate(
resultFileName,
raster,
command = s"gdal_translate",
outOptions
)

result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ object VectorClipper {
* The shapefile name.
*/
private def getShapefileName: String = {
val shapeFileName = PathUtils.createTmpFilePath(".shp")
val shapeFileName = PathUtils.createTmpFilePath("shp")
shapeFileName
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ case class MosaicRasterTile(

def getDriver: String = driver

def setDriver(value: String): MosaicRasterTile = {
new MosaicRasterTile(index, raster.copy(createInfo = raster.createInfo.updated("driver", value)))
}

def driver: String = raster.createInfo("driver")

def getRaster: MosaicRasterGDAL = raster
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,15 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead

val retiledDf = retileRaster(pathsDf, config)

val convertToFormat = if (config("convertToFormat").isEmpty) {
col("tile.metadata").getItem("driver") // which should be a noop
} else {
lit(config("convertToFormat"))
}
val rasterToGridCombiner = getRasterToGridFunc(config("combiner"))

val loadedDf = retiledDf
.withColumn("tile", rst_asformat(col("tile"), convertToFormat))
.withColumn(
"tile",
rst_tessellate(col("tile"), lit(resolution))
Expand Down Expand Up @@ -225,7 +231,8 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead
"retile" -> this.extraOptions.getOrElse("retile", "false"),
"tileSize" -> this.extraOptions.getOrElse("tileSize", "-1"),
"sizeInMB" -> this.extraOptions.getOrElse("sizeInMB", "-1"),
"kRingInterpolate" -> this.extraOptions.getOrElse("kRingInterpolate", "0")
"kRingInterpolate" -> this.extraOptions.getOrElse("kRingInterpolate", "0"),
"convertToFormat" -> this.extraOptions.getOrElse("convertToFormat", "")
)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package com.databricks.labs.mosaic.expressions.raster

import com.databricks.labs.mosaic.core.raster.api.GDAL
import com.databricks.labs.mosaic.core.raster.operator.RasterTranslate.TranslateFormat
import com.databricks.labs.mosaic.core.types.RasterTileType
import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile
import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo}
import com.databricks.labs.mosaic.expressions.raster.base.Raster1ArgExpression
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant}
import org.apache.spark.sql.types.DataType
import org.apache.spark.unsafe.types.UTF8String

case class RST_AsFormat (
tileExpr: Expression,
newFormat: Expression,
expressionConfig: MosaicExpressionConfig
) extends Raster1ArgExpression[RST_AsFormat](
tileExpr,
newFormat,
returnsRaster = true,
expressionConfig
)
with NullIntolerant
with CodegenFallback {

override def dataType: DataType = {
GDAL.enable(expressionConfig)
RasterTileType(expressionConfig.getCellIdType, tileExpr, expressionConfig.isRasterUseCheckpoint)
}

/** Changes the data type of a band of the raster. */
override def rasterTransform(tile: MosaicRasterTile, arg1: Any): Any = {

val newFormat = arg1.asInstanceOf[UTF8String].toString
if (tile.getRaster.driverShortName.getOrElse("") == newFormat) {
return tile
}
val result = TranslateFormat.update(tile.getRaster, newFormat)
tile.copy(raster = result).setDriver(newFormat)
}

}

/** Expression info required for the expression registration for spark SQL. */
object RST_AsFormat extends WithExpressionInfo {

override def name: String = "rst_asformat"

override def usage: String = "_FUNC_(expr1) - Returns a raster tile in a different underlying format"

override def example: String =
"""
| Examples:
| > SELECT _FUNC_(tile, 'GTiff')
| {index_id, updated_raster, parentPath, driver}
| """.stripMargin

override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = {
GenericExpressionFactory.getBaseBuilder[RST_AsFormat](2, expressionConfig)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.databricks.labs.mosaic.expressions.raster

import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile
import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo}
import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant}
import org.apache.spark.sql.types.{DataType, StringType}
import org.apache.spark.unsafe.types.UTF8String

case class RST_Format (
tileExpr: Expression,
expressionConfig: MosaicExpressionConfig
) extends RasterExpression[RST_Format](
tileExpr,
returnsRaster = false,
expressionConfig
)
with NullIntolerant
with CodegenFallback {

override def dataType: DataType = StringType

/** Returns the format of the raster. */
override def rasterTransform(tile: MosaicRasterTile): Any = {
UTF8String.fromString(tile.getDriver)
}

}

/** Expression info required for the expression registration for spark SQL. */
object RST_Format extends WithExpressionInfo {

override def name: String = "rst_format"

override def usage: String = "_FUNC_(expr1) - Returns the driver used to read the raster"

override def example: String =
"""
| Examples:
| > SELECT _FUNC_(tile)
| 'GTiff'
| """.stripMargin

override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = {
GenericExpressionFactory.getBaseBuilder[RST_Format](1, expressionConfig)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends
)

/** RasterAPI dependent functions */
mosaicRegistry.registerExpression[RST_AsFormat](expressionConfig)
mosaicRegistry.registerExpression[RST_Avg](expressionConfig)
mosaicRegistry.registerExpression[RST_BandMetaData](expressionConfig)
mosaicRegistry.registerExpression[RST_BoundingBox](expressionConfig)
Expand All @@ -285,6 +286,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends
mosaicRegistry.registerExpression[RST_DerivedBand](expressionConfig)
mosaicRegistry.registerExpression[RST_DTMFromGeoms](expressionConfig)
mosaicRegistry.registerExpression[RST_Filter](expressionConfig)
mosaicRegistry.registerExpression[RST_Format](expressionConfig)
mosaicRegistry.registerExpression[RST_GeoReference](expressionConfig)
mosaicRegistry.registerExpression[RST_GetNoData](expressionConfig)
mosaicRegistry.registerExpression[RST_GetSubdataset](expressionConfig)
Expand Down Expand Up @@ -695,6 +697,10 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends
def st_within(geom1: Column, geom2: Column): Column = ColumnAdapter(ST_Within(geom1.expr, geom2.expr, expressionConfig))

/** RasterAPI dependent functions */
def rst_asformat(raster: Column, driver: Column): Column =
ColumnAdapter(RST_AsFormat(raster.expr, driver.expr, expressionConfig))
def rst_asformat(raster: Column, driver: String): Column =
ColumnAdapter(RST_AsFormat(raster.expr, lit(driver).expr, expressionConfig))
def rst_bandmetadata(raster: Column, band: Column): Column =
ColumnAdapter(RST_BandMetaData(raster.expr, band.expr, expressionConfig))
def rst_bandmetadata(raster: Column, band: Int): Column =
Expand All @@ -716,6 +722,8 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends
ColumnAdapter(RST_Filter(raster.expr, kernelSize.expr, operation.expr, expressionConfig))
def rst_filter(raster: Column, kernelSize: Int, operation: String): Column =
ColumnAdapter(RST_Filter(raster.expr, lit(kernelSize).expr, lit(operation).expr, expressionConfig))
def rst_format(raster: Column): Column =
ColumnAdapter(RST_Format(raster.expr, expressionConfig))
def rst_georeference(raster: Column): Column = ColumnAdapter(RST_GeoReference(raster.expr, expressionConfig))
def rst_getnodata(raster: Column): Column = ColumnAdapter(RST_GetNoData(raster.expr, expressionConfig))
def rst_getsubdataset(raster: Column, subdatasetName: Column): Column =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,9 @@ object PathUtils {
val destination = Paths.get(copyToPath, path.getFileName.toString)
// noinspection SimplifyBooleanMatch
if (Files.isDirectory(path)) FileUtils.copyDirectory(path.toFile, destination.toFile)
else FileUtils.copyFile(path.toFile, destination.toFile)
else if (path.toFile != destination.toFile) {
FileUtils.copyFile(path.toFile, destination.toFile)
}
}
}

Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import com.databricks.labs.mosaic.test.MosaicSpatialQueryTest
import org.apache.spark.sql.test.SharedSparkSessionGDAL
import org.scalatest.Tag
import org.scalatest.matchers.must.Matchers.{be, noException}
import org.scalatest.matchers.should.Matchers.an
import org.scalatest.matchers.should.Matchers.{an, convertToAnyShouldWrapper}

import java.nio.file.{Files, Paths}

Expand Down Expand Up @@ -41,6 +41,36 @@ class RasterAsGridReaderTest extends MosaicSpatialQueryTest with SharedSparkSess

}

test("Read ECMWF netcdf with Raster As Grid Reader") {
assume(System.getProperty("os.name") == "Linux")
assume(checkpointingEnabled)
val mc = MosaicContext.build(H3IndexSystem, JTS)
mc.register(spark)


val netcdf = "/binary/netcdf-ECMWF/"
val filePath = this.getClass.getResource(netcdf).getPath

val result = MosaicContext.read
.format("raster_to_grid")
.option("sizeInMB", "16")
.option("convertToFormat", "GTiff")
.option("resolution", "0")
.option("readSubdataset", "true")
.option("subdatasetName", "t2m")
.option("retile", "true")
.option("tileSize", "600")
.option("combiner", "avg")
.load(filePath)
.select("measure")
.cache()

result.count shouldBe 1098

noException should be thrownBy result.take(1)

}

test("Read grib with Raster As Grid Reader", ExcludeLocalTag) {
assume(System.getProperty("os.name") == "Linux")
MosaicContext.build(H3IndexSystem, JTS)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.databricks.labs.mosaic.expressions.raster

import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI
import com.databricks.labs.mosaic.core.index.IndexSystem
import com.databricks.labs.mosaic.functions.MosaicContext
import com.databricks.labs.mosaic.test.mocks.filePath
import com.databricks.labs.mosaic.{MOSAIC_RASTER_READ_IN_MEMORY, MOSAIC_RASTER_READ_STRATEGY}
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.functions.lit
import org.scalatest.matchers.should.Matchers._

trait RST_AsFormatBehaviours extends QueryTest {

// noinspection MapGetGet
def behavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = {
spark.sparkContext.setLogLevel("ERROR")
val mc = MosaicContext.build(indexSystem, geometryAPI)
mc.register()
val sc = spark
import mc.functions._
import sc.implicits._

val subDataset = "t2m"

val rastersInMemory = spark.read
.format("gdal")
.option(MOSAIC_RASTER_READ_STRATEGY, MOSAIC_RASTER_READ_IN_MEMORY)
.load(filePath("/binary/netcdf-ECMWF/"))
.withColumn("tile", rst_getsubdataset($"tile", lit(subDataset)))

val newFormat = "GTiff"

val df = rastersInMemory
.withColumn("updated_tile", rst_asformat($"tile", lit(newFormat)))
.select(rst_format($"updated_tile").as("new_type"))

rastersInMemory
.createOrReplaceTempView("source")

noException should be thrownBy spark.sql(s"""
|select rst_asformat(tile, '$newFormat') from source
|""".stripMargin)

noException should be thrownBy rastersInMemory
.withColumn("tile", rst_updatetype($"tile", lit(newFormat)))
.select("tile")

val result = df.first.getString(0)

result shouldBe newFormat

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package com.databricks.labs.mosaic.expressions.raster

import com.databricks.labs.mosaic.core.geometry.api.JTS
import com.databricks.labs.mosaic.core.index.H3IndexSystem
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSessionGDAL

import scala.util.Try

class RST_AsFormatTest extends QueryTest with SharedSparkSessionGDAL with RST_AsFormatBehaviours {

private val noCodegen =
withSQLConf(
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString
) _

// Hotfix for SharedSparkSession afterAll cleanup.
override def afterAll(): Unit = Try(super.afterAll())

// These tests are not index system nor geometry API specific.
// Only testing one pairing is sufficient.
test("Testing RST_UpdateFormat with manual GDAL registration (H3, JTS).") {
noCodegen {
assume(System.getProperty("os.name") == "Linux")
behavior(H3IndexSystem, JTS)
}
}
}

Loading

0 comments on commit a479ab7

Please sign in to comment.