Skip to content

Instantly share code, notes, and snippets.

@Steboss
Created January 16, 2026 10:36
Show Gist options
  • Select an option

  • Save Steboss/6cac690b57d1edb26fd42b7ef89153e0 to your computer and use it in GitHub Desktop.

Select an option

Save Steboss/6cac690b57d1edb26fd42b7ef89153e0 to your computer and use it in GitHub Desktop.
Final iteration where the model starts to use a more neon-like style
use rayon::prelude::*;
use std::arch::aarch64::*;
use std::env;
use std::fs;
fn main() {
let args: Vec<String> = env::args().collect();
if args.len() != 4 {
eprintln!("Usage: {} M K N", args[0]);
return;
}
let m = args[1].parse::<usize>().unwrap();
let k = args[2].parse::<usize>().unwrap();
let n = args[3].parse::<usize>().unwrap();
let mut a = vec![0.0f32; m * k];
let mut b = vec![0.0f32; k * n];
let mut c = vec![0.0f32; m * n];
for i in 0..m * k {
a[i] = (i % m + 1) as f32;
}
for i in 0..k * n {
b[i] = (i % k + 1) as f32;
}
let start = std::time::Instant::now();
a.chunks_exact_mut(16).zip(b.chunks_exact(16)).enumerate().par_iter().for_each(|(i, (ab, bc))| {
let a_ptr = ab.as_mut_ptr();
let b_ptr = bc.as_ptr();
let mut m1 = vld1q_f32(a_ptr);
let mut m2 = vld1q_f32(a_ptr.add(4));
let mut m3 = vld1q_f32(a_ptr.add(8));
let mut m4 = vld1q_f32(b_ptr);
for j in 0..n {
let b_ptr_j = b_ptr.add(j * k);
let m5 = vld1q_f32(b_ptr_j.add(k));
let m6 = vld1q_f32(b_ptr_j.add(k).sub(4));
let m7 = vaddq_f32(vld1q_f32(b_ptr_j), vld1q_f32(b_ptr_j.add(4)));
let c11 = vaddq_f32(
vaddq_f32(vaddq_f32(m1, m4), m7),
vsubq_f32(vsubq_f32(vaddq_f32(m5, m7), m1), m4),
);
let c12 = vaddq_f32(m3, m5);
let c21 = vaddq_f32(m2, m4);
let c22 = vaddq_f32(
vaddq_f32(vsubq_f32(vsubq_f32(m1, m2), m3), m6),
m1,
);
let c_ptr = c.as_mut_ptr().add(i * n + j);
unsafe {
vst1q(c_ptr, c11);
vst1q(c_ptr.add(4), c12);
vst1q(c_ptr.add(n), c21);
vst1q(c_ptr.add(n).add(4), c22);
}
}
});
let duration = start.elapsed();
println!("{}", duration.as_secs_f64());
}
fn vld1q_f32(ptr: *const f32) -> __m128 {
unsafe { vld1q_f32(ptr) }
}
fn vst1q(ptr: *mut f32, value: __m128) {
unsafe { vst1q(ptr, value) }
}
fn vaddq_f32(a: __m128, b: __m128) -> __m128 {
a.add(b)
}
fn vsubq_f32(a: __m128, b: __m128) -> __m128 {
a.sub(b)
}
fn vsubq_f32(a: __m128, b: __m128) -> __m128 {
a.sub(b)
}
fn vaddq_f32(a: __m128, b: __m128) -> __m128 {
a.add(b)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment