Skip to content

Instantly share code, notes, and snippets.

@functionstackx
Created August 7, 2024 21:31
Show Gist options
  • Select an option

  • Save functionstackx/afea0b3423dd2eecaba0b61c633c96de to your computer and use it in GitHub Desktop.

Select an option

Save functionstackx/afea0b3423dd2eecaba0b61c633c96de to your computer and use it in GitHub Desktop.
H100 Mem BW
import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
# Define the range of message sizes in bytes
sizes = [2**i for i in range(0, 35)] # 1B to 16GB
# Define the number of iterations
iterations = 1000
# Function to measure memory bandwidth
def measure_bandwidth():
results = []
for size in sizes:
num_elements = size // 4 if size >= 4 else 1 # Number of float32 elements (at least 1 element)
# Allocate tensors
src = torch.ones(num_elements, dtype=torch.float32, device='cuda')
dst = torch.zeros(num_elements, dtype=torch.float32, device='cuda')
# Warm-up
dst.copy_(src)
torch.cuda.synchronize()
elapsed_times = []
for _ in range(iterations):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
dst.copy_(src)
end_event.record()
# Wait for the copy to complete
torch.cuda.synchronize()
# Calculate elapsed time in milliseconds
elapsed_time_ms = start_event.elapsed_time(end_event)
elapsed_times.append(elapsed_time_ms)
# Average elapsed time
avg_elapsed_time_ms = sum(elapsed_times) / len(elapsed_times)
# Calculate bandwidth in GB/s
bandwidth = (size / (avg_elapsed_time_ms * 1e6)) # GB/s
results.append((size, bandwidth))
print(f"Size: {size / 1e6} MB, Bandwidth: {bandwidth:.2f} GB/s")
return results
# Run the measurement
if __name__ == "__main__":
bandwidth_results = measure_bandwidth()
# Convert results to a DataFrame
df = pd.DataFrame(bandwidth_results, columns=['Size (bytes)', 'Bandwidth (GB/s)'])
# Export to CSV
df.to_csv('bandwidth_results.csv', index=False)
# Plot the results
plt.figure(figsize=(10, 6))
plt.plot(df['Size (bytes)'], df['Bandwidth (GB/s)'], marker='o')
plt.xscale('log', base=2)
plt.xticks(ticks=sizes, labels=[f"{int(size)}B" if size < 1024 else
f"{int(size / 1024)}KB" if size < 1024**2 else
f"{int(size / 1024**2)}MB" if size < 1024**3 else
f"{int(size / 1024**3)}GB" for size in sizes], rotation=45)
plt.yticks(ticks=np.arange(0, max(df['Bandwidth (GB/s)']) + 100, 100)) # Y-axis ticks every 100 GB/s
plt.xlabel('Message Size')
plt.ylabel('Bandwidth (GB/s)')
plt.title('Memory Bandwidth vs Message Size')
plt.grid(True)
plt.tight_layout()
plt.savefig('bandwidth_plot.png')
plt.show()
Size (bytes) Bandwidth (GB/s)
1 9.067511596005305e-05
2 0.00018463647315995806
4 0.0003685391403448731
8 0.00073772644566514
16 0.0014971404628611466
32 0.002989822644883229
64 0.005978054563879316
128 0.011931536829088691
256 0.023927666669290044
512 0.04798133529160767
1024 0.09543806059784696
2048 0.19311193877093613
4096 0.38321632248626125
8192 0.7680675922287223
16384 1.5398727818841034
32768 3.083965793934213
65536 6.133168825096853
131072 12.380382436548839
262144 24.559815085850147
524288 49.16916009169341
1048576 98.27402641108246
2097152 196.0535605940026
4194304 390.17768582305064
8388608 756.7470831803522
16777216 1207.1217915092502
33554432 1017.3119708659103
67108864 1154.3026186403176
134217728 1236.9465796967484
268435456 1283.5936519674017
536870912 1295.281562135945
1073741824 1311.8103389938299
2147483648 1321.7200408294004
4294967296 1327.8750549498875
8589934592 1330.2828550259078
17179869184 1330.6639859873212
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment