Last active
January 6, 2026 15:53
-
-
Save yaoyaoding/919b800e03907e2aa8e6a2a0b5f5e1b8 to your computer and use it in GitHub Desktop.
mlir dtor example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ''' | |
| Steps to reproduce: | |
| $ pip install nvidia-cutlass-dsl | |
| $ python main.py | |
| ''' | |
| import ctypes | |
| import ctypes.util | |
| import gc | |
| import os | |
| import cutlass._mlir.ir as ir | |
| import cutlass._mlir.passmanager as pm | |
| import cutlass._mlir.execution_engine as ee | |
| from cutlass._mlir.dialects import arith, func, llvm | |
| def _define_global_cstr(module: ir.Module, *, symbol: str, text: str) -> tuple[str, ir.Type]: | |
| """Define a private constant C string global and return its LLVM array type.""" | |
| escaped = text.replace("\\", "\\\\").replace('"', '\\"') | |
| with ir.InsertionPoint(module.body): | |
| op = ir.Operation.parse( | |
| f'llvm.mlir.global private constant @{symbol}("{escaped}\\00")' | |
| ) | |
| module.body.append(op) | |
| # The MLIR `llvm.mlir.global` string form uses an array of i8. | |
| # Include the NUL terminator. | |
| nbytes = len(text.encode("utf-8")) + 1 | |
| array_ty = ir.Type.parse(f"!llvm.array<{nbytes} x i8>") | |
| return symbol, array_ty | |
| def run_cutlass_mlir_jit_with_dtor(): | |
| with ir.Context() as ctx: | |
| with ir.Location.unknown(): | |
| module = ir.Module.create() | |
| i32 = ir.IntegerType.get_signless(32) | |
| void = ir.Type.parse("!llvm.void") | |
| ptr = llvm.PointerType.get() | |
| with ir.InsertionPoint(module.body): | |
| # A small regular function (same idea as `main.py`). | |
| f_type = ir.FunctionType.get([i32, i32], [i32]) | |
| f_op = func.FuncOp("add_logic", f_type) | |
| f_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() | |
| block = f_op.add_entry_block() | |
| with ir.InsertionPoint(block): | |
| res = arith.AddIOp(block.arguments[0], block.arguments[1]) | |
| func.ReturnOp([res.result]) | |
| # --- Destructor (dtor) setup --- | |
| # Declare `puts` from the C runtime. | |
| puts_ty = ir.Type.parse("!llvm.func<i32 (!llvm.ptr)>") | |
| puts = llvm.func("puts", function_type=ir.TypeAttr.get(puts_ty)) | |
| puts.attributes["llvm.linkage"] = ir.StringAttr.get("external") | |
| # Define a global string and a dtor function that calls puts("dtor"). | |
| dtor_msg_sym, dtor_msg_arr_ty = _define_global_cstr( | |
| module, symbol="__dtor_msg", text="dtor" | |
| ) | |
| dtor_fn_name = "__dtor_print" | |
| dtor_fn_ty = ir.Type.parse("!llvm.func<!llvm.void ()>") | |
| dtor_fn = llvm.func(dtor_fn_name, function_type=ir.TypeAttr.get(dtor_fn_ty)) | |
| dtor_fn.attributes["linkage"] = ir.Attribute.parse("#llvm.linkage<internal>") | |
| dtor_entry = ir.Block.create_at_start(dtor_fn.body) | |
| with ir.InsertionPoint(dtor_entry): | |
| msg_addr = llvm.AddressOfOp(ptr, dtor_msg_sym).result | |
| msg_ptr = llvm.getelementptr( | |
| ptr, | |
| msg_addr, | |
| [], | |
| raw_constant_indices=ir.DenseI32ArrayAttr.get([0, 0]), | |
| elem_type=dtor_msg_arr_ty, | |
| ) | |
| llvm.call( | |
| result=i32, | |
| callee="puts", | |
| callee_operands=[msg_ptr], | |
| op_bundle_sizes=[], | |
| op_bundle_operands=[], | |
| ) | |
| llvm.return_() | |
| # Register the dtor in the module. | |
| global_dtors = llvm.mlir_global_dtors(dtors=[], priorities=[]) | |
| global_dtors.attributes["dtors"] += [ | |
| ir.FlatSymbolRefAttr.get(dtor_fn_name) | |
| ] | |
| global_dtors.attributes["priorities"] += [ | |
| ir.IntegerAttr.get(i32, 65535) | |
| ] | |
| # Lower everything to LLVM. | |
| pipeline = "builtin.module(convert-to-llvm)" | |
| pass_manager = pm.PassManager.parse(pipeline) | |
| pass_manager.run(module.operation) | |
| print(module) | |
| libc = ctypes.util.find_library("c") | |
| libc_candidates = [ | |
| libc, | |
| "/lib/x86_64-linux-gnu/libc.so.6", | |
| "/usr/lib/x86_64-linux-gnu/libc.so.6", | |
| "/lib64/libc.so.6", | |
| ] | |
| libc_path = next( | |
| ( | |
| p | |
| for p in libc_candidates | |
| if p and os.path.isabs(p) and os.path.exists(p) | |
| ), | |
| None, | |
| ) | |
| shared_libs = [libc_path] if libc_path else [] | |
| engine = ee.ExecutionEngine(module, opt_level=2, shared_libs=shared_libs) | |
| a = ctypes.c_int32(10) | |
| b = ctypes.c_int32(20) | |
| res = ctypes.c_int32(0) | |
| engine.invoke("add_logic", ctypes.byref(a), ctypes.byref(b), ctypes.byref(res)) | |
| print(f"Result: {res.value}") | |
| # Force teardown to make dtor timing obvious in a short script. | |
| del engine | |
| gc.collect() | |
| if __name__ == "__main__": | |
| run_cutlass_mlir_jit_with_dtor() |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
LLVM IR: