Home > Back-end >  Pyspark dataframe column value dependent on value from another row
Pyspark dataframe column value dependent on value from another row

Time:01-24

I have a dataframe like this:

columns = ['manufacturer', 'product_id']
data = [("Factory", "AE222"), ("Sub-Factory-1", "0"), ("Sub-Factory-2", "0"),("Factory", "AE333"), ("Sub-Factory-1", "0"), ("Sub-Factory-2", "0")]
rdd = spark.sparkContext.parallelize(data)
df = rdd.toDF(columns)
 
 ------------- ---------- 
| manufacturer|product_id|
 ------------- ---------- 
|      Factory|     AE222|
|Sub-Factory-1|         0|
|Sub-Factory-2|         0|
|      Factory|     AE333|
|Sub-Factory-1|         0|
|Sub-Factory-2|         0|
 ------------- ---------- 

Which I want to turn into this:

 ------------- ---------- 
| manufacturer|product_id|
 ------------- ---------- 
|      Factory|     AE222|
|Sub-Factory-1|     AE222|
|Sub-Factory-2|     AE222|
|      Factory|     AE333|
|Sub-Factory-1|     AE333|
|Sub-Factory-2|     AE333|
 ------------- ---------- 

So that each Sub-Factory gets the value from the closest Factory Value above the current Sub-Factory row. I can solve it with a nested for loop but it is not very efficient since there could be millions of rows. I have looked into Pyspark Window function but cannot really wrap my head around it. Any ideas?

CodePudding user response:

You can use first function with ignorenulls=True over a Window. But you need to identify groups of manufacturer in order to partition by that group.

As you didn't give any ID column I'm using monotonically_increasing_id and a cumulative conditional sum to create a group column:

from pyspark.sql import functions as F

df1 = df.withColumn(
    "row_id",
    F.monotonically_increasing_id()
).withColumn(
    "group",
    F.sum(F.when(F.col("manufacturer") == "Factory", 1)).over(Window.orderBy("row_id"))
).withColumn(
    "product_id",
    F.when(
        F.col("product_id") == 0,
        F.first("product_id", ignorenulls=True).over(Window.partitionBy("group").orderBy("row_id"))
    ).otherwise(F.col("product_id"))
).drop("row_id", "group")

df1.show()
# ------------- ---------- 
#| manufacturer|product_id|
# ------------- ---------- 
#|      Factory|     AE222|
#|Sub-Factory-1|     AE222|
#|Sub-Factory-2|     AE222|
#|      Factory|     AE333|
#|Sub-Factory-1|     AE333|
#|Sub-Factory-2|     AE333|
# ------------- ---------- 
  •  Tags:  
  • Related