Skip to content

Instantly share code, notes, and snippets.

@mdmitry1
Last active January 17, 2026 19:12
Show Gist options
  • Select an option

  • Save mdmitry1/21a76345acf2ea8b823e34e2a9f684eb to your computer and use it in GitHub Desktop.

Select an option

Save mdmitry1/21a76345acf2ea8b823e34e2a9f684eb to your computer and use it in GitHub Desktop.
#!/usr/bin/python3.14
# -*- coding: utf-8 -*-
import torch
import math
from hashlib import sha256
from matplotlib import pyplot as plt
def main(timeout: int=5000, size: int=100000) -> int:
dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU
# Create random input and output data
x = torch.linspace(0, math.pi/2, size, device=device, dtype=dtype)
y = torch.sin(x)
# Randomly initialize weights
torch.manual_seed(42)
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)
learning_rate = 1e-6
print("Iteration RMSE")
for t in range(60000):
# Forward pass: compute predicted y
y_pred = a + b * x + c * x ** 2 + d * x ** 3
# Compute and print loss
loss = (y_pred - y).pow(2).sum().item()
if t % 1000 == 999:
print(t+1, math.sqrt(loss/size))
# Backprop to compute gradients of a, b, c, d with respect to loss
grad_y_pred = 2.0 * (y_pred - y)
grad_a = grad_y_pred.sum()
grad_b = (grad_y_pred * x).sum()
grad_c = (grad_y_pred * x ** 2).sum()
grad_d = (grad_y_pred * x ** 3).sum()
# Update weights using gradient descent
a -= learning_rate * grad_a
b -= learning_rate * grad_b
c -= learning_rate * grad_c
d -= learning_rate * grad_d
result = f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3'
print(result)
err = a + b * x + c * x ** 2 + d * x ** 3 - y
plt.plot(x,y,x,a + b * x + c * x ** 2 + d * x ** 3)
plt.grid()
plt.gcf().canvas.manager.set_window_title('SIN(X) PyTorch approximation')
plt.title("Plot Example")
plt.legend(['sin','sin predicted'])
plt.xlabel('x')
plt.ylabel('sin(x) and pytorch prediction')
if math.inf != float(timeout):
timer = plt.gcf().canvas.new_timer(interval=timeout, callbacks=[(plt.close, [], {})])
timer.start()
plt.show()
print(f"Maximum error = {torch.max(torch.abs((err))).item():.6f}")
return sha256(result.encode()).hexdigest()
if __name__ == "__main__":
print(main(math.inf))
import sys
from pytorch_ex import main
def test_pytorch_ex(monkeypatch):
with monkeypatch.context() as m:
m.setattr(sys, 'argv', ['pytorch_ex'])
print("")
assert main() == '92138878ecd663f9ff1e8313be1ffd1fb31d87e57ec64a0883672fe30fa5e2e3'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment