Skip to content

Instantly share code, notes, and snippets.

@yaoyaoding
Last active January 6, 2026 15:53
Show Gist options
  • Select an option

  • Save yaoyaoding/919b800e03907e2aa8e6a2a0b5f5e1b8 to your computer and use it in GitHub Desktop.

Select an option

Save yaoyaoding/919b800e03907e2aa8e6a2a0b5f5e1b8 to your computer and use it in GitHub Desktop.
mlir dtor example
'''
Steps to reproduce:
$ pip install nvidia-cutlass-dsl
$ python main.py
'''
import ctypes
import ctypes.util
import gc
import os
import cutlass._mlir.ir as ir
import cutlass._mlir.passmanager as pm
import cutlass._mlir.execution_engine as ee
from cutlass._mlir.dialects import arith, func, llvm
def _define_global_cstr(module: ir.Module, *, symbol: str, text: str) -> tuple[str, ir.Type]:
"""Define a private constant C string global and return its LLVM array type."""
escaped = text.replace("\\", "\\\\").replace('"', '\\"')
with ir.InsertionPoint(module.body):
op = ir.Operation.parse(
f'llvm.mlir.global private constant @{symbol}("{escaped}\\00")'
)
module.body.append(op)
# The MLIR `llvm.mlir.global` string form uses an array of i8.
# Include the NUL terminator.
nbytes = len(text.encode("utf-8")) + 1
array_ty = ir.Type.parse(f"!llvm.array<{nbytes} x i8>")
return symbol, array_ty
def run_cutlass_mlir_jit_with_dtor():
with ir.Context() as ctx:
with ir.Location.unknown():
module = ir.Module.create()
i32 = ir.IntegerType.get_signless(32)
void = ir.Type.parse("!llvm.void")
ptr = llvm.PointerType.get()
with ir.InsertionPoint(module.body):
# A small regular function (same idea as `main.py`).
f_type = ir.FunctionType.get([i32, i32], [i32])
f_op = func.FuncOp("add_logic", f_type)
f_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
block = f_op.add_entry_block()
with ir.InsertionPoint(block):
res = arith.AddIOp(block.arguments[0], block.arguments[1])
func.ReturnOp([res.result])
# --- Destructor (dtor) setup ---
# Declare `puts` from the C runtime.
puts_ty = ir.Type.parse("!llvm.func<i32 (!llvm.ptr)>")
puts = llvm.func("puts", function_type=ir.TypeAttr.get(puts_ty))
puts.attributes["llvm.linkage"] = ir.StringAttr.get("external")
# Define a global string and a dtor function that calls puts("dtor").
dtor_msg_sym, dtor_msg_arr_ty = _define_global_cstr(
module, symbol="__dtor_msg", text="dtor"
)
dtor_fn_name = "__dtor_print"
dtor_fn_ty = ir.Type.parse("!llvm.func<!llvm.void ()>")
dtor_fn = llvm.func(dtor_fn_name, function_type=ir.TypeAttr.get(dtor_fn_ty))
dtor_fn.attributes["linkage"] = ir.Attribute.parse("#llvm.linkage<internal>")
dtor_entry = ir.Block.create_at_start(dtor_fn.body)
with ir.InsertionPoint(dtor_entry):
msg_addr = llvm.AddressOfOp(ptr, dtor_msg_sym).result
msg_ptr = llvm.getelementptr(
ptr,
msg_addr,
[],
raw_constant_indices=ir.DenseI32ArrayAttr.get([0, 0]),
elem_type=dtor_msg_arr_ty,
)
llvm.call(
result=i32,
callee="puts",
callee_operands=[msg_ptr],
op_bundle_sizes=[],
op_bundle_operands=[],
)
llvm.return_()
# Register the dtor in the module.
global_dtors = llvm.mlir_global_dtors(dtors=[], priorities=[])
global_dtors.attributes["dtors"] += [
ir.FlatSymbolRefAttr.get(dtor_fn_name)
]
global_dtors.attributes["priorities"] += [
ir.IntegerAttr.get(i32, 65535)
]
# Lower everything to LLVM.
pipeline = "builtin.module(convert-to-llvm)"
pass_manager = pm.PassManager.parse(pipeline)
pass_manager.run(module.operation)
print(module)
libc = ctypes.util.find_library("c")
libc_candidates = [
libc,
"/lib/x86_64-linux-gnu/libc.so.6",
"/usr/lib/x86_64-linux-gnu/libc.so.6",
"/lib64/libc.so.6",
]
libc_path = next(
(
p
for p in libc_candidates
if p and os.path.isabs(p) and os.path.exists(p)
),
None,
)
shared_libs = [libc_path] if libc_path else []
engine = ee.ExecutionEngine(module, opt_level=2, shared_libs=shared_libs)
a = ctypes.c_int32(10)
b = ctypes.c_int32(20)
res = ctypes.c_int32(0)
engine.invoke("add_logic", ctypes.byref(a), ctypes.byref(b), ctypes.byref(res))
print(f"Result: {res.value}")
# Force teardown to make dtor timing obvious in a short script.
del engine
gc.collect()
if __name__ == "__main__":
run_cutlass_mlir_jit_with_dtor()
@yaoyaoding
Copy link
Author

LLVM IR:

module {
  llvm.func @add_logic(%arg0: i32, %arg1: i32) -> i32 attributes {llvm.emit_c_interface} {
    %0 = llvm.add %arg0, %arg1 : i32
    llvm.return %0 : i32
  }
  llvm.func @_mlir_ciface_add_logic(%arg0: i32, %arg1: i32) -> i32 attributes {llvm.emit_c_interface} {
    %0 = llvm.call @add_logic(%arg0, %arg1) : (i32, i32) -> i32
    llvm.return %0 : i32
  }
  llvm.func @puts(!llvm.ptr) -> i32 attributes {llvm.linkage = "external"}
  llvm.mlir.global private constant @__dtor_msg("dtor\00") {addr_space = 0 : i32}
  llvm.func internal @__dtor_print() {
    %0 = llvm.mlir.addressof @__dtor_msg : !llvm.ptr
    %1 = llvm.getelementptr %0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<5 x i8>
    %2 = llvm.call @puts(%1) : (!llvm.ptr) -> i32
    llvm.return
  }
  llvm.mlir.global_dtors {dtors = [@__dtor_print], priorities = [65535 : i32]}
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment