I have a spark dataframe that looks something like below.
| date | ID | window_size | qty |
|---|---|---|---|
| 01/01/2020 | 1 | 2 | 1 |
| 02/01/2020 | 1 | 2 | 2 |
| 03/01/2020 | 1 | 2 | 3 |
| 04/01/2020 | 1 | 2 | 4 |
| 01/01/2020 | 2 | 3 | 1 |
| 02/01/2020 | 2 | 3 | 2 |
| 03/01/2020 | 2 | 3 | 3 |
| 04/01/2020 | 2 | 3 | 4 |
I'm trying to apply a rolling window of size window_size to each ID in the dataframe and get the rolling sum. Basically I'm calculating a rolling sum (pd.groupby.rolling(window=n).sum() in pandas) where the window size (n) can change per group.
Expected output
| date | ID | window_size | qty | rolling_sum |
|---|---|---|---|---|
| 01/01/2020 | 1 | 2 | 1 | null |
| 02/01/2020 | 1 | 2 | 2 | 3 |
| 03/01/2020 | 1 | 2 | 3 | 5 |
| 04/01/2020 | 1 | 2 | 4 | 7 |
| 01/01/2020 | 2 | 3 | 1 | null |
| 02/01/2020 | 2 | 3 | 2 | null |
| 03/01/2020 | 2 | 3 | 3 | 6 |
| 04/01/2020 | 2 | 3 | 4 | 9 |
I'm struggling to find a solution that works and is fast enough on a large dataframe ( - 350M rows).
What I have tried
I tried the solution in the below thread:
The idea is to first use sf.collect_list and then slice the ArrayType column correctly.
import pyspark.sql.types as st
import pyspark.sql.function as sf
window = Window.partitionBy('id').orderBy(params['date'])
output = (
sdf
.withColumn("qty_list", sf.collect_list('qty').over(window))
.withColumn("count", sf.count('qty').over(window))
.withColumn("rolling_sum", sf.when(sf.col('count') < sf.col('window_size'), None)
.otherwise(sf.slice('qty_list', sf.col('count'), sf.col('window_size'))))
).show()
However this yields below error:
TypeError: Column is not iterable
I have also tried using sf.expr like below
window = Window.partitionBy('id').orderBy(params['date'])
output = (
sdf
.withColumn("qty_list", sf.collect_list('qty').over(window))
.withColumn("count", sf.count('qty').over(window))
.withColumn("rolling_sum", sf.when(sf.col('count') < sf.col('window_size'), None)
.otherwise(sf.expr("slice('window_size', 'count', 'window_size')")))
).show()
Which yields:
data type mismatch: argument 1 requires array type, however, ''qty_list'' is of string type.; line 1 pos 0;
I tried manually casting the qty_list column to ArrayType(IntegerType()) with the same result.
I tried using a UDF but that fails with several out of memory errors after 1,5 hours or so.
Questions
Reading the spark documentation suggests to me that I should be able to pass columns to
sf.slice(), am I doing something wrong? Where is theTypeErrorcoming from?Is there a better way to achieve what I want without using
sf.collect_list()and/orsf.slice()?If all else fails, what would be the optimal way to do this using a udf? I attempted different versions of the same udf and tried to make sure the udf is the last operation spark has to perform, but all failed.
CodePudding user response:
About the errors you get:
- The first one means you can't pass a column to
sliceusing DataFrame API function (unless you have Spark 3.1 ). But you already got it as you tried using it within SQL expression. - Second error occurs because you pass column names quoted in your
expr. It should beslice(qty_list, count, window_size)otherwise Spark is considering them as strings hence the error message.
That said, you almost got it, you need to change the expression for slicing to get the correct size of array, then use aggregate function to sum up the values of the resulting array. Try with this:
from pyspark.sql import Window
import pyspark.sql.functions as F
w = Window.partitionBy('id').orderBy('date')
output = df.withColumn("qty_list", F.collect_list('qty').over(w)) \
.withColumn("rn", F.row_number().over(w)) \
.withColumn(
"qty_list",
F.when(
F.col('rn') < F.col('window_size'),
None
).otherwise(F.expr("slice(qty_list, rn-window_size 1, window_size)"))
).withColumn(
"rolling_sum",
F.expr("aggregate(qty_list, 0D, (acc, x) -> acc x)").cast("int")
).drop("qty_list", "rn")
output.show()
# ---------- --- ----------- --- -----------
#| date| ID|window_size|qty|rolling_sum|
# ---------- --- ----------- --- -----------
#|01/01/2020| 1| 2| 1| null|
#|02/01/2020| 1| 2| 2| 3|
#|03/01/2020| 1| 2| 3| 5|
#|04/01/2020| 1| 2| 4| 7|
#|01/01/2020| 2| 3| 1| null|
#|02/01/2020| 2| 3| 2| null|
#|03/01/2020| 2| 3| 3| 6|
#|04/01/2020| 2| 3| 4| 9|
# ---------- --- ----------- --- -----------
