Skip to content

Instantly share code, notes, and snippets.

@knwng
Last active August 19, 2025 07:36
Show Gist options
  • Select an option

  • Save knwng/abc22eb8a73070dfb8bd32029bb57d7a to your computer and use it in GitHub Desktop.

Select an option

Save knwng/abc22eb8a73070dfb8bd32029bb57d7a to your computer and use it in GitHub Desktop.
Usage: export AMD_INSERT_TTGIR/AMD_INSERT_LLVM_IR/AMD_INSERT_AMDGCN="<kernel_name>:<filename>"
# At the end of make_ttgir, e.g., https://github.com/triton-lang/triton/blob/6ca2dda9bdd331d007d6fab342db5a85f9b23c7d/third_party/amd/backend/compiler.py#L258
import os
if "AMD_INSERT_TTGIR" in os.environ.keys():
fn = os.environ['AMD_INSERT_TTGIR']
if ':' in fn:
kernel_name, insert_module_path = fn.split(':')
print(f"Replace kernel {kernel_name}'s ttgir with {insert_module_path}")
if not mod.has_function(kernel_name):
return mod
else:
insert_module_path = fn
print(f"Replace kernel's ttgir with {insert_module_path}")
ctx = mod.context
mod = ir.parse_mlir_module(insert_module_path, ctx)
mod.context = ctx
# At the end of make_llir, e.g., https://github.com/triton-lang/triton/blob/6ca2dda9bdd331d007d6fab342db5a85f9b23c7d/third_party/amd/backend/compiler.py#L406
import os
if "AMD_INSERT_LLVM_IR" in os.environ.keys():
if ':' in os.environ["AMD_INSERT_LLVM_IR"]:
kernel_name, insert_module_path = os.environ["AMD_INSERT_LLVM_IR"].split(':')
if kernel_name == fns[0].name:
print(f"Replace kernel {kernel_name}'s llir with {insert_module_path}")
else:
return str(llvm_mod)
else:
insert_module_path = os.environ["AMD_INSERT_LLVM_IR"]
print(f"Replace kernel's llir with {insert_module_path}")
if not os.path.exists(insert_module_path):
raise RuntimeError(f'cannot find llvm ir file to insert. Given: `{insert_module_path}`')
with open(insert_module_path, "r") as file:
return file.read()
# At the end of make_amdgcn, e.g. https://github.com/triton-lang/triton/blob/6ca2dda9bdd331d007d6fab342db5a85f9b23c7d/third_party/amd/backend/compiler.py#L428
import os
if "AMD_INSERT_AMDGCN" in os.environ.keys():
if ':' in os.environ["AMD_INSERT_AMDGCN"]:
kernel_name, insert_module_path = os.environ["AMD_INSERT_AMDGCN"].split(':')
if kernel_name == metadata["name"]:
print(f"Replace kernel {kernel_name}'s amdgcn with {insert_module_path}")
else:
return amdgcn
else:
insert_module_path = os.environ["AMD_INSERT_AMDGCN"]
print(f"Replace kernel's amdgcn with {insert_module_path}")
if not os.path.exists(insert_module_path):
raise RuntimeError(f'cannot find amdgcn file to insert. Given: `{insert_module_path}`')
with open(insert_module_path, "r") as file:
return file.read()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment