You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Simplifying the Contrastive Flow Matching Objective
We show that CFM (Contrastive flow matching) objective is fundamentally indifferent from FM because it's simply affine-transformation of the target velocity, which could be learned post-hoc.
Assumptions:
$\alpha_t = 1 - t, \sigma_t = t \Rightarrow v = -x_i + \epsilon_i, \tilde v = -\tilde x + \tilde\epsilon$
$(\tilde x,\tilde\epsilon)$ is drawn i.i.d. from the dataset, independent of $(x_i,\epsilon_i)$
Sampled noise has zero mean: $\mathbb{E}[\tilde\epsilon] = 0$
Let $\mu_x := \mathbb{E}[\tilde x]$ (empirical average $\bar x$ in practice)
Lemma (Averaging targets)
The minimizer of an expected squared error depends only on the mean of the random target. Concretely, for any random target $Y$ and any predictor $f$,
Notice that $\mu_x$, 'average of target distribution' is extremely simple to learn. Furthermore, it is common practice to 'center VAE', because trained VAE latents are not empirically $N(0, I_n)$.
Thus one can expect CFM to have no effect, apart from scaling velocity field.
If you are doubting this math, here is simple follow up example that the two objective is indeed completely identical.
i.e., CFM is essentially 'identical' to FM in the sense its augmenting $v \Rightarrow (v + \lambda \mu_x)/(1 - \lambda)$
importtorchimporttorch.nnasnnimporttorch.optimasoptimtorch.manual_seed(0)
d=5# dimension of vn=10000# number of sampleslam=0.3# dataset: x_i and eps_icenter=torch.randn(d)
x=torch.randn(n, d) *0.8+centereps=torch.randn(n, d)
# contrastive pool: x_tilde, eps_tilde. They are shifted by 1 step, which is same as independent sampling.# if you are paranoid, just resample but use larger n.x_tilde=x.roll(1, dims=0)
eps_tilde=eps.roll(1, dims=0)
# empirical averagesmu_x=x_tilde.mean(0) # dataset mean# eps mean is ~0 so we skip# -----------------------------# Targets# -----------------------------# Original v, v_tildev=-x+epsv_tilde=-x_tilde+eps_tilde# Expected target y_i (simplified, affine transformation of v)y= (v+lam*mu_x) / (1-lam)
# -----------------------------# Model parameter: a single learnable vector# -----------------------------param=nn.Parameter(torch.randn(d))
# -----------------------------# Optimizers# -----------------------------opt1=optim.SGD([param], lr=0.1)
# Train with simplified objectiveforstepinrange(300):
opt1.zero_grad()
loss= ((param-y)**2).mean() # simplified formloss.backward()
opt1.step()
print("Optimized parameter (simplified):", param.data)
# -----------------------------# Now check with original objective# -----------------------------param2=nn.Parameter(torch.randn(d))
opt2=optim.SGD([param2], lr=0.1)
forstepinrange(300):
opt2.zero_grad()
loss= ((param2-v)**2).mean() -lam* ((param2-v_tilde)**2).mean()
loss.backward()
opt2.step()
print("Optimized parameter (original): ", param2.data)
asserttorch.allclose(param2.data, param.data, atol=1e-1)
# Optimized parameter (simplified): tensor([-1.5306, 0.3296, 2.1793, -0.5896, 1.1025])# Optimized parameter (original): tensor([-1.5314, 0.3228, 2.1817, -0.5823, 1.0969])
If you are doubting this math, here is simple follow up example that the two objective is indeed completely identical.$v \Rightarrow (v + \lambda \mu_x)/(1 - \lambda)$
i.e., CFM is essentially 'identical' to FM in the sense its augmenting
It is indeed, completely identical