Created
November 25, 2025 18:20
-
-
Save MahmoudAshraf97/87f5d3eaf5c429ba54458d7f2292db15 to your computer and use it in GitHub Desktop.
TRT Reproduce
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import torch | |
| import tensorrt as trt | |
| from polygraphy.backend.trt import CreateConfig, Profile | |
| def build_encoder_t_gather_engine(encoder_output_shape): | |
| """ | |
| Build a TensorRT engine that: | |
| input_0: encoder_output, shape (B, T, H), dtype float32 | |
| input_1: t, shape (B,), dtype int32 | |
| Computes: out[b, :] = encoder_output[b, t[b], :] | |
| Returns the built engine. | |
| """ | |
| B, T, H = encoder_output_shape | |
| B = -1 | |
| T = -1 | |
| min_encoder_time = 1 | |
| opt_encoder_time = 3 | |
| max_encoder_time = 10 | |
| min_batch_size = 1 | |
| opt_batch_size = 4 | |
| max_batch_size = 64 | |
| logger = trt.Logger(trt.Logger.WARNING) | |
| explicit_batch_flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) | |
| builder = trt.Builder(logger) | |
| network = builder.create_network(explicit_batch_flag) | |
| # ----- Inputs ----- | |
| # encoder_output: (B, T, H) | |
| encoder_output = network.add_input( | |
| name="encoder_output", | |
| dtype=trt.float32, | |
| shape=(B, T, H), | |
| ) | |
| # t: (B,) | |
| t = network.add_input( | |
| name="t", | |
| dtype=trt.int32, | |
| shape=(B,), | |
| ) | |
| output = network.add_input( | |
| name="output", | |
| dtype=trt.float32, | |
| shape=(B, 1, H), | |
| ) | |
| encoder_output.set_dimension_name(0, "batch_size") | |
| encoder_output.set_dimension_name(1, "time_steps") | |
| t.set_dimension_name(0, "batch_size") | |
| output.set_dimension_name(0, "batch_size") | |
| # the following line is the cause of the issue | |
| # although the profiles are set with the correct shapes | |
| # and the output buffer is allocated with the correct shape, | |
| output.set_dimension_name(1, "time_steps") | |
| t_shuffle = network.add_shuffle(t) | |
| t_shuffle.reshape_dims = (-1, 1) # (B,) -> (B, 1) | |
| t_expanded = t_shuffle.get_output(0) | |
| # ----- Gather: encoder_output[b, t[b], :] ----- | |
| # axis = 1 because encoder_output is (B, T, H) and T is axis 1. | |
| gather_layer = network.add_gather( | |
| input=encoder_output, # data | |
| indices=t_expanded, # indices | |
| axis=1, # time dimension | |
| ) | |
| # Elementwise across the leading dim (B) | |
| gather_layer.num_elementwise_dims = 1 | |
| gather_layer.name = "gather_layer" | |
| gathered = gather_layer.get_output(0) # shape (B, 1, H) | |
| out = gathered | |
| out.name = "encoder_output_t" | |
| network.mark_output(out) | |
| config = CreateConfig( | |
| profiles=[ | |
| Profile() | |
| .add( | |
| "encoder_output", | |
| min=(min_batch_size, min_encoder_time, H), | |
| opt=(opt_batch_size, opt_encoder_time, H), | |
| max=(max_batch_size, max_encoder_time, H), | |
| ) | |
| .add( | |
| "t", | |
| min=(min_batch_size,), | |
| opt=(opt_batch_size,), | |
| max=(max_batch_size,), | |
| ) | |
| .add( | |
| "output", | |
| min=(min_batch_size, 1, H), | |
| opt=(opt_batch_size, 1, H), | |
| max=(max_batch_size, 1, H), | |
| ) | |
| ], | |
| )(builder, network) | |
| engine = builder.build_serialized_network(network, config) | |
| if engine is None: | |
| raise RuntimeError("Failed to build TensorRT engine.") | |
| return engine | |
| def run_inference_with_torch(engine, encoder_output_np, t_np): | |
| """ | |
| Run inference on the built engine and return the output. | |
| encoder_output_np: (B, T, H), float32 | |
| t_np: (B,), int32 | |
| Uses PyTorch CUDA tensors for allocations and device pointers. | |
| """ | |
| assert torch.cuda.is_available(), "CUDA is required to run this script." | |
| stream = torch.cuda.streams.Stream() | |
| runtime = trt.Runtime(trt.Logger(trt.Logger.VERBOSE)) | |
| engine = runtime.deserialize_cuda_engine(bytes(engine)) | |
| with torch.cuda.stream(stream): | |
| # Convert to torch tensors on GPU | |
| enc_torch = torch.from_numpy(encoder_output_np).to( | |
| device=stream.device, dtype=torch.float32 | |
| ) | |
| t_torch = torch.from_numpy(t_np).to(device=stream.device, dtype=torch.int32) | |
| # Create execution context | |
| context = engine.create_execution_context() | |
| assert context is not None, "Failed to create an execution context!" | |
| out_torch = torch.empty( | |
| (enc_torch.shape[0], 1, enc_torch.shape[2]), | |
| device=stream.device, | |
| dtype=torch.float32, | |
| ) | |
| assert context.set_input_shape("encoder_output", enc_torch.shape) | |
| assert context.set_input_shape("t", t_torch.shape) | |
| assert context.set_input_shape("output", out_torch.shape) | |
| shape = context.get_tensor_shape("encoder_output_t") | |
| print(f"Output shape from context: {shape}") | |
| assert shape == list(out_torch.shape) | |
| # # Get binding indices | |
| assert context.set_tensor_address("encoder_output", int(enc_torch.data_ptr())) | |
| assert context.set_tensor_address("t", int(t_torch.data_ptr())) | |
| assert context.set_tensor_address("output", int(out_torch.data_ptr())) | |
| assert context.set_tensor_address("encoder_output_t", int(out_torch.data_ptr())) | |
| # # Execute | |
| context.execute_async_v3(stream.cuda_stream) | |
| stream.synchronize() | |
| # Bring output back to host | |
| out_host = out_torch.cpu().numpy() | |
| return out_host | |
| def main(): | |
| # ----- Load encoder_output from pickle ----- | |
| encoder_output = ( | |
| np.arange(50).reshape(1, 10, 5).repeat(3, axis=0).astype(np.float32) | |
| ) | |
| if encoder_output.ndim != 3: | |
| raise ValueError( | |
| f"Expected encoder_output with 3 dims (B, T, H); got shape {encoder_output.shape}" | |
| ) | |
| print( | |
| f"encoder_output shape: {encoder_output.shape}, dtype: {encoder_output.dtype}" | |
| ) | |
| # ----- Create t: shape (B,), int32 ----- | |
| # Example: choose the middle time step for every batch element | |
| t = np.array([0, 0, 0]) | |
| print(f"Using t indices (shape {t.shape}, dtype {t.dtype}):\n{t}\n") | |
| # ----- Build engine ----- | |
| engine = build_encoder_t_gather_engine(encoder_output.shape) | |
| # ----- Run inference using PyTorch allocations ----- | |
| output = run_inference_with_torch(engine, encoder_output, t) | |
| print(f"Output shape: {output.shape}, dtype: {output.dtype}") | |
| print("Output values:\n", output) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment