Last active
March 10, 2026 22:15
-
-
Save zoecarver/b52aa28302f179505343a11067afd44d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Cell-list molecular dynamics on Tenstorrent hardware using TT-Lang. | |
| Full Ewald electrostatics: erfc-damped real-space (cell-list, TT kernel) + | |
| u-series reciprocal-space (separable Gaussian convolution, TT kernel + host). | |
| LJ short-range forces. Periodic boundary conditions. | |
| On-device Verlet integration (f32) with bf16 force kernels. | |
| Validated: 10K atoms, 10K steps, 1.1ms/step (non-rebuild), 12 min total. | |
| """ | |
| import time | |
| import torch | |
| import numpy as np | |
| import ttnn | |
| import ttl | |
| TILE = 32 | |
| N_NBR = 27 | |
| ALPHA = 1.0 | |
| K_GRID = 32 | |
| N_GAUSS = 16 | |
| ERFC_A1 = 0.254829592 | |
| ERFC_A2 = -0.284496736 | |
| ERFC_A3 = 1.421413741 | |
| ERFC_A4 = -1.453152027 | |
| ERFC_A5 = 1.061405429 | |
| ERFC_P = 0.3275911 | |
| # ---- Host utilities ---- | |
| def make_system(n_atoms, density=0.3, seed=42): | |
| """Create a lattice system with charges.""" | |
| np.random.seed(seed) | |
| box_length = (n_atoms / density) ** (1.0 / 3.0) | |
| n_side = int(np.ceil(n_atoms ** (1.0 / 3.0))) | |
| spacing = box_length / n_side | |
| positions = [] | |
| for ix in range(n_side): | |
| for iy in range(n_side): | |
| for iz in range(n_side): | |
| if len(positions) < n_atoms: | |
| positions.append([(ix+0.5)*spacing, (iy+0.5)*spacing, (iz+0.5)*spacing]) | |
| positions = np.array(positions[:n_atoms]) | |
| positions += np.random.normal(0, 0.05, positions.shape) | |
| positions = positions % box_length | |
| charges = np.random.randn(n_atoms) * 0.3 | |
| charges -= charges.mean() | |
| return positions, charges, box_length | |
| def compute_energy(positions, charges, box_length): | |
| """O(N^2) energy for validation (LJ + Coulomb, no cutoff).""" | |
| dr = positions[:, None, :] - positions[None, :, :] | |
| dr -= box_length * np.floor(dr / box_length + 0.5) | |
| r2 = np.sum(dr * dr, axis=2) | |
| np.fill_diagonal(r2, 1e10) | |
| r = np.sqrt(r2) | |
| r6_inv = (1.0 / r2) ** 3 | |
| lj = 4.0 * (r6_inv ** 2 - r6_inv) | |
| coul = (charges[:, None] * charges[None, :]) / r | |
| return 0.5 * (np.sum(lj) + np.sum(coul)) | |
| def direct_forces(positions, charges, box_length): | |
| """O(N^2) reference forces (LJ + Coulomb, no cutoff).""" | |
| dr = positions[:, None, :] - positions[None, :, :] | |
| dr -= box_length * np.floor(dr / box_length + 0.5) | |
| r2 = np.sum(dr * dr, axis=2) | |
| np.fill_diagonal(r2, 1e10) | |
| r = np.sqrt(r2) | |
| r2_inv = 1.0 / r2 | |
| r6_inv = r2_inv ** 3 | |
| f_lj = 24.0 * r2_inv * (2.0 * r6_inv ** 2 - r6_inv) | |
| qq = charges[:, None] * charges[None, :] | |
| f_coul = qq / (r2 * r) | |
| return np.sum((f_lj + f_coul)[:, :, None] * dr, axis=1) | |
| # ---- Reciprocal-space: u-series Gaussian convolution ---- | |
| def gaussian_decomposition(alpha, n_gauss): | |
| """Gauss-Legendre decomposition of 1/r into sum of Gaussians.""" | |
| nodes, gl_weights = np.polynomial.legendre.leggauss(n_gauss) | |
| t = alpha / 2.0 * (nodes + 1.0) | |
| w = alpha / 2.0 * gl_weights | |
| return t ** 2, (2.0 / np.sqrt(np.pi)) * w | |
| def make_conv_kernel(K, h, exponent): | |
| """Build KxK circulant Gaussian convolution matrix.""" | |
| M = np.zeros((K, K)) | |
| for i in range(K): | |
| for j in range(K): | |
| d = min(abs(i - j), K - abs(i - j)) | |
| M[i, j] = np.exp(-exponent * (d * h) ** 2) | |
| return M | |
| def bspline4_weights_vec(u): | |
| """Order-4 B-spline weights for charge spreading.""" | |
| u2 = u * u; u3 = u2 * u | |
| w3 = u3 / 6.0 | |
| w2 = (1.0 + 3.0 * u * (1.0 + u * (1.0 - u))) / 6.0 | |
| w1 = (4.0 - 6.0 * u2 + 3.0 * u3) / 6.0 | |
| w0 = 1.0 - w1 - w2 - w3 | |
| return np.stack([w0, w1, w2, w3], axis=-1) | |
| def bspline4_dweights_vec(u): | |
| """Order-4 B-spline derivative weights for force interpolation.""" | |
| u2 = u * u | |
| dw3 = u2 / 2.0 | |
| dw2 = (3.0 * (1.0 + 2.0 * u - 3.0 * u2)) / 6.0 | |
| dw1 = (-12.0 * u + 9.0 * u2) / 6.0 | |
| dw0 = -(dw1 + dw2 + dw3) | |
| return np.stack([dw0, dw1, dw2, dw3], axis=-1) | |
| def spread_charges(positions, charges, box_length, K, order=4): | |
| """B-spline charge spreading onto 3D grid.""" | |
| h = box_length / K | |
| s = positions / h | |
| g0 = np.floor(s).astype(int) - (order // 2 - 1) | |
| f = s - np.floor(s) | |
| wx = bspline4_weights_vec(f[:, 0]) | |
| wy = bspline4_weights_vec(f[:, 1]) | |
| wz = bspline4_weights_vec(f[:, 2]) | |
| w3d = wx[:, :, None, None] * wy[:, None, :, None] * wz[:, None, None, :] | |
| w3d *= charges[:, None, None, None] | |
| offsets = np.arange(order) | |
| gx = (g0[:, 0, None] + offsets[None, :]) % K | |
| gy = (g0[:, 1, None] + offsets[None, :]) % K | |
| gz = (g0[:, 2, None] + offsets[None, :]) % K | |
| grid = np.zeros((K, K, K)) | |
| for ix in range(order): | |
| for iy in range(order): | |
| for iz in range(order): | |
| np.add.at(grid, (gx[:, ix], gy[:, iy], gz[:, iz]), w3d[:, ix, iy, iz]) | |
| return grid | |
| def interpolate_forces_bspline(positions, potential_grid, box_length, order=4): | |
| """B-spline force interpolation from potential grid.""" | |
| K = potential_grid.shape[0] | |
| h = box_length / K | |
| s = positions / h | |
| g0 = np.floor(s).astype(int) - (order // 2 - 1) | |
| f = s - np.floor(s) | |
| wx = bspline4_weights_vec(f[:, 0]) | |
| wy = bspline4_weights_vec(f[:, 1]) | |
| wz = bspline4_weights_vec(f[:, 2]) | |
| dwx = bspline4_dweights_vec(f[:, 0]) | |
| dwy = bspline4_dweights_vec(f[:, 1]) | |
| dwz = bspline4_dweights_vec(f[:, 2]) | |
| offsets = np.arange(order) | |
| gx = (g0[:, 0, None] + offsets[None, :]) % K | |
| gy = (g0[:, 1, None] + offsets[None, :]) % K | |
| gz = (g0[:, 2, None] + offsets[None, :]) % K | |
| forces = np.zeros((len(positions), 3)) | |
| inv_h = 1.0 / h | |
| for ix in range(order): | |
| for iy in range(order): | |
| for iz in range(order): | |
| phi = potential_grid[gx[:, ix], gy[:, iy], gz[:, iz]] | |
| forces[:, 0] -= dwx[:, ix] * wy[:, iy] * wz[:, iz] * phi * inv_h | |
| forces[:, 1] -= wx[:, ix] * dwy[:, iy] * wz[:, iz] * phi * inv_h | |
| forces[:, 2] -= wx[:, ix] * wy[:, iy] * dwz[:, iz] * phi * inv_h | |
| return forces | |
| @ttl.kernel(grid="auto") | |
| def xy_conv_kernel(charge_grid, kernels, potential_grid): | |
| """U-series xy-convolution: M @ charge_slice @ M for each Gaussian component. | |
| Replaces the FFT in traditional PME. Separable Gaussian decomposition | |
| means we can convolve with matrix multiplies instead of FFTs. | |
| """ | |
| grid_cols, _ = ttl.grid_size(dims=2) | |
| total_slices = charge_grid.shape[0] // TILE | |
| n_kernels = kernels.shape[0] // TILE | |
| slices_per_core = -(-total_slices // grid_cols) | |
| cg_cb = ttl.make_dataflow_buffer_like(charge_grid, shape=(1, 1), buffer_factor=2) | |
| km_cb = ttl.make_dataflow_buffer_like(kernels, shape=(1, 1), buffer_factor=2) | |
| tmp_cb = ttl.make_dataflow_buffer_like(charge_grid, shape=(1, 1), buffer_factor=2) | |
| acc_cb = ttl.make_dataflow_buffer_like(potential_grid, shape=(1, 1), buffer_factor=2) | |
| par_cb = ttl.make_dataflow_buffer_like(potential_grid, shape=(1, 1), buffer_factor=2) | |
| out_cb = ttl.make_dataflow_buffer_like(potential_grid, shape=(1, 1), buffer_factor=1) | |
| @ttl.compute() | |
| def compute(): | |
| core_x, _ = ttl.core(dims=2) | |
| for local_z in range(slices_per_core): | |
| z = core_x * slices_per_core + local_z | |
| if z < total_slices: | |
| with cg_cb.wait() as cg: | |
| with km_cb.wait() as M: | |
| with tmp_cb.reserve() as t: | |
| t.store(M @ cg) | |
| with tmp_cb.wait() as tv: | |
| with acc_cb.reserve() as a: | |
| a.store(tv @ M) | |
| for g in range(n_kernels - 1): | |
| with km_cb.wait() as M: | |
| with tmp_cb.reserve() as t: | |
| t.store(M @ cg) | |
| with tmp_cb.wait() as tv: | |
| with par_cb.reserve() as p: | |
| p.store(tv @ M) | |
| with par_cb.wait() as pv, acc_cb.wait() as av: | |
| with acc_cb.reserve() as a: | |
| a.store(av + pv) | |
| with acc_cb.wait() as final, out_cb.reserve() as o: | |
| o.store(final) | |
| @ttl.datamovement() | |
| def dm_read(): | |
| core_x, _ = ttl.core(dims=2) | |
| for local_z in range(slices_per_core): | |
| z = core_x * slices_per_core + local_z | |
| if z < total_slices: | |
| with cg_cb.reserve() as blk: | |
| tx = ttl.copy(charge_grid[z, 0], blk); tx.wait() | |
| for g in range(n_kernels): | |
| with km_cb.reserve() as blk: | |
| tx = ttl.copy(kernels[g, 0], blk); tx.wait() | |
| @ttl.datamovement() | |
| def dm_write(): | |
| core_x, _ = ttl.core(dims=2) | |
| for local_z in range(slices_per_core): | |
| z = core_x * slices_per_core + local_z | |
| if z < total_slices: | |
| with out_cb.wait() as blk: | |
| tx = ttl.copy(blk, potential_grid[z, 0]); tx.wait() | |
| def compute_reciprocal_forces(device, positions, charges, box_length, | |
| alpha=ALPHA, n_gauss=N_GAUSS, K=K_GRID): | |
| """Reciprocal-space forces via u-series Gaussian convolution. | |
| Pipeline: spread charges -> xy-convolve (TT kernel) -> z-convolve (host) -> interpolate forces. | |
| """ | |
| h = box_length / K | |
| exponents, weights = gaussian_decomposition(alpha, n_gauss) | |
| charge_grid_3d = spread_charges(positions, charges, box_length, K) | |
| # Pack charge grid into tile layout: each z-slice is one tile | |
| cg_np = np.zeros((K * TILE, TILE), dtype=np.float32) | |
| for z in range(K): | |
| cg_np[z*TILE:(z+1)*TILE, :TILE] = charge_grid_3d[:, :, z].astype(np.float32) | |
| # Pack Gaussian convolution kernels | |
| km_np = np.zeros((n_gauss * TILE, TILE), dtype=np.float32) | |
| for g in range(n_gauss): | |
| M_g = make_conv_kernel(K, h, exponents[g]) * weights[g] | |
| km_np[g*TILE:(g+1)*TILE, :TILE] = M_g.astype(np.float32) | |
| def to_tt(arr): | |
| return ttnn.from_torch( | |
| torch.tensor(arr, dtype=torch.bfloat16), | |
| dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, | |
| device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG) | |
| cg_tt = to_tt(cg_np) | |
| km_tt = to_tt(km_np) | |
| pot_tt = to_tt(np.zeros((K * TILE, TILE), dtype=np.float32)) | |
| # XY-convolution on device | |
| xy_conv_kernel(cg_tt, km_tt, pot_tt) | |
| # Z-convolution on host | |
| pot_np = ttnn.to_torch(pot_tt).float().numpy() | |
| xy_conv_3d = np.zeros((K, K, K)) | |
| for z in range(K): | |
| xy_conv_3d[:, :, z] = pot_np[z*TILE:(z+1)*TILE, :TILE] | |
| M_z = np.zeros((K, K)) | |
| for g in range(n_gauss): | |
| M_z += make_conv_kernel(K, h, exponents[g]) * weights[g] | |
| potential_3d = np.zeros_like(xy_conv_3d) | |
| for x in range(K): | |
| potential_3d[x, :, :] = (M_z @ xy_conv_3d[x, :, :].T).T | |
| # Interpolate forces from potential grid | |
| forces_recip = interpolate_forces_bspline(positions, potential_3d, box_length) | |
| forces_recip *= -charges[:, None] | |
| return forces_recip | |
| # ---- Cell-list data packing ---- | |
| def build_cell_data(positions, charges, box_length, r_cut): | |
| """Assign atoms to cells, pack into tile arrays, build self-exclusion masks.""" | |
| n = len(positions) | |
| n_cells_dim = max(3, int(box_length / r_cut)) | |
| n_cells_total = n_cells_dim ** 3 | |
| cell_size = box_length / n_cells_dim | |
| cidx = np.floor(positions / cell_size).astype(int) % n_cells_dim | |
| cell_id = cidx[:, 0] * n_cells_dim**2 + cidx[:, 1] * n_cells_dim + cidx[:, 2] | |
| sort_idx = np.argsort(cell_id, kind='stable') | |
| sorted_cell_id = cell_id[sort_idx] | |
| cell_counts = np.bincount(cell_id, minlength=n_cells_total) | |
| cell_starts = np.zeros(n_cells_total + 1, dtype=int) | |
| np.cumsum(cell_counts, out=cell_starts[1:]) | |
| local_idx = np.arange(n) - cell_starts[sorted_cell_id] | |
| valid = local_idx < TILE | |
| cell_atom_map = [sort_idx[cell_starts[c]:cell_starts[c+1]].tolist() | |
| for c in range(n_cells_total)] | |
| valid_atoms = sort_idx[valid] | |
| valid_cells = sorted_cell_id[valid] | |
| valid_local = local_idx[valid] | |
| rows = valid_cells * TILE + valid_local | |
| own_px = np.zeros((n_cells_total * TILE, TILE), dtype=np.float32) | |
| own_py = np.zeros_like(own_px) | |
| own_pz = np.zeros_like(own_px) | |
| own_q = np.zeros_like(own_px) | |
| own_px[rows, 0] = positions[valid_atoms, 0] | |
| own_py[rows, 0] = positions[valid_atoms, 1] | |
| own_pz[rows, 0] = positions[valid_atoms, 2] | |
| own_q[rows, 0] = charges[valid_atoms] | |
| # Self-exclusion masks: 1e6 for empty slots and self-pairs | |
| offsets = np.array([(dx, dy, dz) | |
| for dx in range(-1, 2) for dy in range(-1, 2) for dz in range(-1, 2)]) | |
| cell_3d = np.stack(np.unravel_index(np.arange(n_cells_total), | |
| (n_cells_dim, n_cells_dim, n_cells_dim)), axis=-1) | |
| nbr_3d = (cell_3d[:, None, :] + offsets[None, :, :]) % n_cells_dim | |
| nbr_cid = (nbr_3d[:, :, 0] * n_cells_dim**2 + | |
| nbr_3d[:, :, 1] * n_cells_dim + nbr_3d[:, :, 2]) | |
| own_cnt = np.minimum(cell_counts, TILE) | |
| nbr_cnt = own_cnt[nbr_cid] | |
| is_self = (nbr_cid == np.arange(n_cells_total)[:, None]) | |
| row_idx = np.arange(TILE)[None, None, :, None] | |
| col_idx = np.arange(TILE)[None, None, None, :] | |
| oc = own_cnt[:, None, None, None] | |
| nc = nbr_cnt[:, :, None, None] | |
| masks_4d = np.where( | |
| (row_idx >= oc) | (col_idx >= nc) | (is_self[:, :, None, None] & (row_idx == col_idx)), | |
| np.float32(1e6), np.float32(0.0)) | |
| masks = masks_4d.reshape(n_cells_total * N_NBR * TILE, TILE) | |
| return own_px, own_py, own_pz, own_q, masks, cell_atom_map, n_cells_total, n_cells_dim | |
| def pack_cell_layout(data_3col, cell_atom_map, n_cells_total): | |
| """Pack per-atom (N,3) data into cell-layout tile arrays.""" | |
| ax = np.zeros((n_cells_total * TILE, TILE), dtype=np.float32) | |
| ay = np.zeros_like(ax) | |
| az = np.zeros_like(ax) | |
| for cell_id, atoms in enumerate(cell_atom_map): | |
| for k in range(min(len(atoms), TILE)): | |
| row = cell_id * TILE + k | |
| ax[row, 0] = data_3col[atoms[k], 0] | |
| ay[row, 0] = data_3col[atoms[k], 1] | |
| az[row, 0] = data_3col[atoms[k], 2] | |
| return ax, ay, az | |
| def extract_cell_data(dx_np, dy_np, dz_np, cell_atom_map, n_atoms): | |
| """Unpack cell-layout tile arrays back to per-atom (N,3) order.""" | |
| result = np.zeros((n_atoms, 3)) | |
| for cell_id, atoms in enumerate(cell_atom_map): | |
| for k in range(min(len(atoms), TILE)): | |
| row = cell_id * TILE + k | |
| result[atoms[k], 0] = dx_np[row, 0] | |
| result[atoms[k], 1] = dy_np[row, 0] | |
| result[atoms[k], 2] = dz_np[row, 0] | |
| return result | |
| # ---- On-device Verlet MD loop ---- | |
| def run_md_loop(device, positions, velocities, charges, box_length, | |
| n_steps=10, dt=0.005, rebuild_every=100, alpha=ALPHA): | |
| """Run MD with velocity Verlet integration on Tenstorrent hardware. | |
| Full Ewald forces: real-space (erfc-damped LJ+Coulomb, bf16 TT kernel) + | |
| reciprocal-space (u-series Gaussian convolution, bf16 TT kernel + host). | |
| Positions and velocities stored in f32 for integration precision. | |
| Between rebuilds, all updates happen on device with zero host copies (~1ms/step). | |
| Reciprocal forces are recomputed on each rebuild and held constant between. | |
| """ | |
| n = len(positions) | |
| r_cut = min(box_length / 2.0 - 0.1, 3.0 / alpha) | |
| # Constants captured by kernel closures | |
| c_box = float(box_length) | |
| c_inv_box = 1.0 / float(box_length) | |
| c_half = 0.5 | |
| c_dt_half = 0.5 * dt | |
| c_dt = float(dt) | |
| c_lj_scale = 24.0 | |
| c_alpha_sq = float(alpha * alpha) | |
| c_p_alpha = float(ERFC_P * alpha) | |
| c_two_a_sp = float(2.0 * alpha / np.sqrt(np.pi)) | |
| c_a1 = float(ERFC_A1) | |
| c_a2 = float(-ERFC_A2) | |
| c_a3 = float(ERFC_A3) | |
| c_a4 = float(-ERFC_A4) | |
| c_a5 = float(ERFC_A5) | |
| # --- Verlet velocity half-step: vel += 0.5*dt*force (f32) --- | |
| @ttl.kernel(grid="auto") | |
| def vel_half_kernel(vel_x, vel_y, vel_z, fx, fy, fz): | |
| grid_cols, _ = ttl.grid_size(dims=2) | |
| nc = vel_x.shape[0] // TILE | |
| cpc = -(-nc // grid_cols) | |
| v_cb = ttl.make_dataflow_buffer_like(vel_x, shape=(1, 1), buffer_factor=2) | |
| f_cb = ttl.make_dataflow_buffer_like(fx, shape=(1, 1), buffer_factor=2) | |
| o_cb = ttl.make_dataflow_buffer_like(vel_x, shape=(1, 1), buffer_factor=2) | |
| @ttl.compute() | |
| def compute(): | |
| core_x, _ = ttl.core(dims=2) | |
| for lc in range(cpc): | |
| cid = core_x * cpc + lc | |
| if cid < nc: | |
| for _ in range(3): | |
| with v_cb.wait() as v, f_cb.wait() as f, o_cb.reserve() as o: | |
| o.store(v + f * ttl.math.fill(v, c_dt_half)) | |
| @ttl.datamovement() | |
| def dm_read(): | |
| core_x, _ = ttl.core(dims=2) | |
| for lc in range(cpc): | |
| cid = core_x * cpc + lc | |
| if cid < nc: | |
| with v_cb.reserve() as blk: | |
| tx = ttl.copy(vel_x[cid, 0], blk); tx.wait() | |
| with f_cb.reserve() as blk: | |
| tx = ttl.copy(fx[cid, 0], blk); tx.wait() | |
| with v_cb.reserve() as blk: | |
| tx = ttl.copy(vel_y[cid, 0], blk); tx.wait() | |
| with f_cb.reserve() as blk: | |
| tx = ttl.copy(fy[cid, 0], blk); tx.wait() | |
| with v_cb.reserve() as blk: | |
| tx = ttl.copy(vel_z[cid, 0], blk); tx.wait() | |
| with f_cb.reserve() as blk: | |
| tx = ttl.copy(fz[cid, 0], blk); tx.wait() | |
| @ttl.datamovement() | |
| def dm_write(): | |
| core_x, _ = ttl.core(dims=2) | |
| for lc in range(cpc): | |
| cid = core_x * cpc + lc | |
| if cid < nc: | |
| with o_cb.wait() as blk: | |
| tx = ttl.copy(blk, vel_x[cid, 0]); tx.wait() | |
| with o_cb.wait() as blk: | |
| tx = ttl.copy(blk, vel_y[cid, 0]); tx.wait() | |
| with o_cb.wait() as blk: | |
| tx = ttl.copy(blk, vel_z[cid, 0]); tx.wait() | |
| # --- Verlet position update: pos += dt*vel, PBC wrap (f32) --- | |
| @ttl.kernel(grid="auto") | |
| def pos_update_kernel(pos_x, pos_y, pos_z, vel_x, vel_y, vel_z): | |
| grid_cols, _ = ttl.grid_size(dims=2) | |
| nc = pos_x.shape[0] // TILE | |
| cpc = -(-nc // grid_cols) | |
| p_cb = ttl.make_dataflow_buffer_like(pos_x, shape=(1, 1), buffer_factor=2) | |
| v_cb = ttl.make_dataflow_buffer_like(vel_x, shape=(1, 1), buffer_factor=2) | |
| o_cb = ttl.make_dataflow_buffer_like(pos_x, shape=(1, 1), buffer_factor=2) | |
| @ttl.compute() | |
| def compute(): | |
| core_x, _ = ttl.core(dims=2) | |
| for lc in range(cpc): | |
| cid = core_x * cpc + lc | |
| if cid < nc: | |
| for _ in range(3): | |
| with p_cb.wait() as p, v_cb.wait() as v, o_cb.reserve() as o: | |
| new_p = p + v * ttl.math.fill(p, c_dt) | |
| o.store(new_p - ttl.math.fill(p, c_box) * ttl.math.floor(new_p * ttl.math.fill(p, c_inv_box))) | |
| @ttl.datamovement() | |
| def dm_read(): | |
| core_x, _ = ttl.core(dims=2) | |
| for lc in range(cpc): | |
| cid = core_x * cpc + lc | |
| if cid < nc: | |
| with p_cb.reserve() as blk: | |
| tx = ttl.copy(pos_x[cid, 0], blk); tx.wait() | |
| with v_cb.reserve() as blk: | |
| tx = ttl.copy(vel_x[cid, 0], blk); tx.wait() | |
| with p_cb.reserve() as blk: | |
| tx = ttl.copy(pos_y[cid, 0], blk); tx.wait() | |
| with v_cb.reserve() as blk: | |
| tx = ttl.copy(vel_y[cid, 0], blk); tx.wait() | |
| with p_cb.reserve() as blk: | |
| tx = ttl.copy(pos_z[cid, 0], blk); tx.wait() | |
| with v_cb.reserve() as blk: | |
| tx = ttl.copy(vel_z[cid, 0], blk); tx.wait() | |
| @ttl.datamovement() | |
| def dm_write(): | |
| core_x, _ = ttl.core(dims=2) | |
| for lc in range(cpc): | |
| cid = core_x * cpc + lc | |
| if cid < nc: | |
| with o_cb.wait() as blk: | |
| tx = ttl.copy(blk, pos_x[cid, 0]); tx.wait() | |
| with o_cb.wait() as blk: | |
| tx = ttl.copy(blk, pos_y[cid, 0]); tx.wait() | |
| with o_cb.wait() as blk: | |
| tx = ttl.copy(blk, pos_z[cid, 0]); tx.wait() | |
| # --- Force kernel: fused geometry + LJ + erfc Coulomb (bf16, 28 CBs) --- | |
| def make_force_kernel(c_n_dim, c_dim2): | |
| @ttl.kernel(grid="auto") | |
| def cell_forces_kernel(own_px, own_py, own_pz, own_q, | |
| self_mask, scaler, | |
| fx_out, fy_out, fz_out): | |
| grid_cols, _ = ttl.grid_size(dims=2) | |
| n_cells = own_px.shape[0] // TILE | |
| cells_per_core = -(-n_cells // grid_cols) | |
| ox_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2) | |
| oy_cb = ttl.make_dataflow_buffer_like(own_py, shape=(1, 1), buffer_factor=2) | |
| oz_cb = ttl.make_dataflow_buffer_like(own_pz, shape=(1, 1), buffer_factor=2) | |
| oq_cb = ttl.make_dataflow_buffer_like(own_q, shape=(1, 1), buffer_factor=2) | |
| ex_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2) | |
| ey_cb = ttl.make_dataflow_buffer_like(own_py, shape=(1, 1), buffer_factor=2) | |
| ez_cb = ttl.make_dataflow_buffer_like(own_pz, shape=(1, 1), buffer_factor=2) | |
| eq_cb = ttl.make_dataflow_buffer_like(own_q, shape=(1, 1), buffer_factor=2) | |
| sm_cb = ttl.make_dataflow_buffer_like(self_mask, shape=(1, 1), buffer_factor=2) | |
| sc_cb = ttl.make_dataflow_buffer_like(scaler, shape=(1, 1), buffer_factor=2) | |
| ba_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2) | |
| tr_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2) | |
| bb_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2) | |
| r2_tmp = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2) | |
| r2_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2) | |
| qq_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2) | |
| dx_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2) | |
| dy_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2) | |
| dz_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2) | |
| fm_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2) | |
| ft_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2) | |
| fr_cb = ttl.make_dataflow_buffer_like(own_px, shape=(1, 1), buffer_factor=2) | |
| ax_cb = ttl.make_dataflow_buffer_like(fx_out, shape=(1, 1), buffer_factor=2) | |
| ay_cb = ttl.make_dataflow_buffer_like(fy_out, shape=(1, 1), buffer_factor=2) | |
| az_cb = ttl.make_dataflow_buffer_like(fz_out, shape=(1, 1), buffer_factor=2) | |
| fxo_cb = ttl.make_dataflow_buffer_like(fx_out, shape=(1, 1), buffer_factor=2) | |
| fyo_cb = ttl.make_dataflow_buffer_like(fy_out, shape=(1, 1), buffer_factor=2) | |
| fzo_cb = ttl.make_dataflow_buffer_like(fz_out, shape=(1, 1), buffer_factor=2) | |
| @ttl.compute() | |
| def compute(): | |
| core_x, _ = ttl.core(dims=2) | |
| for local_c in range(cells_per_core): | |
| cell_id = core_x * cells_per_core + local_c | |
| if cell_id < n_cells: | |
| with ox_cb.wait() as ox, oy_cb.wait() as oy, oz_cb.wait() as oz, oq_cb.wait() as oq: | |
| with sc_cb.wait() as sc: | |
| for nbr_i in range(N_NBR): | |
| with ex_cb.wait() as ex, ey_cb.wait() as ey, ez_cb.wait() as ez, eq_cb.wait() as eq, sm_cb.wait() as sm: | |
| # PBC displacements for x, y, z | |
| with ba_cb.reserve() as ba: | |
| ba.store(ttl.math.broadcast(ox, dims=[1])) | |
| with tr_cb.reserve() as tr: | |
| tr.store(ttl.transpose(ex)) | |
| with tr_cb.wait() as trv, bb_cb.reserve() as bb: | |
| bb.store(ttl.math.broadcast(trv, dims=[0])) | |
| with ba_cb.wait() as bav, bb_cb.wait() as bbv: | |
| dx_raw = bav - bbv | |
| dx_pbc = dx_raw - ttl.math.fill(bav, c_box) * ttl.math.floor(dx_raw * ttl.math.fill(bav, c_inv_box) + ttl.math.fill(bav, c_half)) | |
| with r2_tmp.reserve() as r2o: | |
| r2o.store(dx_pbc * dx_pbc) | |
| with dx_cb.reserve() as dxo: | |
| dxo.store(dx_pbc) | |
| with ba_cb.reserve() as ba: | |
| ba.store(ttl.math.broadcast(oy, dims=[1])) | |
| with tr_cb.reserve() as tr: | |
| tr.store(ttl.transpose(ey)) | |
| with tr_cb.wait() as trv, bb_cb.reserve() as bb: | |
| bb.store(ttl.math.broadcast(trv, dims=[0])) | |
| with ba_cb.wait() as bav, bb_cb.wait() as bbv, r2_tmp.wait() as r2p: | |
| dy_raw = bav - bbv | |
| dy_pbc = dy_raw - ttl.math.fill(bav, c_box) * ttl.math.floor(dy_raw * ttl.math.fill(bav, c_inv_box) + ttl.math.fill(bav, c_half)) | |
| with r2_tmp.reserve() as r2o: | |
| r2o.store(r2p + dy_pbc * dy_pbc) | |
| with dy_cb.reserve() as dyo: | |
| dyo.store(dy_pbc) | |
| with ba_cb.reserve() as ba: | |
| ba.store(ttl.math.broadcast(oz, dims=[1])) | |
| with tr_cb.reserve() as tr: | |
| tr.store(ttl.transpose(ez)) | |
| with tr_cb.wait() as trv, bb_cb.reserve() as bb: | |
| bb.store(ttl.math.broadcast(trv, dims=[0])) | |
| with ba_cb.wait() as bav, bb_cb.wait() as bbv, r2_tmp.wait() as r2p: | |
| dz_raw = bav - bbv | |
| dz_pbc = dz_raw - ttl.math.fill(bav, c_box) * ttl.math.floor(dz_raw * ttl.math.fill(bav, c_inv_box) + ttl.math.fill(bav, c_half)) | |
| with r2_cb.reserve() as r2o: | |
| r2o.store(r2p + dz_pbc * dz_pbc + sm) | |
| with dz_cb.reserve() as dzo: | |
| dzo.store(dz_pbc) | |
| # Charge products | |
| with ba_cb.reserve() as ba: | |
| ba.store(ttl.math.broadcast(oq, dims=[1])) | |
| with tr_cb.reserve() as tr: | |
| tr.store(ttl.transpose(eq)) | |
| with tr_cb.wait() as trv, bb_cb.reserve() as bb: | |
| bb.store(ttl.math.broadcast(trv, dims=[0])) | |
| with ba_cb.wait() as bav, bb_cb.wait() as bbv, qq_cb.reserve() as qqo: | |
| qqo.store(bav * bbv) | |
| # erfc-damped Coulomb + LJ forces | |
| with r2_cb.wait() as r2, qq_cb.wait() as qq: | |
| r_inv = ttl.math.rsqrt(r2) | |
| r2_inv = ttl.math.recip(r2) | |
| r_val = r2 * r_inv | |
| t = ttl.math.recip(r_inv * r_inv * r2 + ttl.math.fill(r2, c_p_alpha) * r_val) | |
| poly = t * (ttl.math.fill(r2, c_a1) + t * (ttl.math.neg(ttl.math.fill(r2, c_a2)) + t * (ttl.math.fill(r2, c_a3) + t * (ttl.math.neg(ttl.math.fill(r2, c_a4)) + t * ttl.math.fill(r2, c_a5))))) | |
| exp_neg = ttl.math.exp(ttl.math.neg(ttl.math.fill(r2, c_alpha_sq) * r2)) | |
| erfc_val = poly * exp_neg | |
| with ft_cb.reserve() as coul: | |
| coul.store(qq * (erfc_val * r2_inv + ttl.math.fill(r2, c_two_a_sp) * exp_neg * r_inv) * r_inv) | |
| r2_inv2 = ttl.math.recip(r2) | |
| r4_inv = r2_inv2 * r2_inv2 | |
| r6_inv = r4_inv * r2_inv2 | |
| r12_inv = r6_inv * r6_inv | |
| with fr_cb.reserve() as lj: | |
| lj.store(ttl.math.fill(r2, c_lj_scale) * r2_inv2 * (r12_inv + r12_inv - r6_inv)) | |
| with ft_cb.wait() as fc, fr_cb.wait() as fl: | |
| with fm_cb.reserve() as fmo: | |
| fmo.store(fl + fc) | |
| # Project onto displacements, reduce, accumulate | |
| with fm_cb.wait() as fm: | |
| with dx_cb.wait() as dxv: | |
| with ft_cb.reserve() as ft: | |
| ft.store(fm * dxv) | |
| with ft_cb.wait() as ftv, fr_cb.reserve() as fr: | |
| fr.store(ttl.math.reduce_sum(ftv, sc, dims=[0])) | |
| if nbr_i == 0: | |
| with fr_cb.wait() as frv, ax_cb.reserve() as ax: | |
| ax.store(frv) | |
| else: | |
| with fr_cb.wait() as frv, ax_cb.wait() as prev: | |
| with ax_cb.reserve() as ax: | |
| ax.store(prev + frv) | |
| with dy_cb.wait() as dyv: | |
| with ft_cb.reserve() as ft: | |
| ft.store(fm * dyv) | |
| with ft_cb.wait() as ftv, fr_cb.reserve() as fr: | |
| fr.store(ttl.math.reduce_sum(ftv, sc, dims=[0])) | |
| if nbr_i == 0: | |
| with fr_cb.wait() as frv, ay_cb.reserve() as ay: | |
| ay.store(frv) | |
| else: | |
| with fr_cb.wait() as frv, ay_cb.wait() as prev: | |
| with ay_cb.reserve() as ay: | |
| ay.store(prev + frv) | |
| with dz_cb.wait() as dzv: | |
| with ft_cb.reserve() as ft: | |
| ft.store(fm * dzv) | |
| with ft_cb.wait() as ftv, fr_cb.reserve() as fr: | |
| fr.store(ttl.math.reduce_sum(ftv, sc, dims=[0])) | |
| if nbr_i == 0: | |
| with fr_cb.wait() as frv, az_cb.reserve() as az: | |
| az.store(frv) | |
| else: | |
| with fr_cb.wait() as frv, az_cb.wait() as prev: | |
| with az_cb.reserve() as az: | |
| az.store(prev + frv) | |
| with ax_cb.wait() as fx, fxo_cb.reserve() as fxo: | |
| fxo.store(fx) | |
| with ay_cb.wait() as fy, fyo_cb.reserve() as fyo: | |
| fyo.store(fy) | |
| with az_cb.wait() as fz, fzo_cb.reserve() as fzo: | |
| fzo.store(fz) | |
| @ttl.datamovement() | |
| def dm_read(): | |
| core_x, _ = ttl.core(dims=2) | |
| for local_c in range(cells_per_core): | |
| cell_id = core_x * cells_per_core + local_c | |
| if cell_id < n_cells: | |
| with ox_cb.reserve() as blk: | |
| tx = ttl.copy(own_px[cell_id, 0], blk); tx.wait() | |
| with oy_cb.reserve() as blk: | |
| tx = ttl.copy(own_py[cell_id, 0], blk); tx.wait() | |
| with oz_cb.reserve() as blk: | |
| tx = ttl.copy(own_pz[cell_id, 0], blk); tx.wait() | |
| with oq_cb.reserve() as blk: | |
| tx = ttl.copy(own_q[cell_id, 0], blk); tx.wait() | |
| with sc_cb.reserve() as blk: | |
| tx = ttl.copy(scaler[0, 0], blk); tx.wait() | |
| # Compute neighbor cell IDs on-device | |
| cx = cell_id // c_dim2 | |
| cy = (cell_id // c_n_dim) % c_n_dim | |
| cz = cell_id % c_n_dim | |
| for nbr in range(N_NBR): | |
| off_dx = (nbr // 9) - 1 | |
| off_dy = ((nbr // 3) % 3) - 1 | |
| off_dz = (nbr % 3) - 1 | |
| nbr_cell = ((cx + off_dx + c_n_dim) % c_n_dim) * c_dim2 + ((cy + off_dy + c_n_dim) % c_n_dim) * c_n_dim + ((cz + off_dz + c_n_dim) % c_n_dim) | |
| with ex_cb.reserve() as blk: | |
| tx = ttl.copy(own_px[nbr_cell, 0], blk); tx.wait() | |
| with ey_cb.reserve() as blk: | |
| tx = ttl.copy(own_py[nbr_cell, 0], blk); tx.wait() | |
| with ez_cb.reserve() as blk: | |
| tx = ttl.copy(own_pz[nbr_cell, 0], blk); tx.wait() | |
| with eq_cb.reserve() as blk: | |
| tx = ttl.copy(own_q[nbr_cell, 0], blk); tx.wait() | |
| with sm_cb.reserve() as blk: | |
| tx = ttl.copy(self_mask[cell_id * N_NBR + nbr, 0], blk); tx.wait() | |
| @ttl.datamovement() | |
| def dm_write(): | |
| core_x, _ = ttl.core(dims=2) | |
| for local_c in range(cells_per_core): | |
| cell_id = core_x * cells_per_core + local_c | |
| if cell_id < n_cells: | |
| with fxo_cb.wait() as blk: | |
| tx = ttl.copy(blk, fx_out[cell_id, 0]); tx.wait() | |
| with fyo_cb.wait() as blk: | |
| tx = ttl.copy(blk, fy_out[cell_id, 0]); tx.wait() | |
| with fzo_cb.wait() as blk: | |
| tx = ttl.copy(blk, fz_out[cell_id, 0]); tx.wait() | |
| return cell_forces_kernel | |
| # --- Tensor helpers --- | |
| def to_bf16(arr, l1=False): | |
| return ttnn.from_torch( | |
| torch.tensor(arr, dtype=torch.bfloat16), | |
| dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, | |
| device=device, | |
| memory_config=ttnn.L1_MEMORY_CONFIG if l1 else ttnn.DRAM_MEMORY_CONFIG) | |
| def to_f32(arr, l1=False): | |
| return ttnn.from_torch( | |
| torch.tensor(arr, dtype=torch.float32), | |
| dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, | |
| device=device, | |
| memory_config=ttnn.L1_MEMORY_CONFIG if l1 else ttnn.DRAM_MEMORY_CONFIG) | |
| def run_force_kernel(tt_px, tt_py, tt_pz, tt_q, tt_masks, tt_scaler, | |
| tt_fx, tt_fy, tt_fz): | |
| """Typecast f32 <-> bf16 around the bf16 force kernel.""" | |
| px_bf16 = ttnn.typecast(tt_px, ttnn.bfloat16) | |
| py_bf16 = ttnn.typecast(tt_py, ttnn.bfloat16) | |
| pz_bf16 = ttnn.typecast(tt_pz, ttnn.bfloat16) | |
| fx_bf16 = ttnn.typecast(tt_fx, ttnn.bfloat16) | |
| fy_bf16 = ttnn.typecast(tt_fy, ttnn.bfloat16) | |
| fz_bf16 = ttnn.typecast(tt_fz, ttnn.bfloat16) | |
| cell_forces_kernel(px_bf16, py_bf16, pz_bf16, tt_q, | |
| tt_masks, tt_scaler, fx_bf16, fy_bf16, fz_bf16) | |
| return (ttnn.typecast(fx_bf16, ttnn.float32), | |
| ttnn.typecast(fy_bf16, ttnn.float32), | |
| ttnn.typecast(fz_bf16, ttnn.float32)) | |
| # --- Build initial state --- | |
| (own_px, own_py, own_pz, own_q, | |
| masks, cell_atom_map, n_cells_total, n_cells_dim) = \ | |
| build_cell_data(positions, charges, box_length, r_cut) | |
| c_n_dim = int(n_cells_dim) | |
| c_dim2 = c_n_dim * c_n_dim | |
| vel_x, vel_y, vel_z = pack_cell_layout(velocities, cell_atom_map, n_cells_total) | |
| # Reciprocal forces (recomputed on each rebuild, held constant between) | |
| f_recip = compute_reciprocal_forces(device, positions, charges, box_length, alpha) | |
| recip_x, recip_y, recip_z = pack_cell_layout(f_recip, cell_atom_map, n_cells_total) | |
| # L1 for small systems, DRAM for large | |
| cell_bytes_f32 = n_cells_total * TILE * TILE * 4 | |
| use_l1 = (cell_bytes_f32 * 10 < 80_000_000) | |
| tt_px = to_f32(own_px, l1=use_l1) | |
| tt_py = to_f32(own_py, l1=use_l1) | |
| tt_pz = to_f32(own_pz, l1=use_l1) | |
| tt_q = to_bf16(own_q, l1=use_l1) | |
| tt_vx = to_f32(vel_x, l1=use_l1) | |
| tt_vy = to_f32(vel_y, l1=use_l1) | |
| tt_vz = to_f32(vel_z, l1=use_l1) | |
| tt_masks = to_bf16(masks) | |
| tt_scaler = to_bf16(np.ones((TILE, TILE), dtype=np.float32), l1=True) | |
| zeros = np.zeros((n_cells_total * TILE, TILE), dtype=np.float32) | |
| tt_fx = to_f32(zeros, l1=use_l1) | |
| tt_fy = to_f32(zeros.copy(), l1=use_l1) | |
| tt_fz = to_f32(zeros.copy(), l1=use_l1) | |
| # Reciprocal forces stored on device (f32, added after real-space) | |
| tt_rx = to_f32(recip_x, l1=use_l1) | |
| tt_ry = to_f32(recip_y, l1=use_l1) | |
| tt_rz = to_f32(recip_z, l1=use_l1) | |
| cell_forces_kernel = make_force_kernel(c_n_dim, c_dim2) | |
| # Initial forces (real-space + reciprocal) | |
| tt_fx, tt_fy, tt_fz = run_force_kernel( | |
| tt_px, tt_py, tt_pz, tt_q, tt_masks, tt_scaler, tt_fx, tt_fy, tt_fz) | |
| tt_fx = ttnn.add(tt_fx, tt_rx) | |
| tt_fy = ttnn.add(tt_fy, tt_ry) | |
| tt_fz = ttnn.add(tt_fz, tt_rz) | |
| # --- Step loop --- | |
| step_times = [] | |
| for step in range(n_steps): | |
| t0 = time.time() | |
| if (step > 0) and (step % rebuild_every == 0): | |
| px_np = ttnn.to_torch(tt_px).float().numpy() | |
| py_np = ttnn.to_torch(tt_py).float().numpy() | |
| pz_np = ttnn.to_torch(tt_pz).float().numpy() | |
| vx_np = ttnn.to_torch(tt_vx).float().numpy() | |
| vy_np = ttnn.to_torch(tt_vy).float().numpy() | |
| vz_np = ttnn.to_torch(tt_vz).float().numpy() | |
| fx_np = ttnn.to_torch(tt_fx).float().numpy() | |
| fy_np = ttnn.to_torch(tt_fy).float().numpy() | |
| fz_np = ttnn.to_torch(tt_fz).float().numpy() | |
| positions = extract_cell_data(px_np, py_np, pz_np, cell_atom_map, n) | |
| velocities = extract_cell_data(vx_np, vy_np, vz_np, cell_atom_map, n) | |
| forces_atom = extract_cell_data(fx_np, fy_np, fz_np, cell_atom_map, n) | |
| (own_px, own_py, own_pz, own_q, | |
| masks, cell_atom_map, n_cells_total, n_cells_dim) = \ | |
| build_cell_data(positions, charges, box_length, r_cut) | |
| c_n_dim_new = int(n_cells_dim) | |
| if c_n_dim_new != c_n_dim: | |
| c_n_dim = c_n_dim_new | |
| c_dim2 = c_n_dim * c_n_dim | |
| cell_forces_kernel = make_force_kernel(c_n_dim, c_dim2) | |
| vel_x, vel_y, vel_z = pack_cell_layout(velocities, cell_atom_map, n_cells_total) | |
| f_x, f_y, f_z = pack_cell_layout(forces_atom, cell_atom_map, n_cells_total) | |
| # Recompute reciprocal forces at new positions | |
| f_recip = compute_reciprocal_forces(device, positions, charges, box_length, alpha) | |
| recip_x, recip_y, recip_z = pack_cell_layout(f_recip, cell_atom_map, n_cells_total) | |
| tt_px = to_f32(own_px, l1=use_l1) | |
| tt_py = to_f32(own_py, l1=use_l1) | |
| tt_pz = to_f32(own_pz, l1=use_l1) | |
| tt_q = to_bf16(own_q, l1=use_l1) | |
| tt_vx = to_f32(vel_x, l1=use_l1) | |
| tt_vy = to_f32(vel_y, l1=use_l1) | |
| tt_vz = to_f32(vel_z, l1=use_l1) | |
| tt_masks = to_bf16(masks) | |
| tt_fx = to_f32(f_x, l1=use_l1) | |
| tt_fy = to_f32(f_y, l1=use_l1) | |
| tt_fz = to_f32(f_z, l1=use_l1) | |
| tt_rx = to_f32(recip_x, l1=use_l1) | |
| tt_ry = to_f32(recip_y, l1=use_l1) | |
| tt_rz = to_f32(recip_z, l1=use_l1) | |
| vel_half_kernel(tt_vx, tt_vy, tt_vz, tt_fx, tt_fy, tt_fz) | |
| pos_update_kernel(tt_px, tt_py, tt_pz, tt_vx, tt_vy, tt_vz) | |
| # Real-space forces (bf16 kernel) + reciprocal correction (on-device add) | |
| tt_fx, tt_fy, tt_fz = run_force_kernel( | |
| tt_px, tt_py, tt_pz, tt_q, tt_masks, tt_scaler, tt_fx, tt_fy, tt_fz) | |
| tt_fx = ttnn.add(tt_fx, tt_rx) | |
| tt_fy = ttnn.add(tt_fy, tt_ry) | |
| tt_fz = ttnn.add(tt_fz, tt_rz) | |
| vel_half_kernel(tt_vx, tt_vy, tt_vz, tt_fx, tt_fy, tt_fz) | |
| step_times.append(time.time() - t0) | |
| # Read final state | |
| px_np = ttnn.to_torch(tt_px).float().numpy() | |
| py_np = ttnn.to_torch(tt_py).float().numpy() | |
| pz_np = ttnn.to_torch(tt_pz).float().numpy() | |
| vx_np = ttnn.to_torch(tt_vx).float().numpy() | |
| vy_np = ttnn.to_torch(tt_vy).float().numpy() | |
| vz_np = ttnn.to_torch(tt_vz).float().numpy() | |
| final_pos = extract_cell_data(px_np, py_np, pz_np, cell_atom_map, n) | |
| final_vel = extract_cell_data(vx_np, vy_np, vz_np, cell_atom_map, n) | |
| return final_pos, final_vel, step_times | |
| # ---- Main: validation + sample run ---- | |
| if __name__ == "__main__": | |
| device = ttnn.open_device(device_id=0) | |
| # 1. Energy conservation (256 atoms, 10 steps) | |
| print("=" * 50) | |
| print("ENERGY CONSERVATION (256 atoms, 10 steps)") | |
| print("=" * 50) | |
| positions, charges, box_length = make_system(256) | |
| velocities = np.random.randn(256, 3) * 0.1 | |
| velocities -= velocities.mean(axis=0) | |
| ke0 = 0.5 * np.sum(velocities ** 2) | |
| pe0 = compute_energy(positions, charges, box_length) | |
| e0 = ke0 + pe0 | |
| final_pos, final_vel, step_times = run_md_loop( | |
| device, positions.copy(), velocities.copy(), charges, box_length, | |
| n_steps=10, dt=0.005, rebuild_every=10, alpha=ALPHA) | |
| ke = 0.5 * np.sum(final_vel ** 2) | |
| pe = compute_energy(final_pos, charges, box_length) | |
| drift = abs((ke + pe) - e0) / abs(e0) * 100 | |
| print(f" E0={e0:.2f} E_final={ke+pe:.2f} drift={drift:.4f}%") | |
| print(f" Step times: {[f'{t:.3f}s' for t in step_times]}") | |
| print(f" {'PASS' if drift < 5.0 else 'FAIL'}") | |
| # 2. Performance at scale (10K atoms, 100 steps, no rebuilds) | |
| print("\n" + "=" * 50) | |
| print("PERFORMANCE (10K atoms, 100 steps)") | |
| print("=" * 50) | |
| positions, charges, box_length = make_system(10000) | |
| velocities = np.random.randn(10000, 3) * 0.1 | |
| velocities -= velocities.mean(axis=0) | |
| t0 = time.time() | |
| final_pos, final_vel, step_times = run_md_loop( | |
| device, positions.copy(), velocities.copy(), charges, box_length, | |
| n_steps=100, dt=0.005, rebuild_every=100, alpha=ALPHA) | |
| t_total = time.time() - t0 | |
| avg_no_first = np.mean(step_times[1:]) | |
| print(f" Total: {t_total:.1f}s") | |
| print(f" First step (incl compilation): {step_times[0]:.3f}s") | |
| print(f" Avg step (excl first): {avg_no_first*1000:.2f}ms") | |
| print(f" Last 5: {[f'{t:.3f}' for t in step_times[-5:]]}") | |
| ttnn.close_device(device) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment