Skip to content

Instantly share code, notes, and snippets.

@stephenlb
Last active November 15, 2025 00:26
Show Gist options
  • Select an option

  • Save stephenlb/3f779be0aa60b7f5ffa7c16ed9294653 to your computer and use it in GitHub Desktop.

Select an option

Save stephenlb/3f779be0aa60b7f5ffa7c16ed9294653 to your computer and use it in GitHub Desktop.
Draw the Pytorch Image
# -*- coding: utf-8 -*-
import torch
import math
import numpy as np
import matplotlib.pyplot as plt
import torchvision.io as io
import torchvision.transforms as T
from torch.utils.data import TensorDataset, DataLoader
print("default device:")
print(torch.get_default_device())
#torch.set_default_device('mps')
#print("default device:")
#print(torch.get_default_device())
#device = torch.device("cpu")
#if torch.backends.mps.is_available():
# device = torch.device("mps")
# print("Using MPS device")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)
def load_image_as_tensor(image_path):
image_tensor = io.decode_image(
torch.tensor(
list(open(image_path, 'rb').read()),
dtype=torch.uint8
).cpu()
)
y = []
x = []
shape = image_tensor.shape
channels = shape[0]
height = shape[1]
width = shape[2]
resolution = (height * width) // 10
data = list(range(resolution))
for h in range(height):
for w in range(width):
red = image_tensor[0][h][w].item()
green = image_tensor[1][h][w].item()
blue = image_tensor[2][h][w].item()
value = int(0.299 * red + 0.587 * green + 0.114 * blue)
x.append([h/height, w/width])
## Generate output data
y.append([value / 255.0])
#if value > 128: y.append([1])
#else: y.append([0])
return x, y, width, height
x, y, width, height = load_image_as_tensor('pytorch.png')
x = torch.tensor(x, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)
model = torch.nn.Sequential(
torch.nn.Linear(2, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, 1),
torch.nn.Tanh(),
).to(device)
dataset = TensorDataset(x, y)
data_loader = DataLoader(dataset, batch_size=512, shuffle=True)
learning_rate = 1e-3
#optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
#optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
#optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)
## Prepare the canvas
plt.figure(figsize=(10,5))
plt.title('Draw Pytroch Image')
plt.ion()
## BCE with logits loss
#loss_fn = torch.nn.BCEWithLogitsLoss()
loss_fn = torch.nn.MSELoss()
#loss_fn = torch.nn.CrossEntropyLoss()
losses = []
batches = 0
model.train()
for epoch in range(200):
## Train
for batch_idx, (batch_x, batch_y) in enumerate(data_loader):
optimizer.zero_grad()
y_pred = model(batch_x.to(device))
loss = loss_fn(y_pred, batch_y.to(device))
loss.backward()
optimizer.step()
losses.append(loss.item())
batches += 1
if batches % 100 == 0:
sum_loss = sum(losses[-500:]) / 500.
print(epoch, sum_loss)
out = model(x.to(device)).reshape(height, width).cpu().detach().numpy()
plt.clf()
plt.imshow(out, cmap="gray")
plt.pause(0.10)
#print(out)
#points = out.detach().numpy()
#print(list(points))
#break
#pM = [points[i][0] for i in range(len(points))]
#print('pM')
#print(pM)
#pX = [points[i][1] for i in range(len(points))]
#pY = [points[i][2] for i in range(len(points))]
#print(points)
#plt.plot(pX, pY, marker='o', markersize=10, linestyle='None', color='red')
#for point in out.detach().numpy():
# ### TODO .....
# ######if point[0] > 0:
# plt.plot(point[1]*10, point[2]*10, marker='o', markersize=10, linestyle='None', color='red')
#plt.pause(0.10)
#out = model(x)
#pX = []
#pY = []
#for point in out.detach().numpy():
# pX.append(point[0])
# pY.append(point[1])
# Graph the result
#plt.plot(pX, pY, 'r-', label='CIRCLE')
#plt.legend()
#plt.show()
plt.pause(100.)
@stephenlb
Copy link
Author

pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment