-
-
Save jrosell/941b10596d496f5d56f49a192f1a8343 to your computer and use it in GitHub Desktop.
| library(rextendr) | |
| library(mirai) | |
| rextendr::rust_source( | |
| file = "rust/xymatch_loop.rs", | |
| module = "xymatch_loop", | |
| profile = "release" | |
| ) | |
| rextendr::rust_source( | |
| file = "rust/xymatch_binary.rs", | |
| module = "xymatch_binary", | |
| profile = "release", | |
| ) | |
| rextendr::rust_source( | |
| file = "rust/xymatch_binary_par.rs", | |
| module = "xymatch_binary_par", | |
| profile = "release", | |
| dependencies = list(`rayon` = "1.11") | |
| ) | |
| library(Rcpp) | |
| old_cxx <- Sys.getenv("PKG_CXXFLAGS") | |
| old_lib <- Sys.getenv("PKG_LIBS") | |
| Sys.setenv(PKG_CXXFLAGS = "-fopenmp -O3") | |
| Sys.setenv(PKG_LIBS = "-fopenmp") | |
| Rcpp::sourceCpp("xymatch_rcpp.cpp") | |
| Sys.setenv(PKG_CXXFLAGS = old_cxx) | |
| Sys.setenv(PKG_LIBS = old_lib) | |
| n_threads_test() | |
| # [1] 20 | |
| xymatch_r_2010 <- function(x, y) { | |
| x[ | |
| apply( | |
| outer(x, y, function(u, v) abs(u - v)), | |
| 1, | |
| function(d) which(d == min(d))[1] | |
| ) | |
| ] | |
| } | |
| xymatch_r_2010_par <- function(x, y, nthreads = 1) { | |
| x[ | |
| apply( | |
| outer_abs_diff_par(x, y, nthreads = nthreads), | |
| 1, | |
| function(d) which(d == min(d))[1] | |
| ) | |
| ] | |
| } | |
| xymatch_r_vapply <- function(x, y) { | |
| vapply(seq_along(x), \(i) x[which.min(abs(x[i]-y))], 1) | |
| } | |
| xymatch_r_loop <- function(x, y) { | |
| idx <- integer(N) | |
| for (i in seq_len(N)) { | |
| idx[i] <- which.min(abs(x[i]-y)) | |
| } | |
| x[idx] | |
| } | |
| xymatch_r_binary <- function(x, y) { | |
| order <- order(y) | |
| y <- y[order] | |
| out <- findInterval(x, y, checkSorted = FALSE, checkNA = FALSE) | |
| n <- length(y) | |
| idx <- integer(n) + n | |
| idx[out == 0] <- 1L | |
| nx <- (out != 0 & (out != n)) | |
| one <- abs(y[out[nx]] - x[nx]) | |
| two <- abs(y[out[nx] + 1] - x[nx]) | |
| idx[nx] <- out[nx] + (one > two) | |
| x[order][idx] | |
| } | |
| # Corrected Wrapper function for parallel execution | |
| xymatch_r_binary_par <- function(x, y, nchunks = 1) { | |
| order <- order(y) | |
| y <- y[order] | |
| mirai::daemons(nchunks) | |
| mirai::everywhere({}, y_val = y) | |
| if (nchunks < 1) stop() | |
| if(nchunks == 1) groups <- x | |
| if(nchunks > 1) { | |
| groups <- cut( | |
| seq_along(x), | |
| breaks = nchunks, | |
| labels = FALSE, | |
| include.lowest = TRUE | |
| ) | |
| } | |
| x_chunks <- split(x, groups) | |
| idx_chunks_mirai <- mirai_map( | |
| x_chunks, | |
| .f = function(chunk) { | |
| out <- findInterval(chunk, y_val, checkSorted = TRUE, checkNA = FALSE) | |
| n_y <- length(y_val) | |
| n_chunk <- length(chunk) | |
| idx <- integer(n_chunk) + n_y | |
| idx[out == 0] <- 1L | |
| nx <- (out != 0 & (out != n_y)) | |
| if (any(nx)) { | |
| one <- abs(y_val[out[nx]] - chunk[nx]) | |
| two <- abs(y_val[out[nx] + 1] - chunk[nx]) | |
| idx[nx] <- out[nx] + (one > two) | |
| } | |
| return(idx) | |
| } | |
| ) | |
| idx <- unlist(idx_chunks_mirai[.flat]) | |
| x[order][idx] | |
| } | |
| N <- 1e4 | |
| set.seed(1) | |
| x <- runif(N) | |
| y <- runif(N) | |
| bench::mark( | |
| r_2010 = xymatch_r_2010(x, y), | |
| r_2010_par = xymatch_r_2010_par(x, y, nthreads = 8), | |
| r_binary = xymatch_r_binary(x, y), | |
| r_vapply = xymatch_r_loop(x, y), | |
| r_loop = xymatch_r_vapply(x, y), | |
| r_binary_par = xymatch_r_binary_par(x, y, nchunks = 8), | |
| cpp_loop = xymatch_cpp_loop(x, y), | |
| cpp_loop_par = xymatch_cpp_loop_par(x, y, nthreads = 8), | |
| cpp_ranges = xymatch_cpp_ranges(x, y), | |
| cpp_range_par = xymatch_cpp_range_par(x, y, nthreads = 8), | |
| cpp_binary = xymatch_cpp_binary(x, y), | |
| cpp_binary_par = xymatch_cpp_binary_par(x, y, nthreads = 8), | |
| rust_loop = xymatch_rust_loop(x, y), | |
| rust_binary = xymatch_rust_binary(x, y), | |
| rust_binary_par = xymatch_rust_binary_par(x, y, nthreads = 8), | |
| relative = TRUE, | |
| min_time = 10, | |
| check = TRUE | |
| )[, c("expression", "median", "total_time", "n_itr")] | |
| # # A tibble: 15 × 4 | |
| # expression median total_time n_itr | |
| # <bch:expr> <dbl> <bch:tm> <int> | |
| # 1 r_2010 3403. 10.2s 4 | |
| # 2 r_2010_par 2594. 11.7s 6 | |
| # 3 r_binary 1.97 10s 6661 | |
| # 4 r_vapply 256. 10s 47 | |
| # 5 r_loop 241. 10.1s 52 | |
| # 6 r_binary_par 5021. 11.4s 3 | |
| # 7 cpp_loop 140. 10s 95 | |
| # 8 cpp_loop_par 18.4 10s 719 | |
| # 9 cpp_ranges 72.1 10s 184 | |
| # 10 cpp_range_par 20.0 10s 641 | |
| # 11 cpp_binary 1.88 10s 7013 | |
| # 12 cpp_binary_par 1 7.8s 10000 | |
| # 13 rust_loop 140. 10.1s 95 | |
| # 14 rust_binary 1.59 10s 8327 | |
| # 15 rust_binary_par 1.68 10s 4851 | |
| N <- 21250 | |
| set.seed(1) | |
| x <- runif(N) | |
| y <- runif(N) | |
| bench::mark( | |
| r_2010 = xymatch_r_2010(x, y), | |
| cpp_range_par = xymatch_cpp_range_par(x, y, nthreads = 8), | |
| r_binary = xymatch_r_binary(x, y), | |
| r_binary_par = xymatch_r_binary_par(x, y, nchunks = 8), | |
| cpp_binary = xymatch_cpp_binary(x, y), | |
| cpp_binary_par = xymatch_cpp_binary_par(x, y, nthreads = 8), | |
| rust_binary = xymatch_rust_binary(x, y), | |
| rust_binary_par = xymatch_rust_binary_par(x, y, nthreads = 8), | |
| relative = TRUE, | |
| min_time = 10, | |
| check = TRUE | |
| )[, c("expression", "median", "total_time", "n_itr")] | |
| # # A tibble: 8 × 4 | |
| # expression median total_time n_itr | |
| # <bch:expr> <dbl> <bch:tm> <int> | |
| # 1 r_2010 6396. 10.6s 1 | |
| # 2 cpp_range_par 28.6 10s 205 | |
| # 3 r_binary 2.11 10s 2736 | |
| # 4 r_binary_par 3366. 11.2s 2 | |
| # 5 cpp_binary 1.92 10s 3124 | |
| # 6 cpp_binary_par 1 10s 6000 | |
| # 7 rust_binary 1.69 10s 3573 | |
| # 8 rust_binary_par 5.32 10s 1324 |
| // Autor: @JosiahParry | |
| use extendr_api::prelude::*; | |
| #[extendr] | |
| fn xymatch_rust_binary(x: &[f64], y: &[f64]) -> Doubles { | |
| let mut y_indexed: Vec<(f64, usize)> = y.iter() | |
| .enumerate() | |
| .map(|(i, &val)| (val, i)) | |
| .collect(); | |
| y_indexed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); | |
| let mut res = Vec::with_capacity(x.len()); | |
| for &xi in x { | |
| let idx = match y_indexed.binary_search_by(|probe| { | |
| probe.0.partial_cmp(&xi).unwrap() | |
| }) { | |
| Ok(i) => i, // Exact match | |
| Err(i) => { | |
| if i == 0 { | |
| 0 | |
| } else if i == y_indexed.len() { | |
| y_indexed.len() - 1 | |
| } else { | |
| let left_dist = (xi - y_indexed[i - 1].0).abs(); | |
| let right_dist = (y_indexed[i].0 - xi).abs(); | |
| if left_dist <= right_dist { | |
| i - 1 | |
| } else { | |
| i | |
| } | |
| } | |
| } | |
| }; | |
| let orig_idx = y_indexed[idx].1; | |
| res.push(x[orig_idx]); | |
| } | |
| Doubles::from_values(res) | |
| } | |
| extendr_module! { | |
| mod xymatch_binary; | |
| fn xymatch_rust_binary; | |
| } |
| use extendr_api::prelude::*; | |
| #[extendr] | |
| fn xymatch_rust_loop(x: Vec<f64>, y: Vec<f64>) -> Vec<f64> { | |
| let n = x.len(); | |
| let m = y.len(); | |
| let mut result = Vec::with_capacity(n); | |
| for i in 0..n { | |
| let xi = x[i]; | |
| let mut min_index = 0; | |
| let mut min_dist = (xi - y[0]).abs(); | |
| for j in 1..m { | |
| let dist = (xi - y[j]).abs(); | |
| if dist < min_dist { | |
| min_index = j; | |
| min_dist = dist; | |
| } | |
| } | |
| result.push(x[min_index]); | |
| } | |
| result | |
| } | |
| extendr_module! { | |
| mod xymatch_loop; | |
| fn xymatch_rust_loop; | |
| } |
| // $ clang++ -std=c++23 -O2 -Wall -fopenmp -I/usr/lib/llvm-22/include -I/home/jordi/vcpkg/installed/x64-linux/include -I/home/jordi/R/x86_64-pc-linux-gnu-library/4.5/RInside/include -I/home/jordi/R/x86_64-pc-linux-gnu-library/4.5/Rcpp/include -I/usr/share/R/include xymatch_rcpp.cpp -o xymatch_rcpp -L/usr/lib/llvm-22/lib -lomp -L/home/jordi/R/x86_64-pc-linux-gnu-library/4.5/RInside/lib -lRInside -lR -Wl,-rpath,/home/jordi/R/x86_64-pc-linux-gnu-library/4.5/RInside/lib && ./xymatch_rcpp | |
| // > install.packages(c("Rcpp","RInside")) | |
| // [[Rcpp::plugins(cpp23)]] | |
| // [[Rcpp::depends(Rcpp)]] | |
| // // [[Rcpp::depends(RInside)]] | |
| // #include <RInside.h> | |
| #include <Rcpp.h> | |
| #include <iostream> | |
| #include <ranges> | |
| #include <vector> | |
| #include <cmath> | |
| #include <algorithm> | |
| #ifdef _OPENMP | |
| #include <omp.h> // clang++ --version / 22 => libomp-22-dev | |
| #endif | |
| #include <execution> | |
| using namespace Rcpp; | |
| // [[Rcpp::export]] | |
| int n_threads_test() { | |
| int n = 0; | |
| #pragma omp parallel | |
| { | |
| #pragma omp single | |
| n = omp_get_num_threads(); | |
| } | |
| std::cout << "OpenMP threads: " << n << "\n"; | |
| return n; | |
| } | |
| // R code: | |
| // xymatch <- function(x, y) { | |
| // x[ | |
| // apply( | |
| // outer(x, y, function(u, v) abs(u - v)), | |
| // 1, | |
| // function(d) which(d == min(d))[1] | |
| // ) | |
| // ] | |
| // } | |
| double closest_x_range(const std::vector<double>& vx, const std::vector<double>& vy, double xi) { | |
| auto min_it = std::ranges::min_element( | |
| std::views::iota(size_t{0}, vy.size()), | |
| {}, | |
| [&](size_t j){ return std::abs(vy[j] - xi); } | |
| ); | |
| return vx[*min_it]; | |
| } | |
| // [[Rcpp::export]] | |
| NumericVector xymatch_cpp_ranges(NumericVector x, NumericVector y) { | |
| std::vector<double> vx(x.begin(), x.end()); | |
| std::vector<double> vy(y.begin(), y.end()); | |
| auto result_view = vx | std::views::transform([&](double xi){ | |
| return closest_x_range(vx, vy, xi); | |
| }); | |
| return NumericVector(result_view.begin(), result_view.end()); | |
| } | |
| // [[Rcpp::export]] | |
| NumericVector xymatch_cpp_range_par(NumericVector x, NumericVector y, int nthreads = 1) { | |
| std::vector<double> vx(x.begin(), x.end()); | |
| std::vector<double> vy(y.begin(), y.end()); | |
| std::vector<double> res(vx.size()); | |
| #pragma omp parallel for num_threads(nthreads) | |
| for (size_t i = 0; i < vx.size(); ++i) { | |
| res[i] = closest_x_range(vx, vy, vx[i]); | |
| } | |
| return NumericVector(res.begin(), res.end()); | |
| } | |
| double closest_x_loop(const std::vector<double>& vx, const std::vector<double>& vy, double xi) { | |
| size_t min_j = 0; | |
| double min_diff = std::abs(vy[0] - xi); | |
| for (size_t j = 1; j < vy.size(); ++j) { | |
| double diff = std::abs(vy[j] - xi); | |
| if (diff < min_diff) { | |
| min_diff = diff; | |
| min_j = j; | |
| } | |
| } | |
| return vx[min_j]; | |
| } | |
| // [[Rcpp::export]] | |
| NumericVector xymatch_cpp_loop(NumericVector x, NumericVector y) { | |
| std::vector<double> vx(x.begin(), x.end()); | |
| std::vector<double> vy(y.begin(), y.end()); | |
| auto result_view = vx | std::views::transform([&](double xi){ | |
| return closest_x_loop(vx, vy, xi); | |
| }); | |
| return NumericVector(result_view.begin(), result_view.end()); | |
| } | |
| // [[Rcpp::export]] | |
| NumericVector xymatch_cpp_loop_par(NumericVector x, NumericVector y, int nthreads = 1) { | |
| std::vector<double> vx(x.begin(), x.end()); | |
| std::vector<double> vy(y.begin(), y.end()); | |
| std::vector<double> res(vx.size()); | |
| #pragma omp parallel for num_threads(nthreads) | |
| for (size_t i = 0; i < vx.size(); ++i) { | |
| res[i] = closest_x_loop(vx, vy, vx[i]); | |
| } | |
| return NumericVector(res.begin(), res.end()); | |
| } | |
| // xymatch_r_binary <- function(x,y) { | |
| // order <- order(y) | |
| // y <- y[order] | |
| // # Let's assume no NA for speediness | |
| // out <- findInterval(x, y, checkSorted = FALSE, checkNA = FALSE) | |
| // n <- length(y) | |
| // idx <- integer(n) + n | |
| // idx[out == 0] <- 1L | |
| // nx <- (out != 0 & (out != n)) | |
| // one <- abs(y[out[nx]] - x[nx]) | |
| // two <- abs(y[out[nx] + 1] - x[nx]) | |
| // idx[nx] <- out[nx] + (one > two) | |
| // x[order][idx] | |
| // } | |
| // [[Rcpp::export]] | |
| NumericVector xymatch_cpp_binary(NumericVector x, NumericVector y) { | |
| size_t ny = y.size(); | |
| std::vector<std::pair<double, size_t>> y_indexed(ny); | |
| for (size_t j = 0; j < ny; ++j) | |
| y_indexed[j] = {y[j], j}; | |
| std::ranges::sort(y_indexed, {}, &std::pair<double,size_t>::first); | |
| auto result_view = x | std::views::transform([&](double xi) { | |
| // Use lower_bound to find insertion point | |
| auto it = std::ranges::lower_bound( | |
| y_indexed, | |
| xi, | |
| {}, | |
| &std::pair<double,size_t>::first | |
| ); | |
| auto candidates = [&]() -> std::vector<std::pair<double, size_t>> { | |
| if (it == y_indexed.begin()) return { *it }; | |
| if (it == y_indexed.end()) return { *(it-1) }; | |
| return { *(it-1), *it }; | |
| }(); | |
| auto nearest = *std::ranges::min_element( | |
| candidates, | |
| [&](const auto &a, const auto &b){ | |
| return std::abs(a.first - xi) < std::abs(b.first - xi); | |
| } | |
| ); | |
| return x[nearest.second]; // return x[j] where j is nearest y | |
| }); | |
| return NumericVector(result_view.begin(), result_view.end()); | |
| } | |
| // [[Rcpp::export]] | |
| NumericVector xymatch_cpp_binary_par(NumericVector x, NumericVector y, int nthreads = 1) { | |
| size_t nx = x.size(); | |
| size_t ny = y.size(); | |
| // Step 1: sort y and keep original indices | |
| std::vector<std::pair<double, size_t>> y_indexed(ny); | |
| for (size_t j = 0; j < ny; ++j) | |
| y_indexed[j] = {y[j], j}; | |
| std::ranges::sort(y_indexed, {}, &std::pair<double,size_t>::first); | |
| // Step 2: result vector | |
| std::vector<double> result(nx); | |
| // Step 3: parallel loop over x | |
| #pragma omp parallel for num_threads(nthreads) | |
| for (ptrdiff_t i = 0; i < static_cast<ptrdiff_t>(nx); ++i) { | |
| double xi = x[i]; | |
| // binary search in sorted y | |
| auto it = std::ranges::lower_bound( | |
| y_indexed, | |
| xi, | |
| {}, | |
| &std::pair<double,size_t>::first | |
| ); | |
| size_t idx; | |
| if (it == y_indexed.begin()) { | |
| idx = 0; | |
| } else if (it == y_indexed.end()) { | |
| idx = ny - 1; | |
| } else { | |
| size_t right = std::distance(y_indexed.begin(), it); | |
| size_t left = right - 1; | |
| double left_dist = std::abs(xi - y_indexed[left].first); | |
| double right_dist = std::abs(xi - y_indexed[right].first); | |
| idx = (left_dist <= right_dist) ? left : right; | |
| } | |
| result[i] = x[y_indexed[idx].second]; | |
| } | |
| return NumericVector(result.begin(), result.end()); | |
| } | |
| // [[Rcpp::export]] | |
| Rcpp::NumericMatrix outer_abs_diff_par(Rcpp::NumericVector a, Rcpp::NumericVector b, int nthreads = 0) { | |
| // 1. Get dimensions | |
| int N = a.length(); // Number of rows for the output | |
| int M = b.length(); // Number of columns for the output | |
| // 2. Create the output matrix (Rcpp handles memory outside the parallel region) | |
| Rcpp::NumericMatrix out(N, M); | |
| // 3. Set the number of threads (optional) | |
| if (nthreads > 0) { | |
| omp_set_num_threads(nthreads); | |
| } | |
| // 4. Parallelize the outer loop (i) | |
| // The #pragma omp parallel for directive divides the 'i' loop iterations among threads. | |
| // Each thread gets a unique set of 'i's and therefore works on separate, non-overlapping | |
| // rows of the 'out' matrix, ensuring thread safety for writing. | |
| #pragma omp parallel for | |
| for (int i = 0; i < N; ++i) { | |
| // Inner loop (j) runs serially within each thread's assigned 'i' iteration. | |
| for (int j = 0; j < M; ++j) { | |
| // Safe Write: out(i, j) is only ever written to by the thread assigned to row 'i'. | |
| out(i, j) = std::abs(a[i] - b[j]); | |
| } | |
| } | |
| return out; | |
| } | |
| int main(int argc, char *argv[]) { | |
| std::vector<double> vx{0.2655087, 0.3721239, 0.5728534}; | |
| std::vector<double> vy{0.9082078, 0.2016819, 0.8983897}; | |
| std::cout << "x: "; | |
| for (double v : vx) std::cout << v << " "; | |
| std::cout << "\n"; | |
| std::cout << "y: "; | |
| for (double v : vy) std::cout << v << " "; | |
| std::cout << "\n"; | |
| // Print results closest_x_loop version | |
| std::cout << "xymatch closest_x_loop: "; | |
| auto result_view_loop = vx | std::views::transform([&](double xi){ | |
| return closest_x_loop(vx, vy, xi); | |
| }); | |
| std::vector<double> result_loop(result_view_loop.begin(), result_view_loop.end()); | |
| for (double v : result_loop) std::cout << v << " "; | |
| std::cout << "\n"; | |
| // Print results closest_x_range version | |
| std::cout << "xymatch closest_x_range: "; | |
| auto result_view_range = vx | std::views::transform([&](double xi){ | |
| return closest_x_range(vx, vy, xi); | |
| }); | |
| std::vector<double> result_range(result_view_range.begin(), result_view_range.end()); | |
| for (double v : result_range) std::cout << v << " "; | |
| std::cout << "\n"; | |
| // Print results R version | |
| // RInside R(argc, argv); | |
| // R["vx"] = NumericVector(vx.begin(), vx.end()); | |
| // R["vy"] = NumericVector(vy.begin(), vy.end()); | |
| // R.parseEvalQ("xymatch <- function(x, y) { x[apply(outer(x, y, function(u,v) abs(u-v)), 1, function(d) which(d==min(d))[1])] }"); | |
| // NumericVector result_r = R.parseEval("xymatch(vx, vy)"); | |
| // std::cout << "xymatch R: "; | |
| // for (double v : result_r) std::cout << v << " "; | |
| // std::cout << "\n"; | |
| return 0; | |
| } | |
| $ lscpu | |
| Architecture: x86_64 | |
| CPU op-mode(s): 32-bit, 64-bit | |
| Address sizes: 39 bits physical, 48 bits virtual | |
| Byte Order: Little Endian | |
| CPU(s): 20 | |
| On-line CPU(s) list: 0-19 | |
| Vendor ID: GenuineIntel | |
| Model name: 13th Gen Intel(R) Core(TM) i7-13700H | |
| CPU family: 6 | |
| Model: 186 | |
| Thread(s) per core: 2 | |
| Core(s) per socket: 14 | |
| Socket(s): 1 | |
| Stepping: 2 | |
| CPU(s) scaling MHz: 15% | |
| CPU max MHz: 5000,0000 | |
| CPU min MHz: 400,0000 |
Funsies - @JosiahParry - how does this compare ...
xymatch_binary2 <- function(x,y) {
order <- order(y)
y <- y[order]
# Let's assume no NA for speediness
out <- findInterval(x, y, checkSorted = FALSE, checkNA = FALSE)
n <- length(y)
idx <- integer(n) + n
idx[out == 0] <- 1L
nx <- (out != 0 & (out != n))
one <- abs(y[out[nx]] - x[nx])
two <- abs(y[out[nx] + 1] - x[nx])
idx[nx] <- out[nx] + (one > two)
x[order][idx]
}@jrosell - what have you started 😅 (and yes this completely ignores what you were illustrating but t'is fun)
EDIT: Added minor tweak
It was not me, it was Ross Ihaka :)
@TimTaylor the R approach is very good! Adding a binary search for Rust—i'm not good at this stuff so i bet there is a good approach to this.
The KD-tree implementation above fails to match the R implementation at 1e5 probably due to floating point precision with the squared euclidean distance calculation. the Rust binary search works well though!
This could be a good function to put in base R :)
#[extendr]
fn xymatch_binsearch(x: &[f64], y: &[f64]) -> Doubles {
let mut y_indexed: Vec<(f64, usize)> = y.iter()
.enumerate()
.map(|(i, &val)| (val, i))
.collect();
y_indexed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let mut res = Vec::with_capacity(x.len());
for &xi in x {
let idx = match y_indexed.binary_search_by(|probe| {
probe.0.partial_cmp(&xi).unwrap()
}) {
Ok(i) => i, // Exact match
Err(i) => {
if i == 0 {
0
} else if i == y_indexed.len() {
y_indexed.len() - 1
} else {
let left_dist = (xi - y_indexed[i - 1].0).abs();
let right_dist = (y_indexed[i].0 - xi).abs();
if left_dist <= right_dist {
i - 1
} else {
i
}
}
}
};
let orig_idx = y_indexed[idx].1;
res.push(x[orig_idx]);
}
Doubles::from_values(res)
}Relative performance.
# A tibble: 3 × 4
expression median total_time n_itr
<bch:expr> <dbl> <bch:tm> <int>
1 r_binary 1.44 9.27s 5231
2 rust_binary 1 10s 8263
3 rust_kdtree 1.59 10s 5077
Thanks @JosiahParry and @TimTaylor. I've just updated the results with your contributions and added cpp binary and cpp binary with 8 threads.
Here's @JosiahParry bin search in parallel:
use extendr_api::prelude::*;
use rayon::prelude::*;
#[extendr]
pub fn xymatch_rust_binary_parallel(x: &[f64], y: &[f64]) -> Doubles {
let mut y_indexed: Vec<(f64, usize)> = y.iter()
.enumerate()
.map(|(i, &val)| (val, i))
.collect();
y_indexed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let res: Vec<f64> = x.par_iter()
.map(|&xi| {
let idx = match y_indexed.binary_search_by(|probe| {
probe.0.partial_cmp(&xi).unwrap()
}) {
Ok(i) => i,
Err(i) => {
if i == 0 {
0
} else if i == y_indexed.len() {
y_indexed.len() - 1
} else {
let left_dist = (xi - y_indexed[i - 1].0).abs();
let right_dist = (y_indexed[i].0 - xi).abs();
if left_dist <= right_dist {
i - 1
} else {
i
}
}
}
};
let orig_idx = y_indexed[idx].1;
x[orig_idx]
})
.collect();
Doubles::from_values(res)
}
Obs.: Rayon must be listed as dependency in rust_source(), e.g.,
rextendr::rust_source(
file = "xymatch_binary_parallel.rs",
module = "xymatch_binary_parallel",
profile = "release",
dependencies = list(`rayon` = "1.11")
)
I want to set 8 threads...
use extendr_api::prelude::*;
use rayon::prelude::*;
use rayon::ThreadPoolBuilder;
/// Parallel version of xymatch with optional number of threads.
///
/// @param x Numeric vector.
/// @param y Numeric vector.
/// @param nthreads Optional number of threads.
/// @export
#[extendr]
pub fn xymatch_rust_binary_par(x: &[f64], y: &[f64], nthreads: Option<usize>) -> Doubles {
// Copy R slices into owned Vecs (safe for threads)
let x_vec = x.to_vec();
let y_vec = y.to_vec();
// Run computation in parallel using Rayon pool if requested
let res_vec: Vec<f64> = if let Some(n) = nthreads {
let pool = ThreadPoolBuilder::new()
.num_threads(n)
.build()
.expect("Failed to build Rayon thread pool");
// Return Vec<f64> (Send)
pool.install(|| run_xymatch_owned(&x_vec, &y_vec))
} else {
// Use default global Rayon pool
run_xymatch_owned(&x_vec, &y_vec)
};
// Convert back to R type AFTER threading (safe)
Doubles::from_values(res_vec)
}
/// Pure Rust computation — returns Vec<f64> (Send)
fn run_xymatch_owned(x: &[f64], y: &[f64]) -> Vec<f64> {
// Pair y values with their original indices
let mut y_indexed: Vec<(f64, usize)> = y
.iter()
.enumerate()
.map(|(i, &val)| (val, i))
.collect();
// Sort by y value
y_indexed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
// Parallel computation over x
x.par_iter()
.map(|&xi| {
let idx = match y_indexed.binary_search_by(|probe| probe.0.partial_cmp(&xi).unwrap()) {
Ok(i) => i,
Err(i) => {
if i == 0 {
0
} else if i == y_indexed.len() {
y_indexed.len() - 1
} else {
let left_dist = (xi - y_indexed[i - 1].0).abs();
let right_dist = (y_indexed[i].0 - xi).abs();
if left_dist <= right_dist {
i - 1
} else {
i
}
}
}
};
let orig_idx = y_indexed[idx].1;
x[orig_idx]
})
.collect()
}
// Export to R
extendr_module! {
mod xymatch_binary_par;
fn xymatch_rust_binary_par;
}
Is there something wrong? Check the results:
# A tibble: 13 × 4
expression median total_time n_itr
<bch:expr> <dbl> <bch:tm> <int>
1 r_2010 3239. 11.57s 5
2 r_binary 2.14 10s 6456
3 r_vapply 339. 10.12s 41
4 r_loop 255. 10.01s 51
5 cpp_loop 149. 10.04s 95
6 cpp_loop_par 19.5 10.01s 722
7 cpp_ranges 76.2 10.05s 186
8 cpp_range_par 19.2 10.02s 745
9 cpp_binary 1.95 10s 7225
10 cpp_binary_par 1 7.22s 10000
11 rust_loop 149. 10.05s 95
12 rust_binary 1.67 10s 8423
13 rust_binary_par 7.23 10s 2311
``
Here is an alternative approach to this problem—it's not apples to apples whatsoever. But it is a more efficient way of approaching the problem, I think, and it is only running on a single thread.
We create a KD-tree and query the kd tree. This will have much better performance.
On two vectors of size
1e6it takes 450ms on my mac M1Edit: the best approach is likely a binary search