Created
February 27, 2026 21:28
-
-
Save pengzhangzhi/a99340534b0e4293156a1d2b37ba0016 to your computer and use it in GitHub Desktop.
PAPL Loss Pytorch Implementation Demo
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def papl_loss( | |
| logits: torch.Tensor, | |
| x0: torch.Tensor, | |
| mask: torch.Tensor, | |
| alpha: float, | |
| tau: float, | |
| eps: float = 1e-8, | |
| ) -> Tuple[torch.Tensor, Dict[str, float]]: | |
| """Compute the PAPL objective from Eq. (7).""" | |
| log_probs = F.log_softmax(logits, dim=-1) # [B, L, V] | |
| # Log-probability of the correct clean token at each position. | |
| target_log_probs = log_probs.gather(dim=-1, index=x0.unsqueeze(-1)).squeeze(-1) # [B, L] | |
| target_nll = -target_log_probs | |
| # Only masked positions contribute to the loss. | |
| masked_nll = target_nll * mask | |
| # Detached planner: w_i ∝ exp(log p(correct token)/tau), normalized over masked positions. | |
| detached_scores = (target_log_probs.detach() / tau).masked_fill(~mask, float("-inf")) | |
| planner_weights = F.softmax(detached_scores, dim=-1) | |
| planner_weights = torch.where(mask, planner_weights, torch.zeros_like(planner_weights)) | |
| # Per-example number of masked positions, matching 1 / (L-k) in the paper. | |
| n_masked = mask.sum(dim=-1).clamp_min(1).float() # [B] | |
| base_weight = (1.0 / n_masked).unsqueeze(-1) # [B, 1] | |
| weights = base_weight * (1.0 + alpha * planner_weights) | |
| loss_per_example = (weights * masked_nll).sum(dim=-1) | |
| loss = loss_per_example.mean() | |
| with torch.no_grad(): | |
| metrics = { | |
| "loss": float(loss.item()), | |
| "avg_n_masked": float(n_masked.mean().item()), | |
| "avg_planner_entropy": float( | |
| (-(planner_weights.clamp_min(eps) * planner_weights.clamp_min(eps).log()).sum(dim=-1)).mean().item() | |
| ), | |
| "avg_correct_prob_on_masked": float(target_log_probs.exp()[mask].mean().item()) if mask.any() else 0.0, | |
| } | |
| return loss, metrics |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The code generation experiment is implemented in https://github.com/pengzhangzhi/Open-dLLM