Last active
October 13, 2025 14:36
-
-
Save jrosell/941b10596d496f5d56f49a192f1a8343 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| library(rextendr) | |
| library(mirai) | |
| rextendr::rust_source( | |
| file = "rust/xymatch_loop.rs", | |
| module = "xymatch_loop", | |
| profile = "release" | |
| ) | |
| rextendr::rust_source( | |
| file = "rust/xymatch_binary.rs", | |
| module = "xymatch_binary", | |
| profile = "release", | |
| ) | |
| rextendr::rust_source( | |
| file = "rust/xymatch_binary_par.rs", | |
| module = "xymatch_binary_par", | |
| profile = "release", | |
| dependencies = list(`rayon` = "1.11") | |
| ) | |
| library(Rcpp) | |
| old_cxx <- Sys.getenv("PKG_CXXFLAGS") | |
| old_lib <- Sys.getenv("PKG_LIBS") | |
| Sys.setenv(PKG_CXXFLAGS = "-fopenmp -O3") | |
| Sys.setenv(PKG_LIBS = "-fopenmp") | |
| Rcpp::sourceCpp("xymatch_rcpp.cpp") | |
| Sys.setenv(PKG_CXXFLAGS = old_cxx) | |
| Sys.setenv(PKG_LIBS = old_lib) | |
| n_threads_test() | |
| # [1] 20 | |
| xymatch_r_2010 <- function(x, y) { | |
| x[ | |
| apply( | |
| outer(x, y, function(u, v) abs(u - v)), | |
| 1, | |
| function(d) which(d == min(d))[1] | |
| ) | |
| ] | |
| } | |
| xymatch_r_2010_par <- function(x, y, nthreads = 1) { | |
| x[ | |
| apply( | |
| outer_abs_diff_par(x, y, nthreads = nthreads), | |
| 1, | |
| function(d) which(d == min(d))[1] | |
| ) | |
| ] | |
| } | |
| xymatch_r_vapply <- function(x, y) { | |
| vapply(seq_along(x), \(i) x[which.min(abs(x[i]-y))], 1) | |
| } | |
| xymatch_r_loop <- function(x, y) { | |
| idx <- integer(N) | |
| for (i in seq_len(N)) { | |
| idx[i] <- which.min(abs(x[i]-y)) | |
| } | |
| x[idx] | |
| } | |
| xymatch_r_binary <- function(x, y) { | |
| order <- order(y) | |
| y <- y[order] | |
| out <- findInterval(x, y, checkSorted = FALSE, checkNA = FALSE) | |
| n <- length(y) | |
| idx <- integer(n) + n | |
| idx[out == 0] <- 1L | |
| nx <- (out != 0 & (out != n)) | |
| one <- abs(y[out[nx]] - x[nx]) | |
| two <- abs(y[out[nx] + 1] - x[nx]) | |
| idx[nx] <- out[nx] + (one > two) | |
| x[order][idx] | |
| } | |
| # Corrected Wrapper function for parallel execution | |
| xymatch_r_binary_par <- function(x, y, nchunks = 1) { | |
| order <- order(y) | |
| y <- y[order] | |
| mirai::daemons(nchunks) | |
| mirai::everywhere({}, y_val = y) | |
| if (nchunks < 1) stop() | |
| if(nchunks == 1) groups <- x | |
| if(nchunks > 1) { | |
| groups <- cut( | |
| seq_along(x), | |
| breaks = nchunks, | |
| labels = FALSE, | |
| include.lowest = TRUE | |
| ) | |
| } | |
| x_chunks <- split(x, groups) | |
| idx_chunks_mirai <- mirai_map( | |
| x_chunks, | |
| .f = function(chunk) { | |
| out <- findInterval(chunk, y_val, checkSorted = TRUE, checkNA = FALSE) | |
| n_y <- length(y_val) | |
| n_chunk <- length(chunk) | |
| idx <- integer(n_chunk) + n_y | |
| idx[out == 0] <- 1L | |
| nx <- (out != 0 & (out != n_y)) | |
| if (any(nx)) { | |
| one <- abs(y_val[out[nx]] - chunk[nx]) | |
| two <- abs(y_val[out[nx] + 1] - chunk[nx]) | |
| idx[nx] <- out[nx] + (one > two) | |
| } | |
| return(idx) | |
| } | |
| ) | |
| idx <- unlist(idx_chunks_mirai[.flat]) | |
| x[order][idx] | |
| } | |
| N <- 1e4 | |
| set.seed(1) | |
| x <- runif(N) | |
| y <- runif(N) | |
| bench::mark( | |
| r_2010 = xymatch_r_2010(x, y), | |
| r_2010_par = xymatch_r_2010_par(x, y, nthreads = 8), | |
| r_binary = xymatch_r_binary(x, y), | |
| r_vapply = xymatch_r_loop(x, y), | |
| r_loop = xymatch_r_vapply(x, y), | |
| r_binary_par = xymatch_r_binary_par(x, y, nchunks = 8), | |
| cpp_loop = xymatch_cpp_loop(x, y), | |
| cpp_loop_par = xymatch_cpp_loop_par(x, y, nthreads = 8), | |
| cpp_ranges = xymatch_cpp_ranges(x, y), | |
| cpp_range_par = xymatch_cpp_range_par(x, y, nthreads = 8), | |
| cpp_binary = xymatch_cpp_binary(x, y), | |
| cpp_binary_par = xymatch_cpp_binary_par(x, y, nthreads = 8), | |
| rust_loop = xymatch_rust_loop(x, y), | |
| rust_binary = xymatch_rust_binary(x, y), | |
| rust_binary_par = xymatch_rust_binary_par(x, y, nthreads = 8), | |
| relative = TRUE, | |
| min_time = 10, | |
| check = TRUE | |
| )[, c("expression", "median", "total_time", "n_itr")] | |
| # # A tibble: 15 × 4 | |
| # expression median total_time n_itr | |
| # <bch:expr> <dbl> <bch:tm> <int> | |
| # 1 r_2010 3403. 10.2s 4 | |
| # 2 r_2010_par 2594. 11.7s 6 | |
| # 3 r_binary 1.97 10s 6661 | |
| # 4 r_vapply 256. 10s 47 | |
| # 5 r_loop 241. 10.1s 52 | |
| # 6 r_binary_par 5021. 11.4s 3 | |
| # 7 cpp_loop 140. 10s 95 | |
| # 8 cpp_loop_par 18.4 10s 719 | |
| # 9 cpp_ranges 72.1 10s 184 | |
| # 10 cpp_range_par 20.0 10s 641 | |
| # 11 cpp_binary 1.88 10s 7013 | |
| # 12 cpp_binary_par 1 7.8s 10000 | |
| # 13 rust_loop 140. 10.1s 95 | |
| # 14 rust_binary 1.59 10s 8327 | |
| # 15 rust_binary_par 1.68 10s 4851 | |
| N <- 21250 | |
| set.seed(1) | |
| x <- runif(N) | |
| y <- runif(N) | |
| bench::mark( | |
| r_2010 = xymatch_r_2010(x, y), | |
| cpp_range_par = xymatch_cpp_range_par(x, y, nthreads = 8), | |
| r_binary = xymatch_r_binary(x, y), | |
| r_binary_par = xymatch_r_binary_par(x, y, nchunks = 8), | |
| cpp_binary = xymatch_cpp_binary(x, y), | |
| cpp_binary_par = xymatch_cpp_binary_par(x, y, nthreads = 8), | |
| rust_binary = xymatch_rust_binary(x, y), | |
| rust_binary_par = xymatch_rust_binary_par(x, y, nthreads = 8), | |
| relative = TRUE, | |
| min_time = 10, | |
| check = TRUE | |
| )[, c("expression", "median", "total_time", "n_itr")] | |
| # # A tibble: 8 × 4 | |
| # expression median total_time n_itr | |
| # <bch:expr> <dbl> <bch:tm> <int> | |
| # 1 r_2010 6396. 10.6s 1 | |
| # 2 cpp_range_par 28.6 10s 205 | |
| # 3 r_binary 2.11 10s 2736 | |
| # 4 r_binary_par 3366. 11.2s 2 | |
| # 5 cpp_binary 1.92 10s 3124 | |
| # 6 cpp_binary_par 1 10s 6000 | |
| # 7 rust_binary 1.69 10s 3573 | |
| # 8 rust_binary_par 5.32 10s 1324 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // Autor: @JosiahParry | |
| use extendr_api::prelude::*; | |
| #[extendr] | |
| fn xymatch_rust_binary(x: &[f64], y: &[f64]) -> Doubles { | |
| let mut y_indexed: Vec<(f64, usize)> = y.iter() | |
| .enumerate() | |
| .map(|(i, &val)| (val, i)) | |
| .collect(); | |
| y_indexed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); | |
| let mut res = Vec::with_capacity(x.len()); | |
| for &xi in x { | |
| let idx = match y_indexed.binary_search_by(|probe| { | |
| probe.0.partial_cmp(&xi).unwrap() | |
| }) { | |
| Ok(i) => i, // Exact match | |
| Err(i) => { | |
| if i == 0 { | |
| 0 | |
| } else if i == y_indexed.len() { | |
| y_indexed.len() - 1 | |
| } else { | |
| let left_dist = (xi - y_indexed[i - 1].0).abs(); | |
| let right_dist = (y_indexed[i].0 - xi).abs(); | |
| if left_dist <= right_dist { | |
| i - 1 | |
| } else { | |
| i | |
| } | |
| } | |
| } | |
| }; | |
| let orig_idx = y_indexed[idx].1; | |
| res.push(x[orig_idx]); | |
| } | |
| Doubles::from_values(res) | |
| } | |
| extendr_module! { | |
| mod xymatch_binary; | |
| fn xymatch_rust_binary; | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use extendr_api::prelude::*; | |
| #[extendr] | |
| fn xymatch_rust_loop(x: Vec<f64>, y: Vec<f64>) -> Vec<f64> { | |
| let n = x.len(); | |
| let m = y.len(); | |
| let mut result = Vec::with_capacity(n); | |
| for i in 0..n { | |
| let xi = x[i]; | |
| let mut min_index = 0; | |
| let mut min_dist = (xi - y[0]).abs(); | |
| for j in 1..m { | |
| let dist = (xi - y[j]).abs(); | |
| if dist < min_dist { | |
| min_index = j; | |
| min_dist = dist; | |
| } | |
| } | |
| result.push(x[min_index]); | |
| } | |
| result | |
| } | |
| extendr_module! { | |
| mod xymatch_loop; | |
| fn xymatch_rust_loop; | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // $ clang++ -std=c++23 -O2 -Wall -fopenmp -I/usr/lib/llvm-22/include -I/home/jordi/vcpkg/installed/x64-linux/include -I/home/jordi/R/x86_64-pc-linux-gnu-library/4.5/RInside/include -I/home/jordi/R/x86_64-pc-linux-gnu-library/4.5/Rcpp/include -I/usr/share/R/include xymatch_rcpp.cpp -o xymatch_rcpp -L/usr/lib/llvm-22/lib -lomp -L/home/jordi/R/x86_64-pc-linux-gnu-library/4.5/RInside/lib -lRInside -lR -Wl,-rpath,/home/jordi/R/x86_64-pc-linux-gnu-library/4.5/RInside/lib && ./xymatch_rcpp | |
| // > install.packages(c("Rcpp","RInside")) | |
| // [[Rcpp::plugins(cpp23)]] | |
| // [[Rcpp::depends(Rcpp)]] | |
| // // [[Rcpp::depends(RInside)]] | |
| // #include <RInside.h> | |
| #include <Rcpp.h> | |
| #include <iostream> | |
| #include <ranges> | |
| #include <vector> | |
| #include <cmath> | |
| #include <algorithm> | |
| #ifdef _OPENMP | |
| #include <omp.h> // clang++ --version / 22 => libomp-22-dev | |
| #endif | |
| #include <execution> | |
| using namespace Rcpp; | |
| // [[Rcpp::export]] | |
| int n_threads_test() { | |
| int n = 0; | |
| #pragma omp parallel | |
| { | |
| #pragma omp single | |
| n = omp_get_num_threads(); | |
| } | |
| std::cout << "OpenMP threads: " << n << "\n"; | |
| return n; | |
| } | |
| // R code: | |
| // xymatch <- function(x, y) { | |
| // x[ | |
| // apply( | |
| // outer(x, y, function(u, v) abs(u - v)), | |
| // 1, | |
| // function(d) which(d == min(d))[1] | |
| // ) | |
| // ] | |
| // } | |
| double closest_x_range(const std::vector<double>& vx, const std::vector<double>& vy, double xi) { | |
| auto min_it = std::ranges::min_element( | |
| std::views::iota(size_t{0}, vy.size()), | |
| {}, | |
| [&](size_t j){ return std::abs(vy[j] - xi); } | |
| ); | |
| return vx[*min_it]; | |
| } | |
| // [[Rcpp::export]] | |
| NumericVector xymatch_cpp_ranges(NumericVector x, NumericVector y) { | |
| std::vector<double> vx(x.begin(), x.end()); | |
| std::vector<double> vy(y.begin(), y.end()); | |
| auto result_view = vx | std::views::transform([&](double xi){ | |
| return closest_x_range(vx, vy, xi); | |
| }); | |
| return NumericVector(result_view.begin(), result_view.end()); | |
| } | |
| // [[Rcpp::export]] | |
| NumericVector xymatch_cpp_range_par(NumericVector x, NumericVector y, int nthreads = 1) { | |
| std::vector<double> vx(x.begin(), x.end()); | |
| std::vector<double> vy(y.begin(), y.end()); | |
| std::vector<double> res(vx.size()); | |
| #pragma omp parallel for num_threads(nthreads) | |
| for (size_t i = 0; i < vx.size(); ++i) { | |
| res[i] = closest_x_range(vx, vy, vx[i]); | |
| } | |
| return NumericVector(res.begin(), res.end()); | |
| } | |
| double closest_x_loop(const std::vector<double>& vx, const std::vector<double>& vy, double xi) { | |
| size_t min_j = 0; | |
| double min_diff = std::abs(vy[0] - xi); | |
| for (size_t j = 1; j < vy.size(); ++j) { | |
| double diff = std::abs(vy[j] - xi); | |
| if (diff < min_diff) { | |
| min_diff = diff; | |
| min_j = j; | |
| } | |
| } | |
| return vx[min_j]; | |
| } | |
| // [[Rcpp::export]] | |
| NumericVector xymatch_cpp_loop(NumericVector x, NumericVector y) { | |
| std::vector<double> vx(x.begin(), x.end()); | |
| std::vector<double> vy(y.begin(), y.end()); | |
| auto result_view = vx | std::views::transform([&](double xi){ | |
| return closest_x_loop(vx, vy, xi); | |
| }); | |
| return NumericVector(result_view.begin(), result_view.end()); | |
| } | |
| // [[Rcpp::export]] | |
| NumericVector xymatch_cpp_loop_par(NumericVector x, NumericVector y, int nthreads = 1) { | |
| std::vector<double> vx(x.begin(), x.end()); | |
| std::vector<double> vy(y.begin(), y.end()); | |
| std::vector<double> res(vx.size()); | |
| #pragma omp parallel for num_threads(nthreads) | |
| for (size_t i = 0; i < vx.size(); ++i) { | |
| res[i] = closest_x_loop(vx, vy, vx[i]); | |
| } | |
| return NumericVector(res.begin(), res.end()); | |
| } | |
| // xymatch_r_binary <- function(x,y) { | |
| // order <- order(y) | |
| // y <- y[order] | |
| // # Let's assume no NA for speediness | |
| // out <- findInterval(x, y, checkSorted = FALSE, checkNA = FALSE) | |
| // n <- length(y) | |
| // idx <- integer(n) + n | |
| // idx[out == 0] <- 1L | |
| // nx <- (out != 0 & (out != n)) | |
| // one <- abs(y[out[nx]] - x[nx]) | |
| // two <- abs(y[out[nx] + 1] - x[nx]) | |
| // idx[nx] <- out[nx] + (one > two) | |
| // x[order][idx] | |
| // } | |
| // [[Rcpp::export]] | |
| NumericVector xymatch_cpp_binary(NumericVector x, NumericVector y) { | |
| size_t ny = y.size(); | |
| std::vector<std::pair<double, size_t>> y_indexed(ny); | |
| for (size_t j = 0; j < ny; ++j) | |
| y_indexed[j] = {y[j], j}; | |
| std::ranges::sort(y_indexed, {}, &std::pair<double,size_t>::first); | |
| auto result_view = x | std::views::transform([&](double xi) { | |
| // Use lower_bound to find insertion point | |
| auto it = std::ranges::lower_bound( | |
| y_indexed, | |
| xi, | |
| {}, | |
| &std::pair<double,size_t>::first | |
| ); | |
| auto candidates = [&]() -> std::vector<std::pair<double, size_t>> { | |
| if (it == y_indexed.begin()) return { *it }; | |
| if (it == y_indexed.end()) return { *(it-1) }; | |
| return { *(it-1), *it }; | |
| }(); | |
| auto nearest = *std::ranges::min_element( | |
| candidates, | |
| [&](const auto &a, const auto &b){ | |
| return std::abs(a.first - xi) < std::abs(b.first - xi); | |
| } | |
| ); | |
| return x[nearest.second]; // return x[j] where j is nearest y | |
| }); | |
| return NumericVector(result_view.begin(), result_view.end()); | |
| } | |
| // [[Rcpp::export]] | |
| NumericVector xymatch_cpp_binary_par(NumericVector x, NumericVector y, int nthreads = 1) { | |
| size_t nx = x.size(); | |
| size_t ny = y.size(); | |
| // Step 1: sort y and keep original indices | |
| std::vector<std::pair<double, size_t>> y_indexed(ny); | |
| for (size_t j = 0; j < ny; ++j) | |
| y_indexed[j] = {y[j], j}; | |
| std::ranges::sort(y_indexed, {}, &std::pair<double,size_t>::first); | |
| // Step 2: result vector | |
| std::vector<double> result(nx); | |
| // Step 3: parallel loop over x | |
| #pragma omp parallel for num_threads(nthreads) | |
| for (ptrdiff_t i = 0; i < static_cast<ptrdiff_t>(nx); ++i) { | |
| double xi = x[i]; | |
| // binary search in sorted y | |
| auto it = std::ranges::lower_bound( | |
| y_indexed, | |
| xi, | |
| {}, | |
| &std::pair<double,size_t>::first | |
| ); | |
| size_t idx; | |
| if (it == y_indexed.begin()) { | |
| idx = 0; | |
| } else if (it == y_indexed.end()) { | |
| idx = ny - 1; | |
| } else { | |
| size_t right = std::distance(y_indexed.begin(), it); | |
| size_t left = right - 1; | |
| double left_dist = std::abs(xi - y_indexed[left].first); | |
| double right_dist = std::abs(xi - y_indexed[right].first); | |
| idx = (left_dist <= right_dist) ? left : right; | |
| } | |
| result[i] = x[y_indexed[idx].second]; | |
| } | |
| return NumericVector(result.begin(), result.end()); | |
| } | |
| // [[Rcpp::export]] | |
| Rcpp::NumericMatrix outer_abs_diff_par(Rcpp::NumericVector a, Rcpp::NumericVector b, int nthreads = 0) { | |
| // 1. Get dimensions | |
| int N = a.length(); // Number of rows for the output | |
| int M = b.length(); // Number of columns for the output | |
| // 2. Create the output matrix (Rcpp handles memory outside the parallel region) | |
| Rcpp::NumericMatrix out(N, M); | |
| // 3. Set the number of threads (optional) | |
| if (nthreads > 0) { | |
| omp_set_num_threads(nthreads); | |
| } | |
| // 4. Parallelize the outer loop (i) | |
| // The #pragma omp parallel for directive divides the 'i' loop iterations among threads. | |
| // Each thread gets a unique set of 'i's and therefore works on separate, non-overlapping | |
| // rows of the 'out' matrix, ensuring thread safety for writing. | |
| #pragma omp parallel for | |
| for (int i = 0; i < N; ++i) { | |
| // Inner loop (j) runs serially within each thread's assigned 'i' iteration. | |
| for (int j = 0; j < M; ++j) { | |
| // Safe Write: out(i, j) is only ever written to by the thread assigned to row 'i'. | |
| out(i, j) = std::abs(a[i] - b[j]); | |
| } | |
| } | |
| return out; | |
| } | |
| int main(int argc, char *argv[]) { | |
| std::vector<double> vx{0.2655087, 0.3721239, 0.5728534}; | |
| std::vector<double> vy{0.9082078, 0.2016819, 0.8983897}; | |
| std::cout << "x: "; | |
| for (double v : vx) std::cout << v << " "; | |
| std::cout << "\n"; | |
| std::cout << "y: "; | |
| for (double v : vy) std::cout << v << " "; | |
| std::cout << "\n"; | |
| // Print results closest_x_loop version | |
| std::cout << "xymatch closest_x_loop: "; | |
| auto result_view_loop = vx | std::views::transform([&](double xi){ | |
| return closest_x_loop(vx, vy, xi); | |
| }); | |
| std::vector<double> result_loop(result_view_loop.begin(), result_view_loop.end()); | |
| for (double v : result_loop) std::cout << v << " "; | |
| std::cout << "\n"; | |
| // Print results closest_x_range version | |
| std::cout << "xymatch closest_x_range: "; | |
| auto result_view_range = vx | std::views::transform([&](double xi){ | |
| return closest_x_range(vx, vy, xi); | |
| }); | |
| std::vector<double> result_range(result_view_range.begin(), result_view_range.end()); | |
| for (double v : result_range) std::cout << v << " "; | |
| std::cout << "\n"; | |
| // Print results R version | |
| // RInside R(argc, argv); | |
| // R["vx"] = NumericVector(vx.begin(), vx.end()); | |
| // R["vy"] = NumericVector(vy.begin(), vy.end()); | |
| // R.parseEvalQ("xymatch <- function(x, y) { x[apply(outer(x, y, function(u,v) abs(u-v)), 1, function(d) which(d==min(d))[1])] }"); | |
| // NumericVector result_r = R.parseEval("xymatch(vx, vy)"); | |
| // std::cout << "xymatch R: "; | |
| // for (double v : result_r) std::cout << v << " "; | |
| // std::cout << "\n"; | |
| return 0; | |
| } | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| $ lscpu | |
| Architecture: x86_64 | |
| CPU op-mode(s): 32-bit, 64-bit | |
| Address sizes: 39 bits physical, 48 bits virtual | |
| Byte Order: Little Endian | |
| CPU(s): 20 | |
| On-line CPU(s) list: 0-19 | |
| Vendor ID: GenuineIntel | |
| Model name: 13th Gen Intel(R) Core(TM) i7-13700H | |
| CPU family: 6 | |
| Model: 186 | |
| Thread(s) per core: 2 | |
| Core(s) per socket: 14 | |
| Socket(s): 1 | |
| Stepping: 2 | |
| CPU(s) scaling MHz: 15% | |
| CPU max MHz: 5000,0000 | |
| CPU min MHz: 400,0000 |
Author
Here's @JosiahParry bin search in parallel:
use extendr_api::prelude::*;
use rayon::prelude::*;
#[extendr]
pub fn xymatch_rust_binary_parallel(x: &[f64], y: &[f64]) -> Doubles {
let mut y_indexed: Vec<(f64, usize)> = y.iter()
.enumerate()
.map(|(i, &val)| (val, i))
.collect();
y_indexed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let res: Vec<f64> = x.par_iter()
.map(|&xi| {
let idx = match y_indexed.binary_search_by(|probe| {
probe.0.partial_cmp(&xi).unwrap()
}) {
Ok(i) => i,
Err(i) => {
if i == 0 {
0
} else if i == y_indexed.len() {
y_indexed.len() - 1
} else {
let left_dist = (xi - y_indexed[i - 1].0).abs();
let right_dist = (y_indexed[i].0 - xi).abs();
if left_dist <= right_dist {
i - 1
} else {
i
}
}
}
};
let orig_idx = y_indexed[idx].1;
x[orig_idx]
})
.collect();
Doubles::from_values(res)
}
Obs.: Rayon must be listed as dependency in rust_source(), e.g.,
rextendr::rust_source(
file = "xymatch_binary_parallel.rs",
module = "xymatch_binary_parallel",
profile = "release",
dependencies = list(`rayon` = "1.11")
)
Author
I want to set 8 threads...
use extendr_api::prelude::*;
use rayon::prelude::*;
use rayon::ThreadPoolBuilder;
/// Parallel version of xymatch with optional number of threads.
///
/// @param x Numeric vector.
/// @param y Numeric vector.
/// @param nthreads Optional number of threads.
/// @export
#[extendr]
pub fn xymatch_rust_binary_par(x: &[f64], y: &[f64], nthreads: Option<usize>) -> Doubles {
// Copy R slices into owned Vecs (safe for threads)
let x_vec = x.to_vec();
let y_vec = y.to_vec();
// Run computation in parallel using Rayon pool if requested
let res_vec: Vec<f64> = if let Some(n) = nthreads {
let pool = ThreadPoolBuilder::new()
.num_threads(n)
.build()
.expect("Failed to build Rayon thread pool");
// Return Vec<f64> (Send)
pool.install(|| run_xymatch_owned(&x_vec, &y_vec))
} else {
// Use default global Rayon pool
run_xymatch_owned(&x_vec, &y_vec)
};
// Convert back to R type AFTER threading (safe)
Doubles::from_values(res_vec)
}
/// Pure Rust computation — returns Vec<f64> (Send)
fn run_xymatch_owned(x: &[f64], y: &[f64]) -> Vec<f64> {
// Pair y values with their original indices
let mut y_indexed: Vec<(f64, usize)> = y
.iter()
.enumerate()
.map(|(i, &val)| (val, i))
.collect();
// Sort by y value
y_indexed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
// Parallel computation over x
x.par_iter()
.map(|&xi| {
let idx = match y_indexed.binary_search_by(|probe| probe.0.partial_cmp(&xi).unwrap()) {
Ok(i) => i,
Err(i) => {
if i == 0 {
0
} else if i == y_indexed.len() {
y_indexed.len() - 1
} else {
let left_dist = (xi - y_indexed[i - 1].0).abs();
let right_dist = (y_indexed[i].0 - xi).abs();
if left_dist <= right_dist {
i - 1
} else {
i
}
}
}
};
let orig_idx = y_indexed[idx].1;
x[orig_idx]
})
.collect()
}
// Export to R
extendr_module! {
mod xymatch_binary_par;
fn xymatch_rust_binary_par;
}
Is there something wrong? Check the results:
# A tibble: 13 × 4
expression median total_time n_itr
<bch:expr> <dbl> <bch:tm> <int>
1 r_2010 3239. 11.57s 5
2 r_binary 2.14 10s 6456
3 r_vapply 339. 10.12s 41
4 r_loop 255. 10.01s 51
5 cpp_loop 149. 10.04s 95
6 cpp_loop_par 19.5 10.01s 722
7 cpp_ranges 76.2 10.05s 186
8 cpp_range_par 19.2 10.02s 745
9 cpp_binary 1.95 10s 7225
10 cpp_binary_par 1 7.22s 10000
11 rust_loop 149. 10.05s 95
12 rust_binary 1.67 10s 8423
13 rust_binary_par 7.23 10s 2311
``
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks @JosiahParry and @TimTaylor. I've just updated the results with your contributions and added cpp binary and cpp binary with 8 threads.