diff --git a/python/mosaic/api/aggregators.py b/python/mosaic/api/aggregators.py index 291d7b534..87eaed84b 100644 --- a/python/mosaic/api/aggregators.py +++ b/python/mosaic/api/aggregators.py @@ -9,6 +9,8 @@ ####################### __all__ = [ + "st_asgeojsontile_agg", + "st_asmvttile_agg", "st_union_agg", "grid_cell_union_agg", "grid_cell_intersection_agg", @@ -45,6 +47,55 @@ def st_intersection_agg(leftIndex: ColumnOrName, rightIndex: ColumnOrName) -> Co ) +def st_asgeojsontile_agg(geom: ColumnOrName, attributes: ColumnOrName) -> Column: + """ + Returns the aggregated GeoJSON tile. + + Parameters + ---------- + geom : Column + The geometry column to aggregate. + attributes : Column + The attributes column to aggregate. + + Returns + ------- + Column + The aggregated GeoJSON tile. + """ + return config.mosaic_context.invoke_function( + "st_asgeojsontile_agg", + pyspark_to_java_column(geom), + pyspark_to_java_column(attributes) + ) + + +def st_asmvttile_agg(geom: ColumnOrName, attributes: ColumnOrName, zxyID: ColumnOrName) -> Column: + """ + Returns the aggregated MVT tile. + + Parameters + ---------- + geom : Column + The geometry column to aggregate. + attributes : Column + The attributes column to aggregate. + zxyID : Column + The zxyID column to aggregate. + + Returns + ------- + Column + The aggregated MVT tile. + """ + return config.mosaic_context.invoke_function( + "st_asmvttile_agg", + pyspark_to_java_column(geom), + pyspark_to_java_column(attributes), + pyspark_to_java_column(zxyID) + ) + + def st_intersects_agg(leftIndex: ColumnOrName, rightIndex: ColumnOrName) -> Column: """ Tests if any `leftIndex` : `rightIndex` pairs intersect. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_AsGeojsonTileAgg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_AsGeojsonTileAgg.scala new file mode 100644 index 000000000..3c02179b4 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_AsGeojsonTileAgg.scala @@ -0,0 +1,126 @@ +package com.databricks.labs.mosaic.expressions.geometry + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.expressions.geometry.base.AsTileExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import com.databricks.labs.mosaic.utils.PathUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.trees.BinaryLike +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.gdal.ogr._ + +import scala.collection.mutable + +case class ST_AsGeojsonTileAgg( + geometryExpr: Expression, + attributesExpr: Expression, + expressionConfig: MosaicExpressionConfig, + mutableAggBufferOffset: Int, + inputAggBufferOffset: Int +) extends TypedImperativeAggregate[mutable.ArrayBuffer[Any]] + with BinaryLike[Expression] + with AsTileExpression { + + val geometryAPI: GeometryAPI = GeometryAPI.apply(expressionConfig.getGeometryAPI) + override lazy val deterministic: Boolean = true + override val left: Expression = geometryExpr + override val right: Expression = attributesExpr + override val nullable: Boolean = false + override val dataType: DataType = StringType + + override def prettyName: String = "st_asgeojsontile_agg" + + private lazy val tupleType = + StructType( + StructField("geom", geometryExpr.dataType, nullable = false) :: + StructField("attrs", attributesExpr.dataType, nullable = false) :: Nil + ) + private lazy val projection = UnsafeProjection.create(Array[DataType](ArrayType(elementType = tupleType, containsNull = false))) + private lazy val row = new UnsafeRow(2) + + override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty + + def update(buffer: mutable.ArrayBuffer[Any], input: InternalRow): mutable.ArrayBuffer[Any] = { + val geom = geometryExpr.eval(input) + val attrs = attributesExpr.eval(input) + val value = InternalRow.fromSeq(Seq(geom, attrs)) + buffer += InternalRow.copyValue(value) + buffer + } + + def merge(buffer: mutable.ArrayBuffer[Any], input: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = { + buffer ++= input + } + + override def eval(buffer: mutable.ArrayBuffer[Any]): Any = { + ogr.RegisterAll() + val driver = ogr.GetDriverByName("GeoJSON") + val tmpName = PathUtils.createTmpFilePath("geojson") + val ds: DataSource = driver.CreateDataSource(tmpName) + + val srs = getSRS(buffer.head, geometryExpr, geometryAPI) + + val layer = createLayer(ds, srs, attributesExpr.dataType.asInstanceOf[StructType]) + + insertRows(buffer, layer, geometryExpr, geometryAPI, attributesExpr) + + ds.FlushCache() + ds.delete() + + val source = scala.io.Source.fromFile(tmpName) + val result = source.getLines().collect { case x => x }.mkString("\n") + UTF8String.fromString(result) + } + + override def serialize(obj: mutable.ArrayBuffer[Any]): Array[Byte] = { + val array = new GenericArrayData(obj.toArray) + projection.apply(InternalRow.apply(array)).getBytes + } + + override def deserialize(bytes: Array[Byte]): mutable.ArrayBuffer[Any] = { + val buffer = createAggregationBuffer() + row.pointTo(bytes, bytes.length) + row.getArray(0).foreach(tupleType, (_, x: Any) => buffer += x) + buffer + } + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): ST_AsGeojsonTileAgg = + copy(geometryExpr = newLeft, attributesExpr = newRight) + +} + +object ST_AsGeojsonTileAgg { + + def registryExpressionInfo(db: Option[String]): ExpressionInfo = + new ExpressionInfo( + classOf[ST_AsGeojsonTileAgg].getCanonicalName, + db.orNull, + "st_asgeojsontile_agg", + """ + | _FUNC_(geom, attrs) - Aggregate function that returns a GeoJSON string from a set of geometries and attributes. + """.stripMargin, + "", + """ + | Examples: + | > SELECT _FUNC_(a, b) FROM table GROUP BY tile_id; + | {"type":"FeatureCollection","features":[{"type":"Feature","geometry":{"type":"Point","coordinates":[1.0,1.0]},"properties":{"name":"a"}},{"type":"Feature","geometry":{"type":"Point","coordinates":[2.0,2.0]},"properties":{"name":"b"}}]} + | {"type":"FeatureCollection","features":[{"type":"Feature","geometry":{"type":"Point","coordinates":[3.0,3.0]},"properties":{"name":"c"}},{"type":"Feature","geometry":{"type":"Point","coordinates":[4.0,4.0]},"properties":{"name":"d"}}]} + | """.stripMargin, + "", + "agg_funcs", + "1.0", + "", + "built-in" + ) + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_AsMVTTileAgg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_AsMVTTileAgg.scala new file mode 100644 index 000000000..34313b3a7 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_AsMVTTileAgg.scala @@ -0,0 +1,169 @@ +package com.databricks.labs.mosaic.expressions.geometry + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.expressions.geometry.base.AsTileExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import com.databricks.labs.mosaic.utils.{PathUtils, SysUtils} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.trees.TernaryLike +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.types._ +import org.gdal.ogr._ + +import java.nio.file.{Files, Paths} +import scala.collection.mutable + +case class ST_AsMVTTileAgg( + geometryExpr: Expression, + attributesExpr: Expression, + zxyIDExpr: Expression, + expressionConfig: MosaicExpressionConfig, + mutableAggBufferOffset: Int, + inputAggBufferOffset: Int +) extends TypedImperativeAggregate[mutable.ArrayBuffer[Any]] + with TernaryLike[Expression] + with AsTileExpression { + + val geometryAPI: GeometryAPI = GeometryAPI.apply(expressionConfig.getGeometryAPI) + override lazy val deterministic: Boolean = true + override val first: Expression = geometryExpr + override val second: Expression = attributesExpr + override val third: Expression = zxyIDExpr + override val nullable: Boolean = false + override val dataType: DataType = BinaryType + + override def prettyName: String = "st_asmvttile_agg" + + // The tiling scheme for the MVT: https://gdal.org/drivers/vector/mvt.html + private val tilingScheme3857 = "EPSG:3857,-20037508.343,20037508.343,40075016.686" + private val tilingScheme4326 = "EPSG:4326,-180,180,360" + + private lazy val tupleType = + StructType( + StructField("geom", geometryExpr.dataType, nullable = false) :: + StructField("attrs", attributesExpr.dataType, nullable = false) :: + StructField("zxyID", zxyIDExpr.dataType, nullable = false) :: + Nil + ) + private lazy val projection = UnsafeProjection.create(Array[DataType](ArrayType(elementType = tupleType, containsNull = false))) + private lazy val row = new UnsafeRow(2) + + override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty + + def update(buffer: mutable.ArrayBuffer[Any], input: InternalRow): mutable.ArrayBuffer[Any] = { + val geom = geometryExpr.eval(input) + val attrs = attributesExpr.eval(input) + val zxyID = zxyIDExpr.eval(input) + val value = InternalRow.fromSeq(Seq(geom, attrs, zxyID)) + buffer += InternalRow.copyValue(value) + buffer + } + + def merge(buffer: mutable.ArrayBuffer[Any], input: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = { + buffer ++= input + } + + override def eval(buffer: mutable.ArrayBuffer[Any]): Any = { + ogr.RegisterAll() + // We assume all zxyIDs are the same for all the rows in the buffer + val zxyID = buffer.head.asInstanceOf[InternalRow].get(2, zxyIDExpr.dataType).toString + val zoom = zxyID.split("/")(0).toInt + val driver = ogr.GetDriverByName("MVT") + val tmpName = PathUtils.createTmpFilePath("mvt") + + val srs = getSRS(buffer.head, geometryExpr, geometryAPI) + val tilingScheme = srs.GetAttrValue("PROJCS", 0) match { + case "WGS 84 / Pseudo-Mercator" => tilingScheme3857 + case "WGS 84" => tilingScheme4326 + case _ => throw new Error(s"Unsupported SRS: ${srs.GetAttrValue("PROJCS", 0)}") + } + + val createOptions = new java.util.Vector[String]() + createOptions.add("NAME=mvttile") + createOptions.add("TYPE=baselayer") + createOptions.add(s"MINZOOM=$zoom") + createOptions.add(s"MAXZOOM=$zoom") + createOptions.add(s"TILING_SCHEME=$tilingScheme") + + val ds: DataSource = driver.CreateDataSource(tmpName, createOptions) + + val layer = createLayer(ds, srs, attributesExpr.dataType.asInstanceOf[StructType]) + + insertRows(buffer, layer, geometryExpr, geometryAPI, attributesExpr) + + ds.FlushCache() + ds.delete() + + val tiles = SysUtils + .runCommand(s"ls $tmpName") + ._1 + .split("\n") + .filterNot(_.endsWith(".json")) + .flatMap(z => + SysUtils + .runCommand(s"ls $tmpName/$z") + ._1 + .split("\n") + .flatMap(x => + SysUtils + .runCommand(s"ls $tmpName/$z/$x") + ._1 + .split("\n") + .map(y => s"$tmpName/$z/$x/$y") + ) + ) + + Files.readAllBytes(Paths.get(tiles.head)) + + } + + override def serialize(obj: mutable.ArrayBuffer[Any]): Array[Byte] = { + val array = new GenericArrayData(obj.toArray) + projection.apply(InternalRow.apply(array)).getBytes + } + + override def deserialize(bytes: Array[Byte]): mutable.ArrayBuffer[Any] = { + val buffer = createAggregationBuffer() + row.pointTo(bytes, bytes.length) + row.getArray(0).foreach(tupleType, (_, x: Any) => buffer += x) + buffer + } + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): ST_AsMVTTileAgg = + copy(geometryExpr = newFirst, attributesExpr = newSecond, zxyIDExpr = newThird) + +} + +object ST_AsMVTTileAgg { + + def registryExpressionInfo(db: Option[String]): ExpressionInfo = + new ExpressionInfo( + classOf[ST_AsMVTTileAgg].getCanonicalName, + db.orNull, + "st_asmvttile_agg", + """ + | _FUNC_(geom, attrs) - Returns a Mapbox Vector Tile (MVT) as a binary. + """.stripMargin, + "", + """ + | Examples: + | > SELECT st_asmvttile_agg(geom, attrs) FROM table; + | 0x1A2B3C4D5E6F + | 0x1A2B3C4D5E6F + """.stripMargin, + "", + "agg_funcs", + "1.0", + "", + "built-in" + ) + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/base/AsTileExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/base/AsTileExpression.scala new file mode 100644 index 000000000..052f10c3b --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/base/AsTileExpression.scala @@ -0,0 +1,88 @@ +package com.databricks.labs.mosaic.expressions.geometry.base + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.gdal.ogr.{DataSource, FieldDefn, Geometry, Layer, Feature, ogrConstants} +import org.gdal.osr.SpatialReference + +import scala.collection.mutable + +trait AsTileExpression { + + def getSRS(firstRow: Any, geometryExpr: Expression, geometryAPI: GeometryAPI): SpatialReference = { + val firstGeomRaw = firstRow + .asInstanceOf[InternalRow] + .get(0, geometryExpr.dataType) + + val firstGeom = geometryAPI.geometry(firstGeomRaw, geometryExpr.dataType) + val srsOSR = firstGeom.getSpatialReferenceOSR + + val srs = new org.gdal.osr.SpatialReference() + if (srsOSR != null) { + srs.ImportFromWkt(srsOSR.ExportToWkt()) + } else { + srs.ImportFromEPSG(4326) + } + + srs + } + + def createLayer(ds: DataSource, srs: SpatialReference, schema: StructType): Layer = { + val layer = ds.CreateLayer("tiles", srs, ogrConstants.wkbUnknown) + for (field <- schema) { + val fieldDefn = field.dataType match { + case StringType => + val fieldDefn = new FieldDefn(field.name, ogrConstants.OFTString) + fieldDefn.SetWidth(255) + fieldDefn + case IntegerType => new FieldDefn(field.name, ogrConstants.OFTInteger) + case LongType => new FieldDefn(field.name, ogrConstants.OFTInteger64) + case FloatType => new FieldDefn(field.name, ogrConstants.OFTReal) + case DoubleType => new FieldDefn(field.name, ogrConstants.OFTReal) + case BooleanType => new FieldDefn(field.name, ogrConstants.OFTInteger) + case DateType => new FieldDefn(field.name, ogrConstants.OFTDate) + case TimestampType => new FieldDefn(field.name, ogrConstants.OFTDateTime) + case _ => throw new Error(s"Unsupported data type: ${field.dataType}") + } + layer.CreateField(fieldDefn) + } + layer + } + + def insertRows( + buffer: mutable.ArrayBuffer[Any], + layer: Layer, + geometryExpr: Expression, + geometryAPI: GeometryAPI, + attributesExpr: Expression + ): Unit = { + for (row <- buffer) { + val geom = row.asInstanceOf[InternalRow].get(0, geometryExpr.dataType) + val geomOgr = Geometry.CreateFromWkb(geometryAPI.geometry(geom, geometryExpr.dataType).toWKB) + val attrs = row.asInstanceOf[InternalRow].get(1, attributesExpr.dataType) + val feature = new Feature(layer.GetLayerDefn) + feature.SetGeometryDirectly(geomOgr) + var i = 0 + for (field <- attributesExpr.dataType.asInstanceOf[StructType]) { + val value = attrs.asInstanceOf[InternalRow].get(i, field.dataType) + field.dataType match { + case StringType => feature.SetField(field.name, value.asInstanceOf[UTF8String].toString) + case IntegerType => feature.SetField(field.name, value.asInstanceOf[Int]) + case LongType => feature.SetField(field.name, value.asInstanceOf[Long]) + case FloatType => feature.SetField(field.name, value.asInstanceOf[Float]) + case DoubleType => feature.SetField(field.name, value.asInstanceOf[Double]) + case BooleanType => feature.SetField(field.name, value.asInstanceOf[Boolean].toString) + case DateType => feature.SetField(field.name, value.asInstanceOf[java.sql.Date].toString) + case TimestampType => feature.SetField(field.name, value.asInstanceOf[java.sql.Timestamp].toString) + case _ => throw new Error(s"Unsupported data type: ${field.dataType}") + } + } + layer.CreateFeature(feature) + i += 1 + } + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala index 4f73e1b00..d6ef4c1cb 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.{LongType, StringType} import org.apache.spark.sql.{Column, SparkSession} import scala.reflect.runtime.universe +import scala.util.Try //noinspection DuplicatedCode class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends Serializable with Logging { @@ -321,6 +322,16 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends mosaicRegistry.registerExpression[RST_WorldToRasterCoordY](expressionConfig) /** Aggregators */ + registry.registerFunction( + FunctionIdentifier("st_asgeojsontile_agg", database), + ST_AsGeojsonTileAgg.registryExpressionInfo(database), + (exprs: Seq[Expression]) => ST_AsGeojsonTileAgg(exprs(0), exprs(1), expressionConfig, 0, 0) + ) + registry.registerFunction( + FunctionIdentifier("st_asmvttile_agg", database), + ST_AsMVTTileAgg.registryExpressionInfo(database), + (exprs: Seq[Expression]) => ST_AsMVTTileAgg(exprs(0), exprs(1), exprs(2), expressionConfig, 0, 0) + ) registry.registerFunction( FunctionIdentifier("st_intersection_aggregate", database), ST_IntersectionAgg.registryExpressionInfo(database), @@ -666,8 +677,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends ColumnAdapter(RST_BandMetaData(raster.expr, lit(band).expr, expressionConfig)) def rst_boundingbox(raster: Column): Column = ColumnAdapter(RST_BoundingBox(raster.expr, expressionConfig)) def rst_clip(raster: Column, geometry: Column): Column = ColumnAdapter(RST_Clip(raster.expr, geometry.expr, expressionConfig)) - def rst_convolve(raster: Column, kernel: Column): Column = - ColumnAdapter(RST_Convolve(raster.expr, kernel.expr, expressionConfig)) + def rst_convolve(raster: Column, kernel: Column): Column = ColumnAdapter(RST_Convolve(raster.expr, kernel.expr, expressionConfig)) def rst_pixelcount(raster: Column): Column = ColumnAdapter(RST_PixelCount(raster.expr, expressionConfig)) def rst_combineavg(rasterArray: Column): Column = ColumnAdapter(RST_CombineAvg(rasterArray.expr, expressionConfig)) def rst_derivedband(raster: Column, pythonFunc: Column, funcName: Column): Column = @@ -736,8 +746,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends ColumnAdapter(RST_ReTile(raster.expr, tileWidth.expr, tileHeight.expr, expressionConfig)) def rst_retile(raster: Column, tileWidth: Int, tileHeight: Int): Column = ColumnAdapter(RST_ReTile(raster.expr, lit(tileWidth).expr, lit(tileHeight).expr, expressionConfig)) - def rst_separatebands(raster: Column): Column = - ColumnAdapter(RST_SeparateBands(raster.expr, expressionConfig)) + def rst_separatebands(raster: Column): Column = ColumnAdapter(RST_SeparateBands(raster.expr, expressionConfig)) def rst_rotation(raster: Column): Column = ColumnAdapter(RST_Rotation(raster.expr, expressionConfig)) def rst_scalex(raster: Column): Column = ColumnAdapter(RST_ScaleX(raster.expr, expressionConfig)) def rst_scaley(raster: Column): Column = ColumnAdapter(RST_ScaleY(raster.expr, expressionConfig)) @@ -752,8 +761,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends def rst_summary(raster: Column): Column = ColumnAdapter(RST_Summary(raster.expr, expressionConfig)) def rst_tessellate(raster: Column, resolution: Column): Column = ColumnAdapter(RST_Tessellate(raster.expr, resolution.expr, expressionConfig)) - def rst_transform(raster: Column, srid: Column): Column = - ColumnAdapter(RST_Transform(raster.expr, srid.expr, expressionConfig)) + def rst_transform(raster: Column, srid: Column): Column = ColumnAdapter(RST_Transform(raster.expr, srid.expr, expressionConfig)) def rst_tessellate(raster: Column, resolution: Int): Column = ColumnAdapter(RST_Tessellate(raster.expr, lit(resolution).expr, expressionConfig)) def rst_fromcontent(raster: Column, driver: Column): Column = @@ -795,14 +803,21 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends ColumnAdapter(RST_WorldToRasterCoordY(raster.expr, lit(x).expr, lit(y).expr, expressionConfig)) /** Aggregators */ + + def st_asgeojsontile_agg(geom: Column, attributes: Column): Column = + ColumnAdapter(ST_AsGeojsonTileAgg(geom.expr, attributes.expr, expressionConfig, 0, 0).toAggregateExpression(isDistinct = false)) + def st_asmvttile_agg(geom: Column, attributes: Column, zxyID: Column): Column = + ColumnAdapter( + ST_AsMVTTileAgg(geom.expr, attributes.expr, zxyID.expr, expressionConfig, 0, 0).toAggregateExpression(isDistinct = false) + ) def st_intersects_agg(leftIndex: Column, rightIndex: Column): Column = ColumnAdapter( - ST_IntersectsAgg(leftIndex.expr, rightIndex.expr, geometryAPI.name).toAggregateExpression(isDistinct = false) + ST_IntersectsAgg(leftIndex.expr, rightIndex.expr, geometryAPI.name).toAggregateExpression(isDistinct = false) ) def st_intersection_agg(leftIndex: Column, rightIndex: Column): Column = ColumnAdapter( - ST_IntersectionAgg(leftIndex.expr, rightIndex.expr, geometryAPI.name, indexSystem, 0, 0) - .toAggregateExpression(isDistinct = false) + ST_IntersectionAgg(leftIndex.expr, rightIndex.expr, geometryAPI.name, indexSystem, 0, 0) + .toAggregateExpression(isDistinct = false) ) def st_union_agg(geom: Column): Column = ColumnAdapter(ST_UnionAgg(geom.expr, geometryAPI.name).toAggregateExpression(isDistinct = false)) @@ -1031,10 +1046,10 @@ object MosaicContext extends Logging { val mosaicVersion: String = "0.4.0" private var instance: Option[MosaicContext] = None - + def tmpDir(mosaicConfig: MosaicExpressionConfig): String = { if (_tmpDir == "" || mosaicConfig != null) { - val prefix = mosaicConfig.getTmpPrefix + val prefix = Try { mosaicConfig.getTmpPrefix }.toOption.getOrElse("") _tmpDir = FileUtils.createMosaicTempDir(prefix) _tmpDir } else { diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_AsGeoJSONTileAggBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_AsGeoJSONTileAggBehaviors.scala new file mode 100644 index 000000000..a77b90df4 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_AsGeoJSONTileAggBehaviors.scala @@ -0,0 +1,41 @@ +package com.databricks.labs.mosaic.expressions.geometry + +import com.databricks.labs.mosaic.functions.MosaicContext +import com.databricks.labs.mosaic.test.{MosaicSpatialQueryTest, mocks} +import org.apache.spark.sql.functions._ +import org.gdal.ogr.ogr +import org.scalatest.matchers.must.Matchers.noException +import org.scalatest.matchers.should.Matchers.{be, convertToAnyShouldWrapper} + +trait ST_AsGeoJSONTileAggBehaviors extends MosaicSpatialQueryTest { + + def behavior(mosaicContext: MosaicContext): Unit = { + spark.sparkContext.setLogLevel("FATAL") + val mc = mosaicContext + import mc.functions._ + val sc = spark + import sc.implicits._ + mc.register(spark) + + val result = mocks + .getWKTRowsDf(mc.getIndexSystem) + .select(st_centroid($"wkt").alias("centroid")) + .withColumn("ids", array((0 until 30).map(_ => rand() * 1000): _*)) // add some random data + .select(explode($"ids").alias("id"), st_translate($"centroid", rand(), rand()).alias("centroid")) + .withColumn("index_id", grid_pointascellid($"centroid", lit(6))) + .groupBy("index_id") + .agg(st_asgeojsontile_agg($"centroid", struct($"id")).alias("geojson")) + .collect() + + val row = result.head + + val payload = row.getAs[String]("geojson") + + val ds = ogr.GetDriverByName("GeoJSON").Open(payload) + + ds.GetLayerCount should be(1L) + ds.GetLayer(0).GetFeatureCount should be > 0L + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_AsGeoJSONTileAggTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_AsGeoJSONTileAggTest.scala new file mode 100644 index 000000000..7414a0ca4 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_AsGeoJSONTileAggTest.scala @@ -0,0 +1,10 @@ +package com.databricks.labs.mosaic.expressions.geometry + +import com.databricks.labs.mosaic.test.MosaicSpatialQueryTest +import org.apache.spark.sql.test.SharedSparkSession + +class ST_AsGeoJSONTileAggTest extends MosaicSpatialQueryTest with SharedSparkSession with ST_AsGeoJSONTileAggBehaviors { + + testAllNoCodegen("Testing stAsGeoJSONTileAgg") { behavior } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_AsMVTTileAggBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_AsMVTTileAggBehaviors.scala new file mode 100644 index 000000000..2dbb0fbcf --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_AsMVTTileAggBehaviors.scala @@ -0,0 +1,56 @@ +package com.databricks.labs.mosaic.expressions.geometry + +import com.databricks.labs.mosaic.core.index.{BNGIndexSystem, H3IndexSystem} +import com.databricks.labs.mosaic.functions.MosaicContext +import com.databricks.labs.mosaic.test.{MosaicSpatialQueryTest, mocks} +import com.databricks.labs.mosaic.utils.SysUtils +import org.apache.spark.sql.functions._ +import org.gdal.ogr.ogr +import org.scalatest.matchers.must.Matchers.noException +import org.scalatest.matchers.should.Matchers.{be, convertToAnyShouldWrapper} + +import java.nio.file.{Files, Paths} + +trait ST_AsMVTTileAggBehaviors extends MosaicSpatialQueryTest { + + def behavior(mosaicContext: MosaicContext): Unit = { + spark.sparkContext.setLogLevel("FATAL") + val mc = mosaicContext + import mc.functions._ + val sc = spark + import sc.implicits._ + mc.register(spark) + + val srcSR = mc.getIndexSystem match { + case H3IndexSystem => 4326 + case BNGIndexSystem => 27700 + case _ => 4326 + } + + val result = mocks + .getWKTRowsDf(mc.getIndexSystem) + .select(st_centroid($"wkt").alias("centroid")) + .withColumn("ids", array((0 until 30).map(_ => rand() * 1000): _*)) // add some random data + .select(explode($"ids").alias("id"), st_translate($"centroid", rand(), rand()).alias("centroid")) + .withColumn("index_id", grid_pointascellid($"centroid", lit(6))) + .withColumn("centroid", as_json(st_asgeojson($"centroid"))) + .withColumn("centroid", st_updatesrid($"centroid", lit(srcSR), lit(3857))) + .groupBy("index_id") + .agg(st_asmvttile_agg($"centroid", struct($"id"), lit("5/21/9")).alias("mvt")) + .collect() + + val row = result.head + + val payload = row.getAs[Array[Byte]]("mvt") + + + val tmpFile = Files.createTempFile(Paths.get("/tmp"), "mvt", ".pbf") + Files.write(tmpFile, payload) + + val ds = ogr.GetDriverByName("MVT").Open(tmpFile.toAbsolutePath.toString) + + ds.GetLayerCount should be(1L) + ds.GetLayer(0).GetFeatureCount should be > 0L + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_AsMVTTileAggTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_AsMVTTileAggTest.scala new file mode 100644 index 000000000..bba873e29 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_AsMVTTileAggTest.scala @@ -0,0 +1,10 @@ +package com.databricks.labs.mosaic.expressions.geometry + +import com.databricks.labs.mosaic.test.MosaicSpatialQueryTest +import org.apache.spark.sql.test.SharedSparkSession + +class ST_AsMVTTileAggTest extends MosaicSpatialQueryTest with SharedSparkSession with ST_AsMVTTileAggBehaviors { + + testAllNoCodegen("Testing stAsMVTTileAgg") { behavior } + +}