Skip to content

Commit

Permalink
Merge pull request #512 from databrickslabs/feature/fix_raster_to_grid
Browse files Browse the repository at this point in the history
Feature/fix raster to grid
  • Loading branch information
Milos Colic authored Mar 6, 2024
2 parents c142240 + c539615 commit 8701525
Show file tree
Hide file tree
Showing 153 changed files with 3,765 additions and 666 deletions.
2 changes: 1 addition & 1 deletion .github/actions/scala_build/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ runs:
sudo apt-get update -y
# - install natives
sudo apt-get install -y unixodbc libcurl3-gnutls libsnappy-dev libopenjp2-7
sudo apt-get install -y gdal-bin libgdal-dev python3-numpy python3-gdal
sudo apt-get install -y gdal-bin libgdal-dev python3-numpy python3-gdal zip unzip
# - install pip libs
pip install --upgrade pip
pip install gdal==${{ matrix.gdal }}
Expand Down
4 changes: 2 additions & 2 deletions R/sparkR-mosaic/sparkrMosaic/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Description: This package extends SparkR to bring the Databricks Mosaic for geos
License: Databricks
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
Collate:
'enableGDAL.R'
'enableMosaic.R'
Expand All @@ -20,4 +20,4 @@ Imports:
Suggests:
testthat (>= 3.0.0),
readr (>= 2.1.5)
Config/testthat/edition: 3
Config/testthat/edition: 3
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ test_that("mosaic can read single-band GeoTiff", {
expect_equal(row$srid, 0)
expect_equal(row$bandCount, 1)
expect_equal(row$metadata[[1]]$LONGNAME, "MODIS/Terra+Aqua BRDF/Albedo Nadir BRDF-Adjusted Ref Daily L3 Global - 500m")
expect_equal(row$tile[[1]]$driver, "GTiff")
expect_equal(row$tile[[1]]$metadata$driver, "GTiff")

})

