Skip to content

Instantly share code, notes, and snippets.

@trilusa
Created September 14, 2024 17:30
Show Gist options
  • Select an option

  • Save trilusa/3cfa105b94d8aebd0891736588aa1b12 to your computer and use it in GitHub Desktop.

Select an option

Save trilusa/3cfa105b94d8aebd0891736588aa1b12 to your computer and use it in GitHub Desktop.
import torch
import matplotlib.pyplot as plt
def matmul_viz(A, x):
Ax = A @ x
# Get dimensions of matrices
rows_A, cols_A = A.shape
rows_x, cols_x = x.shape
rows_Ax, cols_Ax = Ax.shape
# Create the plot
fig, ax = plt.subplots(figsize=(rows_A, 20))
# Plot matrix A
ax.imshow(A, aspect='equal', extent=[0, cols_A, 0, rows_A], cmap='Blues')
# Plot vector/matrix x
ax.imshow(x, aspect='equal', extent=[cols_A + 1, cols_A + 1 + cols_x, 0, rows_x], cmap='Blues')
# Plot result of A @ x
ax.imshow(Ax, aspect='equal', extent=[cols_A + 2 + cols_x, cols_A + 2 + cols_x + cols_Ax, 0, rows_Ax], cmap='Blues')
# Set axis limits to fit everything
ax.set_xlim(0, cols_A + 2 + cols_x + cols_Ax + 1)
ax.set_ylim(0, max(rows_A, rows_Ax, rows_x))
# Add numbers to each cell of A
for i in range(rows_A):
for j in range(cols_A):
ax.text(j + 0.5, i + 0.5, f'{A[rows_A-i-1, j].item():.0f}', ha='center', va='center', color='grey')
# Add numbers to each cell of x
for i in range(rows_x):
for j in range(cols_x):
ax.text(j + cols_A + 1.5, i + 0.5, f'{x[rows_x-i-1, j].item():.0f}', ha='center', va='center', color='grey')
# Add numbers to each cell of Ax
for i in range(rows_Ax):
for j in range(cols_Ax):
ax.text(j + cols_A + 2 + cols_x + 0.5, i + 0.5, f'{Ax[rows_Ax-i-1, j].item():.0f}', ha='center', va='center', color='grey')
# Add multiplication sign (×) between A and x
ax.text(cols_A + 0.5, rows_A / 2, r'$\times$', ha='center', va='center', fontsize=10, color='white')
# Add equals sign (=) between x and Ax
ax.text(cols_A + cols_x + 1.5, rows_A / 2, r'$=$', ha='center', va='center', fontsize=10, color='white')
# Set background color to VSCode dark mode grey
ax.set_facecolor('#2D2D2D')
fig.patch.set_facecolor('#2D2D2D')
# Add grid for clarity
ax.grid('off')
ax.axis('off')
# Display the plot
plt.tight_layout()
plt.show()
# Test with arbitrary sizes
A = F.one_hot(torch.tensor([0,2,1,4,3,5]),num_classes=6).float() # Example: 6x4 matrix
x = torch.randn(6,1) # Example: 4x2 matrix
matmul_viz(A, x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment