Skip to content

Instantly share code, notes, and snippets.

@jrosell
Last active October 13, 2025 14:36
Show Gist options
  • Select an option

  • Save jrosell/941b10596d496f5d56f49a192f1a8343 to your computer and use it in GitHub Desktop.

Select an option

Save jrosell/941b10596d496f5d56f49a192f1a8343 to your computer and use it in GitHub Desktop.
library(rextendr)
library(mirai)
rextendr::rust_source(
file = "rust/xymatch_loop.rs",
module = "xymatch_loop",
profile = "release"
)
rextendr::rust_source(
file = "rust/xymatch_binary.rs",
module = "xymatch_binary",
profile = "release",
)
rextendr::rust_source(
file = "rust/xymatch_binary_par.rs",
module = "xymatch_binary_par",
profile = "release",
dependencies = list(`rayon` = "1.11")
)
library(Rcpp)
old_cxx <- Sys.getenv("PKG_CXXFLAGS")
old_lib <- Sys.getenv("PKG_LIBS")
Sys.setenv(PKG_CXXFLAGS = "-fopenmp -O3")
Sys.setenv(PKG_LIBS = "-fopenmp")
Rcpp::sourceCpp("xymatch_rcpp.cpp")
Sys.setenv(PKG_CXXFLAGS = old_cxx)
Sys.setenv(PKG_LIBS = old_lib)
n_threads_test()
# [1] 20
xymatch_r_2010 <- function(x, y) {
x[
apply(
outer(x, y, function(u, v) abs(u - v)),
1,
function(d) which(d == min(d))[1]
)
]
}
xymatch_r_2010_par <- function(x, y, nthreads = 1) {
x[
apply(
outer_abs_diff_par(x, y, nthreads = nthreads),
1,
function(d) which(d == min(d))[1]
)
]
}
xymatch_r_vapply <- function(x, y) {
vapply(seq_along(x), \(i) x[which.min(abs(x[i]-y))], 1)
}
xymatch_r_loop <- function(x, y) {
idx <- integer(N)
for (i in seq_len(N)) {
idx[i] <- which.min(abs(x[i]-y))
}
x[idx]
}
xymatch_r_binary <- function(x, y) {
order <- order(y)
y <- y[order]
out <- findInterval(x, y, checkSorted = FALSE, checkNA = FALSE)
n <- length(y)
idx <- integer(n) + n
idx[out == 0] <- 1L
nx <- (out != 0 & (out != n))
one <- abs(y[out[nx]] - x[nx])
two <- abs(y[out[nx] + 1] - x[nx])
idx[nx] <- out[nx] + (one > two)
x[order][idx]
}
# Corrected Wrapper function for parallel execution
xymatch_r_binary_par <- function(x, y, nchunks = 1) {
order <- order(y)
y <- y[order]
mirai::daemons(nchunks)
mirai::everywhere({}, y_val = y)
if (nchunks < 1) stop()
if(nchunks == 1) groups <- x
if(nchunks > 1) {
groups <- cut(
seq_along(x),
breaks = nchunks,
labels = FALSE,
include.lowest = TRUE
)
}
x_chunks <- split(x, groups)
idx_chunks_mirai <- mirai_map(
x_chunks,
.f = function(chunk) {
out <- findInterval(chunk, y_val, checkSorted = TRUE, checkNA = FALSE)
n_y <- length(y_val)
n_chunk <- length(chunk)
idx <- integer(n_chunk) + n_y
idx[out == 0] <- 1L
nx <- (out != 0 & (out != n_y))
if (any(nx)) {
one <- abs(y_val[out[nx]] - chunk[nx])
two <- abs(y_val[out[nx] + 1] - chunk[nx])
idx[nx] <- out[nx] + (one > two)
}
return(idx)
}
)
idx <- unlist(idx_chunks_mirai[.flat])
x[order][idx]
}
N <- 1e4
set.seed(1)
x <- runif(N)
y <- runif(N)
bench::mark(
r_2010 = xymatch_r_2010(x, y),
r_2010_par = xymatch_r_2010_par(x, y, nthreads = 8),
r_binary = xymatch_r_binary(x, y),
r_vapply = xymatch_r_loop(x, y),
r_loop = xymatch_r_vapply(x, y),
r_binary_par = xymatch_r_binary_par(x, y, nchunks = 8),
cpp_loop = xymatch_cpp_loop(x, y),
cpp_loop_par = xymatch_cpp_loop_par(x, y, nthreads = 8),
cpp_ranges = xymatch_cpp_ranges(x, y),
cpp_range_par = xymatch_cpp_range_par(x, y, nthreads = 8),
cpp_binary = xymatch_cpp_binary(x, y),
cpp_binary_par = xymatch_cpp_binary_par(x, y, nthreads = 8),
rust_loop = xymatch_rust_loop(x, y),
rust_binary = xymatch_rust_binary(x, y),
rust_binary_par = xymatch_rust_binary_par(x, y, nthreads = 8),
relative = TRUE,
min_time = 10,
check = TRUE
)[, c("expression", "median", "total_time", "n_itr")]
# # A tibble: 15 × 4
# expression median total_time n_itr
# <bch:expr> <dbl> <bch:tm> <int>
# 1 r_2010 3403. 10.2s 4
# 2 r_2010_par 2594. 11.7s 6
# 3 r_binary 1.97 10s 6661
# 4 r_vapply 256. 10s 47
# 5 r_loop 241. 10.1s 52
# 6 r_binary_par 5021. 11.4s 3
# 7 cpp_loop 140. 10s 95
# 8 cpp_loop_par 18.4 10s 719
# 9 cpp_ranges 72.1 10s 184
# 10 cpp_range_par 20.0 10s 641
# 11 cpp_binary 1.88 10s 7013
# 12 cpp_binary_par 1 7.8s 10000
# 13 rust_loop 140. 10.1s 95
# 14 rust_binary 1.59 10s 8327
# 15 rust_binary_par 1.68 10s 4851
N <- 21250
set.seed(1)
x <- runif(N)
y <- runif(N)
bench::mark(
r_2010 = xymatch_r_2010(x, y),
cpp_range_par = xymatch_cpp_range_par(x, y, nthreads = 8),
r_binary = xymatch_r_binary(x, y),
r_binary_par = xymatch_r_binary_par(x, y, nchunks = 8),
cpp_binary = xymatch_cpp_binary(x, y),
cpp_binary_par = xymatch_cpp_binary_par(x, y, nthreads = 8),
rust_binary = xymatch_rust_binary(x, y),
rust_binary_par = xymatch_rust_binary_par(x, y, nthreads = 8),
relative = TRUE,
min_time = 10,
check = TRUE
)[, c("expression", "median", "total_time", "n_itr")]
# # A tibble: 8 × 4
# expression median total_time n_itr
# <bch:expr> <dbl> <bch:tm> <int>
# 1 r_2010 6396. 10.6s 1
# 2 cpp_range_par 28.6 10s 205
# 3 r_binary 2.11 10s 2736
# 4 r_binary_par 3366. 11.2s 2
# 5 cpp_binary 1.92 10s 3124
# 6 cpp_binary_par 1 10s 6000
# 7 rust_binary 1.69 10s 3573
# 8 rust_binary_par 5.32 10s 1324
// Autor: @JosiahParry
use extendr_api::prelude::*;
#[extendr]
fn xymatch_rust_binary(x: &[f64], y: &[f64]) -> Doubles {
let mut y_indexed: Vec<(f64, usize)> = y.iter()
.enumerate()
.map(|(i, &val)| (val, i))
.collect();
y_indexed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let mut res = Vec::with_capacity(x.len());
for &xi in x {
let idx = match y_indexed.binary_search_by(|probe| {
probe.0.partial_cmp(&xi).unwrap()
}) {
Ok(i) => i, // Exact match
Err(i) => {
if i == 0 {
0
} else if i == y_indexed.len() {
y_indexed.len() - 1
} else {
let left_dist = (xi - y_indexed[i - 1].0).abs();
let right_dist = (y_indexed[i].0 - xi).abs();
if left_dist <= right_dist {
i - 1
} else {
i
}
}
}
};
let orig_idx = y_indexed[idx].1;
res.push(x[orig_idx]);
}
Doubles::from_values(res)
}
extendr_module! {
mod xymatch_binary;
fn xymatch_rust_binary;
}
use extendr_api::prelude::*;
#[extendr]
fn xymatch_rust_loop(x: Vec<f64>, y: Vec<f64>) -> Vec<f64> {
let n = x.len();
let m = y.len();
let mut result = Vec::with_capacity(n);
for i in 0..n {
let xi = x[i];
let mut min_index = 0;
let mut min_dist = (xi - y[0]).abs();
for j in 1..m {
let dist = (xi - y[j]).abs();
if dist < min_dist {
min_index = j;
min_dist = dist;
}
}
result.push(x[min_index]);
}
result
}
extendr_module! {
mod xymatch_loop;
fn xymatch_rust_loop;
}
// $ clang++ -std=c++23 -O2 -Wall -fopenmp -I/usr/lib/llvm-22/include -I/home/jordi/vcpkg/installed/x64-linux/include -I/home/jordi/R/x86_64-pc-linux-gnu-library/4.5/RInside/include -I/home/jordi/R/x86_64-pc-linux-gnu-library/4.5/Rcpp/include -I/usr/share/R/include xymatch_rcpp.cpp -o xymatch_rcpp -L/usr/lib/llvm-22/lib -lomp -L/home/jordi/R/x86_64-pc-linux-gnu-library/4.5/RInside/lib -lRInside -lR -Wl,-rpath,/home/jordi/R/x86_64-pc-linux-gnu-library/4.5/RInside/lib && ./xymatch_rcpp
// > install.packages(c("Rcpp","RInside"))
// [[Rcpp::plugins(cpp23)]]
// [[Rcpp::depends(Rcpp)]]
// // [[Rcpp::depends(RInside)]]
// #include <RInside.h>
#include <Rcpp.h>
#include <iostream>
#include <ranges>
#include <vector>
#include <cmath>
#include <algorithm>
#ifdef _OPENMP
#include <omp.h> // clang++ --version / 22 => libomp-22-dev
#endif
#include <execution>
using namespace Rcpp;
// [[Rcpp::export]]
int n_threads_test() {
int n = 0;
#pragma omp parallel
{
#pragma omp single
n = omp_get_num_threads();
}
std::cout << "OpenMP threads: " << n << "\n";
return n;
}
// R code:
// xymatch <- function(x, y) {
// x[
// apply(
// outer(x, y, function(u, v) abs(u - v)),
// 1,
// function(d) which(d == min(d))[1]
// )
// ]
// }
double closest_x_range(const std::vector<double>& vx, const std::vector<double>& vy, double xi) {
auto min_it = std::ranges::min_element(
std::views::iota(size_t{0}, vy.size()),
{},
[&](size_t j){ return std::abs(vy[j] - xi); }
);
return vx[*min_it];
}
// [[Rcpp::export]]
NumericVector xymatch_cpp_ranges(NumericVector x, NumericVector y) {
std::vector<double> vx(x.begin(), x.end());
std::vector<double> vy(y.begin(), y.end());
auto result_view = vx | std::views::transform([&](double xi){
return closest_x_range(vx, vy, xi);
});
return NumericVector(result_view.begin(), result_view.end());
}
// [[Rcpp::export]]
NumericVector xymatch_cpp_range_par(NumericVector x, NumericVector y, int nthreads = 1) {
std::vector<double> vx(x.begin(), x.end());
std::vector<double> vy(y.begin(), y.end());
std::vector<double> res(vx.size());
#pragma omp parallel for num_threads(nthreads)
for (size_t i = 0; i < vx.size(); ++i) {
res[i] = closest_x_range(vx, vy, vx[i]);
}
return NumericVector(res.begin(), res.end());
}
double closest_x_loop(const std::vector<double>& vx, const std::vector<double>& vy, double xi) {
size_t min_j = 0;
double min_diff = std::abs(vy[0] - xi);
for (size_t j = 1; j < vy.size(); ++j) {
double diff = std::abs(vy[j] - xi);
if (diff < min_diff) {
min_diff = diff;
min_j = j;
}
}
return vx[min_j];
}
// [[Rcpp::export]]
NumericVector xymatch_cpp_loop(NumericVector x, NumericVector y) {
std::vector<double> vx(x.begin(), x.end());
std::vector<double> vy(y.begin(), y.end());
auto result_view = vx | std::views::transform([&](double xi){
return closest_x_loop(vx, vy, xi);
});
return NumericVector(result_view.begin(), result_view.end());
}
// [[Rcpp::export]]
NumericVector xymatch_cpp_loop_par(NumericVector x, NumericVector y, int nthreads = 1) {
std::vector<double> vx(x.begin(), x.end());
std::vector<double> vy(y.begin(), y.end());
std::vector<double> res(vx.size());
#pragma omp parallel for num_threads(nthreads)
for (size_t i = 0; i < vx.size(); ++i) {
res[i] = closest_x_loop(vx, vy, vx[i]);
}
return NumericVector(res.begin(), res.end());
}
// xymatch_r_binary <- function(x,y) {
// order <- order(y)
// y <- y[order]
// # Let's assume no NA for speediness
// out <- findInterval(x, y, checkSorted = FALSE, checkNA = FALSE)
// n <- length(y)
// idx <- integer(n) + n
// idx[out == 0] <- 1L
// nx <- (out != 0 & (out != n))
// one <- abs(y[out[nx]] - x[nx])
// two <- abs(y[out[nx] + 1] - x[nx])
// idx[nx] <- out[nx] + (one > two)
// x[order][idx]
// }
// [[Rcpp::export]]
NumericVector xymatch_cpp_binary(NumericVector x, NumericVector y) {
size_t ny = y.size();
std::vector<std::pair<double, size_t>> y_indexed(ny);
for (size_t j = 0; j < ny; ++j)
y_indexed[j] = {y[j], j};
std::ranges::sort(y_indexed, {}, &std::pair<double,size_t>::first);
auto result_view = x | std::views::transform([&](double xi) {
// Use lower_bound to find insertion point
auto it = std::ranges::lower_bound(
y_indexed,
xi,
{},
&std::pair<double,size_t>::first
);
auto candidates = [&]() -> std::vector<std::pair<double, size_t>> {
if (it == y_indexed.begin()) return { *it };
if (it == y_indexed.end()) return { *(it-1) };
return { *(it-1), *it };
}();
auto nearest = *std::ranges::min_element(
candidates,
[&](const auto &a, const auto &b){
return std::abs(a.first - xi) < std::abs(b.first - xi);
}
);
return x[nearest.second]; // return x[j] where j is nearest y
});
return NumericVector(result_view.begin(), result_view.end());
}
// [[Rcpp::export]]
NumericVector xymatch_cpp_binary_par(NumericVector x, NumericVector y, int nthreads = 1) {
size_t nx = x.size();
size_t ny = y.size();
// Step 1: sort y and keep original indices
std::vector<std::pair<double, size_t>> y_indexed(ny);
for (size_t j = 0; j < ny; ++j)
y_indexed[j] = {y[j], j};
std::ranges::sort(y_indexed, {}, &std::pair<double,size_t>::first);
// Step 2: result vector
std::vector<double> result(nx);
// Step 3: parallel loop over x
#pragma omp parallel for num_threads(nthreads)
for (ptrdiff_t i = 0; i < static_cast<ptrdiff_t>(nx); ++i) {
double xi = x[i];
// binary search in sorted y
auto it = std::ranges::lower_bound(
y_indexed,
xi,
{},
&std::pair<double,size_t>::first
);
size_t idx;
if (it == y_indexed.begin()) {
idx = 0;
} else if (it == y_indexed.end()) {
idx = ny - 1;
} else {
size_t right = std::distance(y_indexed.begin(), it);
size_t left = right - 1;
double left_dist = std::abs(xi - y_indexed[left].first);
double right_dist = std::abs(xi - y_indexed[right].first);
idx = (left_dist <= right_dist) ? left : right;
}
result[i] = x[y_indexed[idx].second];
}
return NumericVector(result.begin(), result.end());
}
// [[Rcpp::export]]
Rcpp::NumericMatrix outer_abs_diff_par(Rcpp::NumericVector a, Rcpp::NumericVector b, int nthreads = 0) {
// 1. Get dimensions
int N = a.length(); // Number of rows for the output
int M = b.length(); // Number of columns for the output
// 2. Create the output matrix (Rcpp handles memory outside the parallel region)
Rcpp::NumericMatrix out(N, M);
// 3. Set the number of threads (optional)
if (nthreads > 0) {
omp_set_num_threads(nthreads);
}
// 4. Parallelize the outer loop (i)
// The #pragma omp parallel for directive divides the 'i' loop iterations among threads.
// Each thread gets a unique set of 'i's and therefore works on separate, non-overlapping
// rows of the 'out' matrix, ensuring thread safety for writing.
#pragma omp parallel for
for (int i = 0; i < N; ++i) {
// Inner loop (j) runs serially within each thread's assigned 'i' iteration.
for (int j = 0; j < M; ++j) {
// Safe Write: out(i, j) is only ever written to by the thread assigned to row 'i'.
out(i, j) = std::abs(a[i] - b[j]);
}
}
return out;
}
int main(int argc, char *argv[]) {
std::vector<double> vx{0.2655087, 0.3721239, 0.5728534};
std::vector<double> vy{0.9082078, 0.2016819, 0.8983897};
std::cout << "x: ";
for (double v : vx) std::cout << v << " ";
std::cout << "\n";
std::cout << "y: ";
for (double v : vy) std::cout << v << " ";
std::cout << "\n";
// Print results closest_x_loop version
std::cout << "xymatch closest_x_loop: ";
auto result_view_loop = vx | std::views::transform([&](double xi){
return closest_x_loop(vx, vy, xi);
});
std::vector<double> result_loop(result_view_loop.begin(), result_view_loop.end());
for (double v : result_loop) std::cout << v << " ";
std::cout << "\n";
// Print results closest_x_range version
std::cout << "xymatch closest_x_range: ";
auto result_view_range = vx | std::views::transform([&](double xi){
return closest_x_range(vx, vy, xi);
});
std::vector<double> result_range(result_view_range.begin(), result_view_range.end());
for (double v : result_range) std::cout << v << " ";
std::cout << "\n";
// Print results R version
// RInside R(argc, argv);
// R["vx"] = NumericVector(vx.begin(), vx.end());
// R["vy"] = NumericVector(vy.begin(), vy.end());
// R.parseEvalQ("xymatch <- function(x, y) { x[apply(outer(x, y, function(u,v) abs(u-v)), 1, function(d) which(d==min(d))[1])] }");
// NumericVector result_r = R.parseEval("xymatch(vx, vy)");
// std::cout << "xymatch R: ";
// for (double v : result_r) std::cout << v << " ";
// std::cout << "\n";
return 0;
}
$ lscpu
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 39 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 20
On-line CPU(s) list: 0-19
Vendor ID: GenuineIntel
Model name: 13th Gen Intel(R) Core(TM) i7-13700H
CPU family: 6
Model: 186
Thread(s) per core: 2
Core(s) per socket: 14
Socket(s): 1
Stepping: 2
CPU(s) scaling MHz: 15%
CPU max MHz: 5000,0000
CPU min MHz: 400,0000
@TimTaylor
Copy link

Cool - I'll take a gander. I wonder if it was written before the byte-code compiler was introduced (see https://homepage.divms.uiowa.edu/~luke/talks/Riot-2019.pdf for background).

I actually started with

N <- length(x)
idx <- integer(N)
for (i in seq_len(N)) {
    idx[i] <- which.min(abs(x[i]-y))
}
x[idx]

which is similar/quicker but converted to one-liner as thought interesting.

@jrosell
Copy link
Author

jrosell commented Oct 9, 2025

@TimTaylor Oh. That's so cool. In fact, I remember that months ago I was suprised to see that a preallocate + a for loop was faster than a vectorized function in R 4.5.

(It was from 2010)

@TimTaylor
Copy link

TimTaylor commented Oct 9, 2025

I was suprised to see that a preallocate + a for loop was faster than a vectorized function.

In both the examples, only the x[i] - y is vectorised (i.e. recycling and subtraction done in C). I think the main difference is the slight function overhead of the vapply call.

On that note - I'm not 100% sure what time the bytecode compiler does it's stuff. I've found in the past that I get more stable bench marks for quicker operations by using microbenchmark::microbenchmark() over multiple (e.g 100ish runs). I think bench::mark() tries to be clever but then only runs a handful of times (by default). I'm not sure if that then means that one of the calls is unoptimised and can effect the overall timings more (:shrug: - slightly out of my depth).

Related - I generally find packaged functions quicker then those defined in my current session. Not sure if this is for related reasons.

Anyways - cheers for the fun example!

@jrosell
Copy link
Author

jrosell commented Oct 9, 2025

@TimTaylor I added the r loop, r vapply solutions. They are slower than the rust/C++23 solutions, but they are so much faster than the original 2010 solution.

@JosiahParry
Copy link

JosiahParry commented Oct 10, 2025

Here is an alternative approach to this problem—it's not apples to apples whatsoever. But it is a more efficient way of approaching the problem, I think, and it is only running on a single thread.

We create a KD-tree and query the kd tree. This will have much better performance.

On two vectors of size 1e6 it takes 450ms on my mac M1

use kiddo::KdTree;
use kiddo::SquaredEuclidean;

#[extendr]
fn xymatch(x: &[f64], y: &[f64]) -> Doubles {
    let mut tree = KdTree::<_, 1>::new();

    for (i, yi) in y.iter().enumerate() {
        tree.add(&[*yi], i as u64);
    }

    let mut res = Vec::with_capacity(x.len());

    for xi in x {
        let idx = tree.nearest_one::<SquaredEuclidean>(&[*xi]).item as usize;
        res.push(x[idx])
    }

    Doubles::from_values(res)
}

Edit: the best approach is likely a binary search

@TimTaylor
Copy link

TimTaylor commented Oct 10, 2025

Funsies - @JosiahParry - how does this compare ...

xymatch_binary2 <- function(x,y) {
    order <- order(y)
    y <- y[order]
    # Let's assume no NA for speediness
    out <- findInterval(x, y, checkSorted = FALSE, checkNA = FALSE)
    n <- length(y)
    idx <- integer(n) + n
    idx[out == 0] <- 1L
    nx <- (out != 0 & (out != n))
    one <- abs(y[out[nx]] - x[nx])
    two <- abs(y[out[nx] + 1] - x[nx])
    idx[nx] <- out[nx] + (one > two)
    x[order][idx]
}

@jrosell - what have you started 😅 (and yes this completely ignores what you were illustrating but t'is fun)

EDIT: Added minor tweak

@jrosell
Copy link
Author

jrosell commented Oct 10, 2025

It was not me, it was Ross Ihaka :)

@JosiahParry
Copy link

@TimTaylor the R approach is very good! Adding a binary search for Rust—i'm not good at this stuff so i bet there is a good approach to this.
The KD-tree implementation above fails to match the R implementation at 1e5 probably due to floating point precision with the squared euclidean distance calculation. the Rust binary search works well though!

This could be a good function to put in base R :)

#[extendr]
fn xymatch_binsearch(x: &[f64], y: &[f64]) -> Doubles {
    let mut y_indexed: Vec<(f64, usize)> = y.iter()
        .enumerate()
        .map(|(i, &val)| (val, i))
        .collect();

    y_indexed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());

    let mut res = Vec::with_capacity(x.len());

    for &xi in x {
        let idx = match y_indexed.binary_search_by(|probe| {
            probe.0.partial_cmp(&xi).unwrap()
        }) {
            Ok(i) => i,  // Exact match
            Err(i) => {
                if i == 0 {
                    0
                } else if i == y_indexed.len() {
                    y_indexed.len() - 1
                } else {
                    let left_dist = (xi - y_indexed[i - 1].0).abs();
                    let right_dist = (y_indexed[i].0 - xi).abs();
                    if left_dist <= right_dist {
                        i - 1
                    } else {
                        i
                    }
                }
            }
        };

        let orig_idx = y_indexed[idx].1;
        res.push(x[orig_idx]);
    }

    Doubles::from_values(res)
}

