Skip to content

Instantly share code, notes, and snippets.

@MahmoudAshraf97
Created November 25, 2025 18:20
Show Gist options
  • Select an option

  • Save MahmoudAshraf97/87f5d3eaf5c429ba54458d7f2292db15 to your computer and use it in GitHub Desktop.

Select an option

Save MahmoudAshraf97/87f5d3eaf5c429ba54458d7f2292db15 to your computer and use it in GitHub Desktop.
TRT Reproduce
import numpy as np
import torch
import tensorrt as trt
from polygraphy.backend.trt import CreateConfig, Profile
def build_encoder_t_gather_engine(encoder_output_shape):
"""
Build a TensorRT engine that:
input_0: encoder_output, shape (B, T, H), dtype float32
input_1: t, shape (B,), dtype int32
Computes: out[b, :] = encoder_output[b, t[b], :]
Returns the built engine.
"""
B, T, H = encoder_output_shape
B = -1
T = -1
min_encoder_time = 1
opt_encoder_time = 3
max_encoder_time = 10
min_batch_size = 1
opt_batch_size = 4
max_batch_size = 64
logger = trt.Logger(trt.Logger.WARNING)
explicit_batch_flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
builder = trt.Builder(logger)
network = builder.create_network(explicit_batch_flag)
# ----- Inputs -----
# encoder_output: (B, T, H)
encoder_output = network.add_input(
name="encoder_output",
dtype=trt.float32,
shape=(B, T, H),
)
# t: (B,)
t = network.add_input(
name="t",
dtype=trt.int32,
shape=(B,),
)
output = network.add_input(
name="output",
dtype=trt.float32,
shape=(B, 1, H),
)
encoder_output.set_dimension_name(0, "batch_size")
encoder_output.set_dimension_name(1, "time_steps")
t.set_dimension_name(0, "batch_size")
output.set_dimension_name(0, "batch_size")
# the following line is the cause of the issue
# although the profiles are set with the correct shapes
# and the output buffer is allocated with the correct shape,
output.set_dimension_name(1, "time_steps")
t_shuffle = network.add_shuffle(t)
t_shuffle.reshape_dims = (-1, 1) # (B,) -> (B, 1)
t_expanded = t_shuffle.get_output(0)
# ----- Gather: encoder_output[b, t[b], :] -----
# axis = 1 because encoder_output is (B, T, H) and T is axis 1.
gather_layer = network.add_gather(
input=encoder_output, # data
indices=t_expanded, # indices
axis=1, # time dimension
)
# Elementwise across the leading dim (B)
gather_layer.num_elementwise_dims = 1
gather_layer.name = "gather_layer"
gathered = gather_layer.get_output(0) # shape (B, 1, H)
out = gathered
out.name = "encoder_output_t"
network.mark_output(out)
config = CreateConfig(
profiles=[
Profile()
.add(
"encoder_output",
min=(min_batch_size, min_encoder_time, H),
opt=(opt_batch_size, opt_encoder_time, H),
max=(max_batch_size, max_encoder_time, H),
)
.add(
"t",
min=(min_batch_size,),
opt=(opt_batch_size,),
max=(max_batch_size,),
)
.add(
"output",
min=(min_batch_size, 1, H),
opt=(opt_batch_size, 1, H),
max=(max_batch_size, 1, H),
)
],
)(builder, network)
engine = builder.build_serialized_network(network, config)
if engine is None:
raise RuntimeError("Failed to build TensorRT engine.")
return engine
def run_inference_with_torch(engine, encoder_output_np, t_np):
"""
Run inference on the built engine and return the output.
encoder_output_np: (B, T, H), float32
t_np: (B,), int32
Uses PyTorch CUDA tensors for allocations and device pointers.
"""
assert torch.cuda.is_available(), "CUDA is required to run this script."
stream = torch.cuda.streams.Stream()
runtime = trt.Runtime(trt.Logger(trt.Logger.VERBOSE))
engine = runtime.deserialize_cuda_engine(bytes(engine))
with torch.cuda.stream(stream):
# Convert to torch tensors on GPU
enc_torch = torch.from_numpy(encoder_output_np).to(
device=stream.device, dtype=torch.float32
)
t_torch = torch.from_numpy(t_np).to(device=stream.device, dtype=torch.int32)
# Create execution context
context = engine.create_execution_context()
assert context is not None, "Failed to create an execution context!"
out_torch = torch.empty(
(enc_torch.shape[0], 1, enc_torch.shape[2]),
device=stream.device,
dtype=torch.float32,
)
assert context.set_input_shape("encoder_output", enc_torch.shape)
assert context.set_input_shape("t", t_torch.shape)
assert context.set_input_shape("output", out_torch.shape)
shape = context.get_tensor_shape("encoder_output_t")
print(f"Output shape from context: {shape}")
assert shape == list(out_torch.shape)
# # Get binding indices
assert context.set_tensor_address("encoder_output", int(enc_torch.data_ptr()))
assert context.set_tensor_address("t", int(t_torch.data_ptr()))
assert context.set_tensor_address("output", int(out_torch.data_ptr()))
assert context.set_tensor_address("encoder_output_t", int(out_torch.data_ptr()))
# # Execute
context.execute_async_v3(stream.cuda_stream)
stream.synchronize()
# Bring output back to host
out_host = out_torch.cpu().numpy()
return out_host
def main():
# ----- Load encoder_output from pickle -----
encoder_output = (
np.arange(50).reshape(1, 10, 5).repeat(3, axis=0).astype(np.float32)
)
if encoder_output.ndim != 3:
raise ValueError(
f"Expected encoder_output with 3 dims (B, T, H); got shape {encoder_output.shape}"
)
print(
f"encoder_output shape: {encoder_output.shape}, dtype: {encoder_output.dtype}"
)
# ----- Create t: shape (B,), int32 -----
# Example: choose the middle time step for every batch element
t = np.array([0, 0, 0])
print(f"Using t indices (shape {t.shape}, dtype {t.dtype}):\n{t}\n")
# ----- Build engine -----
engine = build_encoder_t_gather_engine(encoder_output.shape)
# ----- Run inference using PyTorch allocations -----
output = run_inference_with_torch(engine, encoder_output, t)
print(f"Output shape: {output.shape}, dtype: {output.dtype}")
print("Output values:\n", output)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment