Skip to content

Instantly share code, notes, and snippets.

@yushangdi
Last active June 21, 2022 00:51
Show Gist options
  • Select an option

  • Save yushangdi/2eacd3175c6d8c36cc9b47dd68f81f14 to your computer and use it in GitHub Desktop.

Select an option

Save yushangdi/2eacd3175c6d8c36cc9b47dd68f81f14 to your computer and use it in GitHub Desktop.
init_partitioner
import os
import importlib
import pickle
import torch
from torch.fx._symbolic_trace import symbolic_trace
from torch.profiler import profile, ProfilerActivity
from torch.fx.partitioner.partitioner import CapabilityBasedPartitioner
from torch.fx.partitioner.nvfuser_operator_support import NvFuserOperatorSupport
graphs_dir = "./torchbenchmark/"
os.chdir(graphs_dir)
test_cases = [
"torch_bench_graphs/resnext50_32x4d/resnext50_32x4d_backward_0",
]
device = "cuda"
def get_fused_graph(traced_graph):
supported_ops = NvFuserOperatorSupport()
partitioner = CapabilityBasedPartitioner(traced_graph, supported_ops)
candidates = partitioner.get_candidates()
partitions = partitioner.partition(candidates)
fused_graph = partitioner.fuse_partitions(partitions) # modifed traced in-place
return fused_graph
for dir in test_cases[:1]:
path = dir.split('/')
model_name = path[-1]
module_path = '.'.join(path)
input_data_path = f'{dir}/{model_name}.input'
module = importlib.import_module(module_path)
m = module.FxModule()
inputs = []
with (open(input_data_path, 'rb')) as f:
inputs_meta = pickle.load(f)
for meta in inputs_meta:
type, shape, stride, dtype = meta
if dtype in {torch.int, torch.int32, torch.int64, torch.bool, torch.int, torch.uint8}:
input = torch.randint(0, 1, shape, dtype=dtype, device=device)
else:
input = torch.rand(shape, dtype=dtype, device=device)
inputs.append(input)
m.to(device)
traced_graph = symbolic_trace(m)
fused_graph = get_fused_graph(traced_graph)
num_fused_group = 0
for node in fused_graph.graph.nodes:
if "fused_" in node.name:
module = getattr(fused_graph, node.name)
setattr(fused_graph, node.name, torch.jit.script(module) )
num_fused_group += 1
result = fused_graph(*inputs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment