Skip to content

Instantly share code, notes, and snippets.

@jrosell
Last active October 13, 2025 14:36
Show Gist options
  • Select an option

  • Save jrosell/941b10596d496f5d56f49a192f1a8343 to your computer and use it in GitHub Desktop.

Select an option

Save jrosell/941b10596d496f5d56f49a192f1a8343 to your computer and use it in GitHub Desktop.
library(rextendr)
library(mirai)
rextendr::rust_source(
file = "rust/xymatch_loop.rs",
module = "xymatch_loop",
profile = "release"
)
rextendr::rust_source(
file = "rust/xymatch_binary.rs",
module = "xymatch_binary",
profile = "release",
)
rextendr::rust_source(
file = "rust/xymatch_binary_par.rs",
module = "xymatch_binary_par",
profile = "release",
dependencies = list(`rayon` = "1.11")
)
library(Rcpp)
old_cxx <- Sys.getenv("PKG_CXXFLAGS")
old_lib <- Sys.getenv("PKG_LIBS")
Sys.setenv(PKG_CXXFLAGS = "-fopenmp -O3")
Sys.setenv(PKG_LIBS = "-fopenmp")
Rcpp::sourceCpp("xymatch_rcpp.cpp")
Sys.setenv(PKG_CXXFLAGS = old_cxx)
Sys.setenv(PKG_LIBS = old_lib)
n_threads_test()
# [1] 20
xymatch_r_2010 <- function(x, y) {
x[
apply(
outer(x, y, function(u, v) abs(u - v)),
1,
function(d) which(d == min(d))[1]
)
]
}
xymatch_r_2010_par <- function(x, y, nthreads = 1) {
x[
apply(
outer_abs_diff_par(x, y, nthreads = nthreads),
1,
function(d) which(d == min(d))[1]
)
]
}
xymatch_r_vapply <- function(x, y) {
vapply(seq_along(x), \(i) x[which.min(abs(x[i]-y))], 1)
}
xymatch_r_loop <- function(x, y) {
idx <- integer(N)
for (i in seq_len(N)) {
idx[i] <- which.min(abs(x[i]-y))
}
x[idx]
}
xymatch_r_binary <- function(x, y) {
order <- order(y)
y <- y[order]
out <- findInterval(x, y, checkSorted = FALSE, checkNA = FALSE)
n <- length(y)
idx <- integer(n) + n
idx[out == 0] <- 1L
nx <- (out != 0 & (out != n))
one <- abs(y[out[nx]] - x[nx])
two <- abs(y[out[nx] + 1] - x[nx])
idx[nx] <- out[nx] + (one > two)
x[order][idx]
}
# Corrected Wrapper function for parallel execution
xymatch_r_binary_par <- function(x, y, nchunks = 1) {
order <- order(y)
y <- y[order]
mirai::daemons(nchunks)
mirai::everywhere({}, y_val = y)
if (nchunks < 1) stop()
if(nchunks == 1) groups <- x
if(nchunks > 1) {
groups <- cut(
seq_along(x),
breaks = nchunks,
labels = FALSE,
include.lowest = TRUE
)
}
x_chunks <- split(x, groups)
idx_chunks_mirai <- mirai_map(
x_chunks,
.f = function(chunk) {
out <- findInterval(chunk, y_val, checkSorted = TRUE, checkNA = FALSE)
n_y <- length(y_val)
n_chunk <- length(chunk)
idx <- integer(n_chunk) + n_y
idx[out == 0] <- 1L
nx <- (out != 0 & (out != n_y))
if (any(nx)) {
one <- abs(y_val[out[nx]] - chunk[nx])
two <- abs(y_val[out[nx] + 1] - chunk[nx])
idx[nx] <- out[nx] + (one > two)
}
return(idx)
}
)
idx <- unlist(idx_chunks_mirai[.flat])
x[order][idx]
}
N <- 1e4
set.seed(1)
x <- runif(N)
y <- runif(N)
bench::mark(
r_2010 = xymatch_r_2010(x, y),
r_2010_par = xymatch_r_2010_par(x, y, nthreads = 8),
r_binary = xymatch_r_binary(x, y),
r_vapply = xymatch_r_loop(x, y),
r_loop = xymatch_r_vapply(x, y),
r_binary_par = xymatch_r_binary_par(x, y, nchunks = 8),
cpp_loop = xymatch_cpp_loop(x, y),
cpp_loop_par = xymatch_cpp_loop_par(x, y, nthreads = 8),
cpp_ranges = xymatch_cpp_ranges(x, y),
cpp_range_par = xymatch_cpp_range_par(x, y, nthreads = 8),
cpp_binary = xymatch_cpp_binary(x, y),
cpp_binary_par = xymatch_cpp_binary_par(x, y, nthreads = 8),
rust_loop = xymatch_rust_loop(x, y),
rust_binary = xymatch_rust_binary(x, y),
rust_binary_par = xymatch_rust_binary_par(x, y, nthreads = 8),
relative = TRUE,
min_time = 10,
check = TRUE
)[, c("expression", "median", "total_time", "n_itr")]
# # A tibble: 15 × 4
# expression median total_time n_itr
# <bch:expr> <dbl> <bch:tm> <int>
# 1 r_2010 3403. 10.2s 4
# 2 r_2010_par 2594. 11.7s 6
# 3 r_binary 1.97 10s 6661
# 4 r_vapply 256. 10s 47
# 5 r_loop 241. 10.1s 52
# 6 r_binary_par 5021. 11.4s 3
# 7 cpp_loop 140. 10s 95
# 8 cpp_loop_par 18.4 10s 719
# 9 cpp_ranges 72.1 10s 184
# 10 cpp_range_par 20.0 10s 641
# 11 cpp_binary 1.88 10s 7013
# 12 cpp_binary_par 1 7.8s 10000
# 13 rust_loop 140. 10.1s 95
# 14 rust_binary 1.59 10s 8327
# 15 rust_binary_par 1.68 10s 4851
N <- 21250
set.seed(1)
x <- runif(N)
y <- runif(N)
bench::mark(
r_2010 = xymatch_r_2010(x, y),
cpp_range_par = xymatch_cpp_range_par(x, y, nthreads = 8),
r_binary = xymatch_r_binary(x, y),
r_binary_par = xymatch_r_binary_par(x, y, nchunks = 8),
cpp_binary = xymatch_cpp_binary(x, y),
cpp_binary_par = xymatch_cpp_binary_par(x, y, nthreads = 8),
rust_binary = xymatch_rust_binary(x, y),
rust_binary_par = xymatch_rust_binary_par(x, y, nthreads = 8),
relative = TRUE,
min_time = 10,
check = TRUE
)[, c("expression", "median", "total_time", "n_itr")]
# # A tibble: 8 × 4
# expression median total_time n_itr
# <bch:expr> <dbl> <bch:tm> <int>
# 1 r_2010 6396. 10.6s 1
# 2 cpp_range_par 28.6 10s 205
# 3 r_binary 2.11 10s 2736
# 4 r_binary_par 3366. 11.2s 2
# 5 cpp_binary 1.92 10s 3124
# 6 cpp_binary_par 1 10s 6000
# 7 rust_binary 1.69 10s 3573
# 8 rust_binary_par 5.32 10s 1324
// Autor: @JosiahParry
use extendr_api::prelude::*;
#[extendr]
fn xymatch_rust_binary(x: &[f64], y: &[f64]) -> Doubles {
let mut y_indexed: Vec<(f64, usize)> = y.iter()
.enumerate()
.map(|(i, &val)| (val, i))
.collect();
y_indexed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let mut res = Vec::with_capacity(x.len());
for &xi in x {
let idx = match y_indexed.binary_search_by(|probe| {
probe.0.partial_cmp(&xi).unwrap()
}) {
Ok(i) => i, // Exact match
Err(i) => {
if i == 0 {
0
} else if i == y_indexed.len() {
y_indexed.len() - 1
} else {
let left_dist = (xi - y_indexed[i - 1].0).abs();
let right_dist = (y_indexed[i].0 - xi).abs();
if left_dist <= right_dist {
i - 1
} else {
i
}
}
}
};
let orig_idx = y_indexed[idx].1;
res.push(x[orig_idx]);
}
Doubles::from_values(res)
}
extendr_module! {
mod xymatch_binary;
fn xymatch_rust_binary;
}
use extendr_api::prelude::*;
#[extendr]
fn xymatch_rust_loop(x: Vec<f64>, y: Vec<f64>) -> Vec<f64> {
let n = x.len();
let m = y.len();
let mut result = Vec::with_capacity(n);
for i in 0..n {
let xi = x[i];
let mut min_index = 0;
let mut min_dist = (xi - y[0]).abs();
for j in 1..m {
let dist = (xi - y[j]).abs();
if dist < min_dist {
min_index = j;
min_dist = dist;
}
}
result.push(x[min_index]);
}
result
}
extendr_module! {
mod xymatch_loop;
fn xymatch_rust_loop;
}
// $ clang++ -std=c++23 -O2 -Wall -fopenmp -I/usr/lib/llvm-22/include -I/home/jordi/vcpkg/installed/x64-linux/include -I/home/jordi/R/x86_64-pc-linux-gnu-library/4.5/RInside/include -I/home/jordi/R/x86_64-pc-linux-gnu-library/4.5/Rcpp/include -I/usr/share/R/include xymatch_rcpp.cpp -o xymatch_rcpp -L/usr/lib/llvm-22/lib -lomp -L/home/jordi/R/x86_64-pc-linux-gnu-library/4.5/RInside/lib -lRInside -lR -Wl,-rpath,/home/jordi/R/x86_64-pc-linux-gnu-library/4.5/RInside/lib && ./xymatch_rcpp
// > install.packages(c("Rcpp","RInside"))
// [[Rcpp::plugins(cpp23)]]
// [[Rcpp::depends(Rcpp)]]
// // [[Rcpp::depends(RInside)]]
// #include <RInside.h>
#include <Rcpp.h>
#include <iostream>
#include <ranges>
#include <vector>
#include <cmath>
#include <algorithm>
#ifdef _OPENMP
#include <omp.h> // clang++ --version / 22 => libomp-22-dev
#endif
#include <execution>
using namespace Rcpp;
// [[Rcpp::export]]
int n_threads_test() {
int n = 0;
#pragma omp parallel
{
#pragma omp single
n = omp_get_num_threads();
}
std::cout << "OpenMP threads: " << n << "\n";
return n;
}
// R code:
// xymatch <- function(x, y) {
// x[
// apply(
// outer(x, y, function(u, v) abs(u - v)),
// 1,
// function(d) which(d == min(d))[1]
// )
// ]
// }
double closest_x_range(const std::vector<double>& vx, const std::vector<double>& vy, double xi) {
auto min_it = std::ranges::min_element(
std::views::iota(size_t{0}, vy.size()),
{},
[&](size_t j){ return std::abs(vy[j] - xi); }
);
return vx[*min_it];
}
// [[Rcpp::export]]
NumericVector xymatch_cpp_ranges(NumericVector x, NumericVector y) {
std::vector<double> vx(x.begin(), x.end());
std::vector<double> vy(y.begin(), y.end());
auto result_view = vx | std::views::transform([&](double xi){
return closest_x_range(vx, vy, xi);
});
return NumericVector(result_view.begin(), result_view.end());
}
// [[Rcpp::export]]
NumericVector xymatch_cpp_range_par(NumericVector x, NumericVector y, int nthreads = 1) {
std::vector<double> vx(x.begin(), x.end());
std::vector<double> vy(y.begin(), y.end());
std::vector<double> res(vx.size());
#pragma omp parallel for num_threads(nthreads)
for (size_t i = 0; i < vx.size(); ++i) {
res[i] = closest_x_range(vx, vy, vx[i]);
}
return NumericVector(res.begin(), res.end());
}
double closest_x_loop(const std::vector<double>& vx, const std::vector<double>& vy, double xi) {
size_t min_j = 0;
double min_diff = std::abs(vy[0] - xi);
for (size_t j = 1; j < vy.size(); ++j) {
double diff = std::abs(vy[j] - xi);
if (diff < min_diff) {
min_diff = diff;
min_j = j;
}
}
return vx[min_j];
}
// [[Rcpp::export]]
NumericVector xymatch_cpp_loop(NumericVector x, NumericVector y) {
std::vector<double> vx(x.begin(), x.end());
std::vector<double> vy(y.begin(), y.end());
auto result_view = vx | std::views::transform([&](double xi){
return closest_x_loop(vx, vy, xi);
});
return NumericVector(result_view.begin(), result_view.end());
}
// [[Rcpp::export]]
NumericVector xymatch_cpp_loop_par(NumericVector x, NumericVector y, int nthreads = 1) {
std::vector<double> vx(x.begin(), x.end());
std::vector<double> vy(y.begin(), y.end());
std::vector<double> res(vx.size());
#pragma omp parallel for num_threads(nthreads)
for (size_t i = 0; i < vx.size(); ++i) {
res[i] = closest_x_loop(vx, vy, vx[i]);
}
return NumericVector(res.begin(), res.end());
}
// xymatch_r_binary <- function(x,y) {
// order <- order(y)
// y <- y[order]
// # Let's assume no NA for speediness
// out <- findInterval(x, y, checkSorted = FALSE, checkNA = FALSE)
// n <- length(y)
// idx <- integer(n) + n
// idx[out == 0] <- 1L
// nx <- (out != 0 & (out != n))
// one <- abs(y[out[nx]] - x[nx])
// two <- abs(y[out[nx] + 1] - x[nx])
// idx[nx] <- out[nx] + (one > two)
// x[order][idx]
// }
// [[Rcpp::export]]
NumericVector xymatch_cpp_binary(NumericVector x, NumericVector y) {
size_t ny = y.size();
std::vector<std::pair<double, size_t>> y_indexed(ny);
for (size_t j = 0; j < ny; ++j)
y_indexed[j] = {y[j], j};
std::ranges::sort(y_indexed, {}, &std::pair<double,size_t>::first);
auto result_view = x | std::views::transform([&](double xi) {
// Use lower_bound to find insertion point
auto it = std::ranges::lower_bound(
y_indexed,
xi,
{},
&std::pair<double,size_t>::first
);
auto candidates = [&]() -> std::vector<std::pair<double, size_t>> {
if (it == y_indexed.begin()) return { *it };
if (it == y_indexed.end()) return { *(it-1) };
return { *(it-1), *it };
}();
auto nearest = *std::ranges::min_element(
candidates,
[&](const auto &a, const auto &b){
return std::abs(a.first - xi) < std::abs(b.first - xi);
}
);
return x[nearest.second]; // return x[j] where j is nearest y
});
return NumericVector(result_view.begin(), result_view.end());
}
// [[Rcpp::export]]
NumericVector xymatch_cpp_binary_par(NumericVector x, NumericVector y, int nthreads = 1) {
size_t nx = x.size();
size_t ny = y.size();
// Step 1: sort y and keep original indices
std::vector<std::pair<double, size_t>> y_indexed(ny);
for (size_t j = 0; j < ny; ++j)
y_indexed[j] = {y[j], j};
std::ranges::sort(y_indexed, {}, &std::pair<double,size_t>::first);
// Step 2: result vector
std::vector<double> result(nx);
// Step 3: parallel loop over x
#pragma omp parallel for num_threads(nthreads)
for (ptrdiff_t i = 0; i < static_cast<ptrdiff_t>(nx); ++i) {
double xi = x[i];
// binary search in sorted y
auto it = std::ranges::lower_bound(
y_indexed,
xi,
{},
&std::pair<double,size_t>::first
);
size_t idx;
if (it == y_indexed.begin()) {
idx = 0;
} else if (it == y_indexed.end()) {
idx = ny - 1;
} else {
size_t right = std::distance(y_indexed.begin(), it);
size_t left = right - 1;
double left_dist = std::abs(xi - y_indexed[left].first);
double right_dist = std::abs(xi - y_indexed[right].first);
idx = (left_dist <= right_dist) ? left : right;
}
result[i] = x[y_indexed[idx].second];
}
return NumericVector(result.begin(), result.end());
}
// [[Rcpp::export]]
Rcpp::NumericMatrix outer_abs_diff_par(Rcpp::NumericVector a, Rcpp::NumericVector b, int nthreads = 0) {
// 1. Get dimensions
int N = a.length(); // Number of rows for the output
int M = b.length(); // Number of columns for the output
// 2. Create the output matrix (Rcpp handles memory outside the parallel region)
Rcpp::NumericMatrix out(N, M);
// 3. Set the number of threads (optional)
if (nthreads > 0) {
omp_set_num_threads(nthreads);
}
// 4. Parallelize the outer loop (i)
// The #pragma omp parallel for directive divides the 'i' loop iterations among threads.
// Each thread gets a unique set of 'i's and therefore works on separate, non-overlapping
// rows of the 'out' matrix, ensuring thread safety for writing.
#pragma omp parallel for
for (int i = 0; i < N; ++i) {
// Inner loop (j) runs serially within each thread's assigned 'i' iteration.
for (int j = 0; j < M; ++j) {
// Safe Write: out(i, j) is only ever written to by the thread assigned to row 'i'.
out(i, j) = std::abs(a[i] - b[j]);
}
}
return out;
}
int main(int argc, char *argv[]) {
std::vector<double> vx{0.2655087, 0.3721239, 0.5728534};
std::vector<double> vy{0.9082078, 0.2016819, 0.8983897};
std::cout << "x: ";
for (double v : vx) std::cout << v << " ";
std::cout << "\n";
std::cout << "y: ";
for (double v : vy) std::cout << v << " ";
std::cout << "\n";
// Print results closest_x_loop version
std::cout << "xymatch closest_x_loop: ";
auto result_view_loop = vx | std::views::transform([&](double xi){
return closest_x_loop(vx, vy, xi);
});
std::vector<double> result_loop(result_view_loop.begin(), result_view_loop.end());
for (double v : result_loop) std::cout << v << " ";
std::cout << "\n";
// Print results closest_x_range version
std::cout << "xymatch closest_x_range: ";
auto result_view_range = vx | std::views::transform([&](double xi){
return closest_x_range(vx, vy, xi);
});
std::vector<double> result_range(result_view_range.begin(), result_view_range.end());
for (double v : result_range) std::cout << v << " ";
std::cout << "\n";
// Print results R version
// RInside R(argc, argv);
// R["vx"] = NumericVector(vx.begin(), vx.end());
// R["vy"] = NumericVector(vy.begin(), vy.end());
// R.parseEvalQ("xymatch <- function(x, y) { x[apply(outer(x, y, function(u,v) abs(u-v)), 1, function(d) which(d==min(d))[1])] }");
// NumericVector result_r = R.parseEval("xymatch(vx, vy)");
// std::cout << "xymatch R: ";
// for (double v : result_r) std::cout << v << " ";
// std::cout << "\n";
return 0;
}
$ lscpu
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 39 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 20
On-line CPU(s) list: 0-19
Vendor ID: GenuineIntel
Model name: 13th Gen Intel(R) Core(TM) i7-13700H
CPU family: 6
Model: 186
Thread(s) per core: 2
Core(s) per socket: 14
Socket(s): 1
Stepping: 2
CPU(s) scaling MHz: 15%
CPU max MHz: 5000,0000
CPU min MHz: 400,0000
@TimTaylor
Copy link

Somewhat separate to the benchmarks, Per description “for each y, find the closest value in x”, should the margin argument to the apply call be 2 not 1?

@jrosell
Copy link
Author

jrosell commented Oct 9, 2025

@TimTaylor I copied the function from the paper "R: Lessons Learned, Directions for the Future."

I would have said:
"Match x values to their closest y element"

@TimTaylor
Copy link

Fair enough. In that case may be worth adding the following comparison which, for me at least, is a quicker R implementation

vapply(seq_along(x), \(i) x[which.min(abs(x[i]-y))], 1)

@jrosell
Copy link
Author

jrosell commented Oct 9, 2025

I'll try it. BTW, if you read the paper, his point is that we are thinking too much how to optimize code so that it runs fast in R. 😬

@TimTaylor
Copy link

Cool - I'll take a gander. I wonder if it was written before the byte-code compiler was introduced (see https://homepage.divms.uiowa.edu/~luke/talks/Riot-2019.pdf for background).

I actually started with

N <- length(x)
idx <- integer(N)
for (i in seq_len(N)) {
    idx[i] <- which.min(abs(x[i]-y))
}
x[idx]

which is similar/quicker but converted to one-liner as thought interesting.

@jrosell
Copy link
Author

jrosell commented Oct 9, 2025

@TimTaylor Oh. That's so cool. In fact, I remember that months ago I was suprised to see that a preallocate + a for loop was faster than a vectorized function in R 4.5.

(It was from 2010)

@TimTaylor
Copy link

TimTaylor commented Oct 9, 2025

I was suprised to see that a preallocate + a for loop was faster than a vectorized function.

In both the examples, only the x[i] - y is vectorised (i.e. recycling and subtraction done in C). I think the main difference is the slight function overhead of the vapply call.

On that note - I'm not 100% sure what time the bytecode compiler does it's stuff. I've found in the past that I get more stable bench marks for quicker operations by using microbenchmark::microbenchmark() over multiple (e.g 100ish runs). I think bench::mark() tries to be clever but then only runs a handful of times (by default). I'm not sure if that then means that one of the calls is unoptimised and can effect the overall timings more (:shrug: - slightly out of my depth).

Related - I generally find packaged functions quicker then those defined in my current session. Not sure if this is for related reasons.

Anyways - cheers for the fun example!

@jrosell
Copy link
Author

jrosell commented Oct 9, 2025

@TimTaylor I added the r loop, r vapply solutions. They are slower than the rust/C++23 solutions, but they are so much faster than the original 2010 solution.

@JosiahParry
Copy link

JosiahParry commented Oct 10, 2025

Here is an alternative approach to this problem—it's not apples to apples whatsoever. But it is a more efficient way of approaching the problem, I think, and it is only running on a single thread.

We create a KD-tree and query the kd tree. This will have much better performance.

On two vectors of size 1e6 it takes 450ms on my mac M1

use kiddo::KdTree;
use kiddo::SquaredEuclidean;

#[extendr]
fn xymatch(x: &[f64], y: &[f64]) -> Doubles {
    let mut tree = KdTree::<_, 1>::new();

    for (i, yi) in y.iter().enumerate() {
        tree.add(&[*yi], i as u64);
    }

    let mut res = Vec::with_capacity(x.len());

    for xi in x {
        let idx = tree.nearest_one::<SquaredEuclidean>(&[*xi]).item as usize;
        res.push(x[idx])
    }

    Doubles::from_values(res)
}

Edit: the best approach is likely a binary search

@TimTaylor
Copy link

TimTaylor commented Oct 10, 2025

Funsies - @JosiahParry - how does this compare ...

xymatch_binary2 <- function(x,y) {
    order <- order(y)
    y <- y[order]
    # Let's assume no NA for speediness
    out <- findInterval(x, y, checkSorted = FALSE, checkNA = FALSE)
    n <- length(y)
    idx <- integer(n) + n
    idx[out == 0] <- 1L
    nx <- (out != 0 & (out != n))
    one <- abs(y[out[nx]] - x[nx])
    two <- abs(y[out[nx] + 1] - x[nx])
    idx[nx] <- out[nx] + (one > two)
    x[order][idx]
}

@jrosell - what have you started 😅 (and yes this completely ignores what you were illustrating but t'is fun)

EDIT: Added minor tweak

@jrosell
Copy link
Author

jrosell commented Oct 10, 2025

It was not me, it was Ross Ihaka :)

@JosiahParry
Copy link

@TimTaylor the R approach is very good! Adding a binary search for Rust—i'm not good at this stuff so i bet there is a good approach to this.
The KD-tree implementation above fails to match the R implementation at 1e5 probably due to floating point precision with the squared euclidean distance calculation. the Rust binary search works well though!

This could be a good function to put in base R :)

#[extendr]
fn xymatch_binsearch(x: &[f64], y: &[f64]) -> Doubles {
    let mut y_indexed: Vec<(f64, usize)> = y.iter()
        .enumerate()
        .map(|(i, &val)| (val, i))
        .collect();

    y_indexed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());

    let mut res = Vec::with_capacity(x.len());

    for &xi in x {
        let idx = match y_indexed.binary_search_by(|probe| {
            probe.0.partial_cmp(&xi).unwrap()
        }) {
            Ok(i) => i,  // Exact match
            Err(i) => {
                if i == 0 {
                    0
                } else if i == y_indexed.len() {
                    y_indexed.len() - 1
                } else {
                    let left_dist = (xi - y_indexed[i - 1].0).abs();
                    let right_dist = (y_indexed[i].0 - xi).abs();
                    if left_dist <= right_dist {
                        i - 1
                    } else {
                        i
                    }
                }
            }
        };

        let orig_idx = y_indexed[idx].1;
        res.push(x[orig_idx]);
    }

    Doubles::from_values(res)
}

Relative performance.

# A tibble: 3 × 4
  expression  median total_time n_itr
  <bch:expr>   <dbl>   <bch:tm> <int>
1 r_binary      1.44      9.27s  5231
2 rust_binary   1           10s  8263
3 rust_kdtree   1.59        10s  5077

@jrosell
Copy link
Author

jrosell commented Oct 10, 2025

Thanks @JosiahParry and @TimTaylor. I've just updated the results with your contributions and added cpp binary and cpp binary with 8 threads.

@albersonmiranda
Copy link

Here's @JosiahParry bin search in parallel:

use extendr_api::prelude::*;
use rayon::prelude::*;

#[extendr]
pub fn xymatch_rust_binary_parallel(x: &[f64], y: &[f64]) -> Doubles {
    let mut y_indexed: Vec<(f64, usize)> = y.iter()
        .enumerate()
        .map(|(i, &val)| (val, i))
        .collect();

    y_indexed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());

    let res: Vec<f64> = x.par_iter()
        .map(|&xi| {
            let idx = match y_indexed.binary_search_by(|probe| {
                probe.0.partial_cmp(&xi).unwrap()
            }) {
                Ok(i) => i,
                Err(i) => {
                    if i == 0 {
                        0
                    } else if i == y_indexed.len() {
                        y_indexed.len() - 1
                    } else {
                        let left_dist = (xi - y_indexed[i - 1].0).abs();
                        let right_dist = (y_indexed[i].0 - xi).abs();
                        if left_dist <= right_dist {
                            i - 1
                        } else {
                            i
                        }
                    }
                }
            };
            let orig_idx = y_indexed[idx].1;
            x[orig_idx]
        })
        .collect();

    Doubles::from_values(res)
}

Obs.: Rayon must be listed as dependency in rust_source(), e.g.,

rextendr::rust_source(
  file = "xymatch_binary_parallel.rs",
  module = "xymatch_binary_parallel",
  profile = "release",
  dependencies = list(`rayon` = "1.11")
)

@jrosell
Copy link
Author

jrosell commented Oct 11, 2025

I want to set 8 threads...

use extendr_api::prelude::*;
use rayon::prelude::*;
use rayon::ThreadPoolBuilder;

/// Parallel version of xymatch with optional number of threads.
///
/// @param x Numeric vector.
/// @param y Numeric vector.
/// @param nthreads Optional number of threads.
/// @export
#[extendr]
pub fn xymatch_rust_binary_par(x: &[f64], y: &[f64], nthreads: Option<usize>) -> Doubles {
    // Copy R slices into owned Vecs (safe for threads)
    let x_vec = x.to_vec();
    let y_vec = y.to_vec();

    // Run computation in parallel using Rayon pool if requested
    let res_vec: Vec<f64> = if let Some(n) = nthreads {
        let pool = ThreadPoolBuilder::new()
            .num_threads(n)
            .build()
            .expect("Failed to build Rayon thread pool");
        // Return Vec<f64> (Send)
        pool.install(|| run_xymatch_owned(&x_vec, &y_vec))
    } else {
        // Use default global Rayon pool
        run_xymatch_owned(&x_vec, &y_vec)
    };

    // Convert back to R type AFTER threading (safe)
    Doubles::from_values(res_vec)
}

/// Pure Rust computation — returns Vec<f64> (Send)
fn run_xymatch_owned(x: &[f64], y: &[f64]) -> Vec<f64> {
    // Pair y values with their original indices
    let mut y_indexed: Vec<(f64, usize)> = y
        .iter()
        .enumerate()
        .map(|(i, &val)| (val, i))
        .collect();

    // Sort by y value
    y_indexed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());

    // Parallel computation over x
    x.par_iter()
        .map(|&xi| {
            let idx = match y_indexed.binary_search_by(|probe| probe.0.partial_cmp(&xi).unwrap()) {
                Ok(i) => i,
                Err(i) => {
                    if i == 0 {
                        0
                    } else if i == y_indexed.len() {
                        y_indexed.len() - 1
                    } else {
                        let left_dist = (xi - y_indexed[i - 1].0).abs();
                        let right_dist = (y_indexed[i].0 - xi).abs();
                        if left_dist <= right_dist {
                            i - 1
                        } else {
                            i
                        }
                    }
                }
            };
            let orig_idx = y_indexed[idx].1;
            x[orig_idx]
        })
        .collect()
}

// Export to R
extendr_module! {
    mod xymatch_binary_par;
    fn xymatch_rust_binary_par;
}

Is there something wrong? Check the results:

# A tibble: 13 × 4
   expression       median total_time n_itr
   <bch:expr>        <dbl>   <bch:tm> <int>
 1 r_2010          3239.       11.57s     5
 2 r_binary           2.14        10s  6456
 3 r_vapply         339.       10.12s    41
 4 r_loop           255.       10.01s    51
 5 cpp_loop         149.       10.04s    95
 6 cpp_loop_par      19.5      10.01s   722
 7 cpp_ranges        76.2      10.05s   186
 8 cpp_range_par     19.2      10.02s   745
 9 cpp_binary         1.95        10s  7225
10 cpp_binary_par     1         7.22s 10000
11 rust_loop        149.       10.05s    95
12 rust_binary        1.67        10s  8423
13 rust_binary_par    7.23        10s  2311
``

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment