UserDefinedAggregateFunction — Contract for User-Defined Untyped Aggregate Functions (UDAFs)
UserDefinedAggregateFunction
is the contract to define user-defined aggregate functions (UDAFs).
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
// Custom UDAF to count rows import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types.{DataType, LongType, StructType} class MyCountUDAF extends UserDefinedAggregateFunction { override def inputSchema: StructType = { new StructType().add("id", LongType, nullable = true) } override def bufferSchema: StructType = { new StructType().add("count", LongType, nullable = true) } override def dataType: DataType = LongType override def deterministic: Boolean = true override def initialize(buffer: MutableAggregationBuffer): Unit = { println(s">>> initialize (buffer: $buffer)") // NOTE: Scala's update used under the covers buffer(0) = 0L } override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { println(s">>> update (buffer: $buffer -> input: $input)") buffer(0) = buffer.getLong(0) + 1 } override def merge(buffer: MutableAggregationBuffer, row: Row): Unit = { println(s">>> merge (buffer: $buffer -> row: $row)") buffer(0) = buffer.getLong(0) + row.getLong(0) } override def evaluate(buffer: Row): Any = { println(s">>> evaluate (buffer: $buffer)") buffer.getLong(0) } } |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
val dataset = spark.range(start = 0, end = 4, step = 1, numPartitions = 2) // Use the UDAF val mycount = new MyCountUDAF val q = dataset. withColumn("group", 'id % 2). groupBy('group). agg(mycount.distinct('id) as "count") scala> q.show +-----+-----+ |group|count| +-----+-----+ | 0| 2| | 1| 2| +-----+-----+ |
The lifecycle of UserDefinedAggregateFunction
is entirely managed using ScalaUDAF expression container.
Figure 1. UserDefinedAggregateFunction and ScalaUDAF Expression Container
Note
|
Use UDFRegistration to register a (temporary)
|
UserDefinedAggregateFunction Contract
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
package org.apache.spark.sql.expressions abstract class UserDefinedAggregateFunction { // only required methods that have no implementation def bufferSchema: StructType def dataType: DataType def deterministic: Boolean def evaluate(buffer: Row): Any def initialize(buffer: MutableAggregationBuffer): Unit def inputSchema: StructType def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit def update(buffer: MutableAggregationBuffer, input: Row): Unit } |
Method | Description |
---|---|
Creating Column for UDAF — apply
Method
1 2 3 4 5 |
apply(exprs: Column*): Column |
apply
creates a Column with ScalaUDAF (inside AggregateExpression).
Note
|
AggregateExpression uses Complete mode and isDistinct flag is disabled.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction val myUDAF: UserDefinedAggregateFunction = ... val myUdafCol = myUDAF.apply($"id", $"name") scala> myUdafCol.explain(extended = true) mycountudaf('id, 'name, $line17.$read$$iw$$iw$MyCountUDAF@4704b66a, 0, 0) scala> println(myUdafCol.expr.numberedTreeString) 00 mycountudaf('id, 'name, $line17.$read$$iw$$iw$MyCountUDAF@4704b66a, 0, 0) 01 +- MyCountUDAF('id,'name) 02 :- 'id 03 +- 'name import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression myUdafCol.expr.asInstanceOf[AggregateExpression] import org.apache.spark.sql.execution.aggregate.ScalaUDAF val scalaUdaf = myUdafCol.expr.children.head.asInstanceOf[ScalaUDAF] scala> println(scalaUdaf.toString) MyCountUDAF('id,'name) |
Creating Column for UDAF with Distinct Values — distinct
Method
1 2 3 4 5 |
distinct(exprs: Column*): Column |
distinct
creates a Column with ScalaUDAF (inside AggregateExpression).
Note
|
AggregateExpression uses Complete mode and isDistinct flag is enabled.
|
Note
|
distinct is like apply but has isDistinct flag enabled.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction val myUDAF: UserDefinedAggregateFunction = ... scala> val myUdafCol = myUDAF.distinct($"id", $"name") myUdafCol: org.apache.spark.sql.Column = mycountudaf(DISTINCT id, name) scala> myUdafCol.explain(extended = true) mycountudaf(distinct 'id, 'name, $line17.$read$$iw$$iw$MyCountUDAF@4704b66a, 0, 0) import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression val aggExpr = myUdafCol.expr scala> println(aggExpr.numberedTreeString) 00 mycountudaf(distinct 'id, 'name, $line17.$read$$iw$$iw$MyCountUDAF@4704b66a, 0, 0) 01 +- MyCountUDAF('id,'name) 02 :- 'id 03 +- 'name scala> aggExpr.asInstanceOf[AggregateExpression].isDistinct res0: Boolean = true |