Last active
May 7, 2023 21:49
-
-
Save Samu31Nd/dd7bdbb11a38c2d84d6b3368e479df0f to your computer and use it in GitHub Desktop.
Square Matrix Multiply w/Strassen Method
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import java.util.Random; | |
| import java.util.Scanner; | |
| public class Main { | |
| public static void main(String[] args) { | |
| Scanner scanner = new Scanner(System.in); | |
| System.out.print("Insert the length of the matrix: "); | |
| int n = scanner.nextInt(); | |
| // int [][] m1 = StrassenUtils.randomMatrix(n); | |
| // int [][] m2 = StrassenUtils.randomMatrix(n); | |
| int [][] m1 = {{2,0},{1,3}}; | |
| int [][] m2 = {{-1,-1},{5,6}}; | |
| System.out.println("\nMatrix 1 and 2:"); | |
| StrassenUtils.showMatrix(m1,m2); | |
| int [][] C = StrassenUtils.squareMatrixMultiplyRecursive(m1,m2); | |
| System.out.println("\nMatrix C: "); | |
| StrassenUtils.showMatrix(C); | |
| } | |
| } | |
| class StrassenUtils { | |
| static public int[][] randomMatrix(int n){ | |
| int [][] x = new int[n][n]; | |
| for (int p = 0; p < n; p++) | |
| for (int j = 0; j < n; j++) | |
| x[p][j] = (int) (new Random().nextFloat(51)); //from 0 to (51 - 1) | |
| return x; | |
| } | |
| static public void showMatrix(int[][]m){ | |
| for (int []a : m){ | |
| for (int b : a) System.out.print(b + "\t"); | |
| System.out.println(" "); | |
| } | |
| System.out.println(" "); | |
| } | |
| static public void showMatrix(int[][]m, int[][]n){ | |
| int length = m.length; | |
| for (int i = 0; i < length; i++){ | |
| for (int a : m[i]) System.out.print(a + "\t"); | |
| System.out.print("\t\t"); | |
| for (int a : n[i]) System.out.print(a + "\t"); | |
| System.out.println(" "); | |
| } | |
| } | |
| static public int[][] squareMatrixMultiply(int[][]A, int[][]B){ | |
| int n = A.length; | |
| int[][] C = new int[n][n]; | |
| for (int i = 0; i < n; i++) | |
| for (int j = 0; j < n; j++){ | |
| C[i][j] = 0; | |
| for (int k = 0; k < n; k++) C[i][j] = C[i][j] + A[i][k] * B[k][j]; | |
| } | |
| return C; | |
| } | |
| static public int[][] squareMatrixMultiplyRecursive(int[][]A, int[][]B){ | |
| int n = A.length; | |
| int[][] C = new int[n][n]; | |
| if(n==1) C[0][0] = A[0][0] * B[0][0]; | |
| else{ | |
| int [][] A11 = partition(A,1); | |
| int [][] A12 = partition(A,2); | |
| int [][] A21 = partition(A,3); | |
| int [][] A22 = partition(A,4); | |
| int [][] B11 = partition(B,1); | |
| int [][] B12 = partition(B,2); | |
| int [][] B21 = partition(B,3); | |
| int [][] B22 = partition(B,4); | |
| int [][] C11 = addition(squareMatrixMultiplyRecursive(A11,B11), squareMatrixMultiplyRecursive(A12,B21)); | |
| int [][] C12 = addition(squareMatrixMultiplyRecursive(A11,B12), squareMatrixMultiplyRecursive(A12,B22)); | |
| int [][] C21 = addition(squareMatrixMultiplyRecursive(A21,B11), squareMatrixMultiplyRecursive(A22,B21)); | |
| int [][] C22 = addition(squareMatrixMultiplyRecursive(A21,B12), squareMatrixMultiplyRecursive(A22,B22)); | |
| C = merge(C11,C12,C21,C22); | |
| } | |
| return C; | |
| } | |
| static private int[][] merge(int[][]C1, int[][]C2, int[][]C3, int[][]C4){ | |
| int n = C1.length*2; | |
| int [][] C = new int[n][n]; | |
| int n2 = n/2; | |
| for(int i = 0; i < n2; i++) | |
| for(int j = 0; j < n2; j++) C[i][j] = C1[i][j]; | |
| for(int i = 0; i < n2; i++) | |
| for(int j = 0; j < n2; j++) C[i+n2][j] = C2[i][j]; | |
| for(int i = 0; i < n2; i++) | |
| for(int j = 0; j < n2; j++) C[i][j+n2] = C3[i][j]; | |
| for(int i = 0; i < n2; i++) | |
| for(int j = 0; j < n2; j++) C[i+n2][j+n2] = C4[i][j]; | |
| return C; | |
| } | |
| static private int[][] partition(int[][]M, int mode){ | |
| int n = M.length/2; | |
| int[][] newM = new int[n][n]; | |
| switch (mode) { | |
| case 1 -> { | |
| for (int i = 0; i < n; i++) | |
| for (int j = 0; j < n; j++) | |
| newM[i][j] = M[i][j]; | |
| } | |
| case 2 -> { | |
| for (int i = 0; i < n; i++) | |
| for (int j = 0; j < n; j++) | |
| newM[i][j] = M[i + n][j]; | |
| } | |
| case 3 -> { | |
| for (int i = 0; i < n; i++) | |
| for (int j = 0; j < n; j++) | |
| newM[i][j] = M[i][j + n]; | |
| } | |
| case 4 -> { | |
| for (int i = 0; i < n; i++) | |
| for (int j = 0; j < n; j++) | |
| newM[i][j] = M[i + n][j + n]; | |
| } | |
| } | |
| return newM; | |
| } | |
| static private int[][] addition(int[][]A, int[][]B){ | |
| int n = A.length; | |
| int[][] C = new int[n][n]; | |
| for (int i = 0; i < n; i++) | |
| for (int j = 0; j < n; j++) | |
| C[i][j] = A[i][j] + B [i][j]; | |
| return C; | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment