Tutorial on tidymodels for Machine Learning

caret is a well known R package for machine learning, which includes almost everything from data pre-processing to cross-validation. The unofficial successor of caret is tidymodels, which has a modular approach meaning that specific, smaller packages are designed to work hand in hand. Thus, tidymodels is to modeling what the tidyverse is to data wrangling. Herein, I will walk through a machine learning example from start to end and explain how to use the appropriate tidymodels packages at each place.

Set Up

Loading the tidymodels package loads a bunch of packages for modeling and also a few others from the tidyverse like ggplot2 and dplyr.

library("conflicted")
library("tidymodels")
#> -- Attaching packages ------------------------------------------------ tidymodels 0.1.0 --
#> v broom     0.5.5      v recipes   0.1.10
#> v dials     0.0.6      v rsample   0.0.6 
#> v dplyr     0.8.5      v tibble    3.0.0 
#> v ggplot2   3.3.0      v tune      0.1.0 
#> v infer     0.5.1      v workflows 0.1.1 
#> v parsnip   0.0.5      v yardstick 0.0.6 
#> v purrr     0.3.3

# Additional packages for dataviz etc.
library("ggrepel")     # for geom_label_repel()
library("corrplot")    # for corrplot()
#> corrplot 0.84 loaded

conflict_prefer("filter", "dplyr")
#> [conflicted] Will prefer dplyr::filter over any other package

ggplot2::theme_set(theme_light())

Data Set: Diamonds

With ggplot2 comes the diamonds data set, which has information on the size and quality of diamonds. Herein, we’ll use these features to predict the price of a diamond.

data("diamonds")
diamonds %>%
    sample_n(2000) %>% 
    mutate_if(is.factor, as.numeric) %>%
    select(price, everything()) %>%
    cor %>%
    {.[order(abs(.[, 1]), decreasing = TRUE), 
       order(abs(.[, 1]), decreasing = TRUE)]} %>%
    corrplot(method = "number", type = "upper", mar = c(0, 0, 1.5, 0),
             title = "Correlations between price and various features of diamonds")

Separating Testing and Training Data: rsample

First of all, we want to extract a data set for testing the predictions in the end. We’ll only use a small proportion for training (only to speed things up a little). Furthermore, the training data set will be prepared for 3-fold cross-validation (using three here to speed things up). All this is accomplished using the rsample package:

set.seed(1243)

dia_split <- initial_split(diamonds, prop = .1, strata = price)

dia_train <- training(dia_split)
dia_test  <- testing(dia_split)

dim(dia_train)
#> [1] 5395   10
dim(dia_test)
#> [1] 48545    10

dia_vfold <- vfold_cv(dia_train, v = 3, repeats = 1, strata = price)
dia_vfold %>% 
  mutate(df_ana = map(splits, analysis),
         df_ass = map(splits, assessment))
#> #  3-fold cross-validation using stratification 
#> # A tibble: 3 x 4
#>   splits              id    df_ana                df_ass               
#> * <named list>        <chr> <named list>          <named list>         
#> 1 <split [3.6K/1.8K]> Fold1 <tibble [3,596 x 10]> <tibble [1,799 x 10]>
#> 2 <split [3.6K/1.8K]> Fold2 <tibble [3,596 x 10]> <tibble [1,799 x 10]>
#> 3 <split [3.6K/1.8K]> Fold3 <tibble [3,598 x 10]> <tibble [1,797 x 10]>

Data Pre-Processing and Feature Engineering: recipes

The recipes package can be used to prepare a data set (for modeling) using different step_*() functions. For example, the plot below indicates that there may be a nonlinear relationship between price and carat, and I want to address that using higher-order terms.

qplot(carat, price, data = dia_train) +
    scale_y_continuous(trans = log_trans(), labels = function(x) round(x, -2)) +
    geom_smooth(method = "lm", formula = "y ~ poly(x, 4)") +
    labs(title = "Nonlinear relationship between price and carat of diamonds",
         subtitle = "The degree of the polynomial is a potential tuning parameter")

The recipe() takes a formula and a data set, and then the different steps are added using the appropriate step_*() functions. The recipes package comes with a ton of useful step functions (see, e.g., vignette("Simple_Example", package = "recipes")).

Herein, I want to log transform price (step_log()), I want to center and scale all numeric predictors (step_normalize()), and the categorical predictors should be dummy coded (step_dummy()). Furthermore, a quadratic effect of carat is added using step_poly().

dia_rec <-
    recipe(price ~ ., data = dia_train) %>%
    step_log(all_outcomes()) %>%
    step_normalize(all_predictors(), -all_nominal()) %>%
    step_dummy(all_nominal()) %>%
    step_poly(carat, degree = 2)

prep(dia_rec)
#> Data Recipe
#> 
#> Inputs:
#> 
#>       role #variables
#>    outcome          1
#>  predictor          9
#> 
#> Training data contained 5395 data points and no missing data.
#> 
#> Operations:
#> 
#> Log transformation on price [trained]
#> Centering and scaling for carat, depth, table, x, y, z [trained]
#> Dummy variables from cut, color, clarity [trained]
#> Orthogonal polynomials on carat [trained]

Calling prep() on a recipe applies all the steps. You can now call juice() to extract the transformed data set or call bake() on a new data set.

# Note the linear and quadratic term for carat and the dummies for e.g. color
dia_juiced <- juice(prep(dia_rec))
dim(dia_juiced)
#> [1] 5395   25
names(dia_juiced)
#>  [1] "depth"        "table"        "x"            "y"            "z"           
#>  [6] "price"        "cut_1"        "cut_2"        "cut_3"        "cut_4"       
#> [11] "color_1"      "color_2"      "color_3"      "color_4"      "color_5"     
#> [16] "color_6"      "clarity_1"    "clarity_2"    "clarity_3"    "clarity_4"   
#> [21] "clarity_5"    "clarity_6"    "clarity_7"    "carat_poly_1" "carat_poly_2"

Defining and Fitting Models: parsnip

The parsnip package has wrappers around many1 popular machine learning algorithms, and you can fit them using a unified interface. This is extremely helpful, since you have to remember only one rather then dozens of interfaces.

The models are separated into two modes/categories, namely, regression and classification (set_mode()). The model is defined using a function specific to each algorithm (e.g., linear_reg(), rand_forest()). Finally, the backend/engine/implementation is selected using set_engine().

Herein, I will start with a basic linear regression model as implemented in stats::lm().

lm_model <-
    linear_reg() %>%
    set_mode("regression") %>%
    set_engine("lm")

Furthermore, take the example of a random forest model. This could be fit using packages ranger or randomForest. Both have different interfaces (e.g., argument ntree vs. num.trees), and parsnip removes the hassle of remembering both interfaces. More general arguments pertaining to the algorithm are specified in the algorithm function (e.g., rand_forest()). Arguments specific to the engine are specified in set_engine().

rand_forest(mtry = 3, trees = 500, min_n = 5) %>%
    set_mode("regression") %>%
    set_engine("ranger", importance = "impurity_corrected")

Finally, we can fit() the model.

lm_fit1 <- fit(lm_model, price ~ ., dia_juiced)
lm_fit1
#> parsnip model object
#> 
#> Fit time:  20ms 
#> 
#> Call:
#> stats::lm(formula = formula, data = data)
#> 
#> Coefficients:
#>  (Intercept)         depth         table             x             y  
#>     7.711965      0.010871      0.005889      0.251155      0.054422  
#>            z         cut_1         cut_2         cut_3         cut_4  
#>     0.054196      0.106701     -0.026356      0.024207     -0.006191  
#>      color_1       color_2       color_3       color_4       color_5  
#>    -0.455831     -0.084108     -0.004810      0.009725     -0.005591  
#>      color_6     clarity_1     clarity_2     clarity_3     clarity_4  
#>    -0.009730      0.860961     -0.242698      0.132234     -0.052903  
#>    clarity_5     clarity_6     clarity_7  carat_poly_1  carat_poly_2  
#>     0.028996      0.002403      0.022235     51.663971    -17.508316

You can use fit() with a formula (e.g., price ~ .) or by specifying x and y. In both cases, I recommend keeping only the variables you need when preparing the data set, since this will prevent forgetting the new variable d when using y ~ a + b + c. Unnecessary variables can easily be dropped in the recipe using step_rm().

Summarizing Fitted Models: broom

Many models have implemented summary() or coef() methods. However, the output of these is usually not in a tidy format, and the broom package has the aim to resolve this issue.

glance() gives us information about the whole model. Here, R squared is pretty high and the RMSE equals 0.154.

glance(lm_fit1$fit)
#> # A tibble: 1 x 11
#>   r.squared adj.r.squared sigma statistic p.value    df logLik    AIC    BIC deviance
#>       <dbl>         <dbl> <dbl>     <dbl>   <dbl> <int>  <dbl>  <dbl>  <dbl>    <dbl>
#> 1     0.977         0.977 0.154     9607.       0    25  2457. -4863. -4691.     127.
#> # ... with 1 more variable: df.residual <int>

tidy() gives us information about the model parameters, and we see that we have a significant quadratic effect of carat.

tidy(lm_fit1) %>% 
  arrange(desc(abs(statistic)))
#> # A tibble: 25 x 5
#>    term         estimate std.error statistic  p.value
#>    <chr>           <dbl>     <dbl>     <dbl>    <dbl>
#>  1 (Intercept)    7.71     0.00431   1790.   0.      
#>  2 carat_poly_2 -17.5      0.259      -67.7  0.      
#>  3 clarity_1      0.861    0.0135      64.0  0.      
#>  4 color_1       -0.456    0.00738    -61.8  0.      
#>  5 carat_poly_1  51.7      1.10        47.0  0.      
#>  6 clarity_2     -0.243    0.0127     -19.2  3.63e-79
#>  7 color_2       -0.0841   0.00665    -12.7  3.50e-36
#>  8 clarity_3      0.132    0.0107      12.3  2.51e-34
#>  9 cut_1          0.107    0.00997     10.7  1.88e-26
#> 10 clarity_4     -0.0529   0.00839     -6.30 3.15e-10
#> # ... with 15 more rows

Finally, augment() can be used to get model predictions, residuals, etc.

lm_predicted <- augment(lm_fit1$fit, data = dia_juiced) %>% 
    rowid_to_column()
select(lm_predicted, rowid, price, .fitted:.std.resid)
#> # A tibble: 5,395 x 9
#>    rowid price .fitted .se.fit  .resid    .hat .sigma   .cooksd .std.resid
#>    <int> <dbl>   <dbl>   <dbl>   <dbl>   <dbl>  <dbl>     <dbl>      <dbl>
#>  1     1  5.83    5.90 0.0133  -0.0769 0.00744  0.154 0.0000755     -0.502
#>  2     2  5.84    5.80 0.0110   0.0472 0.00507  0.154 0.0000193      0.307
#>  3     3  5.86    5.89 0.0138  -0.0286 0.00801  0.154 0.0000112     -0.187
#>  4     4  5.88    6.24 0.00988 -0.360  0.00413  0.154 0.000910      -2.34 
#>  5     5  6.00    6.24 0.0104  -0.244  0.00458  0.154 0.000467      -1.59 
#>  6     6  6.00    6.06 0.0100  -0.0615 0.00427  0.154 0.0000275     -0.401
#>  7     7  6.00    6.10 0.0111  -0.0999 0.00520  0.154 0.0000887     -0.651
#>  8     8  6.00    6.06 0.0104  -0.0566 0.00459  0.154 0.0000251     -0.369
#>  9     9  6.00    6.22 0.00895 -0.216  0.00339  0.154 0.000270      -1.41 
#> 10    10  6.32    6.55 0.00872 -0.233  0.00321  0.154 0.000297      -1.52 
#> # ... with 5,385 more rows

A plot of the predicted vs. actual prices shows small residuals with a few outliers, which are not well explained by the model.

ggplot(lm_predicted, aes(.fitted, price)) +
    geom_point(alpha = .2) +
    ggrepel::geom_label_repel(aes(label = rowid), 
                              data = filter(lm_predicted, abs(.resid) > 2)) +
    labs(title = "Actual vs. Predicted Price of Diamonds")

Evaluating Model Performance: yardstick

We already saw performance measures RMSE and R squared in the output of glance() above. The yardstick package is specifically designed for such measures for both numeric and categorical outcomes, and it plays well with grouped predictions (e.g., from cross-validation).

Let’s use rsample, parsnip, and yardstick for cross-validation to get a more accurate estimation of RMSE.

In the following pipeline, the model is fit() separately to the three analysis data sets, and then the fitted models are used to predict() on the three corresponding assessment data sets (i.e., 3-fold cross-validation).

Before that, analysis() and assessment() are used to extract the respective folds from dia_vfold. Furthermore, the recipe dia_rec is prepped (i.e., trained) using the analysis data of each fold, and this prepped recipe is then applied to the assessment data of each fold using bake(). Preparing the recipe separately for each fold (rather than once for the whole training data set dia_train) guards against data leakage.

The code in the following chunk makes use of list columns to store all information about the three folds in a single tibble lm_fit2, and a combination of dplyr::mutate() and purrr::map() is used to “loop” across the three rows of the tibble.

# How does dia_vfold look?
dia_vfold
#> #  3-fold cross-validation using stratification 
#> # A tibble: 3 x 2
#>   splits              id   
#>   <named list>        <chr>
#> 1 <split [3.6K/1.8K]> Fold1
#> 2 <split [3.6K/1.8K]> Fold2
#> 3 <split [3.6K/1.8K]> Fold3

# Extract analysis/training and assessment/testing data
lm_fit2 <- mutate(dia_vfold,
                  df_ana = map (splits,  analysis),
                  df_ass = map (splits,  assessment))
lm_fit2
#> #  3-fold cross-validation using stratification 
#> # A tibble: 3 x 4
#>   splits              id    df_ana                df_ass               
#> * <named list>        <chr> <named list>          <named list>         
#> 1 <split [3.6K/1.8K]> Fold1 <tibble [3,596 x 10]> <tibble [1,799 x 10]>
#> 2 <split [3.6K/1.8K]> Fold2 <tibble [3,596 x 10]> <tibble [1,799 x 10]>
#> 3 <split [3.6K/1.8K]> Fold3 <tibble [3,598 x 10]> <tibble [1,797 x 10]>

lm_fit2 <- 
    lm_fit2 %>% 
    # prep, juice, bake
    mutate(
        recipe = map (df_ana, ~prep(dia_rec, training = .x)),
        df_ana = map (recipe,  juice),
        df_ass = map2(recipe, 
                      df_ass, ~bake(.x, new_data = .y))) %>% 
    # fit
    mutate(
        model_fit  = map(df_ana, ~fit(lm_model, price ~ ., data = .x))) %>% 
    # predict
    mutate(
        model_pred = map2(model_fit, df_ass, ~predict(.x, new_data = .y)))
  
select(lm_fit2, id, recipe:model_pred)
#> # A tibble: 3 x 4
#>   id    recipe       model_fit    model_pred          
#>   <chr> <named list> <named list> <named list>        
#> 1 Fold1 <recipe>     <fit[+]>     <tibble [1,799 x 1]>
#> 2 Fold2 <recipe>     <fit[+]>     <tibble [1,799 x 1]>
#> 3 Fold3 <recipe>     <fit[+]>     <tibble [1,797 x 1]>

Note that the cross-validation code above is a bit lengthy—for didactic purposes. Wrapper functions (e.g., from the tune package) would lead to more concise code: fit_resamples() could be used here (without tuning), or tune_grid() and tune_bayes() (see below).

Now, we can extract the actual prices price from the assessment data and compare them to the predicted prices .pred. Then, the yardstick package comes into play: The function metrics() calculates three different metrics for numeric outcomes. Furthermore, it automatically recognizes that lm_preds is grouped by folds and thus calculates the metrics for each fold.

Across the three folds, we see that the RMSE is a little higher and R squared a little smaller compared to above (see output of glance(lm_fit1$fit)). This is expected, since out-of-sample prediction is harder but also way more useful.

lm_preds <- 
    lm_fit2 %>% 
    mutate(res = map2(df_ass, model_pred, ~data.frame(price = .x$price,
                                                      .pred = .y$.pred))) %>% 
    select(id, res) %>% 
    tidyr::unnest(res) %>% 
    group_by(id)
lm_preds
#> # A tibble: 5,395 x 3
#> # Groups:   id [3]
#>    id    price .pred
#>    <chr> <dbl> <dbl>
#>  1 Fold1  5.84  5.83
#>  2 Fold1  6.00  6.25
#>  3 Fold1  6.00  6.05
#>  4 Fold1  6.32  6.56
#>  5 Fold1  6.32  6.31
#>  6 Fold1  7.92  7.73
#>  7 Fold1  7.93  7.58
#>  8 Fold1  7.93  7.80
#>  9 Fold1  7.93  7.88
#> 10 Fold1  7.94  7.91
#> # ... with 5,385 more rows

metrics(lm_preds, truth = price, estimate = .pred)
#> # A tibble: 9 x 4
#>   id    .metric .estimator .estimate
#>   <chr> <chr>   <chr>          <dbl>
#> 1 Fold1 rmse    standard       0.168
#> 2 Fold2 rmse    standard       0.147
#> 3 Fold3 rmse    standard       0.298
#> 4 Fold1 rsq     standard       0.973
#> 5 Fold2 rsq     standard       0.979
#> 6 Fold3 rsq     standard       0.918
#> 7 Fold1 mae     standard       0.116
#> 8 Fold2 mae     standard       0.115
#> 9 Fold3 mae     standard       0.110

Note that metrics() has default measures for numeric and categorical outcomes, and here RMSE, R squared, and the mean absolute difference (MAE) are returned. You could also use one metric directly like rmse() or define a custom set of metrics via metric_set().

Tuning Model Parameters: tune and dials

Let’s get a little bit more involved and do some hyperparameter tuning. We turn to a different model, namely, a random forest model.

The tune package has functions for doing the actual tuning (e.g., via grid search), while all the parameters and their defaults (e.g., mtry(), neighbors()) are implemented in dials. Thus, the two packages can almost only be used in combination.

Preparing a parsnip Model for Tuning

First, I want to tune the mtry parameter of a random forest model. Thus, the model is defined using parsnip as above. However, rather than using a default value (i.e., mtry = NULL) or one specific value (i.e., mtry = 3), we use tune() as a placeholder and let cross-validation decide on the best value for mtry later on.

As the output indicates, the default minimum of mtry is 1 and the maximum depends on the data.

rf_model <- 
    rand_forest(mtry = tune()) %>%
    set_mode("regression") %>%
    set_engine("ranger")

parameters(rf_model)
#> Collection of 1 parameters for tuning
#> 
#>    id parameter type object class
#>  mtry           mtry    nparam[?]
#> 
#> Model parameters needing finalization:
#>    # Randomly Selected Predictors ('mtry')
#> 
#> See `?dials::finalize` or `?dials::update.parameters` for more information.
mtry()
#> # Randomly Selected Predictors  (quantitative)
#> Range: [1, ?]

Thus, this model is not yet ready for fitting. You can either specify the maximum for mtry yourself using update(), or you can use finalize() to let the data decide on the maximum.

rf_model %>% 
    parameters() %>% 
    update(mtry = mtry(c(1L, 5L)))
#> Collection of 1 parameters for tuning
#> 
#>    id parameter type object class
#>  mtry           mtry    nparam[+]

rf_model %>% 
    parameters() %>% 
    # Here, the maximum of mtry equals the number of predictors, i.e., 24.
    finalize(x = select(juice(prep(dia_rec)), -price)) %>% 
    pull("object")
#> [[1]]
#> # Randomly Selected Predictors  (quantitative)
#> Range: [1, 24]

Preparing Data for Tuning: recipes

The second thing I want to tune is the degree of the polynomial for the variable carat. As you saw in the plot above, polynomials up to a degree of four seemed well suited for the data. However, a simpler model might do equally well, and we want to use cross-validation to decide on the degree that works best.

Similar to tuning parameters in a model, certain aspects of a recipe can be tuned. Let’s define a second recipe and use tune() inside step_poly().

# Note that this recipe cannot be prepped (and juiced), since "degree" is a
# tuning parameter
dia_rec2 <-
    recipe(price ~ ., data = dia_train) %>%
    step_log(all_outcomes()) %>%
    step_normalize(all_predictors(), -all_nominal()) %>%
    step_dummy(all_nominal()) %>%
    step_poly(carat, degree = tune())

dia_rec2 %>% 
    parameters() %>% 
    pull("object")
#> [[1]]
#> Polynomial Degree  (quantitative)
#> Range: [1, 3]

Combine Everything: workflows

The workflows package is designed to bundle together different parts of a machine learning pipeline like a recipe or a model.

First, let’s create an initial workflow and add the recipe and the random forest model, both of which have a tuning parameter.

rf_wflow <-
    workflow() %>%
    add_model(rf_model) %>%
    add_recipe(dia_rec2)
rf_wflow
#> == Workflow ==============================================================================
#> Preprocessor: Recipe
#> Model: rand_forest()
#> 
#> -- Preprocessor --------------------------------------------------------------------------
#> 4 Recipe Steps
#> 
#> * step_log()
#> * step_normalize()
#> * step_dummy()
#> * step_poly()
#> 
#> -- Model ---------------------------------------------------------------------------------
#> Random Forest Model Specification (regression)
#> 
#> Main Arguments:
#>   mtry = tune()
#> 
#> Computational engine: ranger

Second, we need to update the parameters in rf_wflow, because the maximum of mtry is not yet known and the maximum of degree should be four (while three is the default).

rf_param <-
    rf_wflow %>%
    parameters() %>%
    update(mtry = mtry(range = c(3L, 5L)),
           degree = degree_int(range = c(2L, 4L)))
rf_param$object
#> [[1]]
#> # Randomly Selected Predictors  (quantitative)
#> Range: [3, 5]
#> 
#> [[2]]
#> Polynomial Degree  (quantitative)
#> Range: [2, 4]

Third, we want to use cross-validation for tuning, that is, to select the best combination of the hyperparameters. Bayesian optimization (see https://tidymodels.github.io/tune/) is recommended for complex tuning problems, and this can be done using tune_bayes().

Herein, however, grid search will suffice. To this end, let’s create a grid of all necessary parameter combinations.

rf_grid <- grid_regular(rf_param, levels = 3)
rf_grid
#> # A tibble: 9 x 2
#>    mtry degree
#>   <int>  <int>
#> 1     3      2
#> 2     4      2
#> 3     5      2
#> 4     3      3
#> 5     4      3
#> 6     5      3
#> 7     3      4
#> 8     4      4
#> 9     5      4

Cross-validation and hyperparameter tuning can involve fitting many models. Herein, for example, we have to fit 3 x 9 models (folds x parameter combinations). To increase speed, we can fit the models in parallel. This is directly supported by the tune package (see https://tidymodels.github.io/tune/).

library("doFuture")
all_cores <- parallel::detectCores(logical = FALSE) - 1

registerDoFuture()
cl <- makeCluster(all_cores)
plan(future::cluster, workers = cl)

Then, we can finally start tuning.

rf_search <- tune_grid(rf_wflow, grid = rf_grid, resamples = dia_vfold,
                       param_info = rf_param)

The results can be examined using autoplot() and show_best():

autoplot(rf_search, metric = "rmse") +
    labs(title = "Results of Grid Search for Two Tuning Parameters of a Random Forest")

show_best(rf_search, "rmse", n = 9)
#> # A tibble: 9 x 7
#>    mtry degree .metric .estimator  mean     n std_err
#>   <int>  <int> <chr>   <chr>      <dbl> <int>   <dbl>
#> 1     5      2 rmse    standard   0.121     3 0.00498
#> 2     5      3 rmse    standard   0.121     3 0.00454
#> 3     4      2 rmse    standard   0.122     3 0.00463
#> 4     5      4 rmse    standard   0.122     3 0.00471
#> 5     4      3 rmse    standard   0.123     3 0.00469
#> 6     4      4 rmse    standard   0.124     3 0.00496
#> 7     3      3 rmse    standard   0.128     3 0.00502
#> 8     3      2 rmse    standard   0.128     3 0.00569
#> 9     3      4 rmse    standard   0.128     3 0.00501

select_best(rf_search, metric = "rmse")
#> # A tibble: 1 x 2
#>    mtry degree
#>   <int>  <int>
#> 1     5      2

select_by_one_std_err(rf_search, mtry, degree, metric = "rmse")
#> # A tibble: 1 x 9
#>    mtry degree .metric .estimator  mean     n std_err .best .bound
#>   <int>  <int> <chr>   <chr>      <dbl> <int>   <dbl> <dbl>  <dbl>
#> 1     4      2 rmse    standard   0.122     3 0.00463 0.121  0.126

With a cross-validation RMSE of ca. 0.12, the random forest model seems to outperform the linear regression from above. Furthermore, 0.12 is (hopefully) a realistic estimate of the out-of-sample error.

Selecting the Best Model to Make the Final Predictions

We saw above that a quadratic trend was enough to get a good model. Furthermore, cross-validation revealed that mtry = 4 seems to perform well.

To use this combination of hyperparameters, we fit() the corresponding model (or workflow, more precisely) on the whole training data set dia_train.

rf_param_final <- select_by_one_std_err(rf_search, mtry, degree,
                                        metric = "rmse")

rf_wflow_final <- finalize_workflow(rf_wflow, rf_param_final)

rf_wflow_final_fit <- fit(rf_wflow_final, data = dia_train)

Now, we want to use this to predict() on data never seen before, namely, dia_test. Unfortunately, predict(rf_wflow_final_fit, new_data = dia_test) does not work in the present case, because the outcome is modified in the recipe via step_log().2

Thus, we need a little workaround: The prepped recipe is extracted from the workflow, and this can then be used to bake() the testing data. This baked data set together with the extracted model can then be used for the final predictions.

dia_rec3     <- pull_workflow_prepped_recipe(rf_wflow_final_fit)
rf_final_fit <- pull_workflow_fit(rf_wflow_final_fit)

dia_test$.pred <- predict(rf_final_fit, 
                          new_data = bake(dia_rec3, dia_test))$.pred
dia_test$logprice <- log(dia_test$price)

metrics(dia_test, truth = logprice, estimate = .pred)
#> # A tibble: 3 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 rmse    standard      0.113 
#> 2 rsq     standard      0.988 
#> 3 mae     standard      0.0846

As you can see, we get an RMSE of 0.11 on the testing data, which is even slightly better compared to the cross-validation RMSE.

Summary

The tidymodels ecosystem bundles together a set of packages that work hand in hand to solve machine-learning problems from start to end. Together with the data-wrangling facilities in the tidyverse and the plotting tools from ggplot2, this makes for a rich toolbox for every data scientist working with R.

The only thing that is definitely missing in tidymodels is a package for combining different machine learning models (i.e., ensemble/stacking/super learner). We have caretEnsemble for caret, and I am sure they are working on something similar for tidymodels at RStudio. Alex Hayes has a related blog post focusing on tidymodels, for those who can’t wait.

Further Resources

Session Info

sessioninfo::session_info()
#> - Session info -------------------------------------------------------------------------
#>  setting  value                       
#>  version  R version 3.6.2 (2019-12-12)
#>  os       Windows 10 x64              
#>  system   x86_64, mingw32             
#>  ui       RTerm                       
#>  language (EN)                        
#>  collate  German_Germany.1252         
#>  ctype    German_Germany.1252         
#>  tz       Europe/Berlin               
#>  date     2020-04-07                  
#> 
#> - Packages -----------------------------------------------------------------------------
#>  package       * version     date       lib source                      
#>  assertthat      0.2.1       2019-03-21 [1] CRAN (R 3.6.2)              
#>  backports       1.1.6       2020-04-05 [1] CRAN (R 3.6.2)              
#>  base64enc       0.1-3       2015-07-28 [1] CRAN (R 3.6.0)              
#>  bayesplot       1.7.1       2019-12-01 [1] CRAN (R 3.6.2)              
#>  blogdown        0.18        2020-03-04 [1] CRAN (R 3.6.3)              
#>  bookdown        0.18        2020-03-05 [1] CRAN (R 3.6.3)              
#>  boot            1.3-23      2019-07-05 [2] CRAN (R 3.6.2)              
#>  broom         * 0.5.5       2020-02-29 [1] CRAN (R 3.6.3)              
#>  callr           3.4.3       2020-03-28 [1] CRAN (R 3.6.2)              
#>  class           7.3-15      2019-01-01 [2] CRAN (R 3.6.2)              
#>  cli             2.0.2       2020-02-28 [1] CRAN (R 3.6.3)              
#>  codetools       0.2-16      2018-12-24 [2] CRAN (R 3.6.2)              
#>  colorspace      1.4-1       2019-03-18 [1] CRAN (R 3.6.1)              
#>  colourpicker    1.0         2017-09-27 [1] CRAN (R 3.6.2)              
#>  conflicted    * 1.0.4       2019-06-21 [1] CRAN (R 3.6.3)              
#>  corrplot      * 0.84        2017-10-16 [1] CRAN (R 3.6.3)              
#>  crayon          1.3.4       2017-09-16 [1] CRAN (R 3.6.2)              
#>  crosstalk       1.1.0.1     2020-03-13 [1] CRAN (R 3.6.3)              
#>  dials         * 0.0.6       2020-04-03 [1] CRAN (R 3.6.2)              
#>  DiceDesign      1.8-1       2019-07-31 [1] CRAN (R 3.6.2)              
#>  digest          0.6.25      2020-02-23 [1] CRAN (R 3.6.3)              
#>  dplyr         * 0.8.5       2020-03-07 [1] CRAN (R 3.6.3)              
#>  DT              0.13        2020-03-23 [1] CRAN (R 3.6.3)              
#>  dygraphs        1.1.1.6     2018-07-11 [1] CRAN (R 3.6.2)              
#>  ellipsis        0.3.0       2019-09-20 [1] CRAN (R 3.6.2)              
#>  evaluate        0.14        2019-05-28 [1] CRAN (R 3.6.2)              
#>  fansi           0.4.1       2020-01-08 [1] CRAN (R 3.6.2)              
#>  farver          2.0.3       2020-01-16 [1] CRAN (R 3.6.2)              
#>  fastmap         1.0.1       2019-10-08 [1] CRAN (R 3.6.2)              
#>  foreach         1.5.0       2020-03-30 [1] CRAN (R 3.6.2)              
#>  furrr           0.1.0       2018-05-16 [1] CRAN (R 3.6.2)              
#>  future          1.16.0      2020-01-16 [1] CRAN (R 3.6.2)              
#>  generics        0.0.2       2018-11-29 [1] CRAN (R 3.6.2)              
#>  ggplot2       * 3.3.0       2020-03-05 [1] CRAN (R 3.6.3)              
#>  ggrepel       * 0.8.2       2020-03-08 [1] CRAN (R 3.6.3)              
#>  ggridges        0.5.2       2020-01-12 [1] CRAN (R 3.6.2)              
#>  globals         0.12.5      2019-12-07 [1] CRAN (R 3.6.1)              
#>  glue            1.4.0       2020-04-03 [1] CRAN (R 3.6.2)              
#>  gower           0.2.1       2019-05-14 [1] CRAN (R 3.6.1)              
#>  GPfit           1.0-8       2019-02-08 [1] CRAN (R 3.6.2)              
#>  gridExtra       2.3         2017-09-09 [1] CRAN (R 3.6.2)              
#>  gtable          0.3.0       2019-03-25 [1] CRAN (R 3.6.2)              
#>  gtools          3.8.2       2020-03-31 [1] CRAN (R 3.6.3)              
#>  hardhat         0.1.2       2020-02-28 [1] CRAN (R 3.6.3)              
#>  htmltools       0.4.0       2019-10-04 [1] CRAN (R 3.6.2)              
#>  htmlwidgets     1.5.1       2019-10-08 [1] CRAN (R 3.6.2)              
#>  httpuv          1.5.2       2019-09-11 [1] CRAN (R 3.6.2)              
#>  igraph          1.2.5       2020-03-19 [1] CRAN (R 3.6.3)              
#>  infer         * 0.5.1       2019-11-19 [1] CRAN (R 3.6.2)              
#>  inline          0.3.15      2018-05-18 [1] CRAN (R 3.6.2)              
#>  ipred           0.9-9       2019-04-28 [1] CRAN (R 3.6.2)              
#>  iterators       1.0.12      2019-07-26 [1] CRAN (R 3.6.2)              
#>  janeaustenr     0.1.5       2017-06-10 [1] CRAN (R 3.6.2)              
#>  knitr           1.28        2020-02-06 [1] CRAN (R 3.6.2)              
#>  labeling        0.3         2014-08-23 [1] CRAN (R 3.6.0)              
#>  later           1.0.0       2019-10-04 [1] CRAN (R 3.6.2)              
#>  lattice         0.20-38     2018-11-04 [2] CRAN (R 3.6.2)              
#>  lava            1.6.7       2020-03-05 [1] CRAN (R 3.6.3)              
#>  lhs             1.0.1       2019-02-03 [1] CRAN (R 3.6.2)              
#>  lifecycle       0.2.0       2020-03-06 [1] CRAN (R 3.6.3)              
#>  listenv         0.8.0       2019-12-05 [1] CRAN (R 3.6.2)              
#>  lme4            1.1-21      2019-03-05 [1] CRAN (R 3.6.2)              
#>  loo             2.2.0       2019-12-19 [1] CRAN (R 3.6.2)              
#>  lubridate       1.7.4       2018-04-11 [1] CRAN (R 3.6.2)              
#>  magrittr        1.5         2014-11-22 [1] CRAN (R 3.6.2)              
#>  markdown        1.1         2019-08-07 [1] CRAN (R 3.6.2)              
#>  MASS            7.3-51.4    2019-03-31 [2] CRAN (R 3.6.2)              
#>  Matrix          1.2-18      2019-11-27 [2] CRAN (R 3.6.2)              
#>  matrixStats     0.56.0      2020-03-13 [1] CRAN (R 3.6.3)              
#>  memoise         1.1.0       2017-04-21 [1] CRAN (R 3.6.2)              
#>  mgcv            1.8-31      2019-11-09 [2] CRAN (R 3.6.2)              
#>  mime            0.9         2020-02-04 [1] CRAN (R 3.6.2)              
#>  miniUI          0.1.1.1     2018-05-18 [1] CRAN (R 3.6.2)              
#>  minqa           1.2.4       2014-10-09 [1] CRAN (R 3.6.2)              
#>  munsell         0.5.0       2018-06-12 [1] CRAN (R 3.6.2)              
#>  nlme            3.1-142     2019-11-07 [2] CRAN (R 3.6.2)              
#>  nloptr          1.2.2.1     2020-03-11 [1] CRAN (R 3.6.3)              
#>  nnet            7.3-12      2016-02-02 [2] CRAN (R 3.6.2)              
#>  parsnip       * 0.0.5       2020-01-07 [1] CRAN (R 3.6.2)              
#>  pillar          1.4.3       2019-12-20 [1] CRAN (R 3.6.2)              
#>  pkgbuild        1.0.6       2019-10-09 [1] CRAN (R 3.6.2)              
#>  pkgconfig       2.0.3       2019-09-22 [1] CRAN (R 3.6.2)              
#>  plyr            1.8.6       2020-03-03 [1] CRAN (R 3.6.3)              
#>  prettyunits     1.1.1       2020-01-24 [1] CRAN (R 3.6.2)              
#>  pROC            1.16.2      2020-03-19 [1] CRAN (R 3.6.3)              
#>  processx        3.4.2       2020-02-09 [1] CRAN (R 3.6.2)              
#>  prodlim         2019.11.13  2019-11-17 [1] CRAN (R 3.6.2)              
#>  promises        1.1.0       2019-10-04 [1] CRAN (R 3.6.2)              
#>  ps              1.3.2       2020-02-13 [1] CRAN (R 3.6.2)              
#>  purrr         * 0.3.3       2019-10-18 [1] CRAN (R 3.6.2)              
#>  R6              2.4.1       2019-11-12 [1] CRAN (R 3.6.2)              
#>  ranger          0.12.1      2020-01-10 [1] CRAN (R 3.6.2)              
#>  Rcpp            1.0.4       2020-03-17 [1] CRAN (R 3.6.3)              
#>  recipes       * 0.1.10      2020-03-18 [1] CRAN (R 3.6.3)              
#>  reshape2        1.4.3       2017-12-11 [1] CRAN (R 3.6.2)              
#>  rlang           0.4.5.9000  2020-03-25 [1] Github (r-lib/rlang@a90b04b)
#>  rmarkdown       2.1         2020-01-20 [1] CRAN (R 3.6.2)              
#>  rpart           4.1-15      2019-04-12 [2] CRAN (R 3.6.2)              
#>  rsample       * 0.0.6       2020-03-31 [1] CRAN (R 3.6.2)              
#>  rsconnect       0.8.16      2019-12-13 [1] CRAN (R 3.6.2)              
#>  rstan           2.19.3      2020-02-11 [1] CRAN (R 3.6.2)              
#>  rstanarm        2.19.3      2020-02-11 [1] CRAN (R 3.6.2)              
#>  rstantools      2.0.0       2019-09-15 [1] CRAN (R 3.6.2)              
#>  rstudioapi      0.11        2020-02-07 [1] CRAN (R 3.6.2)              
#>  scales        * 1.1.0       2019-11-18 [1] CRAN (R 3.6.2)              
#>  sessioninfo     1.1.1       2018-11-05 [1] CRAN (R 3.6.2)              
#>  shiny           1.4.0.2     2020-03-13 [1] CRAN (R 3.6.3)              
#>  shinyjs         1.1         2020-01-13 [1] CRAN (R 3.6.2)              
#>  shinystan       2.5.0       2018-05-01 [1] CRAN (R 3.6.2)              
#>  shinythemes     1.1.2       2018-11-06 [1] CRAN (R 3.6.2)              
#>  SnowballC       0.7.0       2020-04-01 [1] CRAN (R 3.6.3)              
#>  StanHeaders     2.21.0-1    2020-01-19 [1] CRAN (R 3.6.2)              
#>  stringi         1.4.6       2020-02-17 [1] CRAN (R 3.6.2)              
#>  stringr         1.4.0       2019-02-10 [1] CRAN (R 3.6.2)              
#>  survival        3.1-8       2019-12-03 [2] CRAN (R 3.6.2)              
#>  threejs         0.3.3       2020-01-21 [1] CRAN (R 3.6.2)              
#>  tibble        * 3.0.0       2020-03-30 [1] CRAN (R 3.6.2)              
#>  tidymodels    * 0.1.0       2020-02-16 [1] CRAN (R 3.6.2)              
#>  tidyposterior   0.0.2       2018-11-15 [1] CRAN (R 3.6.2)              
#>  tidypredict     0.4.5       2020-02-10 [1] CRAN (R 3.6.2)              
#>  tidyr           1.0.2       2020-01-24 [1] CRAN (R 3.6.2)              
#>  tidyselect      1.0.0       2020-01-27 [1] CRAN (R 3.6.2)              
#>  tidytext        0.2.3       2020-03-04 [1] CRAN (R 3.6.3)              
#>  timeDate        3043.102    2018-02-21 [1] CRAN (R 3.6.2)              
#>  tokenizers      0.2.1       2018-03-29 [1] CRAN (R 3.6.2)              
#>  tune          * 0.1.0       2020-04-02 [1] CRAN (R 3.6.3)              
#>  utf8            1.1.4       2018-05-24 [1] CRAN (R 3.6.2)              
#>  vctrs           0.2.99.9010 2020-04-03 [1] Github (r-lib/vctrs@5c69793)
#>  withr           2.1.2       2018-03-15 [1] CRAN (R 3.6.2)              
#>  workflows     * 0.1.1       2020-03-17 [1] CRAN (R 3.6.3)              
#>  xfun            0.12        2020-01-13 [1] CRAN (R 3.6.2)              
#>  xtable          1.8-4       2019-04-21 [1] CRAN (R 3.6.2)              
#>  xts             0.12-0      2020-01-19 [1] CRAN (R 3.6.2)              
#>  yaml            2.2.1       2020-02-01 [1] CRAN (R 3.6.2)              
#>  yardstick     * 0.0.6       2020-03-17 [1] CRAN (R 3.6.3)              
#>  zoo             1.8-7       2020-01-10 [1] CRAN (R 3.6.2)              
#> 
#> [1] C:/Users/hp/Documents/R/win-library/3.6
#> [2] C:/Program Files/R/R-3.6.2/library

Updates

This blog post was modified on 2012-04-02. In the previous version, cross-validation in the section Evaluating Model Performance above was done using a recipe that was prepped on the whole training data. This could potentially lead to data leakage, as pointed out by Aaron R. Williams. Thus, this was changed and the new version uses a recipe that is prepped separately for each fold.


  1. For a list of models available via parsnip, see https://tidymodels.github.io/parsnip/articles/articles/Models.html.↩︎

  2. Using (a workflow with) a recipe, which modifies the outcome, for prediction would require skipping the respective recipe step at bake() time via step_log(skip = TRUE). While this is a good idea for predicting on the test data, it is a bad idea for predicting on the assessment data during cross-validation (see also https://github.com/tidymodels/workflows/issues/31).↩︎


comments powered by Disqus