Relative performance.

# A tibble: 3 × 4
  expression  median total_time n_itr
  <bch:expr>   <dbl>   <bch:tm> <int>
1 r_binary      1.44      9.27s  5231
2 rust_binary   1           10s  8263
3 rust_kdtree   1.59        10s  5077

@jrosell
Copy link
Author

jrosell commented Oct 10, 2025

Thanks @JosiahParry and @TimTaylor. I've just updated the results with your contributions and added cpp binary and cpp binary with 8 threads.

@albersonmiranda
Copy link

Here's @JosiahParry bin search in parallel:

use extendr_api::prelude::*;
use rayon::prelude::*;

#[extendr]
pub fn xymatch_rust_binary_parallel(x: &[f64], y: &[f64]) -> Doubles {
    let mut y_indexed: Vec<(f64, usize)> = y.iter()
        .enumerate()
        .map(|(i, &val)| (val, i))
        .collect();

    y_indexed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());

    let res: Vec<f64> = x.par_iter()
        .map(|&xi| {
            let idx = match y_indexed.binary_search_by(|probe| {
                probe.0.partial_cmp(&xi).unwrap()
            }) {
                Ok(i) => i,
                Err(i) => {
                    if i == 0 {
                        0
                    } else if i == y_indexed.len() {
                        y_indexed.len() - 1
                    } else {
                        let left_dist = (xi - y_indexed[i - 1].0).abs();
                        let right_dist = (y_indexed[i].0 - xi).abs();
                        if left_dist <= right_dist {
                            i - 1
                        } else {
                            i
                        }
                    }
                }
            };
            let orig_idx = y_indexed[idx].1;
            x[orig_idx]
        })
        .collect();

    Doubles::from_values(res)
}

Obs.: Rayon must be listed as dependency in rust_source(), e.g.,

rextendr::rust_source(
  file = "xymatch_binary_parallel.rs",
  module = "xymatch_binary_parallel",
  profile = "release",
  dependencies = list(`rayon` = "1.11")
)

@jrosell
Copy link
Author

jrosell commented Oct 11, 2025

I want to set 8 threads...

use extendr_api::prelude::*;
use rayon::prelude::*;
use rayon::ThreadPoolBuilder;

/// Parallel version of xymatch with optional number of threads.
///
/// @param x Numeric vector.
/// @param y Numeric vector.
/// @param nthreads Optional number of threads.
/// @export
#[extendr]
pub fn xymatch_rust_binary_par(x: &[f64], y: &[f64], nthreads: Option<usize>) -> Doubles {
    // Copy R slices into owned Vecs (safe for threads)
    let x_vec = x.to_vec();
    let y_vec = y.to_vec();

    // Run computation in parallel using Rayon pool if requested
    let res_vec: Vec<f64> = if let Some(n) = nthreads {
        let pool = ThreadPoolBuilder::new()
            .num_threads(n)
            .build()
            .expect("Failed to build Rayon thread pool");
        // Return Vec<f64> (Send)
        pool.install(|| run_xymatch_owned(&x_vec, &y_vec))
    } else {
        // Use default global Rayon pool
        run_xymatch_owned(&x_vec, &y_vec)
    };

    // Convert back to R type AFTER threading (safe)
    Doubles::from_values(res_vec)
}

/// Pure Rust computation — returns Vec<f64> (Send)
fn run_xymatch_owned(x: &[f64], y: &[f64]) -> Vec<f64> {
    // Pair y values with their original indices
    let mut y_indexed: Vec<(f64, usize)> = y
        .iter()
        .enumerate()
        .map(|(i, &val)| (val, i))
        .collect();

    // Sort by y value
    y_indexed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());

    // Parallel computation over x
    x.par_iter()
        .map(|&xi| {
            let idx = match y_indexed.binary_search_by(|probe| probe.0.partial_cmp(&xi).unwrap()) {
                Ok(i) => i,
                Err(i) => {
                    if i == 0 {
                        0
                    } else if i == y_indexed.len() {
                        y_indexed.len() - 1
                    } else {
                        let left_dist = (xi - y_indexed[i - 1].0).abs();
                        let right_dist = (y_indexed[i].0 - xi).abs();
                        if left_dist <= right_dist {
                            i - 1
                        } else {
                            i
                        }
                    }
                }
            };
            let orig_idx = y_indexed[idx].1;
            x[orig_idx]
        })
        .collect()
}

// Export to R
extendr_module! {
    mod xymatch_binary_par;
    fn xymatch_rust_binary_par;
}

Is there something wrong? Check the results:

# A tibble: 13 × 4
   expression       median total_time n_itr
   <bch:expr>        <dbl>   <bch:tm> <int>
 1 r_2010          3239.       11.57s     5
 2 r_binary           2.14        10s  6456
 3 r_vapply         339.       10.12s    41
 4 r_loop           255.       10.01s    51
 5 cpp_loop         149.       10.04s    95
 6 cpp_loop_par      19.5      10.01s   722
 7 cpp_ranges        76.2      10.05s   186
 8 cpp_range_par     19.2      10.02s   745
 9 cpp_binary         1.95        10s  7225
10 cpp_binary_par     1         7.22s 10000
11 rust_loop        149.       10.05s    95
12 rust_binary        1.67        10s  8423
13 rust_binary_par    7.23        10s  2311
``

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment