Supposing I have a Dataframe like below
| Id | A | B | C | D |
|---|---|---|---|---|
| 1 | 100 | 10 | 20 | 5 |
| 2 | 0 | 5 | 10 | 5 |
| 3 | 0 | 7 | 2 | 3 |
| 4 | 0 | 1 | 3 | 7 |
And the above needs to be converted to something like below
| Id | A | B | C | D | E |
|---|---|---|---|---|---|
| 1 | 100 | 10 | 20 | 5 | 75 |
| 2 | 75 | 5 | 10 | 5 | 60 |
| 3 | 60 | 7 | 2 | 3 | 50 |
| 4 | 50 | 1 | 3 | 7 | 40 |
The thing works by the details provided below
- The data frame now has a new column E which for row 1 is calculated as
col(A) - (max(col(B), col(C)) col(D))=>100-(max(10,20) 5)= 75 - In the row with
Id2, the value of col E from row 1 is brough forward as the value for ColA - So, for row 2, the column
E, is determined as75-(max(5,10) 5)= 60 - Similarly in the row with
Id3, the value of A becomes 60 and the new value for colEis determined based on this
The problem is, the value of col A is dependent on the previous row's values except for the first row
Is there a possibility to solve this using windowing and lag
CodePudding user response:
You can use collect_list function over a Window ordered by Id column and get cumulative array of structs that hold the values of A and max(B, C) D (as field T). Then, apply aggregate to calculate column E.
Note that in this particular case you can't use lag window function as you want the get calculated values recursively.
import org.apache.spark.sql.expressions.Window
val df2 = df.withColumn(
"tmp",
collect_list(
struct(col("A"), (greatest(col("B"), col("C")) col("D")).as("T"))
).over(Window.orderBy("Id"))
).withColumn(
"E",
expr("aggregate(transform(tmp, (x, i) -> IF(i=0, x.A - x.T, -x.T)), 0, (acc, x) -> acc x)")
).withColumn(
"A",
col("E") greatest(col("B"), col("C")) col("D")
).drop("tmp")
df2.show(false)
// --- --- --- --- --- ---
//|Id |A |B |C |D |E |
// --- --- --- --- --- ---
//|1 |100|10 |20 |5 |75 |
//|2 |75 |5 |10 |5 |60 |
//|3 |60 |7 |2 |3 |50 |
//|4 |50 |1 |3 |7 |40 |
// --- --- --- --- --- ---
You can show the intermediary column tmp to understand the logic behind the calculation.
CodePudding user response:
As blackbishop said, you can't use lag function to retrieve changing value of a column. As you're using the scala API, you can develop your own User-Defined Aggregate Function
You create the following case classes, representing the row you're currently reading and your aggregator's buffer:
case class InputRow(A: Integer, B: Integer, C: Integer, D: Integer)
case class Buffer(var E: Integer, var A: Integer)
Then you use them to define your RecursiveAggregator custom aggregator:
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder
object RecursiveAggregator extends Aggregator[InputRow, Buffer, Buffer] {
override def zero: Buffer = Buffer(null, null)
override def reduce(buffer: Buffer, currentRow: InputRow): Buffer = {
buffer.A = if (buffer.E == null) currentRow.A else buffer.E
buffer.E = buffer.A - (math.max(currentRow.B, currentRow.C) currentRow.D)
buffer
}
override def merge(b1: Buffer, b2: Buffer): Buffer = {
throw new NotImplementedError("should be used only over ordered window")
}
override def finish(reduction: Buffer): Buffer = reduction
override def bufferEncoder: Encoder[Buffer] = ExpressionEncoder[Buffer]
override def outputEncoder: Encoder[Buffer] = ExpressionEncoder[Buffer]
}
Finally you transform your RecursiveAggregator to an User-Defined aggregate function that you apply on your input dataframe:
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, udaf}
val recursiveAggregator = udaf(RecursiveAggregator)
val window = Window.orderBy("Id")
val result = input
.withColumn("computed", recursiveAggregator(col("A"), col("B"), col("C"), col("D")).over(window))
.select("Id", "computed.A", "B", "C", "D", "computed.E")
If you take your question's dataframe as input dataframe, you get the following result dataframe:
--- --- --- --- --- ---
|Id |A |B |C |D |E |
--- --- --- --- --- ---
|1 |100|10 |20 |5 |75 |
|2 |75 |5 |10 |5 |60 |
|3 |60 |7 |2 |3 |50 |
|4 |50 |1 |3 |7 |40 |
--- --- --- --- --- ---
