Created
September 25, 2023 00:58
-
-
Save benjaminrich/a0b5b1e6cbd269678cd5e90a90268aa6 to your computer and use it in GitHub Desktop.
Stepwise regression
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| stepwise_forward <- function(base_fit, candidates, alpha=0.05, ...) UseMethod("stepwise_forward") | |
| stepwise_backward <- function(base_fit, candidates, alpha=0.01, ...) UseMethod("stepwise_backward") | |
| forward_step <- function(base_fit, candidates, alpha=0.05, ...) UseMethod("forward_step") | |
| backward_step <- function(base_fit, candidates, alpha=0.01, ...) UseMethod("backward_step") | |
| fit_all_models <- function(base_fit, all_formulas, ...) UseMethod("fit_all_models") | |
| model_table <- function(obj, ...) UseMethod("model_table") | |
| pvalue <- function(obj, ...) UseMethod("pvalue") | |
| selected <- function(obj, ...) UseMethod("selected") | |
| final_model <- function(obj, ...) UseMethod("final_model") | |
| pvalue.default <- function(obj, ...) attr(obj, "pvalue", exact=TRUE) | |
| selected.default <- function(obj, ...) attr(obj, "selected", exact=TRUE) | |
| final_model.default <- function(obj, ...) attr(obj, "final_model", exact=TRUE) | |
| stepwise_search <- function(base_fit, candidates, alpha, step_fn) { | |
| best_fit <- base_fit | |
| res <- list() | |
| while (TRUE) { | |
| if (length(candidates) == 0) break | |
| step <- step_fn(best_fit, candidates, alpha) | |
| res <- c(res, list(step)) | |
| if (is.null(selected(step))) break | |
| best_fit <- final_model(step) | |
| candidates <- setdiff(candidates, selected(step)) | |
| } | |
| structure(list(res), | |
| class = "stepwise_search", | |
| selected = unlist(lapply(res, selected)), | |
| final_model = best_fit | |
| ) | |
| } | |
| stepwise_forward.default <- function(base_fit, candidates, alpha=0.05) { | |
| res <- stepwise_search( | |
| base_fit = base_fit, | |
| candidates = candidates, | |
| alpha = alpha, | |
| step_fn = forward_step | |
| ) | |
| structure(setNames(res, "forward"), | |
| class = "stepwise_forward", | |
| selected = list(forward=selected(res))) | |
| } | |
| stepwise_backward.default <- function(base_fit, candidates, alpha=0.01) { | |
| res <- stepwise_search( | |
| base_fit = base_fit, | |
| candidates = candidates, | |
| alpha = alpha, | |
| step_fn = backward_step | |
| ) | |
| structure(setNames(res, "backward"), | |
| class = "stepwise_backward", | |
| selected = list(backward=selected(res))) | |
| } | |
| stepwise_backward.stepwise_forward <- function(base_fit, candidates, alpha=0.01) { | |
| if (missing(candidates)) { | |
| candidates <- unlist(selected(base_fit), use.names=F) | |
| } | |
| back_fit <- stepwise_backward(final_model(base_fit), candidates, alpha) | |
| structure(c(base_fit, back_fit), | |
| class = "stepwise_forward_backward", | |
| selected = c(selected(base_fit), selected(back_fit)), | |
| final_model = final_model(back_fit) | |
| ) | |
| } | |
| generic_step <- function( | |
| base_fit, | |
| candidates, | |
| alpha, | |
| direction = c("forward", "backward"), | |
| op = if (direction=="forward") `<` else `>=`, | |
| ... | |
| ) { | |
| direction <- match.arg(direction) | |
| all_formulas <- derive_all_formulas(base_fit, base_formula=formula(base_fit), | |
| add = if (direction=="forward") candidates else NULL, | |
| subtract = if (direction=="backward") candidates else NULL | |
| ) | |
| all_fits <- fit_all_models(base_fit, all_formulas, data=base_fit$data, ...) | |
| mtab <- model_table(all_fits, base_fit, direction=direction, sort=TRUE) | |
| pval <- pvalue(mtab) | |
| i <- if (direction=="forward") which.min(pval) else which.max(pval) | |
| if (op(pval[i], alpha)) { | |
| selected <- names(pval)[i] | |
| final_model <- all_fits[[selected]] | |
| } else { | |
| selected <- NULL | |
| final_model <- base_fit | |
| } | |
| structure(mtab, | |
| class = class(mtab), | |
| base_fit = base_fit, | |
| all_fits = all_fits, | |
| selected = selected, | |
| final_model = final_model) | |
| } | |
| forward_step.default <- function(base_fit, candidates, alpha=0.05, ...) { | |
| res <- generic_step( | |
| base_fit = base_fit, | |
| candidates = candidates, | |
| alpha = alpha, | |
| direction = "forward", | |
| ... | |
| ) | |
| structure(res, class=c("forward_step", class(res))) | |
| } | |
| backward_step.default <- function(base_fit, candidates, alpha=0.01, ...) { | |
| res <- generic_step( | |
| base_fit = base_fit, | |
| candidates = candidates, | |
| alpha = alpha, | |
| direction = "backward", | |
| ... | |
| ) | |
| structure(res, class=c("backward_step", class(res))) | |
| } | |
| get_names <- function(...) { | |
| `%||%` <- function(a, b) if (is.null(a)) b else a | |
| lapply(list(...), function(x) names(x) %||% as.character(x)) | |
| } | |
| derive_all_formulas <- function( | |
| base_fit, | |
| base_formula = formula(base_fit), | |
| add = NULL, | |
| subtract = NULL, | |
| formula_names = unlist(get_names(add, subtract)) | |
| ) { | |
| .add <- if (!is.null(add)) paste0("+", add) else NULL | |
| .subtract <- if (!is.null(subtract)) paste0("-", subtract) else NULL | |
| paste0(".~.", c(.add, .subtract)) |> | |
| lapply(update.formula, old=base_formula) |> | |
| setNames(formula_names) | |
| } | |
| fit_all_models.default <- function( | |
| base_fit, | |
| all_formulas, | |
| data = base_fit$data, | |
| model_names = names(all_formulas), | |
| ... | |
| ) { | |
| lapply(all_formulas, function(x) { | |
| update(base_fit, formula.=x, data=data, ...) | |
| }) |> setNames(model_names) | |
| } | |
| model_table.default <- function( | |
| all_fits, | |
| base_fit, | |
| alpha, | |
| direction = c("forward", "backward"), | |
| sort = TRUE, | |
| decreasing = FALSE, | |
| ... | |
| ) { | |
| f <- function(x) ifelse(direction=="forward", x, -x) | |
| mtab <- lapply(all_fits, function(x) { | |
| `-2*loglik` <- -2*as.numeric(logLik(x)) | |
| `df` <- attr(logLik(x), "df", exact=TRUE) | |
| `Base(-2*loglik)` <- -2*as.numeric(logLik(base_fit)) | |
| `Base(df)` <- attr(logLik(base_fit), "df", exact=TRUE) | |
| `Δ(-2*loglik)` <- f(`Base(-2*loglik)` - `-2*loglik`) | |
| `Δdf` <- f(`df` - `Base(df)`) | |
| `P-value` <- pchisq(`Δ(-2*loglik)`, `Δdf`, lower.tail=FALSE) | |
| data.frame(check.names=FALSE, | |
| `Model` = NA, | |
| `-2*loglik`, | |
| `df`, | |
| `Base(-2*loglik)`, | |
| `Base(df)`, | |
| `Δ(-2*loglik)`, | |
| `Δdf`, | |
| `P-value` | |
| ) | |
| }) |> do.call(what=rbind) | |
| mtab$`Model` <- names(all_fits) | |
| if (sort) { | |
| mtab <- mtab[order(mtab$`P-value`, decreasing=decreasing),] | |
| } | |
| structure(mtab, | |
| class = c("model_table", class(mtab)), | |
| all_fits = all_fits, | |
| base_fit = base_fit, | |
| pvalue = setNames(mtab$`P-value`, mtab$`Model`) | |
| ) | |
| } | |
| if (FALSE) { | |
| library(mvtnorm) | |
| set.seed(123) | |
| n <- 100 | |
| p <- 4 | |
| S <- rWishart(1, p, toeplitz(1/(1:p)))[,,1] | |
| x <- rmvnorm(n, rep(0, p), S) | |
| dat <- data.frame( | |
| x1 = x[,1], | |
| x2 = x[,2], | |
| x3 = x[,3], | |
| x4 = x[,4], | |
| x5 = rnorm(n), | |
| x6 = rnorm(n), | |
| x7 = rnorm(n), | |
| x8 = rnorm(n), | |
| x9 = rnorm(n) | |
| ) | |
| dat$y <- with(dat, 6 + 0.3*x1 + 0.1*x2 + 0.4*x3 + 0.2*x4 + rnorm(n, 0, 1.3)) | |
| base_fit <- glm(y ~ x1, data=dat) | |
| candidates <- c( | |
| "x2", | |
| "x3", | |
| "x4", | |
| "x5", | |
| "x6", | |
| "x7", | |
| "x8", | |
| "x9" | |
| ) | |
| x <- forward_step(base_fit, candidates, alpha=0.05) | |
| x | |
| final_model(x) | |
| x <- backward_step(base_fit, "x1", alpha=0.01) | |
| x | |
| final_model(x) | |
| x <- stepwise_forward(base_fit, candidates, alpha=0.05) | |
| x | |
| selected(x) | |
| final_model(x) | |
| y <- stepwise_backward(base_fit, c("x3"), alpha=0.01) | |
| y | |
| selected(y) | |
| final_model(y) | |
| y <- stepwise_backward(x, alpha=0.01) | |
| y | |
| selected(y) | |
| final_model(y) | |
| x <- base_fit |> | |
| stepwise_forward(candidates, alpha=0.05) |> | |
| stepwise_backward(alpha=0.01) | |
| x | |
| selected(x) | |
| final_model(x) | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment