Skip to content

Instantly share code, notes, and snippets.

@banach-space
Last active August 14, 2025 09:11
Show Gist options
  • Select an option

  • Save banach-space/c66f47626fd42467ed741b0d6cc2352e to your computer and use it in GitHub Desktop.

Select an option

Save banach-space/c66f47626fd42467ed741b0d6cc2352e to your computer and use it in GitHub Desktop.
New tiling implementation for linalg.unpack
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \
// DEFINE: --lower-vector-mask |\
// DEFINE: mlir-opt \
// DEFINE: -test-lower-to-llvm -o %t
// DEFINE: %{entry_point} = main
// DEFINE: %{run} = mlir-runner %t -e %{entry_point} -entry-point-result=void \
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_runner_utils
// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
/// End-to-end test for linalg.unpack where one of the inner tile sizes is
/// dynamic.
//=============================================================================
// WIP NOTES
//=============================================================================
// This is a hand-modified version of:
// * mlir/test/Integration/Dialect/Linalg/CPU/unpack-dynamic-inner-tile.mlir
//
// This implementation works correctly.
//=============================================================================
// Wrapper for `linalg.unpack`
//=============================================================================
func.func @unpack(%src: tensor<9x?x8x8xi32>) {
%c8 = arith.constant 8 : index
%dest = tensor.empty() : tensor<72x67xi32>
%unpack = linalg.unpack %src
outer_dims_perm = [0, 1]
inner_dims_pos = [0, 1]
inner_tiles = [8, 8] into %dest : tensor<9x?x8x8xi32> -> tensor<72x67xi32>
%unpack_cast = tensor.cast %unpack : tensor<72x67xi32> to tensor<*xi32>
call @printMemrefI32(%unpack_cast) : (tensor<*xi32>) -> ()
return
}
//=============================================================================
// MAIN entry point
//=============================================================================
func.func @main() {
// Allocate and initialise the inputs
%c64 = arith.constant 64 : index
%A_alloc = tensor.empty(%c64) : tensor<9x?x8x8xi32>
%A = arith.constant dense<123> : tensor<9x9x8x8xi32>
%A_sc = tensor.cast %A: tensor<9x9x8x8xi32> to tensor<9x?x8x8xi32>
func.call @unpack(%A_sc) : (tensor<9x?x8x8xi32>) -> ()
return
}
//=============================================================================
// TD sequence
//=============================================================================
module @transforms attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%module: !transform.any_op {transform.consume}) {
%pack = transform.structured.match ops{["linalg.unpack"]} in %module : (!transform.any_op) -> !transform.any_op
// 1. Tile so that we can decompose linalg.pack
// Ops (see step 2)
%c8 = transform.param.constant 8 : i64 -> !transform.param<i64>
%tiled_pack_op_p, %loops:2 = transform.structured.tile_using_for %pack tile_sizes [%c8, 1]
: (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op, !transform.any_op)
// 2. Decompose the tiled unpack Op into tensor.extract_slice + tensor.insert_slice:
%func_op = transform.get_parent_op %tiled_pack_op_p {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
transform.apply_patterns.linalg.decompose_pack_unpack
transform.apply_patterns.canonicalization
} : !transform.op<"func.func">
// 3. Vectorize tensor.insert_slice - NOTE VECTOR SIZES
// Vector sizes match the inner tiles in the payload IR.
%slice = transform.structured.match ops{["tensor.insert_slice"]} in %func_op : (!transform.op<"func.func">) -> !transform.any_op
transform.structured.vectorize %slice vector_sizes [8, 1] : !transform.any_op
// 4. Bufferize before lowering to LLVM
%bufferize = transform.bufferization.one_shot_bufferize %module
{bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op
// 5. Canonicalize
%func_op_bufferized = transform.structured.match ops{["func.func"]} in %bufferize : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op_bufferized {
transform.apply_patterns.canonicalization
} : !transform.op<"func.func">
transform.yield
}
}
func.func private @printMemrefI32(%ptr : tensor<*xi32>)
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \
// DEFINE: --lower-vector-mask |\
// DEFINE: mlir-opt \
// DEFINE: -test-lower-to-llvm -o %t
// DEFINE: %{entry_point} = main
// DEFINE: %{run} = mlir-runner %t -e %{entry_point} -entry-point-result=void \
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_runner_utils
// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
//=============================================================================
// WIP NOTES
//=============================================================================
// This is a hand-modified version of:
// * mlir/test/Integration/Dialect/Linalg/CPU/unpack-dynamic-inner-tile.mlir
//
// This implementation seems correct, but the generated output is not:
//
// [123, 123, 123, 123, 123, 123, ........., 123, 123, 123]
// [ 0 0, 0, 0, 0, 123, ........., 123, 123, 123]
// [ 0 0, 0, 0, 0, 123, ........., 123, 123, 123]
// ...
// [ 0 0, 0, 0, 0, 123, ........., 123, 123, 123]
//
// It should be this instead:
// [123, 123, 123, 123, 123, 123, ........., 123, 123, 123]
// [123, 123, 123, 123, 123, 123, ........., 123, 123, 123]
// [123, 123, 123, 123, 123, 123, ........., 123, 123, 123]
// ...
// [123, 123, 123, 123, 123, 123, ........., 123, 123, 123]
//
// WHY? Looks like something goes wrong when vectorising linalg.unpack (look at
// the generated `vector.transfer_write`). See also the notes near `scf.for`.
//=============================================================================
// Helper methods
//=============================================================================
// Indirect calls to `@printMemref` can help bufferization (at least in some
// cases)
func.func @print(%dest: tensor<72x67xi32>) -> () {
%dest_cast = tensor.cast %dest : tensor<72x67xi32> to tensor<*xi32>
func.call @printMemrefI32(%dest_cast) : (tensor<*xi32>) -> ()
return
}
//=============================================================================
// Wrapper for `linalg.unpack`
//=============================================================================
func.func @unpack(%src: tensor<9x?x8x8xi32>) {
%dest = tensor.empty() : tensor<72x67xi32>
%unpack = call @unpack_manually_tiled(%dest, %src) : (tensor<72x67xi32>, tensor<9x?x8x8xi32>) -> tensor<72x67xi32>
// DEBUGGING - print the unpacked tensor
%unpack_cast = tensor.cast %unpack : tensor<72x67xi32> to tensor<*xi32>
call @printMemrefI32(%unpack_cast) : (tensor<*xi32>) -> ()
return
}
//=============================================================================
// Manually tiled linalg.unpack
//=============================================================================
#map0 = affine_map<(d0) -> (72 - d0, 8)>
#map1 = affine_map<(d0) -> (67 - d0, 8)>
func.func @unpack_manually_tiled(%dest: tensor<72x67xi32>, %src: tensor<9x?x8x8xi32>) -> tensor<72x67xi32> {
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%tile_size_0 = arith.constant 8 : index
%tile_size_1 = arith.constant 8 : index
%ts0_x_ts1 = arith.muli %tile_size_0, %tile_size_1 : index
%c72 = arith.constant 72 : index
%c67 = arith.constant 67 : index
%c9 = arith.constant 9 : index
%c2 = arith.constant 9 : index
%dim_1 = tensor.dim %src, %c1 : tensor<9x?x8x8xi32>
// DEBUGGING - print the init state of the destination tensor
// func.call @print(%dest) : (tensor<72x67xi32>) -> ()
%0 = scf.for %idx_0 = %c0 to %c9 step %c1 iter_args(%arg0 = %dest) -> (tensor<72x67xi32>) {
// DEBUGGING - only process the final "output" column. This helps to see that two things happen:
// * In row = %out_row (i.e. the current row), only the trailing 3 columns are written to (CORRECT!)
// * In row = %out_row + 1, 5 leading columns are populated with 0s (INCORRECT!)
// To illustrate, the following happens:
// [%out_row] ........................ 123 123 123
// [%out_row + 1] 0 0 0 0 0 ..........................
// This implies issues with masking or `vector.transfer_write` lowering.
//
// Uncomment this line to experiment.
// %1 = scf.for %idx_1 = %c8 to %dim_1 step %c1 iter_args(%arg1 = %arg0) -> (tensor<72x67xi32>) {
%1 = scf.for %idx_1 = %c0 to %dim_1 step %c1 iter_args(%arg1 = %arg0) -> (tensor<72x67xi32>) {
%2 = scf.for %idx_2 = %c0 to %c8 step %c8 iter_args(%arg2 = %arg1) -> (tensor<72x67xi32>) {
%3 = scf.for %idx_3 = %c0 to %c8 step %c8 iter_args(%arg3 = %arg2) -> (tensor<72x67xi32>) {
%slice_in = tensor.extract_slice %src[%idx_0, %idx_1, %idx_2, %idx_3] [1, 1, 8, 8] [1, 1, 1, 1] : tensor<9x?x8x8xi32> to tensor<1x1x8x8xi32>
// (0, 0), (0, 8), (0, 16), ..., (0, 64)
// (8, 0), (8, 8), (8, 16), ..., (0, 64)
// ...
%out_row = arith.muli %idx_0, %tile_size_0 : index
%out_col = arith.muli %idx_1, %tile_size_1 : index
%out_dim_0 = affine.min #map0(%out_row)
%out_dim_1 = affine.min #map1(%out_col)
// e.g. tensor<8x8xi32> for most tiles
%init = tensor.empty(%out_dim_0, %out_dim_1) : tensor<?x?xi32>
// DEBUGGING - print the indices
// vector.print str "OUT DIM 0: "
// vector.print %out_dim_0 : index
// vector.print str "OUT DIM 1: "
// vector.print %out_dim_1 : index
// vector.print str "OUT ROW: "
// vector.print %out_row : index
// vector.print str "OUT COL: "
// vector.print %out_col : index
%unpack = linalg.unpack %slice_in
outer_dims_perm = [0, 1]
inner_dims_pos = [0, 1]
inner_tiles = [8, 8] into %init : tensor<1x1x8x8xi32> -> tensor<?x?xi32>
%slice_out = tensor.insert_slice %unpack
into %arg3[%out_row, %out_col] [%out_dim_0, %out_dim_1] [1, 1]
: tensor<?x?xi32> into tensor<72x67xi32>
// DEBUGGING - print the unpacked slice
%unpack_cast = tensor.cast %unpack : tensor<?x?xi32> to tensor<*xi32>
func.call @printMemrefI32(%unpack_cast) : (tensor<*xi32>) -> ()
scf.yield %slice_out : tensor<72x67xi32>
}
scf.yield %3 : tensor<72x67xi32>
}
scf.yield %2 : tensor<72x67xi32>
}
scf.yield %1 : tensor<72x67xi32>
}
return %0 : tensor<72x67xi32>
}
//=============================================================================
// MAIN entry point
//=============================================================================
func.func @main() {
// INITIALIZE THE TENSOR TO TILE
%A = arith.constant dense<123> : tensor<9x9x8x8xi32>
%A_sc = tensor.cast %A: tensor<9x9x8x8xi32> to tensor<9x?x8x8xi32>
func.call @unpack(%A_sc) : (tensor<9x?x8x8xi32>) -> ()
return
}
//=============================================================================
// TD sequence
//=============================================================================
module @transforms attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%module: !transform.any_op {transform.consume}) {
%pack = transform.structured.match ops{["linalg.unpack"]} in %module : (!transform.any_op) -> !transform.any_op
// 1. DO NOT TILE! (we tiled manually)
// 2. Decompose the tiled unpack Op into tensor.extract_slice + tensor.insert_slice:
%func_op = transform.get_parent_op %pack {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
transform.apply_patterns.linalg.decompose_pack_unpack
transform.apply_patterns.canonicalization
} : !transform.op<"func.func">
// 3. Vectorize tensor.insert_slice - NOTE VECTOR SIZES
// Vector sizes match the inner tiles in the payload IR.
%slice = transform.structured.match ops{["tensor.insert_slice"]} in %func_op : (!transform.op<"func.func">) -> !transform.any_op
transform.structured.vectorize %slice vector_sizes [8, 8] : !transform.any_op
// 4. Bufferize before lowering to LLVM
%bufferize = transform.bufferization.one_shot_bufferize %module
{bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op
// 5. Canonicalize
%func_op_bufferized = transform.structured.match ops{["func.func"]} in %bufferize : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op_bufferized {
transform.apply_patterns.canonicalization
} : !transform.op<"func.func">
transform.yield
}
}
func.func private @printMemrefI32(%ptr : tensor<*xi32>)
@banach-space
Copy link
Author

This Gist demonstrates alternative tiling for linalg.unpack.

Both files use LIT variables that are defined inside MLIR. Copy them into e.g. "mlir/test/Integration/Dialect/Linalg/CPU/" and then just run the following to get a full invocation command:

bin/llvm-lit -va <LLVM-PROJECT-ROOT-DIR>/mlir/test/Integration/Dialect/Linalg/CPU/unpack-dynamic-inner-tile-auto.mlir

NOTE This is WIP - while these files work for me locally, there are potentially multiple issues that I have missed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment