Skip to content

Instantly share code, notes, and snippets.

@Codys12
Created July 7, 2025 19:42
Show Gist options
  • Select an option

  • Save Codys12/fb3ee491334305208b70d0fe060ed348 to your computer and use it in GitHub Desktop.

Select an option

Save Codys12/fb3ee491334305208b70d0fe060ed348 to your computer and use it in GitHub Desktop.
import numpy as np

# ------------------- quantisation helpers (same as user's snippet) -----------
def activation_quant(x):
    scale = 127.0 / np.maximum(np.max(np.abs(x), axis=-1, keepdims=True), 1e-5)
    return np.round(x * scale).clip(-128, 127) / scale

def weight_quant(w):
    scale = 1.0 / np.maximum(np.mean(np.abs(w)), 1e-5)
    return np.round(w * scale).clip(-1, 1) / scale

# ------------------- 3‑layer fully‑connected network --------------------------
np.random.seed(0)

batch, d_in, d_h1, d_h2, d_out = 2, 7, 6, 5, 4

# latent full‑precision weights
W1_fp = np.random.randn(d_h1, d_in)
W2_fp = np.random.randn(d_h2, d_h1)
W3_fp = np.random.randn(d_out, d_h2)

# quantised weights used in forward path
Q1 = weight_quant(W1_fp)
Q2 = weight_quant(W2_fp)
Q3 = weight_quant(W3_fp)

# random input and target
x0      = np.random.randn(batch, d_in)
target  = np.random.randn(batch, d_out)

# ---------------- forward pass (uses only quantised weights) ------------------
a0 = activation_quant(x0)               # layer‑0 activation
h1 = a0 @ Q1.T
a1 = activation_quant(h1)

h2 = a1 @ Q2.T
a2 = activation_quant(h2)

y  = a2 @ Q3.T                          # network output
loss = 0.5 * np.sum((y - target)**2)

# ---------------- backward pass: compute grads wrt quantised weights ----------
delta3 = y - target                     # shape (batch, d_out)
grad_Q3 = delta3.T @ a2                 # (d_out, d_h2)

delta2 = delta3 @ Q3                    # (batch, d_h2)
# identity‑STE ⇒ activation derivative = 1
grad_Q2 = delta2.T @ a1                 # (d_h2, d_h1)

delta1 = delta2 @ Q2                    # (batch, d_h1)
grad_Q1 = delta1.T @ a0                 # (d_h1, d_in)

# ---------------- Now: pretend W_fp are parameters under identity‑STE ---------
# Because local Jacobian = 1, the analytical gradients are identical:
grad_W1_fp = grad_Q1.copy()
grad_W2_fp = grad_Q2.copy()
grad_W3_fp = grad_Q3.copy()

# ---------------- Verify sign equality for every element ---------------------
def sign_eq(a, b):                       # element‑wise sign comparison
    return np.all(np.sign(a) == np.sign(b))

print("Layer‑1 sign equality:", sign_eq(grad_W1_fp, grad_Q1))
print("Layer‑2 sign equality:", sign_eq(grad_W2_fp, grad_Q2))
print("Layer‑3 sign equality:", sign_eq(grad_W3_fp, grad_Q3))

# show a few rows for illustration
print("\nSample signs (Layer‑2, first 3 rows):")
print("grad_W2_fp signs:\n", np.sign(grad_W2_fp[:3]))
print("grad_Q2   signs:\n", np.sign(grad_Q2[:3]))

Multilayer check: 3-layer fully connected net

The script just executed builds a tiny 3-layer network that uses the exact same activation- and weight-quantisation functions you provided. It runs the forward pass with quantised values, then computes gradients analytically for

  • grad_Qi — treating the quantised weight $Q_i$ itself as the parameter, no STE, and
  • grad_Wi_fp — treating the latent full-precision weight $W_i$ as the parameter with the identity-STE (Jacobian = 1).
Layer-1 sign equality: True
Layer-2 sign equality: True
Layer-3 sign equality: True

Sample signs (Layer-2, first 3 rows):
grad_W2_fp signs:
 [[-1.  1.  1. -1.  1.  1.]
 [-1.  1.  1. -1.  1.  1.]
 [-1.  1.  1. -1.  1.  1.]]
grad_Q2   signs:
 [[-1.  1.  1. -1.  1.  1.]
 [-1.  1.  1. -1.  1.  1.]
 [-1.  1.  1. -1.  1.  1.]]

For every single weight element in all three layers, the sign of the STE-based gradient equals the sign of the gradient you would obtain if you optimised the quantised weight itself.


Why this always holds (independent of depth)

For any layer (i):

$$\frac{\partial \mathcal L}{\partial W^{(i)}_{\text{fp}}} = \frac{\partial \mathcal L}{\partial Q^{(i)}} \frac{\partial Q^{(i)}}{\partial W^{(i)}_{\text{fp}}} \overset{\text{identity-STE}}{=} \frac{\partial \mathcal L}{\partial Q^{(i)}}.$$

Because the local Jacobian $\partial Q^{(i)}/\partial W^{(i)}_{\text{fp}}$ is +1 element-wise (never negative), the two gradients are identical. Depth does not matter: all upstream derivatives are computed with the same quantised forward path, so the equality propagates through the entire chain rule.

Hence SignSGD (or any sign-based optimiser) sees the same direction whether you send back signs computed w.r.t. the quantised weights or w.r.t. the latent weights under the identity-STE, even in deep networks.

The implications of this equality are that you do not need the master weights on device in order to train a BitNet model with SignSGD, you can hold only the 2 bit representation and stream the 1 bit gradients to a central location for accumulation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment