Created
February 16, 2026 13:21
-
-
Save jpasquier/3a9729408e653e999fc89e69b69827ee to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| require(Rcpp) | |
| # Reject region | |
| sourceCpp(code=' | |
| #include <Rcpp.h> | |
| #include <algorithm> | |
| #include <vector> | |
| #include <cmath> | |
| using namespace Rcpp; | |
| struct ProbPair { int x1; double p; }; | |
| bool compareProb(const ProbPair& a, const ProbPair& b) { | |
| return a.p < b.p; | |
| } | |
| // [[Rcpp::export]] | |
| LogicalMatrix fisher_reject_matrix(int n1, int n2, double alpha) { | |
| int N = n1 + n2; | |
| LogicalMatrix reject_mat(n1 + 1, n2 + 1); | |
| double tol = std::sqrt(std::numeric_limits<double>::epsilon()); | |
| std::vector<ProbPair> candidates; | |
| candidates.reserve(std::min(n1, n2) + 1); | |
| for (int m = 0; m <= N; ++m) { | |
| int low = std::max(0, m - n2); | |
| int high = std::min(n1, m); | |
| if (low > high) continue; | |
| candidates.clear(); | |
| for (int x1 = low; x1 <= high; ++x1) { | |
| double p = R::dhyper(x1, m, N - m, n1, 0); | |
| candidates.push_back({x1, p}); | |
| } | |
| std::sort(candidates.begin(), candidates.end(), compareProb); | |
| double current_sum = 0.0; | |
| double cutoff_prob = -1.0; | |
| for (size_t i = 0; i < candidates.size(); ++i) { | |
| current_sum += candidates[i].p; | |
| bool is_last = true; | |
| if (i + 1 < candidates.size()) { | |
| if (std::abs(candidates[i+1].p - candidates[i].p) < | |
| candidates[i].p * tol) { | |
| is_last = false; | |
| } | |
| } | |
| if (is_last) { | |
| if (current_sum <= alpha + tol) { | |
| cutoff_prob = candidates[i].p; | |
| } else { | |
| break; | |
| } | |
| } | |
| } | |
| if (cutoff_prob > -0.5) { | |
| for (const auto& item : candidates) { | |
| if (item.p <= cutoff_prob + tol) { | |
| reject_mat(item.x1, m - item.x1) = TRUE; | |
| } | |
| } | |
| } | |
| } | |
| return reject_mat; | |
| }' | |
| ) | |
| # Compute power of Fisher's exact test (rejection probability given p1, p2, n1, | |
| # n2, alpha) | |
| fisher_power <- function(p1, p2, n1, n2, alpha = .05) { | |
| # Sanity checks | |
| stopifnot(p1 >=0, p1 <= 1, p2 >=0, p2 <= 1, n1 >= 0, n2 >= 0, alpha > 0, | |
| alpha < 1) | |
| # Rejection matrix | |
| R <- fisher_reject_matrix(n1, n2, alpha) | |
| # Joint pmf under H1 (matrix with rows x1i=0..n1i, cols x2i=0..n2i) | |
| # Use dbinom for stability | |
| p_x1 <- dbinom(0:n1, n1, p1) | |
| p_x2 <- dbinom(0:n2, n2, p2) | |
| PMF <- outer(p_x1, p_x2, `*`) | |
| # Power = sum over rejection region of joint pmf under H1 | |
| sum(PMF[R]) | |
| } | |
| # Sample size (balanced) given the probabilities and the power | |
| fisher_n <- function(p1, p2, power, alpha = .05, n_min = NULL, n_max = NULL) { | |
| # If no bounds are provided, estimate them using power.prop.test | |
| if (is.null(n_min) || is.null(n_max)) { | |
| n_prop_test <- power.prop.test(p1 = p1, p2 = p2, sig.level = alpha, | |
| power = power)$n | |
| if (is.null(n_min)) n_min <- floor(0.8 * n_prop_test) | |
| if (is.null(n_max)) n_max <- ceiling(1.25 * n_prop_test) | |
| } | |
| power_min <- fisher_power(p1, p2, n_min, n_min) | |
| if (power_min > power) { | |
| warning("Requested power cannot be reached.") | |
| return(n_min) | |
| } | |
| power_max <- fisher_power(p1, p2, n_max, n_max) | |
| if (power_max < power) { | |
| warning("Requested power cannot be reached.") | |
| return(n_max) | |
| } | |
| while (n_max - n_min > 1) { | |
| # Try linear interpolation, but avoid division by very small numbers | |
| if (abs(power_max - power_min) > 1e-10) { | |
| n_interp <- n_min + (power - power_min) * (n_max - n_min) / | |
| (power_max - power_min) | |
| n_new <- round(n_interp) | |
| n_new <- max(n_min + 1, min(n_max - 1, n_new)) | |
| } else { | |
| # Fallback to bisection if powers are too similar | |
| n_new <- round((n_min + n_max) / 2) | |
| } | |
| power_new <- fisher_power(p1, p2, n_new, n_new) | |
| if (power_new >= power) { | |
| n_max <- n_new | |
| power_max <- power_new | |
| } else { | |
| n_min <- n_new | |
| power_min <- power_new | |
| } | |
| } | |
| return(n_max) | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment