Home > Blockchain >  How can I apply linear model coefficients to input data for multiple groups?
How can I apply linear model coefficients to input data for multiple groups?

Time:01-28

My input dependent variable has NA values but was used to build linear models for several groups. How can I use the linear model coefficients to derive predicted values for each row among each group?

Example data

library(dplyr)

df <- iris
df[c(1:2,51:52,101:102),3] <- NA # add NA values to dependent variable

# Creates a linear model for each Species group

fitted_models <- df %>%
  group_by(Species) %>%
  do(model = lm(Petal.Length~Sepal.Length Sepal.Width, data = .))

The below code works so long as there are no NA values in the dependent variables. However, the below code throws the following error message when you try to use it with the data with NA values (i.e., df): Error: Incompatible lengths: 50, 48.

library(tidyr)
library(purrr)

# works
fitted_models <- iris %>%
  nest(data = -Species) %>% 
  mutate(fit = map(data, ~ lm(Petal.Length ~ Sepal.Length   Sepal.Width, data = .x)),
         fitted.values = map(fit, "fitted.values")) %>% 
  unnest(cols = c(data, fitted.values)) %>% 
  select(-fit)

# does not work
fitted_models <- df %>%
  nest(data = -Species) %>% 
  mutate(fit = map(data, ~ lm(Petal.Length ~ Sepal.Length   Sepal.Width, data = .x)),
         fitted.values = map(fit, "fitted.values")) %>% 
  unnest(cols = c(data, fitted.values)) %>% 
  select(-fit)

The ideal output (see below) would use the coefficients from each of the Species models to derive predicted values for each row, including those with NA in the dependent variable.

   Species Sepal.Length Sepal.Width Petal.Length Petal.Width predicted.values
   <fct>          <dbl>       <dbl>        <dbl>       <dbl>         <dbl>
 1 setosa           5.1         3.5          NA          0.2          1.47
 2 setosa           4.9         3            NA          0.2          1.46
 3 setosa           4.7         3.2          1.3         0.2          1.42
 4 setosa           4.6         3.1          1.5         0.2          1.41
 5 setosa           5           3.6          1.4         0.2          1.46
 6 setosa           5.4         3.9          1.7         0.4          1.51
 7 setosa           4.6         3.4          1.4         0.3          1.40
 8 setosa           5           3.4          1.5         0.2          1.46
 9 setosa           4.4         2.9          1.4         0.2          1.38
10 setosa           4.9         3.1          1.5         0.1          1.45
# ... with 140 more rows

CodePudding user response:

This version with predict() and purrr::map2() works.

f <- function (.fit, .new_data) {
  predict(.fit, newdata = .new_data)
}

df %>%
  nest(data = -Species) %>% 
  mutate(
    fit  = map(data, ~ lm(Petal.Length ~ Sepal.Length   Sepal.Width, data = .x)),
    yhat = map2(.x = fit, .y = data, f)
  ) %>% 
  unnest(cols = c(data, yhat)) %>% 
  select(-fit)

Output:

# A tibble: 150 × 6
   Species Sepal.Length Sepal.Width Petal.Length Petal.Width  yhat
   <fct>          <dbl>       <dbl>        <dbl>       <dbl> <dbl>
 1 setosa           5.1         3.5         NA           0.2  1.48
 2 setosa           4.9         3           NA           0.2  1.46
 3 setosa           4.7         3.2          1.3         0.2  1.42
 4 setosa           4.6         3.1          1.5         0.2  1.41
 5 setosa           5           3.6          1.4         0.2  1.46
 6 setosa           5.4         3.9          1.7         0.4  1.51
 7 setosa           4.6         3.4          1.4         0.3  1.40
 8 setosa           5           3.4          1.5         0.2  1.46
 9 setosa           4.4         2.9          1.4         0.2  1.39
10 setosa           4.9         3.1          1.5         0.1  1.46
# … with 140 more rows

I wish I could figure it out with the fitted.values variable that's already calculated. I tried adding a row id before nesting, so the two data.frames could be left-joined (because data has all 50 rows). But then I thought the two unnesting and one join operations are more conceptual work than rerunning the linear equation used in predict().

  •  Tags:  
  • Related