I have dataframe like below. We have data for each month.
df = pd.DataFrame({'id': ['A', 'A', 'A', 'A', 'B', 'B', 'B', 'B'],
'month': ['2020-01', '2020-02', '2020-03', '2020-04',
'2020-01','2020-02','2020-03','2020-04'],
'amt': [2, 3, 4, 5, 2, 3, 1, 5]})
id month amt
0 A 2020-01 2
1 A 2020-02 3
2 A 2020-03 4
3 A 2020-04 5
4 B 2020-01 2
5 B 2020-02 3
6 B 2020-03 1
7 B 2020-04 5
I need to aggregate values from multiple months. Below is the desired result.
id month_start month_end amt
0 A 2020-01 2020-03 9
1 A 2020-02 2020-04 12
2 B 2020-01 2020-03 6
3 B 2020-02 2020-04 9
I'm looking for a general solution. The real case is more complicated. For example, there could be n months between start and end. Appreciate it if someone can find solution in both python and pyspark. Thanks.
CodePudding user response:
groupby_rolling would work well here. groupby "id" and create a rolling n window object. Then find sum of "amt" and find the first and last months of the rolling window. Note that since rolling doesn't accept non-numeric values, we need to use the df.index to filter the correct start_month and end_month dates for each rolling window.
n = 3
# index will be used to get the start and end months in rolling
df = df.reset_index()
r_obj = df.groupby('id').rolling(n)
out = r_obj['amt'].sum().dropna().droplevel(1).reset_index()
month_idx = r_obj['index'].agg({'start_month_idx': lambda x: x.iat[0], 'end_month_idx': lambda x: x.iat[-1]}).dropna().reset_index(drop=True)
out['start_month'] = df.loc[month_idx['start_month_idx'], 'month'].reset_index(drop=True)
out['end_month'] = df.loc[month_idx['end_month_idx'], 'month'].reset_index(drop=True)
out = out[['id', 'start_month', 'end_month', 'amt']]
Output:
id start_month end_month amt
0 A 2020-01 2020-03 9.0
1 A 2020-02 2020-04 12.0
2 B 2020-01 2020-03 6.0
3 B 2020-02 2020-04 9.0
CodePudding user response:
In pyspark you can use collect_list over a Window with frame boundaries specified as rows between [-n, currentRow],
to get the n consecutive months and also calculate a running sum of amt over this same Window. Finally, filter only rows with size of months equals n 1:
from pyspark.sql import functions as F, Window
# create spark df from pandas dataframe
sdf = spark.createDataFrame(df)
n = 2
w = Window.partitionBy("id").orderBy("month").rowsBetween(-n, Window.currentRow)
result = sdf.withColumn("months", F.collect_list("month").over(w)) \
.withColumn("amt", F.sum("amt").over(w)) \
.filter(F.size("months") == n 1) \
.select(
F.col("id"),
F.element_at(F.col("months"), 1).alias("month_start"),
F.element_at(F.col("months"), -1).alias("month_end"),
F.col("amt")
)
result.show()
# --- ----------- --------- ---
#| id|month_start|month_end|amt|
# --- ----------- --------- ---
#| A| 2020-01| 2020-03| 9|
#| A| 2020-02| 2020-04| 12|
#| B| 2020-01| 2020-03| 6|
#| B| 2020-02| 2020-04| 9|
# --- ----------- --------- ---
