Skip to content

Commit

Permalink
Merge pull request #540 from databrickslabs/feature/tiller_functions
Browse files Browse the repository at this point in the history
Add ST_AsGeojsonTileAgg aggregator function.
  • Loading branch information
Milos Colic authored Mar 15, 2024
2 parents 2ec5d9d + 8b7f276 commit 8874c39
Show file tree
Hide file tree
Showing 9 changed files with 577 additions and 11 deletions.
51 changes: 51 additions & 0 deletions python/mosaic/api/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#######################

__all__ = [
"st_asgeojsontile_agg",
"st_asmvttile_agg",
"st_union_agg",
"grid_cell_union_agg",
"grid_cell_intersection_agg",
Expand Down Expand Up @@ -45,6 +47,55 @@ def st_intersection_agg(leftIndex: ColumnOrName, rightIndex: ColumnOrName) -> Co
)


def st_asgeojsontile_agg(geom: ColumnOrName, attributes: ColumnOrName) -> Column:
"""
Returns the aggregated GeoJSON tile.
Parameters
----------
geom : Column
The geometry column to aggregate.
attributes : Column
The attributes column to aggregate.
Returns
-------
Column
The aggregated GeoJSON tile.
"""
return config.mosaic_context.invoke_function(
"st_asgeojsontile_agg",
pyspark_to_java_column(geom),
pyspark_to_java_column(attributes)
)


def st_asmvttile_agg(geom: ColumnOrName, attributes: ColumnOrName, zxyID: ColumnOrName) -> Column:
"""
Returns the aggregated MVT tile.
Parameters
----------
geom : Column
The geometry column to aggregate.
attributes : Column
The attributes column to aggregate.
zxyID : Column
The zxyID column to aggregate.
Returns
-------
Column
The aggregated MVT tile.
"""
return config.mosaic_context.invoke_function(
"st_asmvttile_agg",
pyspark_to_java_column(geom),
pyspark_to_java_column(attributes),
pyspark_to_java_column(zxyID)
)


def st_intersects_agg(leftIndex: ColumnOrName, rightIndex: ColumnOrName) -> Column:
"""
Tests if any `leftIndex` : `rightIndex` pairs intersect.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package com.databricks.labs.mosaic.expressions.geometry

import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI
import com.databricks.labs.mosaic.expressions.geometry.base.AsTileExpression
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import com.databricks.labs.mosaic.utils.PathUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.gdal.ogr._

import scala.collection.mutable

case class ST_AsGeojsonTileAgg(
geometryExpr: Expression,
attributesExpr: Expression,
expressionConfig: MosaicExpressionConfig,
mutableAggBufferOffset: Int,
inputAggBufferOffset: Int
) extends TypedImperativeAggregate[mutable.ArrayBuffer[Any]]
with BinaryLike[Expression]
with AsTileExpression {

val geometryAPI: GeometryAPI = GeometryAPI.apply(expressionConfig.getGeometryAPI)
override lazy val deterministic: Boolean = true
override val left: Expression = geometryExpr
override val right: Expression = attributesExpr
override val nullable: Boolean = false
override val dataType: DataType = StringType

override def prettyName: String = "st_asgeojsontile_agg"

private lazy val tupleType =
StructType(
StructField("geom", geometryExpr.dataType, nullable = false) ::
StructField("attrs", attributesExpr.dataType, nullable = false) :: Nil
)
private lazy val projection = UnsafeProjection.create(Array[DataType](ArrayType(elementType = tupleType, containsNull = false)))
private lazy val row = new UnsafeRow(2)

override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty

def update(buffer: mutable.ArrayBuffer[Any], input: InternalRow): mutable.ArrayBuffer[Any] = {
val geom = geometryExpr.eval(input)
val attrs = attributesExpr.eval(input)
val value = InternalRow.fromSeq(Seq(geom, attrs))
buffer += InternalRow.copyValue(value)
buffer
}

def merge(buffer: mutable.ArrayBuffer[Any], input: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = {
buffer ++= input
}

override def eval(buffer: mutable.ArrayBuffer[Any]): Any = {
ogr.RegisterAll()
val driver = ogr.GetDriverByName("GeoJSON")
val tmpName = PathUtils.createTmpFilePath("geojson")
val ds: DataSource = driver.CreateDataSource(tmpName)

val srs = getSRS(buffer.head, geometryExpr, geometryAPI)

val layer = createLayer(ds, srs, attributesExpr.dataType.asInstanceOf[StructType])

insertRows(buffer, layer, geometryExpr, geometryAPI, attributesExpr)

ds.FlushCache()
ds.delete()

val source = scala.io.Source.fromFile(tmpName)
val result = source.getLines().collect { case x => x }.mkString("\n")
UTF8String.fromString(result)
}

override def serialize(obj: mutable.ArrayBuffer[Any]): Array[Byte] = {
val array = new GenericArrayData(obj.toArray)
projection.apply(InternalRow.apply(array)).getBytes
}

override def deserialize(bytes: Array[Byte]): mutable.ArrayBuffer[Any] = {
val buffer = createAggregationBuffer()
row.pointTo(bytes, bytes.length)
row.getArray(0).foreach(tupleType, (_, x: Any) => buffer += x)
buffer
}

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): ST_AsGeojsonTileAgg =
copy(geometryExpr = newLeft, attributesExpr = newRight)

}

object ST_AsGeojsonTileAgg {

def registryExpressionInfo(db: Option[String]): ExpressionInfo =
new ExpressionInfo(
classOf[ST_AsGeojsonTileAgg].getCanonicalName,
db.orNull,
"st_asgeojsontile_agg",
"""
| _FUNC_(geom, attrs) - Aggregate function that returns a GeoJSON string from a set of geometries and attributes.
""".stripMargin,
"",
"""
| Examples:
| > SELECT _FUNC_(a, b) FROM table GROUP BY tile_id;
| {"type":"FeatureCollection","features":[{"type":"Feature","geometry":{"type":"Point","coordinates":[1.0,1.0]},"properties":{"name":"a"}},{"type":"Feature","geometry":{"type":"Point","coordinates":[2.0,2.0]},"properties":{"name":"b"}}]}
| {"type":"FeatureCollection","features":[{"type":"Feature","geometry":{"type":"Point","coordinates":[3.0,3.0]},"properties":{"name":"c"}},{"type":"Feature","geometry":{"type":"Point","coordinates":[4.0,4.0]},"properties":{"name":"d"}}]}
| """.stripMargin,
"",
"agg_funcs",
"1.0",
"",
"built-in"
)

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package com.databricks.labs.mosaic.expressions.geometry

import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI
import com.databricks.labs.mosaic.expressions.geometry.base.AsTileExpression
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import com.databricks.labs.mosaic.utils.{PathUtils, SysUtils}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.trees.TernaryLike
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
import org.gdal.ogr._

import java.nio.file.{Files, Paths}
import scala.collection.mutable

case class ST_AsMVTTileAgg(
geometryExpr: Expression,
attributesExpr: Expression,
zxyIDExpr: Expression,
expressionConfig: MosaicExpressionConfig,
mutableAggBufferOffset: Int,
inputAggBufferOffset: Int
) extends TypedImperativeAggregate[mutable.ArrayBuffer[Any]]
with TernaryLike[Expression]
with AsTileExpression {

val geometryAPI: GeometryAPI = GeometryAPI.apply(expressionConfig.getGeometryAPI)
override lazy val deterministic: Boolean = true
override val first: Expression = geometryExpr
override val second: Expression = attributesExpr
override val third: Expression = zxyIDExpr
override val nullable: Boolean = false
override val dataType: DataType = BinaryType

override def prettyName: String = "st_asmvttile_agg"

// The tiling scheme for the MVT: https://gdal.org/drivers/vector/mvt.html
private val tilingScheme3857 = "EPSG:3857,-20037508.343,20037508.343,40075016.686"
private val tilingScheme4326 = "EPSG:4326,-180,180,360"

private lazy val tupleType =
StructType(
StructField("geom", geometryExpr.dataType, nullable = false) ::
StructField("attrs", attributesExpr.dataType, nullable = false) ::
StructField("zxyID", zxyIDExpr.dataType, nullable = false) ::
Nil
)
private lazy val projection = UnsafeProjection.create(Array[DataType](ArrayType(elementType = tupleType, containsNull = false)))
private lazy val row = new UnsafeRow(2)

override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty

def update(buffer: mutable.ArrayBuffer[Any], input: InternalRow): mutable.ArrayBuffer[Any] = {
val geom = geometryExpr.eval(input)
val attrs = attributesExpr.eval(input)
val zxyID = zxyIDExpr.eval(input)
val value = InternalRow.fromSeq(Seq(geom, attrs, zxyID))
buffer += InternalRow.copyValue(value)
buffer
}

def merge(buffer: mutable.ArrayBuffer[Any], input: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = {
buffer ++= input
}

override def eval(buffer: mutable.ArrayBuffer[Any]): Any = {
ogr.RegisterAll()
// We assume all zxyIDs are the same for all the rows in the buffer
val zxyID = buffer.head.asInstanceOf[InternalRow].get(2, zxyIDExpr.dataType).toString
val zoom = zxyID.split("/")(0).toInt
val driver = ogr.GetDriverByName("MVT")
val tmpName = PathUtils.createTmpFilePath("mvt")

val srs = getSRS(buffer.head, geometryExpr, geometryAPI)
val tilingScheme = srs.GetAttrValue("PROJCS", 0) match {
case "WGS 84 / Pseudo-Mercator" => tilingScheme3857
case "WGS 84" => tilingScheme4326
case _ => throw new Error(s"Unsupported SRS: ${srs.GetAttrValue("PROJCS", 0)}")
}

val createOptions = new java.util.Vector[String]()
createOptions.add("NAME=mvttile")
createOptions.add("TYPE=baselayer")
createOptions.add(s"MINZOOM=$zoom")
createOptions.add(s"MAXZOOM=$zoom")
createOptions.add(s"TILING_SCHEME=$tilingScheme")

val ds: DataSource = driver.CreateDataSource(tmpName, createOptions)

val layer = createLayer(ds, srs, attributesExpr.dataType.asInstanceOf[StructType])

insertRows(buffer, layer, geometryExpr, geometryAPI, attributesExpr)

ds.FlushCache()
ds.delete()

val tiles = SysUtils
.runCommand(s"ls $tmpName")
._1
.split("\n")
.filterNot(_.endsWith(".json"))
.flatMap(z =>
SysUtils
.runCommand(s"ls $tmpName/$z")
._1
.split("\n")
.flatMap(x =>
SysUtils
.runCommand(s"ls $tmpName/$z/$x")
._1
.split("\n")
.map(y => s"$tmpName/$z/$x/$y")
)
)

Files.readAllBytes(Paths.get(tiles.head))

}

override def serialize(obj: mutable.ArrayBuffer[Any]): Array[Byte] = {
val array = new GenericArrayData(obj.toArray)
projection.apply(InternalRow.apply(array)).getBytes
}

override def deserialize(bytes: Array[Byte]): mutable.ArrayBuffer[Any] = {
val buffer = createAggregationBuffer()
row.pointTo(bytes, bytes.length)
row.getArray(0).foreach(tupleType, (_, x: Any) => buffer += x)
buffer
}

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): ST_AsMVTTileAgg =
copy(geometryExpr = newFirst, attributesExpr = newSecond, zxyIDExpr = newThird)

}

object ST_AsMVTTileAgg {

def registryExpressionInfo(db: Option[String]): ExpressionInfo =
new ExpressionInfo(
classOf[ST_AsMVTTileAgg].getCanonicalName,
db.orNull,
"st_asmvttile_agg",
"""
| _FUNC_(geom, attrs) - Returns a Mapbox Vector Tile (MVT) as a binary.
""".stripMargin,
"",
"""
| Examples:
| > SELECT st_asmvttile_agg(geom, attrs) FROM table;
| 0x1A2B3C4D5E6F
| 0x1A2B3C4D5E6F
""".stripMargin,
"",
"agg_funcs",
"1.0",
"",
"built-in"
)

}
Loading

0 comments on commit 8874c39

Please sign in to comment.