Expand Down Expand Up @@ -61,7 +61,7 @@ test_that("raster flatmap functions behave as intended", {
tessellate_sdf <- withColumn(tessellate_sdf, "rst_tessellate", rst_tessellate(column("tile"), lit(3L)))

expect_no_error(write.df(tessellate_sdf, source = "noop", mode = "overwrite"))
expect_equal(nrow(tessellate_sdf), 66)
expect_equal(nrow(tessellate_sdf), 63)

overlap_sdf <- generate_singleband_raster_df()
overlap_sdf <- withColumn(overlap_sdf, "rst_to_overlapping_tiles", rst_to_overlapping_tiles(column("tile"), lit(200L), lit(200L), lit(10L)))
Expand Down Expand Up @@ -117,7 +117,7 @@ test_that("the tessellate-join-clip-merge flow works on NetCDF files", {
raster_sdf <- read.df(
path = "sparkrMosaic/tests/testthat/data/prAdjust_day_HadGEM2-CC_SMHI-DBSrev930-GFD-1981-2010-postproc_rcp45_r1i1p1_20201201-20201231.nc",
source = "gdal",
raster.read.strategy = "retile_on_read"
raster.read.strategy = "in_memory"
)

raster_sdf <- withColumn(raster_sdf, "tile", rst_separatebands(column("tile")))
Expand Down
8 changes: 6 additions & 2 deletions R/sparkR-mosaic/tests.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ print("Looking for mosaic jar in")
mosaic_jar_path <- paste0(staging_dir, mosaic_jar)
print(mosaic_jar_path)

pwd <- getwd()
spark <- sparkR.session(
master = "local[*]"
,sparkJars = mosaic_jar_path
,sparkJars = mosaic_jar_path,
sparkConfig = list(
spark.databricks.labs.mosaic.raster.tmp.prefix = paste0(pwd, "/mosaic_tmp", sep="")
,spark.databricks.labs.mosaic.raster.checkpoint = paste0(pwd, "/mosaic_checkpoint", sep="")
)
)

enableMosaic()

testthat::test_local(path="./sparkrMosaic")
4 changes: 2 additions & 2 deletions R/sparklyr-mosaic/sparklyrMosaic/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Description: This package extends sparklyr to bring the Databricks Mosaic for ge
License: Databricks
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
Collate:
'enableGDAL.R'
'enableMosaic.R'
Expand All @@ -20,4 +20,4 @@ Suggests:
testthat (>= 3.0.0),
sparklyr.nested (>= 0.0.4),
readr (>= 2.1.5)
Config/testthat/edition: 3
Config/testthat/edition: 3
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ test_that("mosaic can read single-band GeoTiff", {
expect_equal(row$srid, 0)
expect_equal(row$bandCount, 1)
expect_equal(row$metadata[[1]]$LONGNAME, "MODIS/Terra+Aqua BRDF/Albedo Nadir BRDF-Adjusted Ref Daily L3 Global - 500m")
expect_equal(row$tile[[1]]$driver, "GTiff")
expect_equal(row$tile[[1]]$metadata$driver, "GTiff")

})

Expand Down Expand Up @@ -90,7 +90,7 @@ test_that("raster flatmap functions behave as intended", {
mutate(rst_tessellate = rst_tessellate(tile, 3L))

expect_no_error(spark_write_source(tessellate_sdf, "noop", mode = "overwrite"))
expect_equal(sdf_nrow(tessellate_sdf), 66)
expect_equal(sdf_nrow(tessellate_sdf), 63)

overlap_sdf <- generate_singleband_raster_df() %>%
mutate(rst_to_overlapping_tiles = rst_to_overlapping_tiles(tile, 200L, 200L, 10L))
Expand Down Expand Up @@ -157,7 +157,7 @@ test_that("the tessellate-join-clip-merge flow works on NetCDF files", {
name = "raster_raw",
source = "gdal",
path = "data/prAdjust_day_HadGEM2-CC_SMHI-DBSrev930-GFD-1981-2010-postproc_rcp45_r1i1p1_20201201-20201231.nc",
options = list("raster.read.strategy" = "retile_on_read")
options = list("raster.read.strategy" = "in_memory")
) %>%
mutate(tile = rst_separatebands(tile)) %>%
sdf_register("raster")
Expand Down
2 changes: 2 additions & 0 deletions R/sparklyr-mosaic/tests.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ print(paste("Looking for mosaic jar in", mosaic_jar_path))

config <- sparklyr::spark_config()
config$`sparklyr.jars.default` <- c(mosaic_jar_path)
config$`spark.databricks.labs.mosaic.raster.tmp.prefix` <- paste0(getwd(), "/mosaic_tmp", sep="")
config$`spark.databricks.labs.mosaic.raster.checkpoint` <- paste0(getwd(), "/mosaic_checkpoint", sep="")

sc <- spark_connect(master="local[*]", config=config)
enableMosaic(sc)
Expand Down
118 changes: 116 additions & 2 deletions python/mosaic/api/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pyspark.sql.functions import lit
from typing import Any


#######################
# Raster functions #
#######################
Expand All @@ -15,16 +14,19 @@
"rst_boundingbox",
"rst_clip",
"rst_combineavg",
"rst_convolve",
"rst_derivedband",
"rst_frombands",
"rst_fromcontent",
"rst_fromfile",
"rst_filter",
"rst_georeference",
"rst_getnodata",
"rst_getsubdataset",
"rst_height",
"rst_initnodata",
"rst_isempty",
"rst_maketiles",
"rst_mapalgebra",
"rst_memsize",
"rst_merge",
Expand Down Expand Up @@ -55,6 +57,7 @@
"rst_subdivide",
"rst_summary",
"rst_tessellate",
"rst_transform",
"rst_to_overlapping_tiles",
"rst_tryopen",
"rst_upperleftx",
Expand Down Expand Up @@ -156,6 +159,32 @@ def rst_combineavg(raster_tiles: ColumnOrName) -> Column:
)


def rst_convolve(raster_tile: ColumnOrName, kernel: ColumnOrName) -> Column:
"""
Applies a convolution filter to the raster.
The result is Mosaic raster tile struct column to the filtered raster.
The result is stored in the checkpoint directory.
Parameters
----------
raster_tile : Column (RasterTileType)
Mosaic raster tile struct column.
kernel : Column (ArrayType(ArrayType(DoubleType)))
The kernel to apply to the raster.
Returns
-------
Column (RasterTileType)
Mosaic raster tile struct column.
"""
return config.mosaic_context.invoke_function(
"rst_convolve",
pyspark_to_java_column(raster_tile),
pyspark_to_java_column(kernel),
)


def rst_derivedband(
raster_tile: ColumnOrName, python_func: ColumnOrName, func_name: ColumnOrName
) -> Column:
Expand Down Expand Up @@ -316,6 +345,43 @@ def rst_isempty(raster_tile: ColumnOrName) -> Column:
)


def rst_maketiles(input: ColumnOrName, driver: Any = "no_driver", size_in_mb: Any = -1,
with_checkpoint: Any = False) -> Column:
"""
Tiles the raster into tiles of the given size.
:param input: If the raster is stored on disc, the path
to the raster is provided. If the raster is stored in memory, the bytes of
the raster are provided.
:param driver: The driver to use for reading the raster. If not specified, the driver is
inferred from the file extension. If the input is a byte array, the driver
has to be specified.
:param size_in_mb: The size of the tiles in MB. If set to -1, the file is loaded and returned
as a single tile. If set to 0, the file is loaded and subdivided into
tiles of size 64MB. If set to a positive value, the file is loaded and
subdivided into tiles of the specified size. If the file is too big to fit
in memory, it is subdivided into tiles of size 64MB.
:param with_checkpoint: If set to true, the tiles are written to the checkpoint directory. If set
to false, the tiles are returned as a in-memory byte arrays.
:return: A collection of tiles of the raster.
"""
if type(size_in_mb) == int:
size_in_mb = lit(size_in_mb)

if type(with_checkpoint) == bool:
with_checkpoint = lit(with_checkpoint)

if type(driver) == str:
driver = lit(driver)

return config.mosaic_context.invoke_function(
"rst_maketiles",
pyspark_to_java_column(input),
pyspark_to_java_column(driver),
pyspark_to_java_column(size_in_mb),
pyspark_to_java_column(with_checkpoint),
)


def rst_mapalgebra(raster_tile: ColumnOrName, json_spec: ColumnOrName) -> Column:
"""
Parameters
Expand Down Expand Up @@ -630,7 +696,7 @@ def rst_rastertogridmin(raster_tile: ColumnOrName, resolution: ColumnOrName) ->


def rst_rastertoworldcoord(
raster_tile: ColumnOrName, x: ColumnOrName, y: ColumnOrName
raster_tile: ColumnOrName, x: ColumnOrName, y: ColumnOrName
) -> Column:
"""
Computes the world coordinates of the raster pixel at the given x and y coordinates.
Expand Down Expand Up @@ -997,6 +1063,32 @@ def rst_tessellate(raster_tile: ColumnOrName, resolution: ColumnOrName) -> Colum
)


def rst_transform(raster_tile: ColumnOrName, srid: ColumnOrName) -> Column:
"""
Transforms the raster to the given SRID.
The result is a Mosaic raster tile struct of the transformed raster.
The result is stored in the checkpoint directory.
Parameters
----------
raster_tile : Column (RasterTileType)
Mosaic raster tile struct column.
srid : Column (IntegerType)
EPSG authority code for the file's projection.
Returns
-------
Column (RasterTileType)
Mosaic raster tile struct column.
"""
return config.mosaic_context.invoke_function(
"rst_transform",
pyspark_to_java_column(raster_tile),
pyspark_to_java_column(srid),
)


def rst_fromcontent(
raster_bin: ColumnOrName, driver: ColumnOrName, size_in_mb: Any = -1
) -> Column:
Expand Down Expand Up @@ -1035,6 +1127,28 @@ def rst_fromfile(raster_path: ColumnOrName, size_in_mb: Any = -1) -> Column:
)


def rst_filter(raster_tile: ColumnOrName, kernel_size: Any, operation: Any) -> Column:
"""
Applies a filter to the raster.
:param raster_tile: Mosaic raster tile struct column.
:param kernel_size: The size of the kernel. Has to be odd.
:param operation: The operation to apply to the kernel.
:return: A new raster tile with the filter applied.
"""
if type(kernel_size) == int:
kernel_size = lit(kernel_size)

if type(operation) == str:
operation = lit(operation)

return config.mosaic_context.invoke_function(
"rst_filter",
pyspark_to_java_column(raster_tile),
pyspark_to_java_column(kernel_size),
pyspark_to_java_column(operation),
)


def rst_to_overlapping_tiles(
raster_tile: ColumnOrName,
width: ColumnOrName,
Expand Down
7 changes: 4 additions & 3 deletions python/test/test_raster_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_read_raster(self):
result.metadata["LONGNAME"],
"MODIS/Terra+Aqua BRDF/Albedo Nadir BRDF-Adjusted Ref Daily L3 Global - 500m",
)
self.assertEqual(result.tile["driver"], "GTiff")
self.assertEqual(result.tile["metadata"]["driver"], "GTiff")

def test_raster_scalar_functions(self):
result = (
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_raster_flatmap_functions(self):
)

tessellate_result.write.format("noop").mode("overwrite").save()
self.assertEqual(tessellate_result.count(), 66)
self.assertEqual(tessellate_result.count(), 63)

overlap_result = (
self.generate_singleband_raster_df()
Expand Down Expand Up @@ -187,11 +187,12 @@ def test_netcdf_load_tessellate_clip_merge(self):

df = (
self.spark.read.format("gdal")
.option("raster.read.strategy", "retile_on_read")
.option("raster.read.strategy", "in_memory")
.load(
"test/data/prAdjust_day_HadGEM2-CC_SMHI-DBSrev930-GFD-1981-2010-postproc_rcp45_r1i1p1_20201201-20201231.nc"
)
.select(api.rst_separatebands("tile").alias("tile"))
.repartition(self.spark.sparkContext.defaultParallelism)
.withColumn(
"timestep",
element_at(
Expand Down
4 changes: 1 addition & 3 deletions python/test/test_vector_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ def test_st_z(self):
.select(col("id").cast("double"))
.withColumn(
"points",
api.st_geomfromwkt(
concat(lit("POINT (9 9 "), "id", lit(")"))
),
api.st_geomfromwkt(concat(lit("POINT (9 9 "), "id", lit(")"))),
)
.withColumn("z", api.st_z("points"))
.collect()
Expand Down
10 changes: 10 additions & 0 deletions python/test/utils/spark_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ def setUpClass(cls) -> None:
if not os.path.exists(cls.library_location):
cls.library_location = f"{mosaic.__path__[0]}/lib/mosaic-{version('databricks-mosaic')}-SNAPSHOT-jar-with-dependencies.jar"

pwd_dir = os.getcwd()
tmp_dir = f"{pwd_dir}/mosaic_test/"
check_dir = f"{pwd_dir}/checkpoint"
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
if not os.path.exists(check_dir):
os.makedirs(check_dir)

cls.spark = (
SparkSession.builder.master("local[*]")
.config("spark.jars", cls.library_location)
Expand All @@ -33,6 +41,8 @@ def setUpClass(cls) -> None:
.getOrCreate()
)
cls.spark.conf.set("spark.databricks.labs.mosaic.jar.autoattach", "false")
cls.spark.conf.set("spark.databricks.labs.mosaic.raster.tmp.prefix", tmp_dir)
cls.spark.conf.set("spark.databricks.labs.mosaic.raster.checkpoint", check_dir)
cls.spark.sparkContext.setLogLevel("FATAL")

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ object FormatLookup {
"CAD" -> "dwg",
"CEOS" -> "ceos",
"COASP" -> "coasp",
"COG" -> "tif",
"COSAR" -> "cosar",
"CPG" -> "cpg",
"CSW" -> "csw",
Expand Down
Loading

0 comments on commit 8701525

Please sign in to comment.