Skip to content

Instantly share code, notes, and snippets.

@Steboss
Created January 16, 2026 10:35
Show Gist options
  • Select an option

  • Save Steboss/4dc163801f2b9ab9faf82526097bc862 to your computer and use it in GitHub Desktop.

Select an option

Save Steboss/4dc163801f2b9ab9faf82526097bc862 to your computer and use it in GitHub Desktop.
First iteration from the model. It generates a Strassen algorithm
use rayon::prelude::*;
use std::env;
const N: usize = 2; // Fixed block size
// HELPER: Convert slice to fixed array for safety (optional but good for debugging)
fn to_array(slice: &[f64]) -> [f64; N * N] {
slice.try_into().expect("Slice has incorrect length")
}
// Changed inputs from &[f64; N*N] to &[f64] to accept slices
fn add_block(a: &[f64], b: &[f64]) -> [f64; N * N] {
let mut c = [0.0; N * N];
for i in 0..N {
for j in 0..N {
c[i * N + j] = a[i * N + j] + b[i * N + j];
}
}
c
}
fn sub_block(a: &[f64], b: &[f64]) -> [f64; N * N] {
let mut c = [0.0; N * N];
for i in 0..N {
for j in 0..N {
c[i * N + j] = a[i * N + j] - b[i * N + j];
}
}
c
}
fn mul_block(a: &[f64], b: &[f64]) -> [f64; N * N] {
let mut c = [0.0; N * N];
for i in 0..N {
for j in 0..N {
let mut sum = 0.0;
for k in 0..N {
sum += a[i * N + k] * b[k * N + j];
}
c[i * N + j] = sum;
}
}
c
}
// Fixed: This is now a 2x2 kernel.
// True Strassen requires splitting a matrix into 4 sub-matrices.
// Since N=2, the inputs ARE the sub-matrices (scalars), so we just do standard 2x2 mul.
fn strassen_kernel(a: &[f64], b: &[f64]) -> [f64; N * N] {
// For a 2x2 matrix, standard multiplication is most efficient.
// Logic:
// [C00 C01] = [A00 A01] * [B00 B01]
// [C10 C11] [A10 A11] [B10 B11]
let m1 = (a[0] + a[3]) * (b[0] + b[3]);
let m2 = (a[2] + a[3]) * b[0];
let m3 = a[0] * (b[1] - b[3]);
let m4 = a[3] * (b[2] - b[0]);
let m5 = (a[0] + a[1]) * b[3];
let m6 = (a[2] - a[0]) * (b[0] + b[1]);
let m7 = (a[1] - a[3]) * (b[2] + b[3]);
let c00 = m1 + m4 - m5 + m7;
let c01 = m3 + m5;
let c10 = m2 + m4;
let c11 = m1 - m2 + m3 + m6;
[c00, c01, c10, c11]
}
fn main() {
let args: Vec<String> = env::args().collect();
// Default values if args aren't provided
let (m, k_dim, n) = if args.len() == 4 {
(
args[1].parse().unwrap_or(2),
args[2].parse().unwrap_or(2),
args[3].parse().unwrap_or(2),
)
} else {
println!("Usage: {} <M> <K> <N>. Using defaults: 2 2 2", args[0]);
(2, 2, 2)
};
if m % N != 0 || k_dim % N != 0 || n % N != 0 {
eprintln!("M, K, and N must be multiples of {}", N);
return;
}
let mut a = vec![0.0; m * k_dim];
let mut b = vec![0.0; k_dim * n];
// let mut c = vec![0.0; m * n]; // Unused in the map-reduce part
// Initialize
for i in 0..m * k_dim { a[i] = i as f64; }
for i in 0..k_dim * n { b[i] = i as f64; }
let start = std::time::Instant::now();
// ERROR 1 FIX: Logic for parallel chunks
// The previous loop was calculating a dot product of diagonal blocks only.
// This is just a demo of how to call the function in parallel.
let chunks = (m / N).min(k_dim / N).min(n / N);
let _result_sum = (0..chunks)
.into_par_iter()
.map(|i| { // ERROR 2 FIX: Named the variable 'i' instead of '_'
let start_a = i * N * N; // Assuming flattened block layout for this demo
let start_b = i * N * N;
// Bounds check to avoid panic
if start_a + N*N > a.len() || start_b + N*N > b.len() {
return [0.0; N*N];
}
// ERROR 3 FIX: Slicing
// &a[...] creates a slice &[f64]. The function now accepts this.
strassen_kernel(&a[start_a..start_a + N * N], &b[start_b..start_b + N * N])
})
.reduce(|| [0.0; N * N], |acc, x| add_block(&acc, &x));
let duration = start.elapsed();
println!("Time: {:.6}s", duration.as_secs_f64());
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment