If you’ve worked with Spark, you have probably written some custom UDF or UDAFs.
UDFs are ‘User Defined Functions’, so you can introduce complex logic in your queries/jobs, for instance, to calculate a digest for a string, or if you want to use a java/scala library in your queries.
UDAF stands for ‘User Defined Aggregate Function’ and it works on aggregates, so you can implement functions that can be used in a GROUP BY
clause, similar to AVG
.
You may not be familiar with Window functions, which are similar to aggregate functions, but they add a layer of complexity, since they are applied within a PARTITION BY
clause. An example of window function is RANK(). You can read more about window functions here.
在使用 spark sql 的时候,有时候默认提供的sql 函数可能满足不了需求,这时候可以自定义一些函数,也就是UDF 或者UDAF(顾名思义,User Defined Functions)。
UDF 只是在sql中简单的处理转换一些字段,类似默认的trim 函数把一个字符串类型的列的头尾空格去掉, 还有一种sql函数叫做UDAF,不同于UDF,这种是在sql聚合语句中使用的sql函数,必须配合 GROUP BY 一同使用,类似默认的count,sum函数,但是还有一种自定义函数叫做 UDWF, 这种一般人就不知道了,这种叫做窗口自定义函数,不了解窗口函数的,可以参考上一篇文章,或者官方的介绍
While aggregate functions work over a group, window functions work over a logical window of record and allow you to produce new columns from the combination of a record and one or more records in the window.
Describing what window functions are is beyond the scope of this article, so for that refer to the previously mentioned article from Databricks, but in particular, we are interested at the ‘previous event in time for a user’ in order to figure out sessions.
There is plenty of documentation on how to write UDFs and UDAFs, see for instance This link for UDFs or this link for UDAFs.
I was surprised to find out there’s not much info on how to build an custom window function, so I dug up the source code for spark and started looking at how window functions are implemented. That opened to me a whole new world, since Window functions, although conceptually similar to UDAFs, use a lower level Spark API than UDAFs, they are written using Catalyst expressions.
窗口函数是 SQL 中一类特别的函数。和聚合函数相似,窗口函数的输入也是多行记录。不 同的是,聚合函数的作用于由 GROUP BY 子句聚合的组,而窗口函数则作用于一个窗口
这里怎么理解一个窗口呢,spark君在这里得好好的解释解释,一个窗口是怎么定义的,
窗口语句中,partition by用来指定分区的列,在同一个分区的行属于同一个窗口
order by用来指定数据在一个窗口内的多行,如何排序
windowing_clause 用来指定开窗方式,在spark sql 中开窗方式有那么几种
- 一个分区中的所有行作为一个窗口:UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING(上下都没有边界),这种情况下,spark sql 会把所有行作为一个输入,进行一次求值
- Growing frame:UNBOUNDED PRECEDING AND ….(上无边界), 这种就是不断的把当前行加入的窗口中,而不删除, 例子:
.rowsBetween(Long.MinValue, 0)
:窗口的大小是按照排序从最小值到当前行,在数据迭代过程中,不断的把当前行加入的窗口中。 - Shrinking frame:… AND UNBOUNDED FOLLOWING(下无边界)和Growing frame 相反,窗口不断的把迭代到的当前行从窗口中删除掉。
- Moving frame:滑动的窗口,举例:
.rowsBetween(-1, 1) 就是指
窗口定义从 -1(当前行前一行)到 1(当前行后一行) ,每一个滑动的窗口总用有3行 - Offset frame 窗口中只有一条数据,就是偏移当前行一定距离的哪一行,举例:
lag(field, n)
: 就是取从当前字段往前第n个值
这里就针对窗口函数就介绍这么多,如果不懂请参考相关文档,加强理解,我们在平时使用 spark sql 的过程中,会发现有很多教你自定义 UDF 和 UDAF 的教程,却没有针对UDWF的教程,这是为啥呢,这是因为 UDF 和UDAF 都作为上层API暴露给用户了,使用scala很简单就可以写一个函数出来,但是UDWF没有对上层用户暴露,只能使用 Catalyst expressions. 也就是Catalyst框架底层的表达式语句才可以定义,如果没有对源码有很深入的研究,根本就搞不出来。spark 君在工作中写了一些UDWF的函数,但是都比较复杂,不太好单独抽出来作为一个简明的例子给大家讲解,这里翻译一篇文章来做说明。
窗口函数的使用场景
Now, for what kind of problem do we need window functions in the first place?
A common problem when working on any kind of website, is to determine ‘user sessions’, periods of user activity. if an user is inactive for a certain time T, then it’s considered a new ‘session’. Statistics over sessions are used to determine for instance if the user is a bot, to find out what pages have the most activity, etc.
Let’s say that we consider a session over if we don’t see any activity for one hour (sixty minutes). Let’s see an example of user activity, where ‘event’ has the name of the page the user visited and time is the time of the event. I simplified it, since the event would be a URL, while the time would be a full timestamp, and the session id would be generated as a random UUID, but I put simpler names/times just to illustrate the logic.
我们来举个实际例子来说明 窗口函数的使用场景,在网站的统计指标中,有一个概念叫做用户会话,什么叫做用户会话呢,我来说明一下,我们在网站服务端使用用户session来管理用户状态,过程如下
1) 服务端session是用户第一次访问应用时,服务器就会创建的对象,代表用户的一次会话过程,可以用来存放数据。服务器为每一个session都分配一个唯一的sessionid,以保证每个用户都有一个不同的session对象。
2)服务器在创建完session后,会把sessionid通过cookie返回给用户所在的浏览器,这样当用户第二次及以后向服务器发送请求的时候,就会通过cookie把sessionid传回给服务器,以便服务器能够根据sessionid找到与该用户对应的session对象。
3)session通常有失效时间的设定,比如1个小时。当失效时间到,服务器会销毁之前的session,并创建新的session返回给用户。但是只要用户在失效时间内,有发送新的请求给服务器,通常服务器都会把他对应的session的失效时间根据当前的请求时间再延长1个小时。
也就是说如果用户在1个超过一个小时不产生用户事件,当前会话就结束了,如果后续再产生用户事件,就当做新的用户会话,我们现在就使用spark sql 来统计用户的会话数,这种场景就很适合使用窗口函数来做统计,因为判断当前是否是一个新会话的依据,需要依赖当前行的前一行的时间戳和当前行的时间戳的间隔来判断,下面的表格可以帮助你理解这个概念,例子中有3列数据,用户,event字段代表用户访问了一个页面,time字段代表访问页面的时间戳:
user | event | time | session |
---|---|---|---|
user1 | page1 | 10:12 | session1 (new session) |
user1 | page2 | 10:20 | session1 (same session, 8 minutes from last event) |
user1 | page1 | 11:13 | session1 (same session, 53 minutes from last event) |
user1 | page3 | 14:12 | session2 (new session, 3 hours after last event) |
Note that this is the activity for one user. We do have many users, and in fact partitioning by user is the job of the window function.
上面只有一个用户,如果多个用户,可以使用 partition by 来进行分区。
深入研究
It’s better to use an example to illustrate how the function works in respect of the window definition.
Let’s assume we have a very simple user activity data, with a user ID called user
, while ts
is a numeric timestamp and session
is a session ID, that may be already present. While we may start with no session whatsoever, in most practical cases, we may be processing data hourly, so at hour N + 1 we want to continue the sessions
we calculated at hour n.
Let’s create some test data and show what we want to achieve.
我们来构造一些假数据:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
case class UserActivityData(user:String, ts:Long, session:String) // our sample data val d = Array[UserActivityData]( UserActivityData("user1", st, "ss1"), UserActivityData("user2", st + 5*one_minute, null), UserActivityData("user1", st + 10*one_minute, null), UserActivityData("user1", st + 15*one_minute, null), UserActivityData("user2", st + 15*one_minute, null), UserActivityData("user1", st + 140*one_minute, null), UserActivityData("user1", st + 160*one_minute, null)) // creating the DataFrame val sqlContext = new SQLContext(sc) val df = sqlContext.createDataFrame(sc.parallelize(d)) // Window specification val specs = Window.partitionBy(f.col("user")).orderBy(f.col("ts").asc) // create the session val res = df.withColumn( "newsession", calculateSession(f.col("ts"), f.col("session")) over specs) |
First, the window specification. Sessions are create per user, and the ordering is of course by timestamp.
Hence, we want to apply the function partitionBy
user and orderBy
timestamp.
怎么使用 spark sql 来统计会话数目呢,因为不同用户产生的是不同的会话,首先使用user字段进行分区,然后按照时间戳进行排序
We want to write a createSession
function that will use the following logic:
这时候我们需要一个自定义函数来加一个列,这个列的值的逻辑如下
1 2 3 4 |
IF(no previous event) create new session ELSE (if current event was past session window) THEN create new session ELSE use current session |
and will produce something like this:
运行结果如下:
user | ts | session | newsession |
---|---|---|---|
user1 | 1508863564166 | f237e656-1e.. | f237e656-1e.. |
user1 | 1508864164166 | null | f237e656-1e.. |
user1 | 1508864464166 | null | f237e656-1e5.. |
user1 | 1508871964166 | null | 51c05c35-6f.. |
user1 | 1508873164166 | null | 51c05c35-6f.. |
user2 | 1508863864166 | null | 2c16b61a-6c.. |
user2 | 1508864464166 | null | 2c16b61a-6c.. |
Note that we are using random UUIDs as it’s pretty much the standard, and we’re shortening them for typographical reasons.
As you see, for each user, it will create a new session whenever the difference between two events is bigger than the session threshold.
Internally, for every record, we want to keep track of:
- The current session ID
- The timestamp of the previous session
This is going to be the state that we must maintain. Spark takes care of initializing it for us.
It is also going to be the parameters the function expects.
Let’s see the skeleton of the function:
我们使用 UUID 来作为会话id, 当后一行的时间戳和前一行的时间戳间隔大于1小时的时候,就创建一个新的会话id作为列值,否则使用老的会话id作为列值。
这种就涉及到状态,我们在内部需要维护的状态数据
- 当前的session ID
- 当前session的最后活动事件的时间戳
自定义函数的代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
object MyUDWF { val defaultSessionLengthms = 3600 * 1000 // longer than this, and it's a new session case class SessionUDWF(timestamp:Expression, session:Expression, sessionWindow:Expression = Literal(defaultMaxSessionLengthms)) extends AggregateWindowFunction { self: Product => override def children: Seq[Expression] = Seq(timestamp, session) override def dataType: DataType = StringType protected val zero = Literal( 0L ) protected val nullString = Literal(null:String) protected val curentSession = AttributeReference("currentSession", StringType, nullable = true)() protected val previousTs = AttributeReference("previousTs", LongType, nullable = false)() override val aggBufferAttributes: Seq[AttributeReference] = curentSession :: previousTs :: Nil override val initialValues: Seq[Expression] = nullString :: zero :: Nil override def prettyName: String = "makeSession" // we have to write these ones override val updateExpressions: Seq[Expression] = ... override val evaluateExpression: Expression = ... } } |
A few notes here:
- Our ‘state’ is going to be a
Seq[AttributeReference]
- Each
AttributeReference
must be declared with its type. As we said, we keep the current Session and the timestamp of the previous one. - We inizialize it by overriding
initialValues
- For every record, within the window, spark will call first
updateExpressions
, then will produce the values callingevaluateExpression
Now it’s time to implement the updateExpressions
and evaluateExpression
functions.
注解:
- 状态保存在
Seq[AttributeReference]
中 - 重写
initialValues
方法进行初始化 - spark sql 在迭代处理每一行数据的时候,都会调用 updateExpressions 函数来处理,根据当后一行的时间戳和前一行的时间戳间隔大于1小时来进行不同的逻辑处理,如果不大于,就使用 aggBufferAttributes(0) 中保存的老的sessionid,如果大于,就把 createNewSession 包装为一个scalaUDF作为一个子表达式来创建一个新的sessionID,并且每次都把当前行的时间戳作为用户活动的最后时间戳。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
// this is invoked whenever we need to create a a new session ID. You can use your own logic, here we create UUIDs protected val createNewSession = () => org.apache.spark.unsafe.types. UTF8String.fromString(UUID.randomUUID().toString) // initialize with no session, zero previous timestamp override val initialValues: Seq[Expression] = nullString :: zero :: Nil // if a session is already assigned, keep it, otherwise, assign one override val updateExpressions: Seq[Expression] = If(IsNotNull(session), session, assignSession) :: timestamp :: Nil // assign session: if previous timestamp was longer than interval, // new session, otherwise, keep current. protected val assignSession = If(LessThanOrEqual( Subtract(timestamp, aggBufferAttributes(1)), sessionWindow), aggBufferAttributes(0), ScalaUDF( createNewSession, StringType, children = Nil)) // just return the current session in the buffer override val evaluateExpression: Expression = aggBufferAttributes(0) |
Notice how we use catalyst expressions, while in normal UDAFs we just use plain scala expressions.
Last thing, we need to declare a static method that we can invoke from the query that will instantiate the function. Notice how I created two, one that allows the user to specify what’s the max duration of a session, and one that takes the default:
最后包装为静态对象的方法,就可以在spark sql中使用这个自定义窗口函数了,下面是两个重载的方法,一个最大间隔时间使用默认值,一个可以运行用户自定义,perfect。
1 2 3 4 5 6 7 8 |
def calculateSession(ts:Column,sess:Column): Column = withExpr { SessionUDWF(ts.expr,sess.expr, Literal(defaultMaxSessionLengthms)) } def calculateSession(ts:Column,sess:Column, sessionWindow:Column): Column = withExpr { SessionUDWF(ts.expr,sess.expr, sessionWindow.expr) } |
Now creating session IDs is as easy as:
现在,我们就可以拿来用在我们的main函数中了。
1 2 3 4 5 6 |
// Window specification val specs = Window.partitionBy(f.col("user")).orderBy(f.col("ts").asc) // create the session val res = df.withColumn( "newsession", calculateSession(f.col("ts"), f.col("session"), f.lit(10*1000) over specs) // 10 seconds. Duration is in ms. |
Notice that here we specified 10 second sessions.
There’s a little more piping involved which was omitted for clarity, but you can find the complete code, including unit tests, in my github project