Skip to content

Instantly share code, notes, and snippets.

@daniel-e
Created June 9, 2022 06:51
Show Gist options
  • Select an option

  • Save daniel-e/8ae5611eb40ff70ee04ebf65dcc20081 to your computer and use it in GitHub Desktop.

Select an option

Save daniel-e/8ae5611eb40ff70ee04ebf65dcc20081 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 0, Loss: 0.1344318389892578\n",
"Epoch: 1, Loss: 0.10877793282270432\n",
"Epoch: 2, Loss: 0.1007857620716095\n",
"Epoch: 3, Loss: 0.06776227056980133\n",
"Epoch: 4, Loss: 0.03420461341738701\n",
"Epoch: 5, Loss: 0.11107035726308823\n",
"Epoch: 6, Loss: 0.028283704072237015\n",
"Epoch: 7, Loss: 0.01456095464527607\n",
"Epoch: 8, Loss: 0.0457158200442791\n",
"Epoch: 9, Loss: 0.030669715255498886\n"
]
},
{
"data": {
"text/plain": [
"tensor(0.9742)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"import torchvision as tv\n",
"\n",
"t = tv.transforms.ToTensor()\n",
"\n",
"mnist_training = tv.datasets.MNIST(\n",
" root='/tmp/mnist', \n",
" train=True, \n",
" download=True, \n",
" transform=t\n",
")\n",
"\n",
"mnist_val = tv.datasets.MNIST(\n",
" root='/tmp/mnist', \n",
" train=False, \n",
" download=True, \n",
" transform=t\n",
")\n",
"\n",
"model = torch.nn.Sequential(\n",
" torch.nn.Linear(28*28, 128),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Linear(128, 10)\n",
")\n",
"\n",
"opt = torch.optim.Adam(params=model.parameters(), lr=0.01)\n",
"\n",
"loss_fn = torch.nn.CrossEntropyLoss()\n",
"\n",
"loader = torch.utils.data.DataLoader(\n",
" mnist_training, \n",
" batch_size=500, \n",
" shuffle=True\n",
")\n",
"\n",
"for epoch in range(10):\n",
" for imgs, labels in loader:\n",
" n = len(imgs)\n",
" imgs = imgs.view(n, -1)\n",
" predictions = model(imgs) \n",
" loss = loss_fn(predictions, labels) \n",
" opt.zero_grad()\n",
" loss.backward()\n",
" opt.step()\n",
" print(f\"Epoch: {epoch}, Loss: {float(loss)}\")\n",
"\n",
"n = 10000\n",
"loader = torch.utils.data.DataLoader(mnist_val, batch_size=n)\n",
"images, labels = iter(loader).next()\n",
"\n",
"predictions = model(images.view(n, -1))\n",
"\n",
"predicted_labels = predictions.argmax(dim=1)\n",
"torch.sum(predicted_labels == labels) / n"
]
}
],
"metadata": {
"interpreter": {
"hash": "0b567db37d44c4d7108c6daa0f9e0ef92bd52d42f7ebb7e5d1452f34c4acbe1f"
},
"kernelspec": {
"display_name": "Python 3.8.10 ('.venv': venv)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment