-
-
Save banach-space/91a868f992a5747dd58b17a85f584ec8 to your computer and use it in GitHub Desktop.
| // The following matmul maps very nicely onto SME with SVL=128bits. For f32, there are 4 4x4 tiles, | |
| // that can be assembled as a 8x8 matrix. | |
| // %mat_A_tr - transpose of A | |
| func.func @matmul(%mat_A_tr: memref<6x8xf32>, %mat_B: memref<6x8xf32>, %mat_C: memref<8x8xf32>) { | |
| linalg.matmul ins(%mat_A_tr, %mat_B: memref<6x8xf32>, memref<6x8xf32>) | |
| outs(%mat_C: memref<8x8xf32>) | |
| return | |
| } |
| //------------------------------------------------- | |
| // DESIRED LOWERING FOR SME - HAND WRITTEN | |
| // linalg.matmul as vector.contract | |
| // | |
| // Overview: | |
| // * element type - f32 | |
| // * number of vector.contract accumulators - 4 | |
| // * C is partitioned into 4 tiles | |
| // * A and B are partitioned into 2 halves | |
| // * A is transposed before entering the kernel | |
| //------------------------------------------------- | |
| // %mat_A_tr - transpose of A | |
| #map = affine_map<(d0, d1, d2) -> (d0, d2)> | |
| #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> | |
| #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> | |
| module attributes { transform.with_named_sequence } { | |
| func.func @outerproduct_matmul(%mat_A_tr: memref<6x8xf32>, %mat_B: memref<6x8xf32>, %mat_C: memref<8x8xf32>) { | |
| %c0 = arith.constant 0 : index | |
| %c4 = arith.constant 4 : index | |
| %cst = arith.constant 0.000000e+00 : f32 | |
| // Tile 0 | |
| %tile0_C = vector.transfer_read %mat_C[%c0, %c0], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
| // Tile 1 | |
| %tile1_C = vector.transfer_read %mat_C[%c0, %c4], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
| // Tile 2 | |
| %tile2_C = vector.transfer_read %mat_C[%c4, %c0], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
| // Tile 3 | |
| %tile3_C = vector.transfer_read %mat_C[%c4, %c4], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
| // Tile upper - A | |
| %tile_A_hi = vector.transfer_read %mat_A_tr[%c0, %c0], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
| // Tile lower - A | |
| %tile_A_lo = vector.transfer_read %mat_A_tr[%c0, %c4], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
| // Tile left - B | |
| %tile_B_left = vector.transfer_read %mat_B[%c0, %c0], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
| // Tile right - B | |
| %tile_B_right = vector.transfer_read %mat_B[%c0, %c4], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
| %tile_A_hi_tr = vector.transpose %tile_A_hi, [1, 0] : vector<6x4xf32> to vector<4x6xf32> | |
| %tile_A_lo_tr = vector.transpose %tile_A_lo, [1, 0] : vector<6x4xf32> to vector<4x6xf32> | |
| // Compute upper half of C | |
| %tile0_matmul = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %tile_A_hi_tr, %tile_B_left, %tile0_C : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32> | |
| %tile1_matmul = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %tile_A_hi_tr, %tile_B_right, %tile1_C : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32> | |
| // Compute lower half of C | |
| %tile2_matmul = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %tile_A_lo_tr, %tile_B_left, %tile2_C : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32> | |
| %tile3_matmul = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %tile_A_lo_tr, %tile_B_right, %tile3_C : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32> | |
| vector.transfer_write %tile0_matmul, %mat_C[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
| vector.transfer_write %tile1_matmul, %mat_C[%c0, %c4] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
| vector.transfer_write %tile2_matmul, %mat_C[%c4, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
| vector.transfer_write %tile3_matmul, %mat_C[%c4, %c4] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
| return | |
| } |
| //------------------------------------------------- | |
| // DESIRED LOWERING FOR SME - HAND WRITTEN | |
| // linalg.matmul as vector.outerproduct | |
| // | |
| // Overview: | |
| // * element type - f32 | |
| // * number of OP accumulators - 4 | |
| // * C is partitioned into 4 tiles | |
| // * A and B are partitioned into 2 halves | |
| // * A is transposed before entering the kernel | |
| //------------------------------------------------- | |
| // LEGEND: | |
| // Matrix C. Each tile is [SVL_s x SVL_s] (4x4xf32 in this example) | |
| // ----------------------------- | |
| // | | | | |
| // | | | | |
| // | tile0_C | tile1_C | | |
| // | | | | |
| // | | | | |
| // ----------------------------- | |
| // | | | | |
| // | | | | |
| // | tile2_C | tile3_C | | |
| // | | | | |
| // | | | | |
| // ----------------------------- | |
| // Matrix B. Each tile is [K x SVL_s] (6x4xf32 in this example) | |
| // ---------------------------------- | |
| // | | | | |
| // | | | | |
| // | tile_B_left | tile_B_right | | |
| // | | | | |
| // | | | | |
| // ---------------------------------| | |
| // Matrix A transpose. Each tile is [K x SVL_s] (6x4xf32 in this example) | |
| // ---------------------------------- | |
| // | | | | |
| // | | | | |
| // | tile_A_upper | tile_A_lower | | |
| // | | | | |
| // | | | | |
| // ---------------------------------| | |
| // Columns of A are rows of A^T. | |
| // Row 3 of matrix B. Each half is SVL_s long (4xf32 in this example) | |
| // ----------------------------- | |
| // | B_row_3_0 | B_row_3_1 | | |
| // ----------------------------- | |
| // half 0 half 1 | |
| // %mat_A_tr - transpose of A | |
| module { | |
| func.func @outerproduct_matmul(%mat_A_tr: memref<6x8xf32>, %mat_B: memref<6x8xf32>, %mat_C: memref<8x8xf32>) { | |
| %c0 = arith.constant 0 : index | |
| %c4 = arith.constant 4 : index | |
| %cst = arith.constant 0.000000e+00 : f32 | |
| // Tile 0 | |
| %tile0_C = vector.transfer_read %mat_C[%c0, %c0], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
| // Tile 1 | |
| %tile1_C = vector.transfer_read %mat_C[%c0, %c4], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
| // Tile 2 | |
| %tile2_C = vector.transfer_read %mat_C[%c4, %c0], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
| // Tile 3 | |
| %tile3_C = vector.transfer_read %mat_C[%c4, %c4], %cst {in_bounds = [true, true]} : memref<8x8xf32>, vector<4x4xf32> | |
| // Tile upper - A | |
| %tile_A_hi = vector.transfer_read %mat_A_tr[%c0, %c0], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
| // Tile lower - A | |
| %tile_A_lo = vector.transfer_read %mat_A_tr[%c0, %c4], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
| // Tile left - B | |
| %tile_B_left = vector.transfer_read %mat_B[%c0, %c0], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
| // Tile right - B | |
| %tile_B_right = vector.transfer_read %mat_B[%c0, %c4], %cst {in_bounds = [true, true]} : memref<6x8xf32>, vector<6x4xf32> | |
| // =================================================================================================================== | |
| // 1 COMPUTE UPPER HALF OF C | |
| // =================================================================================================================== | |
| // 1.1 Input tile "upper" from A and "left" from B - accumulate to ZA0.s (tile 0 of C) | |
| %A_col_0_0 = vector.extract %tile_A_hi[0] : vector<6x4xf32> | |
| %A_col_1_0 = vector.extract %tile_A_hi[1] : vector<6x4xf32> | |
| %A_col_2_0 = vector.extract %tile_A_hi[2] : vector<6x4xf32> | |
| %A_col_3_0 = vector.extract %tile_A_hi[3] : vector<6x4xf32> | |
| %A_col_4_0 = vector.extract %tile_A_hi[4] : vector<6x4xf32> | |
| %A_col_5_0 = vector.extract %tile_A_hi[5] : vector<6x4xf32> | |
| %B_row_0_0 = vector.extract %tile_B_left[0] : vector<6x4xf32> | |
| %B_row_1_0 = vector.extract %tile_B_left[1] : vector<6x4xf32> | |
| %B_row_2_0 = vector.extract %tile_B_left[2] : vector<6x4xf32> | |
| %B_row_3_0 = vector.extract %tile_B_left[3] : vector<6x4xf32> | |
| %B_row_4_0 = vector.extract %tile_B_left[3] : vector<6x4xf32> | |
| %B_row_5_0 = vector.extract %tile_B_left[3] : vector<6x4xf32> | |
| %op_0 = vector.outerproduct %A_col_0_0, %B_row_0_0, %tile0_C {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_1 = vector.outerproduct %A_col_1_0, %B_row_1_0, %op_0 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_2 = vector.outerproduct %A_col_2_0, %B_row_2_0, %op_1 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_3 = vector.outerproduct %A_col_3_0, %B_row_3_0, %op_2 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_4 = vector.outerproduct %A_col_4_0, %B_row_4_0, %op_3 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_5 = vector.outerproduct %A_col_5_0, %B_row_5_0, %op_4 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| vector.transfer_write %op_5, %mat_C[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
| // 1.2 Input tile "upper" from A and "right" from B - accumulate to ZA1.s (tile 1 of C) | |
| %B_row_0_1 = vector.extract %tile_B_right[0] : vector<6x4xf32> | |
| %B_row_1_1 = vector.extract %tile_B_right[1] : vector<6x4xf32> | |
| %B_row_2_1 = vector.extract %tile_B_right[2] : vector<6x4xf32> | |
| %B_row_3_1 = vector.extract %tile_B_right[3] : vector<6x4xf32> | |
| %B_row_4_1 = vector.extract %tile_B_right[3] : vector<6x4xf32> | |
| %B_row_5_1 = vector.extract %tile_B_right[3] : vector<6x4xf32> | |
| %op_6 = vector.outerproduct %A_col_0_0, %B_row_0_1, %tile1_C {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_7 = vector.outerproduct %A_col_1_0, %B_row_1_1, %op_6 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_8 = vector.outerproduct %A_col_2_0, %B_row_2_1, %op_7 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_9 = vector.outerproduct %A_col_3_0, %B_row_3_1, %op_8 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_10 = vector.outerproduct %A_col_4_0, %B_row_4_1, %op_9 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_11 = vector.outerproduct %A_col_5_0, %B_row_5_1, %op_10 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| vector.transfer_write %op_11, %mat_C[%c0, %c4] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
| // =================================================================================================================== | |
| // 2 COMPUTE LOWER HALF OF C | |
| // =================================================================================================================== | |
| // 2.1 Input tile "lower" for A and tile left from B - accumulate to ZA2.s (tile 2 of C) | |
| %A_col_0_1 = vector.extract %tile_A_lo[0] : vector<6x4xf32> | |
| %A_col_1_1 = vector.extract %tile_A_lo[1] : vector<6x4xf32> | |
| %A_col_2_1 = vector.extract %tile_A_lo[2] : vector<6x4xf32> | |
| %A_col_3_1 = vector.extract %tile_A_lo[3] : vector<6x4xf32> | |
| %A_col_4_1 = vector.extract %tile_A_lo[4] : vector<6x4xf32> | |
| %A_col_5_1 = vector.extract %tile_A_lo[5] : vector<6x4xf32> | |
| %op_12 = vector.outerproduct %A_col_0_1, %B_row_0_0, %tile2_C {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_13 = vector.outerproduct %A_col_1_1, %B_row_1_0, %op_12 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_14 = vector.outerproduct %A_col_2_1, %B_row_2_0, %op_13 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_15 = vector.outerproduct %A_col_3_1, %B_row_3_0, %op_14 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_16 = vector.outerproduct %A_col_3_1, %B_row_4_0, %op_15 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_17 = vector.outerproduct %A_col_3_1, %B_row_5_0, %op_16 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| vector.transfer_write %op_17, %mat_C[%c4, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
| // 2.2 Input tile "lower" from A and tile "right" from B - accumulate to ZA3.s (tile 3 of C) | |
| %op_18 = vector.outerproduct %A_col_0_1, %B_row_0_1, %tile3_C {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_19 = vector.outerproduct %A_col_1_1, %B_row_1_1, %op_18 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_20 = vector.outerproduct %A_col_2_1, %B_row_2_1, %op_19 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_21 = vector.outerproduct %A_col_3_1, %B_row_3_1, %op_20 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_22 = vector.outerproduct %A_col_3_1, %B_row_4_1, %op_21 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| %op_23 = vector.outerproduct %A_col_3_1, %B_row_5_1, %op_22 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> | |
| vector.transfer_write %op_23, %mat_C[%c4, %c4] {in_bounds = [true, true]} : vector<4x4xf32>, memref<8x8xf32> | |
| return | |
| } | |
| } |
| // Lower `vector.contract` to `vector.outerproduct` | |
| transform.sequence failures(propagate) { | |
| ^bb1(%arg1: !transform.any_op): | |
| %0 = transform.structured.match ops{["vector.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op | |
| %1 = transform.get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op | |
| %2 = transform.merge_handles %1 { deduplicate } : !transform.any_op | |
| transform.apply_patterns to %2 { | |
| transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" | |
| } : !transform.any_op | |
| } |
| // Lower `linalg.matmul` to `vector.outerproduct` - use 4 accumulators. | |
| transform.sequence failures(propagate) { | |
| ^bb0(%arg1: !transform.any_op): | |
| %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op | |
| %tiled, %loops:2 = transform.structured.tile %0 [4, 4] : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.op<"scf.for">) | |
| %1 = get_closest_isolated_parent %tiled : (!transform.any_op) -> !transform.any_op | |
| %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op | |
| %3 = transform.structured.match ops{["scf.for"]} in %2 : (!transform.any_op) -> !transform.op<"scf.for"> | |
| transform.loop.unroll %3 { factor = 2 } : !transform.op<"scf.for"> | |
| } |
Hey there, wanted to ask about some questions regarding SME dialect lowering to matmul that I was hoping you could help with
Sure, feel free to DM me on Discord or Discourse!
Thank you so much, do you have a discord username that I can add?
Hey there, wanted to ask about some questions regarding SME dialect lowering to matmul that I was hoping you could help with
Sure, feel free to DM me on Discord or Discourse!
Thank you so much, do you have a discord username that I can add?
Search for my surname or banach-space @ https://discourse.llvm.org/ ;-)
Hey there, wanted to ask about some questions regarding SME dialect lowering to matmul that I was hoping you could help with
Sure, feel free to DM me on Discord or Discourse!
Thank you so much, do you have a discord username that I can add?
Search for my surname or banach-space @ https://discourse.llvm.org/ ;-)
Couldn't find your name on discord, could you try requesting @danikhan632 on discord. Thank you
Hey there, wanted to ask about some questions regarding SME dialect lowering to matmul that I was hoping you could help with
Sure, feel free to DM me on Discord or Discourse!
Thank you so much, do you have a discord username that I can add?
Search for my surname or banach-space @ https://discourse.llvm.org/ ;-)
Couldn't find your name on discord, could you try requesting @danikhan632 on discord. Thank you
Sorry, I wasn't able to find you. Lets just use good old fashioned e-mail: name.surname@arm.com :)
Sure, feel free to DM me on Discord or Discourse!