I have the following dataframe and need to compute the standard deviation of each vector in the column salary.
| dept_name | salary |
|---|---|
| Sales | [30, 36] |
| Finance | [10, 98] |
| Marketing | [20, 22] |
| IT | [40, 90] |
CodePudding user response:
Option 1 - using UDF
- Create a function to calculate the standard deviation for a python list.
- Assign that function to a pyspark sql
udf. - Create a new
stdev_salarycolumn that applies theudfto thesalarycolumn usingwithColumn.
# imports required for this solution
from pyspark.sql.types import *
from pyspark.sql.functions import udf
# calculate std dev for list input
def stdev_list(salary_list):
mean = sum(salary_list) / len(salary_list)
variance = sum([((x - mean) ** 2) for x in salary_list]) / len(salary_list)
stdev = variance ** 0.5
return stdev
# apply std dev function to pyspark sql udf
stdev_udf = udf(stdev_list, FloatType() )
# make a new column using the pyspark sql udf
df = df.withColumn('stdev_salary',stdev_udf('salary'))
More about the pyspark sql udf function here: https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.udf.html
Option 2 - not using UDF
- First
explodethesalarycolumn so each salary item is represented on a new row
from pyspark.sql import functions as F
df_exploded = df.select('dept_name', 'salary', F.explode('salary').alias('salary_item'))
- Then, calculate the standard deviation using the
salary_itemcolumn while grouping bydept_nameandsalary
df_final = df_exploded.groupBy('dept_name', 'salary').agg(F.stddev('salary_item').alias('stddev_salary'))
