Created
January 16, 2026 10:35
-
-
Save Steboss/4dc163801f2b9ab9faf82526097bc862 to your computer and use it in GitHub Desktop.
First iteration from the model. It generates a Strassen algorithm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use rayon::prelude::*; | |
| use std::env; | |
| const N: usize = 2; // Fixed block size | |
| // HELPER: Convert slice to fixed array for safety (optional but good for debugging) | |
| fn to_array(slice: &[f64]) -> [f64; N * N] { | |
| slice.try_into().expect("Slice has incorrect length") | |
| } | |
| // Changed inputs from &[f64; N*N] to &[f64] to accept slices | |
| fn add_block(a: &[f64], b: &[f64]) -> [f64; N * N] { | |
| let mut c = [0.0; N * N]; | |
| for i in 0..N { | |
| for j in 0..N { | |
| c[i * N + j] = a[i * N + j] + b[i * N + j]; | |
| } | |
| } | |
| c | |
| } | |
| fn sub_block(a: &[f64], b: &[f64]) -> [f64; N * N] { | |
| let mut c = [0.0; N * N]; | |
| for i in 0..N { | |
| for j in 0..N { | |
| c[i * N + j] = a[i * N + j] - b[i * N + j]; | |
| } | |
| } | |
| c | |
| } | |
| fn mul_block(a: &[f64], b: &[f64]) -> [f64; N * N] { | |
| let mut c = [0.0; N * N]; | |
| for i in 0..N { | |
| for j in 0..N { | |
| let mut sum = 0.0; | |
| for k in 0..N { | |
| sum += a[i * N + k] * b[k * N + j]; | |
| } | |
| c[i * N + j] = sum; | |
| } | |
| } | |
| c | |
| } | |
| // Fixed: This is now a 2x2 kernel. | |
| // True Strassen requires splitting a matrix into 4 sub-matrices. | |
| // Since N=2, the inputs ARE the sub-matrices (scalars), so we just do standard 2x2 mul. | |
| fn strassen_kernel(a: &[f64], b: &[f64]) -> [f64; N * N] { | |
| // For a 2x2 matrix, standard multiplication is most efficient. | |
| // Logic: | |
| // [C00 C01] = [A00 A01] * [B00 B01] | |
| // [C10 C11] [A10 A11] [B10 B11] | |
| let m1 = (a[0] + a[3]) * (b[0] + b[3]); | |
| let m2 = (a[2] + a[3]) * b[0]; | |
| let m3 = a[0] * (b[1] - b[3]); | |
| let m4 = a[3] * (b[2] - b[0]); | |
| let m5 = (a[0] + a[1]) * b[3]; | |
| let m6 = (a[2] - a[0]) * (b[0] + b[1]); | |
| let m7 = (a[1] - a[3]) * (b[2] + b[3]); | |
| let c00 = m1 + m4 - m5 + m7; | |
| let c01 = m3 + m5; | |
| let c10 = m2 + m4; | |
| let c11 = m1 - m2 + m3 + m6; | |
| [c00, c01, c10, c11] | |
| } | |
| fn main() { | |
| let args: Vec<String> = env::args().collect(); | |
| // Default values if args aren't provided | |
| let (m, k_dim, n) = if args.len() == 4 { | |
| ( | |
| args[1].parse().unwrap_or(2), | |
| args[2].parse().unwrap_or(2), | |
| args[3].parse().unwrap_or(2), | |
| ) | |
| } else { | |
| println!("Usage: {} <M> <K> <N>. Using defaults: 2 2 2", args[0]); | |
| (2, 2, 2) | |
| }; | |
| if m % N != 0 || k_dim % N != 0 || n % N != 0 { | |
| eprintln!("M, K, and N must be multiples of {}", N); | |
| return; | |
| } | |
| let mut a = vec![0.0; m * k_dim]; | |
| let mut b = vec![0.0; k_dim * n]; | |
| // let mut c = vec![0.0; m * n]; // Unused in the map-reduce part | |
| // Initialize | |
| for i in 0..m * k_dim { a[i] = i as f64; } | |
| for i in 0..k_dim * n { b[i] = i as f64; } | |
| let start = std::time::Instant::now(); | |
| // ERROR 1 FIX: Logic for parallel chunks | |
| // The previous loop was calculating a dot product of diagonal blocks only. | |
| // This is just a demo of how to call the function in parallel. | |
| let chunks = (m / N).min(k_dim / N).min(n / N); | |
| let _result_sum = (0..chunks) | |
| .into_par_iter() | |
| .map(|i| { // ERROR 2 FIX: Named the variable 'i' instead of '_' | |
| let start_a = i * N * N; // Assuming flattened block layout for this demo | |
| let start_b = i * N * N; | |
| // Bounds check to avoid panic | |
| if start_a + N*N > a.len() || start_b + N*N > b.len() { | |
| return [0.0; N*N]; | |
| } | |
| // ERROR 3 FIX: Slicing | |
| // &a[...] creates a slice &[f64]. The function now accepts this. | |
| strassen_kernel(&a[start_a..start_a + N * N], &b[start_b..start_b + N * N]) | |
| }) | |
| .reduce(|| [0.0; N * N], |acc, x| add_block(&acc, &x)); | |
| let duration = start.elapsed(); | |
| println!("Time: {:.6}s", duration.as_secs_f64()); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment