I have a reasonably large df
library(data.table)
# nrows
x <- 1e7
# data
df <- data.table(group = sample(letters[1:2], x, replace=T)
, user = sample(1:x, x, replace=T)
, price = sample(1:3, x, replace=T)
, quantity = sample(1:3, x, replace=T)
); df
group user price quantity
1: a 8286968 1 3
2: b 8652340 1 3
3: a 7388954 1 1
4: b 6932335 3 3
5: a 1468016 1 2
Here is the timing of the hardcode method:
# hardcode
system.time(
df[, .(price = sum(price), quantity = sum(quantity))
, .(user, group)
][, .(mean_price = mean( price ), mean_quantity = mean(quantity))
, .(group)
]
)
user system elapsed
2.77 0.28 1.41
vs that of lapply:
# lapply
x <- c('price', 'quantity')
system.time(
df[, lapply(.SD, \(i) sum(i))
, .SDcols = x
, .(user, group)
][, lapply(.SD, \(i) mean(i))
, .SDcols = x
, .(group)
]
)
user system elapsed
18.86 0.10 17.86
What could be the cause of the big run time difference?
CodePudding user response:
This is because data.table internal optimization on mean and sum isn't used due to lambda functions \(i) sum(i) and \(i) mean(i):
base::mean function is internally optimised to use data.table's fastmean function. mean() from base is an S3 generic and gets slow with many groups.
lapply isn't causing the problem as it also gets optimized:
The expression dt[, lapply(.SD, fun), by=.] gets optimised to dt[, list(fun(a), fun(b), ...), by=.] where a,b, ... are columns in .SD. This improves performance tremendously.
system.time(
df[, lapply(.SD, sum)
, .SDcols = x
, .(user, group)
][, lapply(.SD, mean)
, .SDcols = x
, .(group)
]
)
user system elapsed
2.58 0.39 1.45
