Created
November 18, 2025 14:51
-
-
Save tomMoral/fbb03cbb4cae163cc99d80d0eaa5edaa to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| # %% [markdown] | |
| # # Bilevel Optimization for Hyperparameter Tuning | |
| # | |
| # ## The Problem: Ridge Regression | |
| # | |
| # We consider the problem of fitting a linear model on noisy training data.abs | |
| # In order to prevent overfitting, we can use regularization to control | |
| # the complexity of the model. The strength of this regularization is | |
| # a hyperparameter that needs to be tuned on external validation data. | |
| # | |
| # * **Outer Problem (Variables: $\lambda$):** Minimize the loss on a | |
| # **validation set**. | |
| # $$ \min_\lambda \mathcal{L}_{\text{val}}(w^*(\lambda)) = | |
| # \| X_{\text{val}} w^*(\lambda) - y_{\text{val}} \|^2 $$ | |
| # | |
| # * **Inner Problem (Variables: $w$):** Find the optimal weights $w$ that | |
| # minimize the **regularized training loss**. | |
| # $$ w^*(\lambda) = \arg\min_w \mathcal{L}_{\text{train}}(w, \lambda) = | |
| # \| X_{\text{train}} w - y_{\text{train}} \|^2 + \lambda \|w\|^2 $$ | |
| # | |
| # %% | |
| # ## 1. Setup and Data Generation | |
| import time # To measure runtime | |
| import math | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| torch.manual_seed(27) | |
| # Define problem dimensions | |
| n_features = 25 | |
| n_train = 100 | |
| n_val = 50 | |
| # Generate a "true" weight vector | |
| w_true = torch.randn(n_features, 1) / math.sqrt(n_features) | |
| # Generate training data (X_train, y_train) | |
| X_train = torch.randn(n_train, n_features) | |
| y_train = X_train @ w_true + 2 * torch.randn(n_train, 1) # Add some noise | |
| # Generate validation data (X_val, y_val) | |
| X_val = torch.randn(n_val, n_features) | |
| y_val = X_val @ w_true + 0.3 * torch.randn(n_val, 1) # Add some noise | |
| print(f"X_train shape: {X_train.shape}") | |
| print(f"y_train shape: {y_train.shape}") | |
| print(f"X_val shape: {X_val.shape}") | |
| print(f"y_val shape: {y_val.shape}") | |
| # %% | |
| # ## 2. Define Loss Functions and Inner Solver | |
| def train_loss(w, lmbd): | |
| """Computes the inner objective (training loss).""" | |
| mse = ((X_train @ w - y_train) ** 2).mean() | |
| reg = lmbd / 2 * (w ** 2).sum() | |
| return mse + reg | |
| def val_loss(w): | |
| """Computes the outer objective (validation loss).""" | |
| return ((X_val @ w - y_val) ** 2).mean() | |
| # Pre-compute parts of the closed-form solution | |
| XTX = X_train.T @ X_train / n_train | |
| XTy = X_train.T @ y_train / n_train | |
| Id = torch.eye(n_features) | |
| def solve_inner_closed_form(lmbd): | |
| """Solves the inner problem exactly using the closed-form solution.""" | |
| A = XTX + lmbd / 2 * Id | |
| b = XTy | |
| w_star = torch.linalg.solve(A, b) | |
| return w_star | |
| def value_function(lmbd): | |
| """Computes the outer loss given lambda by solving the inner problem.""" | |
| with torch.no_grad(): | |
| w_star = solve_inner_closed_form(lmbd) | |
| return val_loss(w_star) | |
| # %% [markdown] | |
| # ## Visualizing the Hyperparameter Tuning Task | |
| # | |
| # Before diving into the optimization methods, let's understand | |
| # the problem visually. Since we have 10 features, we can't directly plot | |
| # the functions. However, we can project the data onto the true weights | |
| # direction to get an intuition of how the model fits. | |
| # | |
| # Our goal is to find a $\lambda$ that makes the Ridge Regression model | |
| # generalize well to unseen (validation) data. | |
| # | |
| # We will show how different $\lambda$ values affect the model's fit: | |
| # * **Small $\lambda$ (e.g., 0.1):** Less regularization, potentially leading | |
| # to overfitting to the training data. | |
| # * **Large $\lambda$ (e.g., 10):** More regularization, potentially leading | |
| # to underfitting. | |
| # * **Optimal $\lambda$:** A balance that performs well on both training | |
| # and validation data. | |
| # %% | |
| # ## 0. Visualizing the Hyperparameter Tuning Task: Implementation | |
| # Define a range for the components in the direction of the model | |
| x_plot = torch.linspace( | |
| (X_train @ w_true).min(), (X_train @ w_true).max(), 4 | |
| ).reshape(-1, 1) | |
| # Create dummy X values for prediction with all features as 0 except the first | |
| X_plot = x_plot @ w_true.T / torch.linalg.vector_norm(w_true) | |
| # X_plot = x_plot.squeeze() | |
| plt.figure(figsize=(12, 7)) | |
| # Scatter plot training and validation data (projected on first feature) | |
| plt.scatter(X_train @ w_true, y_train, label="Training Data", alpha=0.6) | |
| plt.scatter( | |
| X_val @ w_true, y_val, label="Validation Data", alpha=0.6, marker='x' | |
| ) | |
| # Plot fits for different lambda values | |
| for lambda_val in [0.1, 1, 10]: | |
| # Solve inner problem for this lambda | |
| w_fit = solve_inner_closed_form(torch.tensor(lambda_val)).detach() | |
| # Predict over the plotting range | |
| y_pred_plot = (X_plot @ w_fit) | |
| plt.plot(x_plot, y_pred_plot, label=rf"Ridge Fit ($\lambda={lambda_val}$)") | |
| y_pred_plot = (X_plot @ w_true) | |
| plt.plot(x_plot, y_pred_plot, "k--", label="True Model") | |
| plt.xlabel("Feature Value", fontsize=12) | |
| plt.ylabel("Target Value", fontsize=12) | |
| plt.title(r"Ridge Regression Fits with Different $\lambda$", fontsize=14) | |
| plt.legend(fontsize=10) | |
| plt.grid(True, linestyle=':') | |
| plt.show() | |
| # %% [markdown] | |
| # **Observation:** | |
| # * With a very small $\lambda=0.01$, the model tries to fit the training data | |
| # almost perfectly, potentially showing high variance. | |
| # * With a very large $\lambda=100.0$, the model is heavily regularized, | |
| # leading to a flatter line (high bias/underfitting). | |
| # * A moderate $\lambda=1$ provides a better balance. | |
| # | |
| # Our goal in bilevel optimization is to *automatically find* this optimal | |
| # $\lambda$ without relying on manual trial-and-error (like grid search). | |
| # %% [markdown] | |
| # ## 3. Method 0: Grid Search (Baseline) | |
| # | |
| # This is the "brute force" approach. We test a grid of $\lambda$ values, | |
| # solve the inner problem for each, and check the validation loss. | |
| # This is our "ground truth". | |
| # %% | |
| # ## 3. Method 0: Implementation | |
| print("--- Starting Grid Search ---") | |
| start_time_grid = time.time() | |
| lambda_grid = torch.logspace(-4, 2, 100) # 100 points from 0.001 to 100 | |
| with torch.no_grad(): | |
| grid_losses = np.array([ | |
| value_function(lambda_val).item() for lambda_val in lambda_grid | |
| ]) | |
| idx_best_lambda = grid_losses.argmin() | |
| best_lambda_grid = lambda_grid[idx_best_lambda].item() | |
| best_loss_grid = grid_losses[idx_best_lambda] | |
| end_time_grid = time.time() | |
| runtime_grid = end_time_grid - start_time_grid | |
| print(f"Best Lambda (Grid Search): {best_lambda_grid:.4f}") | |
| print(f"Best Validation Loss: {best_loss_grid:.4f}") | |
| print(f"Runtime: {runtime_grid:.4f}s") | |
| print("--- Finished Grid Search ---") | |
| # Plot the loss landscape | |
| plt.figure() #figsize=(10, 6)) | |
| plt.semilogx(lambda_grid, grid_losses) | |
| plt.xlabel(r"Lambda ($\lambda$)") | |
| plt.ylabel("Validation Loss") | |
| plt.title("Validation Loss vs. Lambda (Grid Search)") | |
| plt.axvline( | |
| best_lambda_grid, color='r', linestyle='--', | |
| label=rf"Best $\lambda$ ({best_lambda_grid:.4f})" | |
| ) | |
| plt.legend() | |
| plt.grid(which="both", linestyle=':') | |
| plt.savefig("../images/ridge_regression_grid_search.pdf") | |
| plt.show() | |
| # %% [markdown] | |
| # ## 4. Method 1: Gradient Descent on the Value-function | |
| # | |
| # We use optimization with the *exact* gradient computed by differentiating | |
| # through the closed-form solution `solve_inner_closed_form`. | |
| # %% | |
| # ## 4. Method 1: Implementation | |
| print("\n--- Starting GD Optimization ---") | |
| start_time_ift = time.time() | |
| log_lambda_ift = torch.zeros(1, requires_grad=True) | |
| optimizer_ift = torch.optim.Adam([log_lambda_ift], lr=0.1) | |
| losses_gd = [] | |
| n_outer_steps = 100 | |
| for step in range(n_outer_steps): | |
| optimizer_ift.zero_grad() | |
| lmbd = torch.exp(log_lambda_ift) | |
| # 1. Solve inner problem (getting w*) | |
| w_star = solve_inner_closed_form(lmbd) | |
| # 2. Compute outer loss | |
| outer_loss = val_loss(w_star) | |
| # 3. Compute outer gradient (d_loss / d_lambda) | |
| outer_loss.backward() | |
| # 4. Update lambda | |
| optimizer_ift.step() | |
| losses_gd.append(value_function(lmbd).item()) | |
| final_lmbd_gd = torch.exp(log_lambda_ift).item() | |
| end_time_ift = time.time() | |
| runtime_ift = end_time_ift - start_time_ift | |
| print(f"Final Lambda (GD): {final_lmbd_gd:.4f}") | |
| print(f"Runtime: {runtime_ift:.4f}s") | |
| print("--- Finished GD ---") | |
| plt.figure(figsize=(10, 6)) | |
| plt.plot(losses_gd, label=rf"GD (Final $\lambda \approx$ {final_lmbd_gd:.4f})", | |
| linewidth=2.5, color='black') | |
| plt.axhline(best_loss_grid, color='red', linestyle=':', linewidth=2, | |
| label=rf"Grid Search (at $\lambda = {best_lambda_grid:.4f})") | |
| plt.xlabel("Outer Optimization Step", fontsize=12) | |
| plt.ylabel("Validation Loss (Outer Objective)", fontsize=12) | |
| plt.title("IFT Optimization Trajectory", fontsize=16) | |
| plt.legend(fontsize=12) | |
| plt.grid(True, linestyle=':') | |
| plt.show() | |
| # %% [markdown] | |
| # ## 5. Method 2: Implicit Function Theorem | |
| # | |
| # We use optimization with the *exact* gradient computed by differentiating | |
| # through the closed-form solution `solve_inner_closed_form`. | |
| # %% | |
| # ## 5. Method 2: Implementation | |
| def solve_inner_gd(lmbd, k_inner_steps=50, inner_lr=0.1): | |
| """Solves the inner problem approximately by unrolling K steps of GD.""" | |
| w = torch.zeros(n_features, 1, requires_grad=True) | |
| for k in range(k_inner_steps): | |
| inner_loss = train_loss(w, lmbd) | |
| # create_graph=True allows us to backprop through this step | |
| grad_w = torch.autograd.grad(inner_loss, w, create_graph=True)[0] | |
| # Manually update w to maintain the graph | |
| w = w - inner_lr * grad_w | |
| return w | |
| def get_hypergradient_ift(lmbd): | |
| """Computes d(val_loss)/d(lambda) using IFT and Scipy CG.""" | |
| # 1. Compute w_k with K step of our solver and | |
| # setup w treated as fixed point for differentiation | |
| # with torch.no_grad(): | |
| w_k = solve_inner_gd(lmbd, K) | |
| w = w_k.detach().requires_grad_(True) | |
| # 2. Compute Gradient of Outer Loss w.r.t weights (rhs of linear system) | |
| l_val = val_loss(w) | |
| grad_val_w = torch.autograd.grad(l_val, w)[0] | |
| # 3. Solve linear system: H * z = grad_val_w, where H = XTX + lmbd/2 * I | |
| # to get z = H^{-1} * grad_val_w | |
| z = torch.linalg.solve(XTX + lmbd / 2 * Id, grad_val_w).detach() | |
| # 4. Compute Final Hypergradient | |
| # Formula: - (d^2 L_train / d lambda d w) * z | |
| # Equivalent to: - d(grad_train_w * z) / d lambda | |
| l_train = train_loss(w, lmbd) | |
| grad_train_w = torch.autograd.grad(l_train, w, create_graph=True)[0] | |
| prod = torch.sum(grad_train_w * z) | |
| grad_lambda = -torch.autograd.grad(prod, lmbd)[0] | |
| return grad_lambda | |
| def get_hypergradient_ift_cg(lmbd, n_cg_steps=10): | |
| """Computes d(val_loss)/d(lambda) using IFT and Scipy CG.""" | |
| # 1. Compute w_k with K step of our solver and | |
| # setup w treated as fixed point for differentiation | |
| # with torch.no_grad(): | |
| w_k = solve_inner_gd(lmbd, K) | |
| w = w_k.detach().requires_grad_(True) | |
| # 2. Compute Gradient of Outer Loss w.r.t weights (rhs of linear system) | |
| l_val = val_loss(w) | |
| grad_val_w = torch.autograd.grad(l_val, w)[0] | |
| # 3. Define Hessian-Vector Product (HVP) for Inner Loss | |
| # We need H = d^2 L_train / dw^2. | |
| # We compute this implicitly: H*v = d(grad_train_w * v)/dw | |
| l_train = train_loss(w, lmbd) | |
| grad_train_w = torch.autograd.grad(l_train, w, create_graph=True)[0] | |
| def hvp(v_tensor): | |
| prod = torch.sum(grad_train_w * v_tensor) | |
| return torch.autograd.grad(prod, w, retain_graph=True)[0] | |
| # 4. Solve linear system: H * z = grad_val_w using Scipy CG | |
| # Wrap HVP in a LinearOperator for Scipy | |
| num_params = w.numel() | |
| def matvec_numpy(v_np): | |
| v_torch = torch.from_numpy(v_np).float().reshape(w.shape) | |
| return hvp(v_torch).detach().numpy().flatten() | |
| # Use scipy Conjugate Gradient (CG) to solve H * z = grad_val_w | |
| # here, cg only require computing H*v products | |
| from scipy.sparse.linalg import LinearOperator, cg | |
| A = LinearOperator((num_params, num_params), matvec=matvec_numpy) | |
| b = grad_val_w.detach().numpy().flatten() | |
| z_np, info = cg(A, b, maxiter=n_cg_steps) | |
| z = torch.from_numpy(z_np).float().reshape(w.shape) | |
| # 5. Compute Final Hypergradient | |
| # Formula: - (d^2 L_train / d lambda d w) * z | |
| # Equivalent to: - d(grad_train_w * z) / d lambda | |
| prod = torch.sum(grad_train_w * z) | |
| grad_lambda = -torch.autograd.grad(prod, lmbd)[0] | |
| return grad_lambda | |
| # We will test several values for K | |
| K_values = [5, 10, 20] | |
| results_ift = {} | |
| runtimes_ift = {} | |
| print("\n--- Starting IFT Optimization ---") | |
| for K in K_values: | |
| print(f"\n Running with K={K}...") | |
| start_time_ift = time.time() | |
| log_lambda_ift = torch.zeros(1, requires_grad=True) | |
| optimizer_ift = torch.optim.Adam([log_lambda_ift], lr=0.1) | |
| losses_ift = [] | |
| for step in range(n_outer_steps): | |
| optimizer_ift.zero_grad() | |
| lmbd = torch.exp(log_lambda_ift) | |
| # 1. Compute Hypergradient using explicit IFT helper | |
| grad_lambda = get_hypergradient_ift(lmbd) | |
| # 2. Apply Chain Rule for log_lambda: | |
| # dL/d(log) = dL/d(lam) * d(lam)/d(log) and d(lam)/d(log) = lam | |
| log_lambda_ift.grad = grad_lambda * lmbd.detach() | |
| optimizer_ift.step() | |
| losses_ift.append(value_function(lmbd).item()) | |
| end_time_ift = time.time() | |
| # Store results for this K | |
| runtime = end_time_ift - start_time_ift | |
| runtimes_ift[K] = runtime | |
| results_ift[K] = { | |
| 'losses': losses_ift, | |
| 'final': torch.exp(log_lambda_ift).item() | |
| } | |
| print(f" Final Lambda (K={K}): {results_ift[K]['final']:.4f}") | |
| print(f" Runtime (K={K}): {runtime:.4f}s") | |
| print("--- Finished IFT ---") | |
| plt.figure(figsize=(10, 6)) | |
| for i, K in enumerate(K_values): | |
| res = results_ift[K] | |
| plt.plot(res['losses'], f'C{i}', label=f"IFT K={K} " | |
| rf"(Final $\lambda \approx$ {res['final']:.4f})") | |
| # plt.plot(res['val_losses'], f'C{i}--') | |
| plt.axhline(best_loss_grid, color='red', linestyle=':', linewidth=2, | |
| label=rf"Grid Search ($\lambda \approx$ {best_lambda_grid:.4f})") | |
| plt.xlabel("Outer Optimization Step", fontsize=12) | |
| plt.ylabel("Validation Loss (Outer Objective)", fontsize=12) | |
| plt.title("IFT Optimization Trajectory", fontsize=16) | |
| plt.legend(fontsize=12) | |
| plt.grid(True, linestyle=':') | |
| plt.show() | |
| # %% [markdown] | |
| # ## 6. Method 3: Unrolling (Varying K) | |
| # | |
| # We now test the unrolling method. We will run the *entire* outer optimization process for different values of `K` (the number of inner steps). | |
| # | |
| # This will show how the *accuracy* of the final $\lambda$ and the *runtime* of the optimization depend on `K`. | |
| # %% | |
| # ## 6. Method 3: Implementation | |
| # We will test several values for K | |
| results_unroll = {} | |
| runtimes_unroll = {} | |
| print(f"\n--- Starting Unrolling Optimization (Varying K={K_values}) ---") | |
| for K in K_values: | |
| print(f"\n Running with K={K}...") | |
| start_time_unroll = time.time() | |
| log_lambda_unroll = torch.zeros(1, requires_grad=True) | |
| optimizer_unroll = torch.optim.Adam([log_lambda_unroll], lr=0.1) | |
| losses_unroll = [] | |
| for step in range(n_outer_steps): | |
| optimizer_unroll.zero_grad() | |
| lmbd = torch.exp(log_lambda_unroll) | |
| # 1. Solve inner problem (getting w_K) | |
| w_k = solve_inner_gd(lmbd, k_inner_steps=K) | |
| # 2. Compute outer loss | |
| outer_loss = val_loss(w_k) | |
| # 3. Compute outer gradient (backprops through K steps) | |
| outer_loss.backward() | |
| # 4. Update lambda | |
| optimizer_unroll.step() | |
| losses_unroll.append(value_function(lmbd).item()) | |
| end_time_unroll = time.time() | |
| # Store results for this K | |
| runtime = end_time_unroll - start_time_unroll | |
| runtimes_unroll[K] = runtime | |
| results_unroll[K] = { | |
| 'losses': losses_unroll, | |
| 'final': torch.exp(log_lambda_unroll).item() | |
| } | |
| print(f" Final Lambda (K={K}): {results_unroll[K]['final']:.4f}") | |
| print(f" Runtime (K={K}): {runtime:.4f}s") | |
| print("\n--- Finished Unrolling ---") | |
| plt.figure(figsize=(10, 6)) | |
| for i, K in enumerate(K_values): | |
| res = results_unroll[K] | |
| plt.plot(res['losses'], f'C{i}', label=f"Unrolling K={K} " | |
| rf"(Final $\lambda \approx$ {res['final']:.4f})") | |
| # plt.plot(res['val_losses'], f'C{i}--') | |
| plt.axhline(best_loss_grid, color='red', linestyle=':', linewidth=2, | |
| label=rf"Grid Search ($\lambda \approx$ {best_lambda_grid:.4f})") | |
| plt.xlabel("Outer Optimization Step", fontsize=12) | |
| plt.ylabel("Validation Loss (Outer Objective)", fontsize=12) | |
| plt.title("IFT Optimization Trajectory", fontsize=16) | |
| plt.legend(fontsize=12) | |
| plt.grid(True, linestyle=':') | |
| plt.show() | |
| # %% [markdown] | |
| # ## 7. Comparison and Takeaways | |
| # %% | |
| # ## 7.1. Plot: Optimization Trajectories | |
| plt.figure(figsize=(14, 8)) | |
| # 1. Plot the "ground truth" from grid search | |
| plt.axhline(best_loss_grid, color='red', linestyle=':', linewidth=2, | |
| label=rf"Grid Search Best Loss (at $\lambda \approx$ {best_lambda_grid:.4f})") | |
| # 2. Plot the GD (ideal) optimization | |
| plt.plot(losses_gd, label=rf"GD (Final $\lambda \approx$ {final_lmbd_gd:.4f})", | |
| linewidth=2.5, color='black') | |
| # 3. Plot the unrolling results for each K | |
| K = 10 | |
| res = results_unroll[K] | |
| plt.plot(res['losses'], label=f"Unrolling K={K} " | |
| rf"(Final $\lambda \approx$ {res['final']:.4f})") | |
| res = results_ift[K] | |
| plt.plot(res['losses'], label=f"IFT K={K} " | |
| rf"(Final $\lambda \approx$ {res['final']:.4f})") | |
| plt.xlabel("Outer Optimization Step", fontsize=12) | |
| plt.ylabel("Validation Loss (Outer Objective)", fontsize=12) | |
| plt.title("Bilevel Optimization Method Comparison", fontsize=16) | |
| plt.legend(fontsize=12, loc='upper right') | |
| plt.grid(True, linestyle=':') | |
| plt.show() | |
| # %% | |
| # ## 7.2. Plot: Runtime vs. K for Unrolling | |
| K_list = sorted(runtimes_unroll.keys()) | |
| K_runtimes_unroll = [runtimes_unroll[k] for k in K_list] | |
| K_runtimes_ift = [runtimes_ift[k] for k in K_list] | |
| final_lambdas_unroll = [results_unroll[k]['final'] for k in K_list] | |
| final_losses_unroll = [results_unroll[k]['losses'][-1] for k in K_list] | |
| final_lambdas_ift = [results_ift[k]['final'] for k in K_list] | |
| final_losses_ift = [results_ift[k]['losses'][-1] for k in K_list] | |
| fig, ax1 = plt.subplots(figsize=(10, 6)) | |
| # Plot runtime | |
| color = 'tab:blue' | |
| ax1.set_xlabel("K (Inner Unrolling Steps)", fontsize=12) | |
| ax1.set_ylabel("Total Runtime (seconds)", color=color, fontsize=12) | |
| ax1.plot(K_list, K_runtimes_unroll, 'o-', color=color, label="Runtime (unrolled)") | |
| ax1.plot(K_list, K_runtimes_ift, 'o-', color='tab:cyan', label="Runtime (IFT)") | |
| ax1.tick_params(axis='y', labelcolor=color) | |
| ax1.set_xticks(K_list) | |
| # Plot accuracy (final loss) on a second y-axis | |
| ax2 = ax1.twinx() | |
| color = 'tab:green' | |
| ax2.set_ylabel("Final Validation Loss", color=color, fontsize=12) | |
| ax2.plot( | |
| K_list, final_losses_unroll, 's--', color=color, | |
| label="Final Validation Loss (unrolled)" | |
| ) | |
| ax2.plot( | |
| K_list, final_losses_ift, 's--', color='tab:olive', | |
| label="Final Validation Loss (IFT)" | |
| ) | |
| # Add a line for the "true" best lambda | |
| ax2.axhline(best_loss_grid, color='black', linestyle=':', label="Grid Search Best Loss") | |
| ax2.tick_params(axis='y', labelcolor=color) | |
| fig.suptitle("Unrolling Tradeoff: Runtime vs. Accuracy (K)", fontsize=16) | |
| fig.tight_layout(rect=[0, 0.03, 1, 0.95]) | |
| # Add a single legend | |
| lines, labels = ax1.get_legend_handles_labels() | |
| lines2, labels2 = ax2.get_legend_handles_labels() | |
| ax2.legend(lines + lines2, labels + labels2, loc='center right') | |
| plt.show() | |
| # %% | |
| # ## 7.3. Summary Table and Key Takeaways | |
| print("\n--- 🏁 Final Results Summary ---") | |
| print("\nMethod 0: Grid Search") | |
| print(f" Best Lambda: {best_lambda_grid:.4f}") | |
| print(f" Best Loss: {best_loss_grid:.4f}") | |
| print(f" Runtime: {runtime_grid:.4f}s") | |
| print("\nMethod 1: GD (Exact Gradient)") | |
| print(f" Final Lambda: {final_lmbd_gd:.4f}") | |
| print(f" Final Loss: {losses_ift[-1]:.4f}") | |
| print(f" Runtime: {runtime_ift:.4f}s") | |
| print("\nMethod 2: Unrolling") | |
| print(" K | Final Lambda | Final Loss | Runtime (s)") | |
| print("---------------------------------------------") | |
| for K in K_values: | |
| res = results_unroll[K] | |
| print(f" {K:2d} | {res['final']:12.4f} | {res['losses'][-1]:10.4f} | " | |
| f"{runtimes_unroll[K]:.4f}") | |
| print("\nMethod 3: IFT") | |
| print(" K | Final Lambda | Final Loss | Runtime (s)") | |
| print("---------------------------------------------") | |
| for K in K_values: | |
| res = results_ift[K] | |
| print(f" {K:2d} | {res['final']:12.4f} | {res['losses'][-1]:10.4f} | " | |
| f"{runtimes_ift[K]:.4f}") | |
| # %% [markdown] | |
| # ### Key Takeaways | |
| # | |
| # 1. **Grid Search:** | |
| # * **Pro:** Simple, guaranteed to find the optimum *on the grid*. Provides a reliable "ground truth" (Best $\lambda \approx 0.16$, Best Loss $\approx 0.02$). | |
| # * **Con:** Extremely inefficient. Its cost grows exponentially with the number of hyperparameters. Not feasible for more than 1 or 2. | |
| # | |
| # 2. **IFT (Implicit Function Theorem):** | |
| # * **Pro:** **Fastest and most accurate.** Converges to the correct $\lambda$ and minimum validation loss very quickly. It's also memory-efficient. | |
| # * **Con:** It's a "special case." Requires a differentiable, closed-form (or otherwise "ideal") solution to the inner problem, which is *not* possible for most deep learning models. | |
| # | |
| # 3. **Unrolling (Iterative Differentiation):** | |
| # * **This is the most important tradeoff.** | |
| # * **Accuracy (Bias):** Look at the plots. | |
| # * With `K=1`, the final $\lambda$ is *very* wrong, and the loss is high. The gradient is highly biased due to aggressive truncation. | |
| # * As `K` increases (`K=5`, `K=20`), the final $\lambda$ gets closer to the true value, and the final validation loss decreases, approaching the IFT/Grid Search optimum. | |
| # * `K=50` gives a result almost identical to IFT in terms of final loss. The bias of the truncated gradient has been effectively reduced. | |
| # * **Runtime (Cost):** | |
| # * The runtime plot clearly shows the cost **scales linearly with K**. This is because each outer loop step involves backpropagating through `K` inner optimization steps. | |
| # * This is the core tradeoff: **Higher `K` provides a more accurate (less biased) gradient, but costs more in both time and memory** (since PyTorch must store the computation graph for all `K` steps). | |
| # | |
| # **Final Conclusion:** Bilevel optimization is a tradeoff. IFT is ideal but rarely usable in complex scenarios. Unrolling is the general-purpose tool, but we must carefully choose a `K` large enough to obtain a good gradient approximation, without incurring excessive computational cost. This balancing act is a key challenge in applications like meta-learning. | |
| # %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment