Spark SQL Window Function With Complex Condition

Spark SQL window function with complex condition

Spark >= 3.2

Recent Spark releases provide native support for session windows in both batch and structured streaming queries (see SPARK-10816 and its sub-tasks, especially SPARK-34893).

The official documentation provides nice usage example.

Spark < 3.2

Here is the trick. Import a bunch of functions:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{coalesce, datediff, lag, lit, min, sum}

Define windows:

val userWindow = Window.partitionBy("user_name").orderBy("login_date")
val userSessionWindow = Window.partitionBy("user_name", "session")

Find the points where new sessions starts:

val newSession =  (coalesce(
datediff($"login_date", lag($"login_date", 1).over(userWindow)),
lit(0)
) > 5).cast("bigint")

val sessionized = df.withColumn("session", sum(newSession).over(userWindow))

Find the earliest date per session:

val result = sessionized
.withColumn("became_active", min($"login_date").over(userSessionWindow))
.drop("session")

With dataset defined as:

val df = Seq(
("SirChillingtonIV", "2012-01-04"), ("Booooooo99900098", "2012-01-04"),
("Booooooo99900098", "2012-01-06"), ("OprahWinfreyJr", "2012-01-10"),
("SirChillingtonIV", "2012-01-11"), ("SirChillingtonIV", "2012-01-14"),
("SirChillingtonIV", "2012-08-11")
).toDF("user_name", "login_date")

The result is:

+----------------+----------+-------------+
| user_name|login_date|became_active|
+----------------+----------+-------------+
| OprahWinfreyJr|2012-01-10| 2012-01-10|
|SirChillingtonIV|2012-01-04| 2012-01-04| <- The first session for user
|SirChillingtonIV|2012-01-11| 2012-01-11| <- The second session for user
|SirChillingtonIV|2012-01-14| 2012-01-11|
|SirChillingtonIV|2012-08-11| 2012-08-11| <- The third session for user
|Booooooo99900098|2012-01-04| 2012-01-04|
|Booooooo99900098|2012-01-06| 2012-01-04|
+----------------+----------+-------------+

Conditions in Spark window function

You can calculate the minimum per group once for rows with r = z and then for all rows within a group. The first non-null value can then be compared to e:

from pyspark.sql import functions as F
from pyspark.sql import Window

df = ...

w = Window.partitionBy("q")
#When ordering is not defined, an unbounded window frame is used by default.

df.withColumn("min_e_with_r_eq_z", F.expr("min(case when r='z' then e else null end)").over(w)) \
.withColumn("min_e_overall", F.min("e").over(w)) \
.withColumn("t", F.coalesce("min_e_with_r_eq_z","min_e_overall") == F.col("e")) \
.orderBy("w") \
.show()

Output:

+---+---+---+---+-----------------+-------------+-----+
| q| w| e| r|min_e_with_r_eq_z|min_e_overall| t|
+---+---+---+---+-----------------+-------------+-----+
| a| 1| 20| y| 22| 20|false|
| a| 2| 22| z| 22| 20| true|
| b| 3| 10| y| null| 10| true|
| b| 4| 12| y| null| 10|false|
+---+---+---+---+-----------------+-------------+-----+

Note: I assume that q is the grouping column for the window.

OOM using Spark window function with 30 days interval

Filtering

It's always good to remove data which is not needed. You said you need just last 60 days, so You could filter out what's not needed.

This line would keep only rows with date not older than 60 last days (until today):

df = df.filter(F.to_date('txn_date', 'yyyyMMdd').between(F.current_date()-61, F.current_date()))

I'll not use it now in order to illustrate other issues.

Window

The first simple thing, if it's already in long format, you don't need to cast to long again, so we can remove .cast(LongType()).

The other, big thing, is that your window's lower bound is wrong. Look, let's add one more line to the input:

[19990101, 'B', 9999999, "xxxxxxx"],

The line represents the date from the year 1999. After the line was added, running the code, we get this:

# +--------+--------+----------+----------------+------------------+
# |txn_date|txn_type|txn_amount|other_attributes|stddev_last_30days|
# +--------+--------+----------+----------------+------------------+
# |20210101| A| 103| abc| 1.0|
# |20210101| A| 102| def| 1.0|
# |20210101| A| 101| def| 1.0|
# |20210102| A| 34| ghu|34.009802508492555|
# |19990101| B| 9999999| xxxxxxx| null|
# |20210101| B| 180| xyz| 7070939.82553808|
# |20210102| B| 123| kqt| 5773414.64605055|
# +--------+--------+----------+----------------+------------------+

You can see that stddev for 2021 year lines was also affected, so 30 day window does not work, your window actually takes all the data it can. We can check what is the lower bound for date 20210101:

print(20210101-days(30))  # Returns 17618101 - I doubt you wanted this date as lower bound

Probably this was your biggest problem. You should never try to outsmart dates and times. Always use functions specialized for dates and times.

You can use this window:

days = lambda i: i * 86400
w = Window.partitionBy('txn_type').orderBy(F.unix_timestamp(F.col('txn_date').cast('string'), 'yyyyMMdd')).rangeBetween(-days(30), 0)
df = df.withColumn('stddev_last_30days', F.stddev('txn_amount').over(w))

df.show()
# +--------+--------+----------+----------------+------------------+
# |txn_date|txn_type|txn_amount|other_attributes|stddev_last_30days|
# +--------+--------+----------+----------------+------------------+
# |20210101| A| 103| abc| 1.0|
# |20210101| A| 102| def| 1.0|
# |20210101| A| 101| def| 1.0|
# |20210102| A| 34| ghu|34.009802508492555|
# |19990101| B| 9999999| xxxxxxx| null|
# |20210101| B| 180| xyz| null|
# |20210102| B| 123| kqt| 40.30508652763321|
# +--------+--------+----------+----------------+------------------+

unix_timestamp can transform your 'yyyyMMdd' format into a proper long-format number (UNIX time in seconds). From this, now you can subtract seconds (30 days worth of seconds).

Window function based on a condition

The following should do the trick (but I'm sure it can be further optimized).

Setup:

data_1=[
("2022-01-08",2,0),
("2022-01-09",4,1),
("2022-01-10",6,1),
("2022-01-11",8,0),
("2022-01-12",2,1),
("2022-01-13",5,1),
("2022-01-14",7,0),
("2022-01-15",9,0),
("2022-01-16",9,0),
("2022-01-17",9,0)
]

schema_1 = StructType([
StructField("Date", StringType(),True),
StructField("Val", IntegerType(),True),
StructField("Cond", IntegerType(),True)
])

df_1 = spark.createDataFrame(data=data_1,schema=schema_1)
df_1 = df_1.withColumn('Date', to_date("Date", "yyyy-MM-dd"))

+----------+---+----+
| Date|Val|Cond|
+----------+---+----+
|2022-01-08| 2| 0|
|2022-01-09| 4| 1|
|2022-01-10| 6| 1|
|2022-01-11| 8| 0|
|2022-01-12| 2| 1|
|2022-01-13| 5| 1|
|2022-01-14| 7| 0|
|2022-01-15| 9| 0|
|2022-01-16| 9| 0|
|2022-01-17| 9| 0|
+----------+---+----+

Create a new DF only with Cond==1 rows to obtain the sum of two consecutive rows with that condition:

windowSpec = Window.partitionBy("Cond").orderBy("Date")
df_2 = df_1.where(df_1.Cond==1).withColumn(
"Sum",
sum("Val").over(windowSpec.rowsBetween(-1, 0))
).withColumn('date_1', col('date')).drop('date')

+---+----+---+----------+
|Val|Cond|Sum| date_1|
+---+----+---+----------+
| 4| 1| 4|2022-01-09|
| 6| 1| 10|2022-01-10|
| 2| 1| 8|2022-01-12|
| 5| 1| 7|2022-01-13|
+---+----+---+----------+

Do a left join to get the sum into the original data frame, and set the sum to zero for the rows with Cond==0:

df_3 = df_1.join(df_2.select('sum', col('date_1')), df_1.Date == df_2.date_1, "left").drop('date_1').fillna(0)

+----------+---+----+---+
| Date|Val|Cond|sum|
+----------+---+----+---+
|2022-01-08| 2| 0| 0|
|2022-01-09| 4| 1| 4|
|2022-01-10| 6| 1| 10|
|2022-01-11| 8| 0| 0|
|2022-01-12| 2| 1| 8|
|2022-01-13| 5| 1| 7|
|2022-01-14| 7| 0| 0|
|2022-01-15| 9| 0| 0|
|2022-01-16| 9| 0| 0|
|2022-01-17| 9| 0| 0|
+----------+---+----+---+

Do a cumulative sum on the condition column:

df_3=df_3.withColumn('cond_sum', sum('cond').over(Window.orderBy('Date')))

+----------+---+----+---+--------+
| Date|Val|Cond|sum|cond_sum|
+----------+---+----+---+--------+
|2022-01-08| 2| 0| 0| 0|
|2022-01-09| 4| 1| 4| 1|
|2022-01-10| 6| 1| 10| 2|
|2022-01-11| 8| 0| 0| 2|
|2022-01-12| 2| 1| 8| 3|
|2022-01-13| 5| 1| 7| 4|
|2022-01-14| 7| 0| 0| 4|
|2022-01-15| 9| 0| 0| 4|
|2022-01-16| 9| 0| 0| 4|
|2022-01-17| 9| 0| 0| 4|
+----------+---+----+---+--------+

Finally, for each partition where the cond_sum is greater than 1, use the max sum for that partition:

df_3.withColumn('sum', when(df_3.cond_sum > 1, max('sum').over(Window.partitionBy('cond_sum'))).otherwise(0)).show()

+----------+---+----+---+--------+
| Date|Val|Cond|sum|cond_sum|
+----------+---+----+---+--------+
|2022-01-08| 2| 0| 0| 0|
|2022-01-09| 4| 1| 0| 1|
|2022-01-10| 6| 1| 10| 2|
|2022-01-11| 8| 0| 10| 2|
|2022-01-12| 2| 1| 8| 3|
|2022-01-13| 5| 1| 7| 4|
|2022-01-14| 7| 0| 7| 4|
|2022-01-15| 9| 0| 7| 4|
|2022-01-16| 9| 0| 7| 4|
|2022-01-17| 9| 0| 7| 4|
+----------+---+----+---+--------+

Spark window function with condition on current row

Very good question!!!
A couple of remarks, using rangeBetween creates a fixed frame that is based on number of rows in it and not on values, so it will be problematic in 2 cases:

  1. customer does not have orders every single day, so 365 rows window might contain rows with order_date well before one year ago
  2. if customer has more than one order per day, it will mess with the one year coverage
  3. combination of the 1 and 2

Also rangeBetween does not work with Date and Timestamp datatypes.

To solve it, it is possible to use window function with lists and an UDF:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val df = spark.sparkContext.parallelize(Seq(
(1, "2017-01-01", "2017-01-10", "A")
, (2, "2017-02-01", "2017-02-10", "A")
, (3, "2017-02-02", "2017-02-20", "A")
)
).toDF("order_id", "order_date", "payment_date", "customer_id")
.withColumn("order_date_ts", to_timestamp($"order_date", "yyyy-MM-dd").cast("long"))
.withColumn("payment_date_ts", to_timestamp($"payment_date", "yyyy-MM-dd").cast("long"))

// df.printSchema()
// df.show(false)

val window = Window.partitionBy("customer_id").orderBy("order_date_ts").rangeBetween(Window.unboundedPreceding, -1)

val count_filtered_dates = udf( (days: Int, top: Long, array: Seq[Long]) => {
val bottom = top - (days * 60 * 60 * 24).toLong // in spark timestamps are in secconds, calculating the date days ago
array.count(v => v >= bottom && v < top)
}
)

val res = df.withColumn("paid_orders", collect_list("payment_date_ts") over window)
.withColumn("paid_order_count", count_filtered_dates(lit(365), $"order_date_ts", $"paid_orders"))

res.show(false)

Output:

+--------+----------+------------+-----------+-------------+---------------+------------------------+----------------+
|order_id|order_date|payment_date|customer_id|order_date_ts|payment_date_ts|paid_orders |paid_order_count|
+--------+----------+------------+-----------+-------------+---------------+------------------------+----------------+
|1 |2017-01-01|2017-01-10 |A |1483228800 |1484006400 |[] |0 |
|2 |2017-02-01|2017-02-10 |A |1485907200 |1486684800 |[1484006400] |1 |
|3 |2017-02-02|2017-02-20 |A |1485993600 |1487548800 |[1484006400, 1486684800]|1 |
+--------+----------+------------+-----------+-------------+---------------+------------------------+----------------+

Converting dates to Spark timestamps in seconds makes the lists more memory efficient.

It is the easiest code to implement, but not the most optimal as the lists will take up some memory, custom UDAF would be best, but requires more coding, might do later. This will still work if you have thousands of orders per customer.

Spark first Window function is taking much longer than last

The solution that doesn't answer the question

In trying various things to speed up my routine, it occurred to me to try re-rewriting my usages of first() to just be usages of last() with a reversed sort order.

So rewriting this:

win_next = (Window.partitionBy('PORT_TYPE', 'loss_process')
.orderBy('rank').rowsBetween(0, Window.unboundedFollowing))

df_part2 = (df_part1
.withColumn('next_rank', F.first(F.col('rank'), ignorenulls=True).over(win_next))
.withColumn('next_sf', F.first(F.col('scale_factor'), ignorenulls=True).over(win_next))
)

As this:

win_next = (Window.partitionBy('PORT_TYPE', 'loss_process')
.orderBy(F.desc('rank')).rowsBetween(Window.unboundedPreceding, 0))

df_part2 = (df_part1
.withColumn('next_rank', F.last(F.col('rank'), ignorenulls=True).over(win_next))
.withColumn('next_sf', F.last(F.col('scale_factor'), ignorenulls=True).over(win_next))
)

Much to my amazement, this actually solved the performance problem, and now the entire dataframe is generated in just 3 seconds. I'm pleased, but still vexed.

As I somewhat predicted, the query plan now includes a new SORT step before creating these next two columns, and they've changed from Window to RunningWindowFunction as the first two. Here's the new plan (without the code broken up into 3 separate cached parts anymore, because that was just to troubleshoot performance):
Sample Image

As for the question:

Why do my calls to first() over Window.unboundedFollowing take so much longer than last() over Window.unboundedPreceding?

I'm hoping someone can still answer this, for academic reasons

Understanding an example of window function

Note that count only counts non-null items, and that the grouping is only defined by the partitionBy clause, but not the orderBy clause.

When you specify an ordering column, the default window range is (according to the docs)

(rangeFrame, unboundedPreceding, currentRow)

So your window defintion is actually

w = (Window().partitionBy("I_id","p_id")
.orderBy(F.col("xyz"))
.rangeBetween(Window.unboundedPreceding, Window.currentRow)
)

And so the window only includes the rows from xyz = -infinity to the value of xyz in the current row. That's why the first row has a count of zero because it counts non-null items from xyz = -infinity to xyz = null, i.e. the first two rows of the dataframe.

For the row where xyz = 2, the count includes non-null items from xyz = -infinity to xyz = 2, i.e. the first four rows. That's why you got a count of 2, because the non-null items are 1 and 2.

Spark Window function - Get all records in a partition in each row, with order maintained

Add a specification for the window, as shown in capital letters below. This will ensure all rows in the partition are included.

spark.sql("""
select *,
collect_list( concat(sender, ' : ', message) ) over (
partition by sessionID
order by timestamp asc
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) as conversation from messages
""")


Related Topics



Leave a reply



Submit