Last active
November 15, 2025 00:26
-
-
Save stephenlb/3f779be0aa60b7f5ffa7c16ed9294653 to your computer and use it in GitHub Desktop.
Draw the Pytorch Image
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding: utf-8 -*- | |
| import torch | |
| import math | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import torchvision.io as io | |
| import torchvision.transforms as T | |
| from torch.utils.data import TensorDataset, DataLoader | |
| print("default device:") | |
| print(torch.get_default_device()) | |
| #torch.set_default_device('mps') | |
| #print("default device:") | |
| #print(torch.get_default_device()) | |
| #device = torch.device("cpu") | |
| #if torch.backends.mps.is_available(): | |
| # device = torch.device("mps") | |
| # print("Using MPS device") | |
| device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
| print("Using device:", device) | |
| def load_image_as_tensor(image_path): | |
| image_tensor = io.decode_image( | |
| torch.tensor( | |
| list(open(image_path, 'rb').read()), | |
| dtype=torch.uint8 | |
| ).cpu() | |
| ) | |
| y = [] | |
| x = [] | |
| shape = image_tensor.shape | |
| channels = shape[0] | |
| height = shape[1] | |
| width = shape[2] | |
| resolution = (height * width) // 10 | |
| data = list(range(resolution)) | |
| for h in range(height): | |
| for w in range(width): | |
| red = image_tensor[0][h][w].item() | |
| green = image_tensor[1][h][w].item() | |
| blue = image_tensor[2][h][w].item() | |
| value = int(0.299 * red + 0.587 * green + 0.114 * blue) | |
| x.append([h/height, w/width]) | |
| ## Generate output data | |
| y.append([value / 255.0]) | |
| #if value > 128: y.append([1]) | |
| #else: y.append([0]) | |
| return x, y, width, height | |
| x, y, width, height = load_image_as_tensor('pytorch.png') | |
| x = torch.tensor(x, dtype=torch.float32) | |
| y = torch.tensor(y, dtype=torch.float32) | |
| model = torch.nn.Sequential( | |
| torch.nn.Linear(2, 128), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(128, 256), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(256, 256), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(256, 256), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(256, 1), | |
| torch.nn.Tanh(), | |
| ).to(device) | |
| dataset = TensorDataset(x, y) | |
| data_loader = DataLoader(dataset, batch_size=512, shuffle=True) | |
| learning_rate = 1e-3 | |
| #optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) | |
| #optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate) | |
| #optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate) | |
| ## Prepare the canvas | |
| plt.figure(figsize=(10,5)) | |
| plt.title('Draw Pytroch Image') | |
| plt.ion() | |
| ## BCE with logits loss | |
| #loss_fn = torch.nn.BCEWithLogitsLoss() | |
| loss_fn = torch.nn.MSELoss() | |
| #loss_fn = torch.nn.CrossEntropyLoss() | |
| losses = [] | |
| batches = 0 | |
| model.train() | |
| for epoch in range(200): | |
| ## Train | |
| for batch_idx, (batch_x, batch_y) in enumerate(data_loader): | |
| optimizer.zero_grad() | |
| y_pred = model(batch_x.to(device)) | |
| loss = loss_fn(y_pred, batch_y.to(device)) | |
| loss.backward() | |
| optimizer.step() | |
| losses.append(loss.item()) | |
| batches += 1 | |
| if batches % 100 == 0: | |
| sum_loss = sum(losses[-500:]) / 500. | |
| print(epoch, sum_loss) | |
| out = model(x.to(device)).reshape(height, width).cpu().detach().numpy() | |
| plt.clf() | |
| plt.imshow(out, cmap="gray") | |
| plt.pause(0.10) | |
| #print(out) | |
| #points = out.detach().numpy() | |
| #print(list(points)) | |
| #break | |
| #pM = [points[i][0] for i in range(len(points))] | |
| #print('pM') | |
| #print(pM) | |
| #pX = [points[i][1] for i in range(len(points))] | |
| #pY = [points[i][2] for i in range(len(points))] | |
| #print(points) | |
| #plt.plot(pX, pY, marker='o', markersize=10, linestyle='None', color='red') | |
| #for point in out.detach().numpy(): | |
| # ### TODO ..... | |
| # ######if point[0] > 0: | |
| # plt.plot(point[1]*10, point[2]*10, marker='o', markersize=10, linestyle='None', color='red') | |
| #plt.pause(0.10) | |
| #out = model(x) | |
| #pX = [] | |
| #pY = [] | |
| #for point in out.detach().numpy(): | |
| # pX.append(point[0]) | |
| # pY.append(point[1]) | |
| # Graph the result | |
| #plt.plot(pX, pY, 'r-', label='CIRCLE') | |
| #plt.legend() | |
| #plt.show() | |
| plt.pause(100.) |
Author
stephenlb
commented
Nov 15, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment