Created
June 3, 2025 17:42
-
-
Save vpnagraj/59fa609c5adf47c8c7a5b156eb261be7 to your computer and use it in GitHub Desktop.
Demonstration of using "workflow sets" with tidymodels in R
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ## script to demonstrate running sets of workflows | |
| ## adapted from the workflowsets package vignette | |
| ## https://workflowsets.tidymodels.org/articles/evaluating-different-predictor-sets.html | |
| ## load packages | |
| ## NOTE: both tidymodels and tidyverse are "meta" packages ... | |
| ## ... so they will load lots of other packages under the hood | |
| library(tidymodels) | |
| library(tidyverse) | |
| ## use data from the modeldata package (from tidymodels) to demo | |
| ?modeldata::ischemic_stroke | |
| ## take a peek at the data | |
| glimpse(ischemic_stroke) | |
| ## do a bit of manipulation to select features we are modeling | |
| stroke_dat <- | |
| ischemic_stroke %>% | |
| select(stroke, age:hypertension_history) | |
| ## configure logistic regression "engine" | |
| logreg_model <- | |
| logistic_reg() %>% | |
| set_engine("glm") | |
| ## set random seed for reproducibility | |
| set.seed(123) | |
| ## create split for training and testing | |
| to_split <- initial_split(stroke_dat) | |
| ## resample the training data with cross | |
| folds <- vfold_cv(training(to_split)) | |
| ## what does the cross validation object look like? | |
| glimpse(folds) | |
| folds$splits | |
| ## create list of formulas for models that iteratively leave out each variable | |
| ## NOTE: including the full_model = TRUE for demonstration purposes (default is TRUE) ... | |
| ## ... this argument will toggle whether the original formula will be included | |
| formulas <- leave_var_out_formulas(stroke ~ ., data = stroke_dat, full_model = TRUE) | |
| ## what does this formulas object look like? | |
| glimpse(formulas) | |
| ## create a "set" of multiple worfklows | |
| ## workflow sets can include multiple model engines and/or preprocessing steps | |
| ## in this case we are creating the set so we can evaluate all different model formulas | |
| stroke_workflows <- | |
| workflow_set( | |
| preproc = formulas, | |
| models = list(logistic = logreg_model) | |
| ) | |
| ## what comes out? | |
| stroke_workflows | |
| ## given the set of workflows we can now fit the models across cross-validation folds | |
| stroke_fits <- | |
| stroke_workflows %>% | |
| workflow_map("fit_resamples", resamples = folds) | |
| ## extract accuracy values | |
| ## in this case we will keep un-summarized (i.e., not taking mean) accuracy | |
| ## this will give us a distribution from which we can get error | |
| acc_values <- | |
| stroke_fits %>% | |
| collect_metrics(summarize = FALSE) %>% | |
| filter(.metric == "accuracy") %>% | |
| mutate(wflow_id = gsub("_logistic", "", wflow_id)) | |
| ## access the full model estimate for comparison below | |
| full_model <- | |
| acc_values %>% | |
| filter(wflow_id == "everything") %>% | |
| select(full_model = .estimate, id) | |
| ## now get the accuracy values for all the other model formulations | |
| ## join with the full model accuracy values | |
| ## and compute the drop in performance compared to the full model | |
| differences <- | |
| acc_values %>% | |
| filter(wflow_id != "everything") %>% | |
| full_join(full_model, by = "id") %>% | |
| mutate(performance_drop = full_model - .estimate) | |
| ## lastly do some wrangling to make it easier to visualize | |
| ## compute summary stats like mean and standard error | |
| summary_stats <- | |
| differences %>% | |
| group_by(wflow_id) %>% | |
| summarize( | |
| std_err = sd(performance_drop) / sum(!is.na(performance_drop)), | |
| performance_drop = mean(performance_drop), | |
| lower = performance_drop - qnorm(0.975) * std_err, | |
| upper = performance_drop + qnorm(0.975) * std_err, | |
| .groups = "drop" | |
| ) %>% | |
| mutate( | |
| wflow_id = factor(wflow_id), | |
| wflow_id = reorder(wflow_id, performance_drop) | |
| ) | |
| ## and plot them with error bars | |
| ggplot(summary_stats, aes(x = performance_drop, y = wflow_id)) + | |
| geom_point(size = 2) + | |
| geom_errorbar(aes(xmin = lower, xmax = upper), width = .25) + | |
| ylab("") + | |
| theme_minimal() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment