Skip to content

Instantly share code, notes, and snippets.

@pengzhangzhi
Created February 27, 2026 21:28
Show Gist options
  • Select an option

  • Save pengzhangzhi/a99340534b0e4293156a1d2b37ba0016 to your computer and use it in GitHub Desktop.

Select an option

Save pengzhangzhi/a99340534b0e4293156a1d2b37ba0016 to your computer and use it in GitHub Desktop.
PAPL Loss Pytorch Implementation Demo
def papl_loss(
logits: torch.Tensor,
x0: torch.Tensor,
mask: torch.Tensor,
alpha: float,
tau: float,
eps: float = 1e-8,
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""Compute the PAPL objective from Eq. (7)."""
log_probs = F.log_softmax(logits, dim=-1) # [B, L, V]
# Log-probability of the correct clean token at each position.
target_log_probs = log_probs.gather(dim=-1, index=x0.unsqueeze(-1)).squeeze(-1) # [B, L]
target_nll = -target_log_probs
# Only masked positions contribute to the loss.
masked_nll = target_nll * mask
# Detached planner: w_i ∝ exp(log p(correct token)/tau), normalized over masked positions.
detached_scores = (target_log_probs.detach() / tau).masked_fill(~mask, float("-inf"))
planner_weights = F.softmax(detached_scores, dim=-1)
planner_weights = torch.where(mask, planner_weights, torch.zeros_like(planner_weights))
# Per-example number of masked positions, matching 1 / (L-k) in the paper.
n_masked = mask.sum(dim=-1).clamp_min(1).float() # [B]
base_weight = (1.0 / n_masked).unsqueeze(-1) # [B, 1]
weights = base_weight * (1.0 + alpha * planner_weights)
loss_per_example = (weights * masked_nll).sum(dim=-1)
loss = loss_per_example.mean()
with torch.no_grad():
metrics = {
"loss": float(loss.item()),
"avg_n_masked": float(n_masked.mean().item()),
"avg_planner_entropy": float(
(-(planner_weights.clamp_min(eps) * planner_weights.clamp_min(eps).log()).sum(dim=-1)).mean().item()
),
"avg_correct_prob_on_masked": float(target_log_probs.exp()[mask].mean().item()) if mask.any() else 0.0,
}
return loss, metrics
@pengzhangzhi
Copy link
Author

The code generation experiment is implemented in https://github.com/pengzhangzhi/Open-dLLM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment