Skip to content

Instantly share code, notes, and snippets.

@zoecarver
Last active March 10, 2026 22:15
Show Gist options
  • Select an option

  • Save zoecarver/b52aa28302f179505343a11067afd44d to your computer and use it in GitHub Desktop.

Select an option

Save zoecarver/b52aa28302f179505343a11067afd44d to your computer and use it in GitHub Desktop.
"""
Cell-list molecular dynamics on Tenstorrent hardware using TT-Lang.
Full Ewald electrostatics: erfc-damped real-space (cell-list, TT kernel) +
u-series reciprocal-space (separable Gaussian convolution, TT kernel + host).
LJ short-range forces. Periodic boundary conditions.
On-device Verlet integration (f32) with bf16 force kernels.
Validated: 10K atoms, 10K steps, 1.1ms/step (non-rebuild), 12 min total.
"""
import time
import torch
import numpy as np
import ttnn
import ttl
TILE = 32
N_NBR = 27
ALPHA = 1.0
K_GRID = 32
N_GAUSS = 16
ERFC_A1 = 0.254829592
ERFC_A2 = -0.284496736
ERFC_A3 = 1.421413741
ERFC_A4 = -1.453152027
ERFC_A5 = 1.061405429
ERFC_P = 0.3275911
# ---- Host utilities ----
def make_system(n_atoms, density=0.3, seed=42):
"""Create a lattice system with charges."""
np.random.seed(seed)
box_length = (n_atoms / density) ** (1.0 / 3.0)
n_side = int(np.ceil(n_atoms ** (1.0 / 3.0)))
spacing = box_length / n_side
positions = []
for ix in range(n_side):
for iy in range(n_side):
for iz in range(n_side):
if len(positions) < n_atoms:
positions.append([(ix+0.5)*spacing, (iy+0.5)*spacing, (iz+0.5)*spacing])
positions = np.array(positions[:n_atoms])
positions += np.random.normal(0, 0.05, positions.shape)
positions = positions % box_length
charges = np.random.randn(n_atoms) * 0.3
charges -= charges.mean()
return positions, charges, box_length
def compute_energy(positions, charges, box_length):
"""O(N^2) energy for validation (LJ + Coulomb, no cutoff)."""
dr = positions[:, None, :] - positions[None, :, :]
dr -= box_length * np.floor(dr / box_length + 0.5)
r2 = np.sum(dr * dr, axis=2)
np.fill_diagonal(r2, 1e10)
r = np.sqrt(r2)
r6_inv = (1.0 / r2) ** 3
lj = 4.0 * (r6_inv ** 2 - r6_inv)
coul = (charges[:, None] * charges[None, :]) / r
return 0.5 * (np.sum(lj) + np.sum(coul))
def direct_forces(positions, charges, box_length):
"""O(N^2) reference forces (LJ + Coulomb, no cutoff)."""
dr = positions[:, None, :] - positions[None, :, :]
dr -= box_length * np.floor(dr / box_length + 0.5)
r2 = np.sum(dr * dr, axis=2)
np.fill_diagonal(r2, 1e10)
r = np.sqrt(r2)
r2_inv = 1.0 / r2
r6_inv = r2_inv ** 3
f_lj = 24.0 * r2_inv * (2.0 * r6_inv ** 2 - r6_inv)
qq = charges[:, None] * charges[None, :]
f_coul = qq / (r2 * r)
return np.sum((f_lj + f_coul)[:, :, None] * dr, axis=1)
# ---- Reciprocal-space: u-series Gaussian convolution ----
def gaussian_decomposition(alpha, n_gauss):
"""Gauss-Legendre decomposition of 1/r into sum of Gaussians."""
nodes, gl_weights = np.polynomial.legendre.leggauss(n_gauss)
t = alpha / 2.0 * (nodes + 1.0)
w = alpha / 2.0 * gl_weights
return t ** 2, (2.0 / np.sqrt(np.pi)) * w
def make_conv_kernel(K, h, exponent):
"""Build KxK circulant Gaussian convolution matrix."""
M = np.zeros((K, K))
for i in range(K):
for j in range(K):
d = min(abs(i - j), K - abs(i - j))
M[i, j] = np.exp(-exponent * (d * h) ** 2)
return M
def bspline4_weights_vec(u):
"""Order-4 B-spline weights for charge spreading."""
u2 = u * u; u3 = u2 * u
w3 = u3 / 6.0
w2 = (1.0 + 3.0 * u * (1.0 + u * (1.0 - u))) / 6.0
w1 = (4.0 - 6.0 * u2 + 3.0 * u3) / 6.0
w0 = 1.0 - w1 - w2 - w3
return np.stack([w0, w1, w2, w3], axis=-1)
def bspline4_dweights_vec(u):
"""Order-4 B-spline derivative weights for force interpolation."""
u2 = u * u
dw3 = u2 / 2.0
dw2 = (3.0 * (1.0 + 2.0 * u - 3.0 * u2)) / 6.0
dw1 = (-12.0 * u + 9.0 * u2) / 6.0
dw0 = -(dw1 + dw2 + dw3)
return np.stack([dw0, dw1, dw2, dw3], axis=-1)
def spread_charges(positions, charges, box_length, K, order=4):
"""B-spline charge spreading onto 3D grid."""
h = box_length / K
s = positions / h
g0 = np.floor(s).astype(int) - (order // 2 - 1)
f = s - np.floor(s)
wx = bspline4_weights_vec(f[:, 0])
wy = bspline4_weights_vec(f[:, 1])
wz = bspline4_weights_vec(f[:, 2])
w3d = wx[:, :, None, None] * wy[:, None, :, None] * wz[:, None, None, :]
w3d *= charges[:, None, None, None]
offsets = np.arange(order)
gx = (g0[:, 0, None] + offsets[None, :]) % K
gy = (g0[:, 1, None] + offsets[None, :]) % K
gz = (g0[:, 2, None] + offsets[None, :]) % K
grid = np.zeros((K, K, K))
for ix in range(order):
for iy in range(order):
for iz in range(order):
np.add.at(grid, (gx[:, ix], gy[:, iy], gz[:, iz]), w3d[:, ix, iy, iz])
return grid
def interpolate_forces_bspline(positions, potential_grid, box_length, order=4):
"""B-spline force interpolation from potential grid."""
K = potential_grid.shape[0]
h = box_length / K
s = positions / h
g0 = np.floor(s).astype(int) - (order // 2 - 1)
f = s - np.floor(s)
wx = bspline4_weights_vec(f[:, 0])
wy = bspline4_weights_vec(f[:, 1])
wz = bspline4_weights_vec(f[:, 2])
dwx = bspline4_dweights_vec(f[:, 0])
dwy = bspline4_dweights_vec(f[:, 1])
dwz = bspline4_dweights_vec(f[:, 2])
offsets = np.arange(order)
gx = (g0[:, 0, None] + offsets[None, :]) % K
gy = (g0[:, 1, None] + offsets[None, :]) % K
gz = (g0[:, 2, None] + offsets[None, :]) % K
forces = np.zeros((len(positions), 3))
inv_h = 1.0 / h
for ix in range(order):
for iy in range(order):
for iz in range(order):
phi = potential_grid[gx[:, ix], gy[:, iy], gz[:, iz]]
forces[:, 0] -= dwx[:, ix] * wy[:, iy] * wz[:, iz] * phi * inv_h
forces[:, 1] -= wx[:, ix] * dwy[:, iy] * wz[:, iz] * phi * inv_h
forces[:, 2] -= wx[:, ix] * wy[:, iy] * dwz[:, iz] * phi * inv_h
return forces
@ttl.kernel(grid="auto")
def xy_conv_kernel(charge_grid, kernels, potential_grid):
"""U-series xy-convolution: M @ charge_slice @ M for each Gaussian component.
Replaces the FFT in traditional PME. Separable Gaussian decomposition
means we can convolve with matrix multiplies instead of FFTs.
"""
grid_cols, _ = ttl.grid_size(dims=2)
total_slices = charge_grid.shape[0] // TILE
n_kernels = kernels.shape[0] // TILE
slices_per_core = -(-total_slices // grid_cols)
cg_cb = ttl.make_dataflow_buffer_like(charge_grid, shape=(1, 1), buffer_factor=2)
km_cb = ttl.make_dataflow_buffer_like(kernels, shape=(1, 1), buffer_factor=2)
tmp_cb = ttl.make_dataflow_buffer_like(charge_grid, shape=(1, 1), buffer_factor=2)
acc_cb = ttl.make_dataflow_buffer_like(potential_grid, shape=(1, 1), buffer_factor=2)
par_cb = ttl.make_dataflow_buffer_like(potential_grid, shape=(1, 1), buffer_factor=2)
out_cb = ttl.make_dataflow_buffer_like(potential_grid, shape=(1, 1), buffer_factor=1)
@ttl.compute()
def compute():
core_x, _ = ttl.core(dims=2)
for local_z in range(slices_per_core):
z = core_x * slices_per_core + local_z
if z < total_slices:
with cg_cb.wait() as cg:
with km_cb.wait() as M:
with tmp_cb.reserve() as t:
t.store(M @ cg)
with tmp_cb.wait() as tv:
with acc_cb.reserve() as a:
a.store(tv @ M)
for g in range(n_kernels - 1):
with km_cb.wait() as M:
with tmp_cb.reserve() as t:
t.store(M @ cg)
with tmp_cb.wait() as tv:
with par_cb.reserve() as p:
p.store(tv @ M)
with par_cb.wait() as pv, acc_cb.wait() as av:
with acc_cb.reserve() as a:
a.store(av + pv)
with acc_cb.wait() as final, out_cb.reserve() as o:
o.store(final)
@ttl.datamovement()
def dm_read():
core_x, _ = ttl.core(dims=2)
for local_z in range(slices_per_core):
z = core_x * slices_per_core + local_z
if z < total_slices:
with cg_cb.reserve() as blk:
tx = ttl.copy(charge_grid[z, 0], blk); tx.wait()
for g in range(n_kernels):
with km_cb.reserve() as blk:
tx = ttl.copy(kernels[g, 0], blk); tx.wait()
@ttl.datamovement()
def dm_write():
core_x, _ = ttl.core(dims=2)
for local_z in range(slices_per_core):
z = core_x * slices_per_core + local_z
if z < total_slices:
with out_cb.wait() as blk:
tx = ttl.copy(blk, potential_grid[z, 0]); tx.wait()
def compute_reciprocal_forces(device, positions, charges, box_length,
alpha=ALPHA, n_gauss=N_GAUSS, K=K_GRID):
"""Reciprocal-space forces via u-series Gaussian convolution.
Pipeline: spread charges -> xy-convolve (TT kernel) -> z-convolve (host) -> interpolate forces.
"""
h = box_length / K
exponents, weights = gaussian_decomposition(alpha, n_gauss)
charge_grid_3d = spread_charges(positions, charges, box_length, K)
# Pack charge grid into tile layout: each z-slice is one tile
cg_np = np.zeros((K * TILE, TILE), dtype=np.float32)
for z in range(K):
cg_np[z*TILE:(z+1)*TILE, :TILE] = charge_grid_3d[:, :, z].astype(np.float32)
# Pack Gaussian convolution kernels
km_np = np.zeros((n_gauss * TILE, TILE), dtype=np.float32)
for g in range(n_gauss):
M_g = make_conv_kernel(K, h, exponents[g]) * weights[g]
km_np[g*TILE:(g+1)*TILE, :TILE] = M_g.astype(np.float32)
def to_tt(arr):
return ttnn.from_torch(
torch.tensor(arr, dtype=torch.bfloat16),
dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT,
device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
cg_tt = to_tt(cg_np)
km_tt = to_tt(km_np)
pot_tt = to_tt(np.zeros((K * TILE, TILE), dtype=np.float32))
# XY-convolution on device
xy_conv_kernel(cg_tt, km_tt, pot_tt)
# Z-convolution on host
pot_np = ttnn.to_torch(pot_tt).float().numpy()
xy_conv_3d = np.zeros((K, K, K))
for z in range(K):
xy_conv_3d[:, :, z] = pot_np[z*TILE:(z+1)*TILE, :TILE]
M_z = np.zeros((K, K))
for g in range(n_gauss):
M_z += make_conv_kernel(K, h, exponents[g]) * weights[g]
potential_3d = np.zeros_like(xy_conv_3d)
for x in range(K):
potential_3d[x, :, :] = (M_z @ xy_conv_3d[x, :, :].T).T
# Interpolate forces from potential grid
forces_recip = interpolate_forces_bspline(positions, potential_3d, box_length)
forces_recip *= -charges[:, None]
return forces_recip
# ---- Cell-list data packing ----
def build_cell_data(positions, charges, box_length, r_cut):
"""Assign atoms to cells, pack into tile arrays, build self-exclusion masks."""
n = len(positions)
n_cells_dim = max(3, int(box_length / r_cut))
n_cells_total = n_cells_dim ** 3
cell_size = box_length / n_cells_dim
cidx = np.floor(positions / cell_size).astype(int) % n_cells_dim
cell_id = cidx[:, 0] * n_cells_dim**2 + cidx[:, 1] * n_cells_dim + cidx[:, 2]
sort_idx = np.argsort(cell_id, kind='stable')
sorted_cell_id = cell_id[sort_idx]
cell_counts = np.bincount(cell_id, minlength=n_cells_total)
cell_starts = np.zeros(n_cells_total + 1, dtype=int)
np.cumsum(cell_counts, out=cell_starts[1:])
local_idx = np.arange(n) - cell_starts[sorted_cell_id]
valid = local_idx < TILE
cell_atom_map = [sort_idx[cell_starts[c]:cell_starts[c+1]].tolist()
for c in range(n_cells_total)]
valid_atoms = sort_idx[valid]
valid_cells = sorted_cell_id[valid]
valid_local = local_idx[valid]
rows = valid_cells * TILE + valid_local
own_px = np.zeros((n_cells_total * TILE, TILE), dtype=np.float32)
own_py = np.zeros_like(own_px)
own_pz = np.zeros_like(own_px)
own_q = np.zeros_like(own_px)
own_px[rows, 0] = positions[valid_atoms, 0]
own_py[rows, 0] = positions[valid_atoms, 1]
own_pz[rows, 0] = positions[valid_atoms, 2]
own_q[rows, 0] = charges[valid_atoms]
# Self-exclusion masks: 1e6 for empty slots and self-pairs
offsets = np.array([(dx, dy, dz)
for dx in range(-1, 2) for dy in range(-1, 2) for dz in range(-1, 2)])
cell_3d = np.stack(np.unravel_index(np.arange(n_cells_total),
(n_cells_dim, n_cells_dim, n_cells_dim)), axis=-1)
nbr_3d = (cell_3d[:, None, :] + offsets[None, :, :]) % n_cells_dim
nbr_cid = (nbr_3d[:, :, 0] * n_cells_dim**2 +
nbr_3d[:, :, 1] * n_cells_dim + nbr_3d[:, :, 2])
own_cnt = np.minimum(cell_counts, TILE)
nbr_cnt = own_cnt[nbr_cid]
is_self = (nbr_cid == np.arange(n_cells_total)[:, None])
row_idx = np.arange(TILE)[None, None, :, None]
col_idx = np.arange(TILE)[None, None, None, :]
oc = own_cnt[:, None, None, None]
nc = nbr_cnt[:, :, None, None]
masks_4d = np.where(
(row_idx >= oc) | (col_idx >= nc) | (is_self[:, :, None, None] & (row_idx == col_idx)),
np.float32(1e6), np.float32(0.0))
masks = masks_4d.reshape(n_cells_total * N_NBR * TILE, TILE)
return own_px, own_py, own_pz, own_q, masks, cell_atom_map, n_cells_total, n_cells_dim
def pack_cell_layout(data_3col, cell_atom_map, n_cells_total):
"""Pack per-atom (N,3) data into cell-layout tile arrays."""
ax = np.zeros((n_cells_total * TILE, TILE), dtype=np.float32)
ay = np.zeros_like(ax)
az = np.zeros_like(ax)
for cell_id, atoms in enumerate(cell_atom_map):
for k in range(min(len(atoms), TILE)):
row = cell_id * TILE + k
ax[row, 0] = data_3col[atoms[k], 0]
ay[row, 0] = data_3col[atoms[k], 1]
az[row, 0] = data_3col[atoms[k], 2]
return ax, ay, az
def extract_cell_data(dx_np, dy_np, dz_np, cell_atom_map, n_atoms):
"""Unpack cell-layout tile arrays back to per-atom (N,3) order."""
result = np.zeros((n_atoms, 3))
for cell_id, atoms in enumerate(cell_atom_map):
for k in range(min(len(atoms), TILE)):
row = cell_id * TILE + k
result[atoms[k], 0] = dx_np[row, 0]
result[atoms[k], 1] = dy_np[row, 0]
result[atoms[k], 2] = dz_np[row, 0]
return result
# ---- On-device Verlet MD loop ----
def run_md_loop(device, positions, velocities, charges, box_length,
n_steps=10, dt=0.005, rebuild_every=100, alpha=ALPHA):
"""Run MD with velocity Verlet integration on Tenstorrent hardware.
Full Ewald forces: real-space (erfc-damped LJ+Coulomb, bf16 TT kernel) +
reciprocal-space (u-series Gaussian convolution, bf16 TT kernel + host).
Positions and velocities stored in f32 for integration precision.
Between rebuilds, all updates happen on device with zero host copies (~1ms/step).
Reciprocal forces are recomputed on each rebuild and held constant between.
"""
n = len(positions)
r_cut = min(box_length / 2.0 - 0.1, 3.0 / alpha)
# Constants captured by kernel closures
c_box = float(box_length)
c_inv_box = 1.0 / float(box_length)
c_half = 0.5
c_dt_half = 0.5 * dt
c_dt = float(dt)
c_lj_scale = 24.0
c_alpha_sq = float(alpha * alpha)
c_p_alpha = float(ERFC_P * alpha)
c_two_a_sp = float(2.0 * alpha / np.sqrt(np.pi))
c_a1 = float(ERFC_A1)
c_a2 = float(-ERFC_A2)
c_a3 = float(ERFC_A3)
c_a4 = float(-ERFC_A4)
c_a5 = float(ERFC_A5)
# --- Verlet velocity half-step: vel += 0.5*dt*force (f32) ---
@ttl.kernel(grid="auto")
def vel_half_kernel(vel_x, vel_y, vel_z, fx, fy, fz):
grid_cols, _ = ttl.grid_size(dims=2)
nc = vel_x.shape[0] // TILE
cpc = -(-nc // grid_cols)
v_cb = ttl.make_dataflow_buffer_like(vel_x, shape=(1, 1), buffer_factor=2)
f_cb = ttl.make_dataflow_buffer_like(fx, shape=(1, 1), buffer_factor=2)
o_cb = ttl.make_dataflow_buffer_like(vel_x, shape=(1, 1), buffer_factor=2)
@ttl.compute()
def compute():
core_x, _ = ttl.core(dims=2)
for lc in range(cpc):
cid = core_x * cpc + lc
if cid < nc:
for _ in range(3):
with v_cb.wait() as v, f_cb.wait() as f, o_cb.reserve() as o:
o.store(v + f * ttl.math.fill(v, c_dt_half))
@ttl.datamovement()
def dm_read():
core_x, _ = ttl.core(dims=2)
for lc in range(cpc):
cid = core_x * cpc + lc
if cid < nc:
with v_cb.reserve() as blk:
tx = ttl.copy(vel_x[cid, 0], blk); tx.wait()
with f_cb.reserve() as blk:
tx = ttl.copy(fx[cid, 0], blk); tx.wait()
with v_cb.reserve() as blk:
tx = ttl.copy(vel_y[cid, 0], blk); tx.wait()
with f_cb.reserve() as blk:
tx = ttl.copy(fy[cid, 0], blk); tx.wait()
with v_cb.reserve() as blk:
tx = ttl.copy(vel_z[cid, 0], blk); tx.wait()
with f_cb.reserve() as blk:
tx = ttl.copy(fz[cid, 0], blk); tx.wait()
@ttl.datamovement()
def dm_write():
core_x, _ = ttl.core(dims=2)
for lc in range(cpc):
cid = core_x * cpc + lc
if cid < nc:
with o_cb.wait() as blk:
tx = ttl.copy(blk, vel_x[cid, 0]); tx.wait()
with o_cb.wait() as blk:
tx = ttl.copy(blk, vel_y[cid, 0]); tx.wait()
with o_cb.wait() as blk:
tx = ttl.copy(blk, vel_z[cid, 0]); tx.wait()
# --- Verlet position update: pos += dt*vel, PBC wrap (f32) ---
@ttl.kernel(grid="auto")
def pos_update_kernel(pos_x, pos_y, pos_z, vel_x, vel_y, vel_z):
grid_cols, _ = ttl.grid_size(dims=2)
nc = pos_x.shape[0] // TILE
cpc = -(-nc // grid_cols)
p_cb = ttl.make_dataflow_buffer_like(pos_x, shape=(1, 1), buffer_factor=2)
v_cb = ttl.make_dataflow_buffer_like(vel_x, shape=(1, 1), buffer_factor=2)
o_cb = ttl.make_dataflow_buffer_like(pos_x, shape=(1, 1), buffer_factor=2)
@ttl.compute()
def compute():
core_x, _ = ttl.core(dims=2)
for lc in range(cpc):
cid = core_x * cpc + lc
if cid < nc:
for _ in range(3):
with p_cb.wait() as p, v_cb.wait() as v, o_cb.reserve() as o:
new_p = p + v * ttl.math.fill(p, c_dt)
o.store(new_p - ttl.math.fill(p, c_box) * ttl.math.floor(new_p * ttl.math.fill(p, c_inv_box)))
@ttl.datamovement()
def dm_read():
core_x, _ = ttl.core(dims=2)
for lc in range(cpc):
cid = core_x * cpc + lc
if cid < nc:
with p_cb.reserve() as blk:
tx = ttl.copy(pos_x[cid, 0], blk); tx.wait()
with v_cb.reserve() as blk:
tx = ttl.copy(vel_x[cid, 0], blk); tx.wait()
with p_cb.reserve() as blk:
tx = ttl.copy(pos_y[cid, 0], blk); tx.wait()
with v_cb.reserve() as blk:
tx = ttl.copy(vel_y[cid, 0], blk); tx.wait()
with p_cb.reserve() as blk:
tx = ttl.copy(pos_z[cid, 0], blk); tx.wait()
with v_cb.reserve() as blk:
tx = ttl.copy(vel_z[cid, 0], blk); tx.wait()
@ttl.datamovement()
def dm_write():
core_x, _ = ttl.core(dims=2)
for lc in range(cpc):
cid = core_x * cpc + lc
if cid < nc:
with o_cb.wait() as blk:
tx = ttl.copy(blk, pos_x[cid, 0]); tx.wait()
with o_cb.wait() as blk:
tx = ttl.copy(blk, pos_y[cid, 0]); tx.wait()
with o_cb.wait() as blk:
tx = ttl.copy(blk, pos_z[cid, 0]); tx.wait()
# --- Force kernel: fused geometry + LJ + erfc Coulomb (bf16, 28 CBs) ---
def make_force_kernel(c_n_dim, c_dim2):
@ttl.kernel(grid="auto")
def cell_forces_kernel(own_px, own_py, own_pz, own_q,
self_mask, scaler,
fx_out, fy_out, fz_out):
grid_cols, _ = ttl.grid_size(dims=2)
n_cells = own_px.shape[0] // TILE
cells_per_core = -(-n_cells // grid_cols)
ox_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2)
oy_cb = ttl.make_dataflow_buffer_like(own_py, shape=(1, 1), buffer_factor=2)
oz_cb = ttl.make_dataflow_buffer_like(own_pz, shape=(1, 1), buffer_factor=2)
oq_cb = ttl.make_dataflow_buffer_like(own_q, shape=(1, 1), buffer_factor=2)
ex_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2)
ey_cb = ttl.make_dataflow_buffer_like(own_py, shape=(1, 1), buffer_factor=2)
ez_cb = ttl.make_dataflow_buffer_like(own_pz, shape=(1, 1), buffer_factor=2)
eq_cb = ttl.make_dataflow_buffer_like(own_q, shape=(1, 1), buffer_factor=2)
sm_cb = ttl.make_dataflow_buffer_like(self_mask, shape=(1, 1), buffer_factor=2)
sc_cb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2)
ba_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2)
tr_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2)
bb_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2)
r2_tmp = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2)
r2_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2)
qq_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2)
dx_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2)
dy_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2)
dz_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2)
fm_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2)
ft_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2)
fr_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2)
ax_cb = ttl.make_dataflow_buffer_like(fx_out, shape=(1, 1), buffer_factor=2)
ay_cb = ttl.make_dataflow_buffer_like(fy_out, shape=(1, 1), buffer_factor=2)
az_cb = ttl.make_dataflow_buffer_like(fz_out, shape=(1, 1), buffer_factor=2)
fxo_cb = ttl.make_dataflow_buffer_like(fx_out, shape=(1, 1), buffer_factor=2)
fyo_cb = ttl.make_dataflow_buffer_like(fy_out, shape=(1, 1), buffer_factor=2)
fzo_cb = ttl.make_dataflow_buffer_like(fz_out, shape=(1, 1), buffer_factor=2)
@ttl.compute()
def compute():
core_x, _ = ttl.core(dims=2)
for local_c in range(cells_per_core):
cell_id = core_x * cells_per_core + local_c
if cell_id < n_cells:
with ox_cb.wait() as ox, oy_cb.wait() as oy, oz_cb.wait() as oz, oq_cb.wait() as oq:
with sc_cb.wait() as sc:
for nbr_i in range(N_NBR):
with ex_cb.wait() as ex, ey_cb.wait() as ey, ez_cb.wait() as ez, eq_cb.wait() as eq, sm_cb.wait() as sm:
# PBC displacements for x, y, z
with ba_cb.reserve() as ba:
ba.store(ttl.math.broadcast(ox, dims=[1]))
with tr_cb.reserve() as tr:
tr.store(ttl.transpose(ex))
with tr_cb.wait() as trv, bb_cb.reserve() as bb:
bb.store(ttl.math.broadcast(trv, dims=[0]))
with ba_cb.wait() as bav, bb_cb.wait() as bbv:
dx_raw = bav - bbv
dx_pbc = dx_raw - ttl.math.fill(bav, c_box) * ttl.math.floor(dx_raw * ttl.math.fill(bav, c_inv_box) + ttl.math.fill(bav, c_half))
with r2_tmp.reserve() as r2o:
r2o.store(dx_pbc * dx_pbc)
with dx_cb.reserve() as dxo:
dxo.store(dx_pbc)
with ba_cb.reserve() as ba:
ba.store(ttl.math.broadcast(oy, dims=[1]))
with tr_cb.reserve() as tr:
tr.store(ttl.transpose(ey))
with tr_cb.wait() as trv, bb_cb.reserve() as bb:
bb.store(ttl.math.broadcast(trv, dims=[0]))
with ba_cb.wait() as bav, bb_cb.wait() as bbv, r2_tmp.wait() as r2p:
dy_raw = bav - bbv
dy_pbc = dy_raw - ttl.math.fill(bav, c_box) * ttl.math.floor(dy_raw * ttl.math.fill(bav, c_inv_box) + ttl.math.fill(bav, c_half))
with r2_tmp.reserve() as r2o:
r2o.store(r2p + dy_pbc * dy_pbc)
with dy_cb.reserve() as dyo:
dyo.store(dy_pbc)
with ba_cb.reserve() as ba:
ba.store(ttl.math.broadcast(oz, dims=[1]))
with tr_cb.reserve() as tr:
tr.store(ttl.transpose(ez))
with tr_cb.wait() as trv, bb_cb.reserve() as bb:
bb.store(ttl.math.broadcast(trv, dims=[0]))
with ba_cb.wait() as bav, bb_cb.wait() as bbv, r2_tmp.wait() as r2p:
dz_raw = bav - bbv
dz_pbc = dz_raw - ttl.math.fill(bav, c_box) * ttl.math.floor(dz_raw * ttl.math.fill(bav, c_inv_box) + ttl.math.fill(bav, c_half))
with r2_cb.reserve() as r2o:
r2o.store(r2p + dz_pbc * dz_pbc + sm)
with dz_cb.reserve() as dzo:
dzo.store(dz_pbc)
# Charge products
with ba_cb.reserve() as ba:
ba.store(ttl.math.broadcast(oq, dims=[1]))
with tr_cb.reserve() as tr:
tr.store(ttl.transpose(eq))
with tr_cb.wait() as trv, bb_cb.reserve() as bb:
bb.store(ttl.math.broadcast(trv, dims=[0]))
with ba_cb.wait() as bav, bb_cb.wait() as bbv, qq_cb.reserve() as qqo:
qqo.store(bav * bbv)
# erfc-damped Coulomb + LJ forces
with r2_cb.wait() as r2, qq_cb.wait() as qq:
r_inv = ttl.math.rsqrt(r2)
r2_inv = ttl.math.recip(r2)
r_val = r2 * r_inv
t = ttl.math.recip(r_inv * r_inv * r2 + ttl.math.fill(r2, c_p_alpha) * r_val)
poly = t * (ttl.math.fill(r2, c_a1) + t * (ttl.math.neg(ttl.math.fill(r2, c_a2)) + t * (ttl.math.fill(r2, c_a3) + t * (ttl.math.neg(ttl.math.fill(r2, c_a4)) + t * ttl.math.fill(r2, c_a5)))))
exp_neg = ttl.math.exp(ttl.math.neg(ttl.math.fill(r2, c_alpha_sq) * r2))
erfc_val = poly * exp_neg
with ft_cb.reserve() as coul:
coul.store(qq * (erfc_val * r2_inv + ttl.math.fill(r2, c_two_a_sp) * exp_neg * r_inv) * r_inv)
r2_inv2 = ttl.math.recip(r2)
r4_inv = r2_inv2 * r2_inv2
r6_inv = r4_inv * r2_inv2
r12_inv = r6_inv * r6_inv
with fr_cb.reserve() as lj:
lj.store(ttl.math.fill(r2, c_lj_scale) * r2_inv2 * (r12_inv + r12_inv - r6_inv))
with ft_cb.wait() as fc, fr_cb.wait() as fl:
with fm_cb.reserve() as fmo:
fmo.store(fl + fc)
# Project onto displacements, reduce, accumulate
with fm_cb.wait() as fm:
with dx_cb.wait() as dxv:
with ft_cb.reserve() as ft:
ft.store(fm * dxv)
with ft_cb.wait() as ftv, fr_cb.reserve() as fr:
fr.store(ttl.math.reduce_sum(ftv, sc, dims=[0]))
if nbr_i == 0:
with fr_cb.wait() as frv, ax_cb.reserve() as ax:
ax.store(frv)
else:
with fr_cb.wait() as frv, ax_cb.wait() as prev:
with ax_cb.reserve() as ax:
ax.store(prev + frv)
with dy_cb.wait() as dyv:
with ft_cb.reserve() as ft:
ft.store(fm * dyv)
with ft_cb.wait() as ftv, fr_cb.reserve() as fr:
fr.store(ttl.math.reduce_sum(ftv, sc, dims=[0]))
if nbr_i == 0:
with fr_cb.wait() as frv, ay_cb.reserve() as ay:
ay.store(frv)
else:
with fr_cb.wait() as frv, ay_cb.wait() as prev:
with ay_cb.reserve() as ay:
ay.store(prev + frv)
with dz_cb.wait() as dzv:
with ft_cb.reserve() as ft:
ft.store(fm * dzv)
with ft_cb.wait() as ftv, fr_cb.reserve() as fr:
fr.store(ttl.math.reduce_sum(ftv, sc, dims=[0]))
if nbr_i == 0:
with fr_cb.wait() as frv, az_cb.reserve() as az:
az.store(frv)
else:
with fr_cb.wait() as frv, az_cb.wait() as prev:
with az_cb.reserve() as az:
az.store(prev + frv)
with ax_cb.wait() as fx, fxo_cb.reserve() as fxo:
fxo.store(fx)
with ay_cb.wait() as fy, fyo_cb.reserve() as fyo:
fyo.store(fy)
with az_cb.wait() as fz, fzo_cb.reserve() as fzo:
fzo.store(fz)
@ttl.datamovement()
def dm_read():
core_x, _ = ttl.core(dims=2)
for local_c in range(cells_per_core):
cell_id = core_x * cells_per_core + local_c
if cell_id < n_cells:
with ox_cb.reserve() as blk:
tx = ttl.copy(own_px[cell_id, 0], blk); tx.wait()
with oy_cb.reserve() as blk:
tx = ttl.copy(own_py[cell_id, 0], blk); tx.wait()
with oz_cb.reserve() as blk:
tx = ttl.copy(own_pz[cell_id, 0], blk); tx.wait()
with oq_cb.reserve() as blk:
tx = ttl.copy(own_q[cell_id, 0], blk); tx.wait()
with sc_cb.reserve() as blk:
tx = ttl.copy(scaler[0, 0], blk); tx.wait()
# Compute neighbor cell IDs on-device
cx = cell_id // c_dim2
cy = (cell_id // c_n_dim) % c_n_dim
cz = cell_id % c_n_dim
for nbr in range(N_NBR):
off_dx = (nbr // 9) - 1
off_dy = ((nbr // 3) % 3) - 1
off_dz = (nbr % 3) - 1
nbr_cell = ((cx + off_dx + c_n_dim) % c_n_dim) * c_dim2 + ((cy + off_dy + c_n_dim) % c_n_dim) * c_n_dim + ((cz + off_dz + c_n_dim) % c_n_dim)
with ex_cb.reserve() as blk:
tx = ttl.copy(own_px[nbr_cell, 0], blk); tx.wait()
with ey_cb.reserve() as blk:
tx = ttl.copy(own_py[nbr_cell, 0], blk); tx.wait()
with ez_cb.reserve() as blk:
tx = ttl.copy(own_pz[nbr_cell, 0], blk); tx.wait()
with eq_cb.reserve() as blk:
tx = ttl.copy(own_q[nbr_cell, 0], blk); tx.wait()
with sm_cb.reserve() as blk:
tx = ttl.copy(self_mask[cell_id * N_NBR + nbr, 0], blk); tx.wait()
@ttl.datamovement()
def dm_write():
core_x, _ = ttl.core(dims=2)
for local_c in range(cells_per_core):
cell_id = core_x * cells_per_core + local_c
if cell_id < n_cells:
with fxo_cb.wait() as blk:
tx = ttl.copy(blk, fx_out[cell_id, 0]); tx.wait()
with fyo_cb.wait() as blk:
tx = ttl.copy(blk, fy_out[cell_id, 0]); tx.wait()
with fzo_cb.wait() as blk:
tx = ttl.copy(blk, fz_out[cell_id, 0]); tx.wait()
return cell_forces_kernel
# --- Tensor helpers ---
def to_bf16(arr, l1=False):
return ttnn.from_torch(
torch.tensor(arr, dtype=torch.bfloat16),
dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=ttnn.L1_MEMORY_CONFIG if l1 else ttnn.DRAM_MEMORY_CONFIG)
def to_f32(arr, l1=False):
return ttnn.from_torch(
torch.tensor(arr, dtype=torch.float32),
dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=ttnn.L1_MEMORY_CONFIG if l1 else ttnn.DRAM_MEMORY_CONFIG)
def run_force_kernel(tt_px, tt_py, tt_pz, tt_q, tt_masks, tt_scaler,
tt_fx, tt_fy, tt_fz):
"""Typecast f32 <-> bf16 around the bf16 force kernel."""
px_bf16 = ttnn.typecast(tt_px, ttnn.bfloat16)
py_bf16 = ttnn.typecast(tt_py, ttnn.bfloat16)
pz_bf16 = ttnn.typecast(tt_pz, ttnn.bfloat16)
fx_bf16 = ttnn.typecast(tt_fx, ttnn.bfloat16)
fy_bf16 = ttnn.typecast(tt_fy, ttnn.bfloat16)
fz_bf16 = ttnn.typecast(tt_fz, ttnn.bfloat16)
cell_forces_kernel(px_bf16, py_bf16, pz_bf16, tt_q,
tt_masks, tt_scaler, fx_bf16, fy_bf16, fz_bf16)
return (ttnn.typecast(fx_bf16, ttnn.float32),
ttnn.typecast(fy_bf16, ttnn.float32),
ttnn.typecast(fz_bf16, ttnn.float32))
# --- Build initial state ---
(own_px, own_py, own_pz, own_q,
masks, cell_atom_map, n_cells_total, n_cells_dim) = \
build_cell_data(positions, charges, box_length, r_cut)
c_n_dim = int(n_cells_dim)
c_dim2 = c_n_dim * c_n_dim
vel_x, vel_y, vel_z = pack_cell_layout(velocities, cell_atom_map, n_cells_total)
# Reciprocal forces (recomputed on each rebuild, held constant between)
f_recip = compute_reciprocal_forces(device, positions, charges, box_length, alpha)
recip_x, recip_y, recip_z = pack_cell_layout(f_recip, cell_atom_map, n_cells_total)
# L1 for small systems, DRAM for large
cell_bytes_f32 = n_cells_total * TILE * TILE * 4
use_l1 = (cell_bytes_f32 * 10 < 80_000_000)
tt_px = to_f32(own_px, l1=use_l1)
tt_py = to_f32(own_py, l1=use_l1)
tt_pz = to_f32(own_pz, l1=use_l1)
tt_q = to_bf16(own_q, l1=use_l1)
tt_vx = to_f32(vel_x, l1=use_l1)
tt_vy = to_f32(vel_y, l1=use_l1)
tt_vz = to_f32(vel_z, l1=use_l1)
tt_masks = to_bf16(masks)
tt_scaler = to_bf16(np.ones((TILE, TILE), dtype=np.float32), l1=True)
zeros = np.zeros((n_cells_total * TILE, TILE), dtype=np.float32)
tt_fx = to_f32(zeros, l1=use_l1)
tt_fy = to_f32(zeros.copy(), l1=use_l1)
tt_fz = to_f32(zeros.copy(), l1=use_l1)
# Reciprocal forces stored on device (f32, added after real-space)
tt_rx = to_f32(recip_x, l1=use_l1)
tt_ry = to_f32(recip_y, l1=use_l1)
tt_rz = to_f32(recip_z, l1=use_l1)
cell_forces_kernel = make_force_kernel(c_n_dim, c_dim2)
# Initial forces (real-space + reciprocal)
tt_fx, tt_fy, tt_fz = run_force_kernel(
tt_px, tt_py, tt_pz, tt_q, tt_masks, tt_scaler, tt_fx, tt_fy, tt_fz)
tt_fx = ttnn.add(tt_fx, tt_rx)
tt_fy = ttnn.add(tt_fy, tt_ry)
tt_fz = ttnn.add(tt_fz, tt_rz)
# --- Step loop ---
step_times = []
for step in range(n_steps):
t0 = time.time()
if (step > 0) and (step % rebuild_every == 0):
px_np = ttnn.to_torch(tt_px).float().numpy()
py_np = ttnn.to_torch(tt_py).float().numpy()
pz_np = ttnn.to_torch(tt_pz).float().numpy()
vx_np = ttnn.to_torch(tt_vx).float().numpy()
vy_np = ttnn.to_torch(tt_vy).float().numpy()
vz_np = ttnn.to_torch(tt_vz).float().numpy()
fx_np = ttnn.to_torch(tt_fx).float().numpy()
fy_np = ttnn.to_torch(tt_fy).float().numpy()
fz_np = ttnn.to_torch(tt_fz).float().numpy()
positions = extract_cell_data(px_np, py_np, pz_np, cell_atom_map, n)
velocities = extract_cell_data(vx_np, vy_np, vz_np, cell_atom_map, n)
forces_atom = extract_cell_data(fx_np, fy_np, fz_np, cell_atom_map, n)
(own_px, own_py, own_pz, own_q,
masks, cell_atom_map, n_cells_total, n_cells_dim) = \
build_cell_data(positions, charges, box_length, r_cut)
c_n_dim_new = int(n_cells_dim)
if c_n_dim_new != c_n_dim:
c_n_dim = c_n_dim_new
c_dim2 = c_n_dim * c_n_dim
cell_forces_kernel = make_force_kernel(c_n_dim, c_dim2)
vel_x, vel_y, vel_z = pack_cell_layout(velocities, cell_atom_map, n_cells_total)
f_x, f_y, f_z = pack_cell_layout(forces_atom, cell_atom_map, n_cells_total)
# Recompute reciprocal forces at new positions
f_recip = compute_reciprocal_forces(device, positions, charges, box_length, alpha)
recip_x, recip_y, recip_z = pack_cell_layout(f_recip, cell_atom_map, n_cells_total)
tt_px = to_f32(own_px, l1=use_l1)
tt_py = to_f32(own_py, l1=use_l1)
tt_pz = to_f32(own_pz, l1=use_l1)
tt_q = to_bf16(own_q, l1=use_l1)
tt_vx = to_f32(vel_x, l1=use_l1)
tt_vy = to_f32(vel_y, l1=use_l1)
tt_vz = to_f32(vel_z, l1=use_l1)
tt_masks = to_bf16(masks)
tt_fx = to_f32(f_x, l1=use_l1)
tt_fy = to_f32(f_y, l1=use_l1)
tt_fz = to_f32(f_z, l1=use_l1)
tt_rx = to_f32(recip_x, l1=use_l1)
tt_ry = to_f32(recip_y, l1=use_l1)
tt_rz = to_f32(recip_z, l1=use_l1)
vel_half_kernel(tt_vx, tt_vy, tt_vz, tt_fx, tt_fy, tt_fz)
pos_update_kernel(tt_px, tt_py, tt_pz, tt_vx, tt_vy, tt_vz)
# Real-space forces (bf16 kernel) + reciprocal correction (on-device add)
tt_fx, tt_fy, tt_fz = run_force_kernel(
tt_px, tt_py, tt_pz, tt_q, tt_masks, tt_scaler, tt_fx, tt_fy, tt_fz)
tt_fx = ttnn.add(tt_fx, tt_rx)
tt_fy = ttnn.add(tt_fy, tt_ry)
tt_fz = ttnn.add(tt_fz, tt_rz)
vel_half_kernel(tt_vx, tt_vy, tt_vz, tt_fx, tt_fy, tt_fz)
step_times.append(time.time() - t0)
# Read final state
px_np = ttnn.to_torch(tt_px).float().numpy()
py_np = ttnn.to_torch(tt_py).float().numpy()
pz_np = ttnn.to_torch(tt_pz).float().numpy()
vx_np = ttnn.to_torch(tt_vx).float().numpy()
vy_np = ttnn.to_torch(tt_vy).float().numpy()
vz_np = ttnn.to_torch(tt_vz).float().numpy()
final_pos = extract_cell_data(px_np, py_np, pz_np, cell_atom_map, n)
final_vel = extract_cell_data(vx_np, vy_np, vz_np, cell_atom_map, n)
return final_pos, final_vel, step_times
# ---- Main: validation + sample run ----
if __name__ == "__main__":
device = ttnn.open_device(device_id=0)
# 1. Energy conservation (256 atoms, 10 steps)
print("=" * 50)
print("ENERGY CONSERVATION (256 atoms, 10 steps)")
print("=" * 50)
positions, charges, box_length = make_system(256)
velocities = np.random.randn(256, 3) * 0.1
velocities -= velocities.mean(axis=0)
ke0 = 0.5 * np.sum(velocities ** 2)
pe0 = compute_energy(positions, charges, box_length)
e0 = ke0 + pe0
final_pos, final_vel, step_times = run_md_loop(
device, positions.copy(), velocities.copy(), charges, box_length,
n_steps=10, dt=0.005, rebuild_every=10, alpha=ALPHA)
ke = 0.5 * np.sum(final_vel ** 2)
pe = compute_energy(final_pos, charges, box_length)
drift = abs((ke + pe) - e0) / abs(e0) * 100
print(f" E0={e0:.2f} E_final={ke+pe:.2f} drift={drift:.4f}%")
print(f" Step times: {[f'{t:.3f}s' for t in step_times]}")
print(f" {'PASS' if drift < 5.0 else 'FAIL'}")
# 2. Performance at scale (10K atoms, 100 steps, no rebuilds)
print("\n" + "=" * 50)
print("PERFORMANCE (10K atoms, 100 steps)")
print("=" * 50)
positions, charges, box_length = make_system(10000)
velocities = np.random.randn(10000, 3) * 0.1
velocities -= velocities.mean(axis=0)
t0 = time.time()
final_pos, final_vel, step_times = run_md_loop(
device, positions.copy(), velocities.copy(), charges, box_length,
n_steps=100, dt=0.005, rebuild_every=100, alpha=ALPHA)
t_total = time.time() - t0
avg_no_first = np.mean(step_times[1:])
print(f" Total: {t_total:.1f}s")
print(f" First step (incl compilation): {step_times[0]:.3f}s")
print(f" Avg step (excl first): {avg_no_first*1000:.2f}ms")
print(f" Last 5: {[f'{t:.3f}' for t in step_times[-5:]]}")
ttnn.close_device(device)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment