Created
September 14, 2024 17:30
-
-
Save trilusa/3cfa105b94d8aebd0891736588aa1b12 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import matplotlib.pyplot as plt | |
| def matmul_viz(A, x): | |
| Ax = A @ x | |
| # Get dimensions of matrices | |
| rows_A, cols_A = A.shape | |
| rows_x, cols_x = x.shape | |
| rows_Ax, cols_Ax = Ax.shape | |
| # Create the plot | |
| fig, ax = plt.subplots(figsize=(rows_A, 20)) | |
| # Plot matrix A | |
| ax.imshow(A, aspect='equal', extent=[0, cols_A, 0, rows_A], cmap='Blues') | |
| # Plot vector/matrix x | |
| ax.imshow(x, aspect='equal', extent=[cols_A + 1, cols_A + 1 + cols_x, 0, rows_x], cmap='Blues') | |
| # Plot result of A @ x | |
| ax.imshow(Ax, aspect='equal', extent=[cols_A + 2 + cols_x, cols_A + 2 + cols_x + cols_Ax, 0, rows_Ax], cmap='Blues') | |
| # Set axis limits to fit everything | |
| ax.set_xlim(0, cols_A + 2 + cols_x + cols_Ax + 1) | |
| ax.set_ylim(0, max(rows_A, rows_Ax, rows_x)) | |
| # Add numbers to each cell of A | |
| for i in range(rows_A): | |
| for j in range(cols_A): | |
| ax.text(j + 0.5, i + 0.5, f'{A[rows_A-i-1, j].item():.0f}', ha='center', va='center', color='grey') | |
| # Add numbers to each cell of x | |
| for i in range(rows_x): | |
| for j in range(cols_x): | |
| ax.text(j + cols_A + 1.5, i + 0.5, f'{x[rows_x-i-1, j].item():.0f}', ha='center', va='center', color='grey') | |
| # Add numbers to each cell of Ax | |
| for i in range(rows_Ax): | |
| for j in range(cols_Ax): | |
| ax.text(j + cols_A + 2 + cols_x + 0.5, i + 0.5, f'{Ax[rows_Ax-i-1, j].item():.0f}', ha='center', va='center', color='grey') | |
| # Add multiplication sign (×) between A and x | |
| ax.text(cols_A + 0.5, rows_A / 2, r'$\times$', ha='center', va='center', fontsize=10, color='white') | |
| # Add equals sign (=) between x and Ax | |
| ax.text(cols_A + cols_x + 1.5, rows_A / 2, r'$=$', ha='center', va='center', fontsize=10, color='white') | |
| # Set background color to VSCode dark mode grey | |
| ax.set_facecolor('#2D2D2D') | |
| fig.patch.set_facecolor('#2D2D2D') | |
| # Add grid for clarity | |
| ax.grid('off') | |
| ax.axis('off') | |
| # Display the plot | |
| plt.tight_layout() | |
| plt.show() | |
| # Test with arbitrary sizes | |
| A = F.one_hot(torch.tensor([0,2,1,4,3,5]),num_classes=6).float() # Example: 6x4 matrix | |
| x = torch.randn(6,1) # Example: 4x2 matrix | |
| matmul_viz(A, x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment