Home > Net >  How to compute the mean and standard deviation of columns ignoring NaN values
How to compute the mean and standard deviation of columns ignoring NaN values

Time:02-08

I have a dataframe of doubles with some NaN/Null/NA values:

val dfDouble = Seq(
  (1.0, 1.0, 1.0, 3.0),
  (1.0, 2.0, 0.0, 0.0),
  (1.0, 3.0, 1.0, 1.0),
  (1.0, NaN, 0.0, 2.0)).toDF("m1", "m2", "m3", "m4")

I would like to compute the mean, standard deviation, and # of non-null observations for each column, but it seems like the regular aggregate functions in spark returns NaN when there is a NaN value:

dfDouble.select(dfDouble.columns.map(c => mean(col(c))) :_*).show
//  ------- ------- ------- ------- 
// |avg(m1)|avg(m2)|avg(m3)|avg(m4)|
//  ------- ------- ------- ------- 
// |    1.0|    NaN|    0.5|    1.5|
//  ------- ------- ------- ------- 
dfDouble.select(dfDouble.columns.map(c => stddev(col(c))) :_*).show
//  --------------- --------------- ------------------ ------------------ 
// |stddev_samp(m1)|stddev_samp(m2)|   stddev_samp(m3)|   stddev_samp(m4)|
//  --------------- --------------- ------------------ ------------------ 
// |            0.0|            NaN|0.5773502691896257|1.2909944487358056|
//  --------------- --------------- ------------------ ------------------ 

How can I compute the mean, standard deviation, and # of non-null observations EXCLUDING NaN values?

CodePudding user response:

You can replace the NaN values by null before applying mean and stddev functions:

val df = dfDouble.na.fill(dfDouble.columns.map((_, "null")).toMap)

df.select(df.columns.map(c => mean(col(c))) :_*).show

// ------- ------- ------- ------- 
//|avg(m1)|avg(m2)|avg(m3)|avg(m4)|
// ------- ------- ------- ------- 
//|    1.0|    2.0|    0.5|    1.5|
// ------- ------- ------- ------- 
  •  Tags:  
  • Related