Skip to content

Instantly share code, notes, and snippets.

@brusangues
Created March 3, 2026 12:24
Show Gist options
  • Select an option

  • Save brusangues/a01126e8fc4d7e9dcbca8688cfaebdba to your computer and use it in GitHub Desktop.

Select an option

Save brusangues/a01126e8fc4d7e9dcbca8688cfaebdba to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "bee24b83",
"metadata": {},
"source": [
"# Construindo uma rede neural apenas com Numpy"
]
},
{
"cell_type": "markdown",
"id": "3879c3b2",
"metadata": {},
"source": [
"https://www.youtube.com/watch?v=w8yWXqWQYmU \n",
"\n",
"https://www.kaggle.com/code/wwsalmon/simple-mnist-nn-from-scratch-numpy-no-tf-keras "
]
},
{
"cell_type": "markdown",
"id": "42a04b14",
"metadata": {},
"source": [
"# 1. Dataset MNIST Dígitos"
]
},
{
"cell_type": "markdown",
"id": "c304ddec",
"metadata": {},
"source": [
"## 1.1. Leitura dos dados"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "25e3e55e",
"metadata": {},
"outputs": [],
"source": [
"# import pandas as pd\n",
"# from sklearn.datasets import fetch_openml\n",
"\n",
"# mnist = fetch_openml('mnist_784', version=1)\n",
"# X, y = mnist['data'], mnist['target']\n",
"# X[\"target\"] = y\n",
"# df = pd.DataFrame(X)\n",
"# print(X.shape, df.shape)\n",
"# df"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d98f0cae",
"metadata": {},
"outputs": [],
"source": [
"# len(df)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "85dd10a2",
"metadata": {},
"outputs": [],
"source": [
"# # Salvando dados\n",
"# step = 5_000\n",
"# len_df = len(df)\n",
"# for i in range(0, len_df, step):\n",
"# print(i, i+step)\n",
"# dfi = df.iloc[i:i+step]\n",
"# dfi.to_csv(f\"mnist/{i}.csv\", index=False)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "65f3601e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 5000\n",
"5000 10000\n",
"10000 15000\n",
"15000 20000\n",
"20000 25000\n",
"25000 30000\n",
"30000 35000\n",
"35000 40000\n",
"40000 45000\n",
"45000 50000\n",
"50000 55000\n",
"55000 60000\n",
"60000 65000\n",
"65000 70000\n"
]
}
],
"source": [
"# Lendo dados\n",
"import pandas as pd\n",
"\n",
"dfs = []\n",
"step = 5_000\n",
"len_df = 70000\n",
"for i in range(0, len_df, step):\n",
" print(i, i+step)\n",
" dfi = pd.read_csv(f\"mnist/{i}.csv\")\n",
" dfs.append(dfi)\n",
"df_ = pd.concat(dfs, ignore_index=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "7ac4423a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 70000 entries, 0 to 69999\n",
"Columns: 785 entries, pixel1 to target\n",
"dtypes: int64(785)\n",
"memory usage: 419.2 MB\n"
]
}
],
"source": [
"df_.info()"
]
},
{
"cell_type": "markdown",
"id": "c47cd01c",
"metadata": {},
"source": [
"## 1.2. Visualização"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c7456014",
"metadata": {},
"outputs": [],
"source": [
"X = df_.filter(regex=\"pixel\")\n",
"y = df_[\"target\"]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "8c8b02ce",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 70000 entries, 0 to 69999\n",
"Columns: 784 entries, pixel1 to pixel784\n",
"dtypes: int16(784)\n",
"memory usage: 104.7 MB\n"
]
}
],
"source": [
"X = X.astype(\"int16\")\n",
"X.info()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "ad1e952b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"count 70000.000000\n",
"mean 13.230286\n",
"std 50.548067\n",
"min 0.000000\n",
"25% 0.000000\n",
"50% 0.000000\n",
"75% 0.000000\n",
"max 255.000000\n",
"Name: pixel100, dtype: float64"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.pixel100.describe()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a0016645",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"count 70000.000000\n",
"mean 13.230286\n",
"std 50.548067\n",
"min 0.000000\n",
"25% 0.000000\n",
"50% 0.000000\n",
"75% 0.000000\n",
"max 255.000000\n",
"Name: pixel100, dtype: float64"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.pixel100.describe()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "5ea126bc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Id: 56982, Label: 5')"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGzCAYAAABpdMNsAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAI2JJREFUeJzt3Qt0FOX5x/En3MI1wRDIhZtcxXItiIgiIFACKgqiFbUtWA8UBBVQ0PiXi1YNYkUqB8GeWoIHREEFKrVBCBK0ggoa8QaHYJBgICI2CYRwEeZ/nvec3WZDAs6yybvZ/X7OGdednXd3dnaYX+Z933knwnEcRwAAqGTVKvsDAQAggAAA1nAGBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAgiVasyYMXLppZey1S2bPXu2REREyI8//hiw9+S3hVsEEC5aamqqOZht3769wremhpd+Vulp/PjxZS6/ceNGGTBggERHR0uDBg2kR48e8vrrr/ssc+zYMZk8ebI0a9ZMIiMj5fLLL5dFixaV+X47duyQG2+8UeLj46V+/frSpUsXeeGFF+TMmTPeZY4cOSLPPvus9O3bVxo3biwNGzaUq6666pzPdat///7SqVMnCVVuf1tUfTVsrwDgVrdu3eTBBx/0mde+fftzlluyZIncc8898pvf/EaefvppqV69uuzevVtycnK8y2hwJCUlmfCcOHGitGvXTtavXy/33nuv/Pe//5VHH33UJ3yuvvpqs8zDDz8sdevWlX//+9/ywAMPyN69e+Wvf/2rWW7r1q3yf//3f3L99dfLY489JjVq1JA333xTRo0aJV9//bU8/vjj/OgX+dsiROhgpMDFWLJkiQ5o63zyyScXXHb06NFOy5Yt/f4sLXvDDTdccLns7GynTp06zv3333/e5VauXGnW/eWXX/aZP3LkSKd27dpOXl6ed97YsWOdWrVqOUeOHPFZtm/fvk5UVJT3+bfffuvs27fPZ5mzZ886AwYMcCIjI51jx445/ujXr5/TsWNHJxBmzZplvvfhw4edQKms3xahgyo4VJg1a9aYKqPatWubx9WrV5e53MGDB2XXrl1y+vTpX/zep06dkqKionJfX7x4sTm7eeKJJ7zVbGUN/P7++++bRz07KUmfnzhxQtauXeudV1hYaL6LVqmVlJCQIHXq1PE+b9WqlbRs2dJnGa1KGj58uJw8eVK+/fZbqSg7d+40bTGtW7c266pVhX/84x9NtWBZtA3ot7/9rURFRUmjRo3M2Zx+79KWLVtmqi/1e8bExJjtU/JMsjwV8dsidBBAqBDvvvuujBw50hx4U1JSzMH37rvvLrOdKDk52bS7fP/997/ovTdt2mSqv7QNRtsNPFVfpdt+OnToIO+8845p29H2Hz3AzpgxQ86ePetdTgNBq+Zq1arlU17f31PtVrINRkPoT3/6k3zzzTfy3XffmaB76623zHe4kEOHDpnH2NhYqSgbNmwwAafbesGCBSYoXnvtNVMdWFYAa/ho4OhvpMtoe9a4ceN8lnnqqafkD3/4g6l6nDdvnmkvS09PN21c+fn5512fivhtEUJsn4IhNKvgunXr5iQkJDj5+fneee+++65ZrnQ1jVbd6HytNruQYcOGOc8884yzZs0aU2127bXXmrLTp0/3WU6rxC655BJT5TVjxgznjTfecO68806z7COPPOJd7rnnnjPz3n//fZ/yuozOv/HGG73zfv75Z2fSpElOzZo1zWs6Va9e3Vm0aNEF11ur7Zo0aWLW11+/pAru+PHj58xbsWKFWdctW7acUwV30003+Sx77733mvmff/65ea5Vifodn3rqKZ/lvvjiC6dGjRo+88uqgquI3xahgwBCwAMoNzf3nAO9x69+9auLaicoTdtWkpKSzMEwJyfHO79atWpmHebMmeOz/JAhQ0zbUGFhoXl+8OBBJzo62mnXrp0JSD1QvvTSSybAtPzAgQN9yj///PMmlJYuXeq8/vrrzvDhw81nr169utx1PHPmjPlcbT/KzMz0+7u6bQMqLi42bTz6nfS7zJ8//5wAWr9+vU+Zb775xsxPSUkxz+fNm+dEREQ4e/bsMe9Vcrr88sudQYMGBawN6Jf+tggdVMEh4LRqSmmVTWmXXXZZQD9Lq/imTJkiP//8s2zevNk739Mmc8cdd/gsr8+Li4vls88+M8+1jeSf//ynqYobPHiwab+ZNm2aqb5SWhXkMWfOHHnmmWdkxYoVpkpKq6+0XatPnz6mB52uQ1nuu+8+SUtLk7///e/StWtXqUg//fSTaceJi4sz20C7get3UgUFBecsX/o3atOmjVSrVk327dtnnu/Zs8dU3ely+l4lJ62G/OGHHyrsu5T32yJ00A0bVV7z5s29B1+PxMREc/DUA3FJTZo0MY/axdpD2zK03eSLL74wjd8aErm5ued0AX7xxRfNNUUlQ0nddNNNMnXqVHPQbtu2rc9r2uVay2l4/f73v5eKpqH44YcfmhDVLs26rtrmNWTIEJ+2r/Md9EvSMjpPu5trW1lppbdFZfy2CB0EEALO0wNMA6A0vQ4n0Dy9yvSvcg/tsaWfr43f2iPMwxMsJZdVenDVA3bJTgxq0KBB3nl5eXk+F5x6eHp4lT4DWrhwoRlxQBvt9bqhiqahqp0DNPRmzpzpnV/W71DyNc8ZksrKyjKh4xmtQs+I9AxIl7FxPU5Zvy1CB1VwCDjtlqwH86VLl/pU+2gPLb0Q09+uuvpXcOkA0DJ6dqG92K677jrv/Ntvv908vvzyy955emDVi1O1G7EGVHkOHz5sqtp0lIOSAaQHYP0OJbs06/qsXLnS9LLTg7WHjnpw//33y1133WV6jlUGzxlK6d5u8+fPL7eMhmRJnqrHoUOHmsdbbrnFvK+GWun31eflde+uyN8WoYMzIFQI7dZ7ww03mPYRvQ5FDzB6cOvYsaO5Jqd0V10Nq+zs7POOE6dtNU8++aTceuut5i9yfc9XX31VvvzySzPSgbbneNx8880ycOBAsx56rYtWq+l1SR988IG89NJLZsgdj379+knv3r1N9Zl2lf7b3/5m1nHdunWmPcTjkUcekd/97nfSq1cv01VZ21i0PUi7aut61axZ0yz38ccfmzYi7fat67B8+XKf76GjKZQ8K9MqLl2HX9LOoeGon1Wabg8NO61OnDt3rjl4N23a1HSH1+1aHn1NqxC1ik5HcNDrfe68805vW5WGqn6e/kZaxajd6TVstZy2f+l2eOihh8p9/4r4bRFCbPeCQOiOhPDmm2+anlLaFVp7v7311lsX1VV3+/btpqtu06ZNTY+y+vXrO3369DGjGZTl6NGjzgMPPODEx8eb5Tt37uwsW7bsnOWmTJnitG7d2qxn48aNTXftvXv3lvmeaWlppjdabGys9z0XL15c5vYob9LXS66jzhs1apRzIfq55b2np7fegQMHnBEjRjgNGzY0vftuu+02b69E7flWuhfc119/7dx6661OgwYNTLd17WauvedK099St3W9evXM1KFDB2fixInO7t27vctU5m+L0BCh/7EdgkC40gtldXDTzz//XDp37mx7dYBKRRsQYNF7771nRisgfBCOOAMCAFjBGRAAwAoCCABgBQEEALCCAAIAWBF0F6Lq1eo6XIpe7FZ6XCoAQPDTq3uOHj1qxmQseTF30AeQho9nAEIAQNWld83VG0JWmSo4PfMBAFR9FzqeV1gA6SCHOvaT3pdex87S8bF+CardACA0XOh4XiEBpCMB6/1RZs2aJZ9++qkZ2DApKalCb14FAKhiKmKAuSuvvNIMVFjylsSJiYne2/yeT0FBwXkHcmRiG7APsA+wD0iV2AZ6PD+fgJ8BnTp1ygxPX/I+KtoLQp/rcO+l6a2QCwsLfSYAQOgLeADpvVf0xlKlb4Wsz/VeK6Xp/Vqio6O9Ez3gACA8WO8Fpzes0rtmeibttgcACH0Bvw4oNjbW3MI3Ly/PZ74+L+uuhnpnypJ3pwQAhIeAnwHp/dt79Ogh6enpPqMb6HO97TEAABU2EoJ2wR49erRcccUVcuWVV8r8+fOlqKhI7r77brY6AKDiAuj222+Xw4cPy8yZM03Hg27duklaWto5HRMAAOEr6O6Iqt2wtTccAKBq045lUVFRwdsLDgAQngggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAEEAAgPDBGRAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsqGHnY4Ffrl69eq4315w5c/zaxN26dXNdJjc313WZ2267zXWZiIgI12WeffZZ8UezZs1cl/n1r3/tuswbb7zhuszy5ctdl9m1a5frMqh4nAEBAKwggAAAoRFAs2fPNlUFJacOHToE+mMAAFVchbQBdezYUTZu3Pi/D6lBUxMAwFeFJIMGTnx8fEW8NQAgRFRIG9CePXskMTFRWrduLXfddZfs37+/3GVPnjwphYWFPhMAIPQFPIB69eolqampkpaWJosWLZLs7Gy59tpr5ejRo2Uun5KSItHR0d6pefPmgV4lAEA4BNDQoUPNNQ5dunSRpKQkeeeddyQ/P19WrlxZ5vLJyclSUFDgnXJycgK9SgCAIFThvQMaNmwo7du3l6ysrDJfj4yMNBMAILxU+HVAx44dk71790pCQkJFfxQAIJwD6KGHHpKMjAzZt2+ffPjhhzJixAipXr263HHHHYH+KABAFRbwKrgDBw6YsDly5Ig0btxY+vTpI9u2bTP/DwCAR4TjOI4EEe2Grb3hAI8pU6a43hh/+ctfQm4D+jMYaZD98w6Ikhe5/1LaIQqVTzuWRUVFlfs6Y8EBAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAIIAAAOGDMyAAgBUEEADACgIIAGAFAQQACM0b0gEXS++u69ayZcv8+qxhw4a5LlNcXOy6zD/+8Q/XZXJzc6WydO/e3XWZtm3b+jVYpVuhONBsuOIMCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFZEOI7jSBApLCyU6Oho26uBKq5WrVp+ldu6davrMk888YTrMmvXrnVdBqhqdLTzqKiocl/nDAgAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArKhh52OBitW2bVu/ynXr1s11meLiYr8+Cwh3nAEBAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUMRoqQNH/+fNurAOACOAMCAFhBAAEAqkYAbdmyRYYNGyaJiYkSEREha9as8XndcRyZOXOmJCQkSJ06dWTQoEGyZ8+eQK4zACAcA6ioqEi6du0qCxcuLPP1uXPnygsvvCCLFy+Wjz76SOrVqydJSUly4sSJQKwvACBcOyEMHTrUTGXRsx9t/H3sscfk5ptvNvNeeeUViYuLM2dKo0aNuvg1BgCEhIC2AWVnZ8uhQ4dMtZtHdHS09OrVS7Zu3VpmmZMnT0phYaHPBAAIfQENIA0fpWc8Jelzz2ulpaSkmJDyTM2bNw/kKgEAgpT1XnDJyclSUFDgnXJycmyvEgCgqgVQfHy8eczLy/OZr889r5UWGRkpUVFRPhMAIPQFNIBatWplgiY9Pd07T9t0tDdc7969A/lRAIBw6wV37NgxycrK8ul4kJmZKTExMdKiRQuZPHmyPPnkk9KuXTsTSDNmzDDXDA0fPjzQ6w4ACKcA2r59u1x33XXe51OnTjWPo0ePltTUVJk+fbq5VmjcuHGSn58vffr0kbS0NKldu3Zg1xwAUKVFOHrxThDRKjvtDQdcDK329Uf37t1dlyl52cEvlZGR4boMUNVox7Lztetb7wUHAAhPBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAVI3bMQChrLi42HWZ77//3nUZvX+WWwkJCa7LfPXVV67LAJWFMyAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsILBSBH0EhMTXZdp3769X59Vr14912UyMjJcl4mPj3dd5vDhw67LZGZmij+mT5/uuszOnTv9+iyEL86AAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKBiNF0Ktbt67rMlFRUVJZ/BlY9Omnn3ZdpkYN9/9cH3zwQfHH0qVLXZe5++67K22wVIQGzoAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoGI0XQGzVqlOsyERERfn3WqVOnXJeZNm2a6zILFiyQyrBmzRq/ym3cuNF1mQ0bNrgu06lTJ9dl8vLyXJdBcOIMCABgBQEEAKgaAbRlyxYZNmyYJCYmmmqO0qf4Y8aMMfNLTkOGDAnkOgMAwjGAioqKpGvXrrJw4cJyl9HAOXjwoHdasWLFxa4nACDcOyEMHTrUTOcTGRnp110iAQDho0LagDZv3ixNmjSRyy67TCZMmCBHjhwpd9mTJ09KYWGhzwQACH0BDyCtfnvllVckPT1dnnnmGcnIyDBnTGfOnClz+ZSUFImOjvZOzZs3D/QqAQDC4TqgktdsdO7cWbp06SJt2rQxZ0UDBw48Z/nk5GSZOnWq97meARFCABD6KrwbduvWrSU2NlaysrLKbS+KiorymQAAoa/CA+jAgQOmDSghIaGiPwoAEMpVcMeOHfM5m8nOzpbMzEyJiYkx0+OPPy4jR440veD27t0r06dPl7Zt20pSUlKg1x0AEE4BtH37drnuuuu8zz3tN6NHj5ZFixbJzp07ZenSpZKfn28uVh08eLD8+c9/NlVtAAD4HUD9+/cXx3HKfX39+vVu3xI4r/Ptb4Eso/bs2RO0A4v646OPPvKr3HPPPee6zIwZM1yXGTBggOsyXNgeOhgLDgBgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFZEOP4OG1xB9Jbc0dHRtlcDQaRfv36uy8yePduvz5o2bZr4c4uSUFOzZk3XZfbt2+e6TG5urusyPXv2dF0GdhQUFJz3LtecAQEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQxGCiAgFixY4LpMUlKS6zJdu3Z1Xaa4uNh1GVw8BiMFAAQlquAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVNex8LMJVXFyc6zKNGzd2XWbv3r3iDwat9F/79u1dl2nTpo3rMq1bt3Zd5quvvnJdBhWPMyAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsILBSOG3xMRE12W2bNniuszq1atdl5k5c6brMvifoUOHut4cvXr1cl3m1KlTrsucPHnSdRkEJ86AAABWEEAAgOAPoJSUFOnZs6c0aNBAmjRpIsOHD5fdu3f7LHPixAmZOHGiNGrUSOrXry8jR46UvLy8QK83ACCcAigjI8OEy7Zt22TDhg1y+vRpGTx4sBQVFXmXmTJlirz99tuyatUqs3xubq7ccsstFbHuAIBw6YSQlpbm8zw1NdWcCe3YsUP69u0rBQUF8vLLL8urr74qAwYMMMssWbJELr/8chNaV111VWDXHgAQnm1AGjgqJibGPGoQ6VnRoEGDvMt06NBBWrRoIVu3bi23R0thYaHPBAAIfX4H0NmzZ2Xy5MlyzTXXSKdOncy8Q4cOSa1ataRhw4Y+y8bFxZnXymtXio6O9k7Nmzf3d5UAAOEQQNoW9OWXX8prr712USuQnJxszqQ8U05OzkW9HwAghC9EnTRpkqxbt85cVNisWTPv/Pj4eHNhWX5+vs9ZkPaC09fKEhkZaSYAQHhxdQbkOI4JH70yfdOmTdKqVSuf13v06CE1a9aU9PR07zztpr1//37p3bt34NYaABBeZ0Ba7aY93NauXWuuBfK062jbTZ06dczjPffcI1OnTjUdE6KiouS+++4z4UMPOACA3wG0aNEi89i/f3+f+drVesyYMeb/n3/+ealWrZq5AFV7uCUlJcmLL77o5mMAAGGghtsquAupXbu2LFy40EwIbXoNmFv+9HL86aefXJcpLi52XSYUef4wdEt7p7qltSJuzZs3z3WZrKws12UQnBgLDgBgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFZEOL9kiOtKVFhYaO4rhND01VdfuS7TokUL12X0vlX++OSTT1yXOXLkiOsy5d0h+Hyuvvpq12UGDhwo/oiLi3Nd5sCBA67L6O1a3Nq1a5frMrCjoKDA3BeuPJwBAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVDEaKSnXFFVe4LvOvf/3LdZnY2FgJNREREa7L/Pzzz3591vfff++6zPjx412XWb9+vesyqDoYjBQAEJSoggMAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYwGCmCXvfu3SttkMuYmBjXZc6cOeO6zKpVqyplMNL58+eLPzIzM12XOXXqlF+fhdDFYKQAgKBEFRwAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCwUgBABWCwUgBAEGJKjgAQPAHUEpKivTs2VMaNGggTZo0keHDh8vu3bt9lunfv7+5b0nJafz48YFebwBAOAVQRkaGTJw4UbZt2yYbNmyQ06dPy+DBg6WoqMhnubFjx8rBgwe909y5cwO93gCAKq6Gm4XT0tJ8nqemppozoR07dkjfvn298+vWrSvx8fGBW0sAQMipdrE9HMq6jfHy5cslNjZWOnXqJMnJyXL8+PFy3+PkyZNSWFjoMwEAwoDjpzNnzjg33HCDc8011/jMf+mll5y0tDRn586dzrJly5ymTZs6I0aMKPd9Zs2a5ehqMLEN2AfYB9gHJKS2QUFBwXlzxO8AGj9+vNOyZUsnJyfnvMulp6ebFcnKyirz9RMnTpiV9Ez6frY3GhPbgH2AfYB9QCo8gFy1AXlMmjRJ1q1bJ1u2bJFmzZqdd9levXqZx6ysLGnTps05r0dGRpoJABBeXAWQnjHdd999snr1atm8ebO0atXqgmUyMzPNY0JCgv9rCQAI7wDSLtivvvqqrF271lwLdOjQITM/Ojpa6tSpI3v37jWvX3/99dKoUSPZuXOnTJkyxfSQ69KlS0V9BwBAVeSm3ae8er4lS5aY1/fv3+/07dvXiYmJcSIjI522bds606ZNu2A9YEm6LHWv1L+zD7APsA9U/X3gQsd+BiMFAFQIBiMFAAQlBiMFAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwIugCyHEc26sAAKiE43nQBdDRo0dtrwIAoBKO5xFOkJ1ynD17VnJzc6VBgwYSERHh81phYaE0b95ccnJyJCoqSsIV24HtwP7Av4tgPj5orGj4JCYmSrVq5Z/n1JAgoyvbrFmz8y6jGzWcA8iD7cB2YH/g30WwHh+io6MvuEzQVcEBAMIDAQQAsKJKBVBkZKTMmjXLPIYztgPbgf2BfxehcHwIuk4IAIDwUKXOgAAAoYMAAgBYQQABAKwggAAAVhBAAAArqkwALVy4UC699FKpXbu29OrVSz7++GPbq1TpZs+ebYYnKjl16NBBQt2WLVtk2LBhZlgP/c5r1qzxeV07cs6cOVMSEhKkTp06MmjQINmzZ4+E23YYM2bMOfvHkCFDJJSkpKRIz549zVBdTZo0keHDh8vu3bt9ljlx4oRMnDhRGjVqJPXr15eRI0dKXl6ehNt26N+//zn7w/jx4yWYVIkAev3112Xq1Kmmb/unn34qXbt2laSkJPnhhx8k3HTs2FEOHjzonT744AMJdUVFReY31z9CyjJ37lx54YUXZPHixfLRRx9JvXr1zP6hB6Jw2g5KA6fk/rFixQoJJRkZGSZctm3bJhs2bJDTp0/L4MGDzbbxmDJlirz99tuyatUqs7yOLXnLLbdIuG0HNXbsWJ/9Qf+tBBWnCrjyyiudiRMnep+fOXPGSUxMdFJSUpxwMmvWLKdr165OONNddvXq1d7nZ8+edeLj451nn33WOy8/P9+JjIx0VqxY4YTLdlCjR492br75Ziec/PDDD2ZbZGRkeH/7mjVrOqtWrfIu880335hltm7d6oTLdlD9+vVzHnjgASeYBf0Z0KlTp2THjh2mWqXkgKX6fOvWrRJutGpJq2Bat24td911l+zfv1/CWXZ2thw6dMhn/9BBELWaNhz3j82bN5sqmcsuu0wmTJggR44ckVBWUFBgHmNiYsyjHiv0bKDk/qDV1C1atAjp/aGg1HbwWL58ucTGxkqnTp0kOTlZjh8/LsEk6EbDLu3HH3+UM2fOSFxcnM98fb5r1y4JJ3pQTU1NNQcXPZ1+/PHH5dprr5Uvv/zS1AWHIw0fVdb+4XktXGj1m1Y1tWrVSvbu3SuPPvqoDB061Bx4q1evLqFGb90yefJkueaaa8wBVulvXqtWLWnYsGHY7A9ny9gO6s4775SWLVuaP1h37twpDz/8sGkneuuttyRYBH0A4X/0YOLRpUsXE0i6g61cuVLuueceNlWYGzVqlPf/O3fubPaRNm3amLOigQMHSqjRNhD94ysc2kH92Q7jxo3z2R+0k47uB/rHie4XwSDoq+D09FH/eivdi0Wfx8fHSzjTv/Lat28vWVlZEq48+wD7x7m0mlb//YTi/jFp0iRZt26dvPfeez73D9P9Qavt8/Pzw+J4Mamc7VAW/YNVBdP+EPQBpKfTPXr0kPT0dJ9TTn3eu3dvCWfHjh0zf83oXzbhSqub9MBScv/QO0Jqb7hw3z8OHDhg2oBCaf/Q/hd60F29erVs2rTJ/P4l6bGiZs2aPvuDVjtpW2ko7Q/OBbZDWTIzM81jUO0PThXw2muvmV5Nqampztdff+2MGzfOadiwoXPo0CEnnDz44IPO5s2bnezsbOc///mPM2jQICc2Ntb0gAllR48edT777DMz6S47b9488//fffedeX3OnDlmf1i7dq2zc+dO0xOsVatWTnFxsRMu20Ffe+ihh0xPL90/Nm7c6HTv3t1p166dc+LECSdUTJgwwYmOjjb/Dg4ePOidjh8/7l1m/PjxTosWLZxNmzY527dvd3r37m2mUDLhAtshKyvLeeKJJ8z31/1B/220bt3a6du3rxNMqkQAqQULFpidqlatWqZb9rZt25xwc/vttzsJCQlmGzRt2tQ81x0t1L333nvmgFt60m7Hnq7YM2bMcOLi4swfKgMHDnR2797thNN20APP4MGDncaNG5tuyC1btnTGjh0bcn+klfX9dVqyZIl3Gf3D495773UuueQSp27dus6IESPMwTmctsP+/ftN2MTExJh/E23btnWmTZvmFBQUOMGE+wEBAKwI+jYgAEBoIoAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQAIIABA+OAMCAAgNvw/zJG9Fliu2m4AAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"id = X.sample(1).index[0]\n",
"plt.imshow(X.iloc[id].values.reshape(28, 28), cmap='gray')\n",
"plt.title(f'Id: {id}, Label: {y[id]}')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "58ba3512",
"metadata": {},
"outputs": [],
"source": [
"y = y.astype(\"int16\")"
]
},
{
"cell_type": "markdown",
"id": "7d25ace7",
"metadata": {},
"source": [
"## 1.3. Train test split"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "18929619",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(69000, 784)\n",
"(1000, 784)\n"
]
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X.values, y.values, test_size=1000, random_state=42, shuffle=True)\n",
"print(X_train.shape)\n",
"print(X_test.shape)"
]
},
{
"cell_type": "markdown",
"id": "180bf32b",
"metadata": {},
"source": [
"## 1.4. Scaling"
]
},
{
"cell_type": "markdown",
"id": "af330d6f",
"metadata": {},
"source": [
"A escala dos valores de entrada para a rede é bastante importante para que não haja nenhum problema no treinamento. \n",
"Aqui poderíamos aplicar o standard scaler, mas como sabemos o intervalo de valores que os pixels podem assumir, \n",
"vamos fazer um simples min max."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "040c3575",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "99be654f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(69000, 784)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train.shape"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "ea8b393d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(784,)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Standard Scaler\n",
"# Calculando média e desvio padrão no conjunto de treino\n",
"means = X_train.mean(axis=0)\n",
"stds = X_train.std(axis=0)\n",
"stds.shape"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "7967eb25",
"metadata": {},
"outputs": [],
"source": [
"# Min max scaler\n",
"subtract = [0.0] * X_train.shape[1]\n",
"divide = [255.0] * X_train.shape[1]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "e144fa81",
"metadata": {},
"outputs": [],
"source": [
"def apply_scaling(X_train, subtract, divide):\n",
" cols_scaled = []\n",
" for i in range(X_train.shape[1]):\n",
" col = X_train[:, i]\n",
" if divide[i] == 0:\n",
" # print(f\"Col {i}: Std == 0\")\n",
" col_scaled = col\n",
" else:\n",
" # print(f\"Col {i}: Std != 0\")\n",
" col_scaled = (col - subtract[i]) / divide[i]\n",
" col_scaled = col_scaled.reshape(-1, 1)\n",
" # print(col_scaled.shape)\n",
" cols_scaled.append(col_scaled)\n",
" X_train_scaled = np.hstack(cols_scaled)\n",
" print(X_train_scaled.shape)\n",
" return X_train_scaled"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "3e2a5ec1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(69000, 784)\n",
"(1000, 784)\n"
]
}
],
"source": [
"# Min max scaler\n",
"X_train_scaled = apply_scaling(X_train, [0.0] * X_train.shape[1], [255.0] * X_train.shape[1])\n",
"X_test_scaled = apply_scaling(X_test, [0.0] * X_test.shape[1], [255.0] * X_test.shape[1])"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "2077e5e7",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>69000.0000</td>\n",
" <td>69000.0000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>12.9836</td>\n",
" <td>11.5725</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>49.7809</td>\n",
" <td>47.1865</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>0.0000</td>\n",
" <td>0.0000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>0.0000</td>\n",
" <td>0.0000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>0.0000</td>\n",
" <td>0.0000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>0.0000</td>\n",
" <td>0.0000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>255.0000</td>\n",
" <td>255.0000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 0 1\n",
"count 69000.0000 69000.0000\n",
"mean 12.9836 11.5725\n",
"std 49.7809 47.1865\n",
"min 0.0000 0.0000\n",
"25% 0.0000 0.0000\n",
"50% 0.0000 0.0000\n",
"75% 0.0000 0.0000\n",
"max 255.0000 255.0000"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(X_train[:, 100:102]).describe().round(4)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "9823c865",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>69000.0000</td>\n",
" <td>69000.0000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>0.0509</td>\n",
" <td>0.0454</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>0.1952</td>\n",
" <td>0.1850</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>0.0000</td>\n",
" <td>0.0000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>0.0000</td>\n",
" <td>0.0000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>0.0000</td>\n",
" <td>0.0000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>0.0000</td>\n",
" <td>0.0000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>1.0000</td>\n",
" <td>1.0000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 0 1\n",
"count 69000.0000 69000.0000\n",
"mean 0.0509 0.0454\n",
"std 0.1952 0.1850\n",
"min 0.0000 0.0000\n",
"25% 0.0000 0.0000\n",
"50% 0.0000 0.0000\n",
"75% 0.0000 0.0000\n",
"max 1.0000 1.0000"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(X_train_scaled[:, 100:102]).describe().round(4)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "15aded6d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(784, 69000)\n",
"(784, 1000)\n"
]
}
],
"source": [
"# Aqui vamos transpor os dados para facilitar as multiplicações de matrizes\n",
"X_train_scaled, X_test_scaled, y_train, y_test = X_train_scaled.T, X_test_scaled.T, y_train.T, y_test.T\n",
"print(X_train_scaled.shape)\n",
"print(X_test_scaled.shape)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "096d2430",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(784,)"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train_scaled[:, 0].shape"
]
},
{
"cell_type": "markdown",
"id": "0b41f758",
"metadata": {},
"source": [
"# 2. Funções de ativação"
]
},
{
"cell_type": "markdown",
"id": "fb84d8e4",
"metadata": {},
"source": [
"A escolha das funções de ativação das hidden layers são arbitrárias, contanto que as derivadas sejam calculáveis. \n",
"Mas a camada de saída deve passar por uma função que corresponda à tarefa. \n",
"No caso da Classificação Multiclass, vamos usar a softmax pra transformar os valores de saída em uma distribuição \n",
"de probabilidade cuja soma é 1."
]
},
{
"cell_type": "markdown",
"id": "97f7a916",
"metadata": {},
"source": [
"![alt text](image-1.png)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "c1bcba9b",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "8032439b",
"metadata": {},
"outputs": [],
"source": [
"# Funções de ativação\n",
"\n",
"def sigmoid(z):\n",
" return 1 / (1 + np.exp(-z))\n",
"\n",
"def step(z):\n",
" return np.where(z >= 0, 1, 0)\n",
"\n",
"def relu(z):\n",
" return np.maximum(0, z)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "51a78688",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.1, 0.2, 0.3]])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"array([[0.30060961, 0.33222499, 0.3671654 ]])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"np.float64(1.0000000000000002)"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Softmax simples - erro de arredondamento\n",
"def softmax(z):\n",
" return np.exp(z) / np.sum(np.exp(z))\n",
"\n",
"x = np.array([[0.1, 0.2, 0.3]])\n",
"display(x)\n",
"display(softmax(x))\n",
"softmax(x).sum()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "a63b144d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.1, 0.2, 0.3]])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"array([[0.30060961, 0.33222499, 0.3671654 ]])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"np.float64(1.0)"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Softmax com fator de estabilidade numérica\n",
"def softmax(x, axis=0):\n",
" # Subtract the maximum value for numerical stability\n",
" # keepdims=True ensures the shape is maintained for correct broadcasting\n",
" x_max = np.max(x, axis=axis, keepdims=True)\n",
" e_x = np.exp(x - x_max)\n",
" \n",
" # Divide by the sum to obtain probabilities that sum to 1\n",
" return e_x / e_x.sum(axis=axis, keepdims=True)\n",
"\n",
"x = np.array([[0.1, 0.2, 0.3]])\n",
"display(x)\n",
"output = softmax(x, axis=1)\n",
"display(output)\n",
"output.sum()"
]
},
{
"cell_type": "markdown",
"id": "5797a04a",
"metadata": {},
"source": [
"# 3. Inicialização dos parâmetros treináveis"
]
},
{
"cell_type": "markdown",
"id": "9c609550",
"metadata": {},
"source": [
"Aqui os parâmetros de pesos e biases são criados. \n",
"A forma com que esses parâmetros são inicializados é muito importante para o treinamento. \n",
"A inicialização com valores aleatórios é bem simples, mas existem formas melhores."
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "376cfdc5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"W1.shape=(12, 784)\n",
"b1.shape=(12, 1)\n",
"W2.shape=(10, 12)\n",
"b2.shape=(10, 1)\n"
]
}
],
"source": [
"# https://numpy.org/doc/stable/reference/random/generated/numpy.random.randn.html\n",
"\n",
"def init_params_standard_normal(input_len=784, n_hidden = 12, n_output = 10, sigma=1, mu=0):\n",
" W1 = np.random.randn(n_hidden, input_len) * sigma + mu\n",
" b1 = np.random.randn(n_hidden, 1) * sigma + mu\n",
" W2 = np.random.randn(n_output, n_hidden) * sigma + mu\n",
" b2 = np.random.randn(n_output, 1) * sigma + mu\n",
" print(f\"{W1.shape=}\\n{b1.shape=}\\n{W2.shape=}\\n{b2.shape=}\")\n",
" return W1, b1, W2, b2\n",
"\n",
"W1, b1, W2, b2 = init_params_standard_normal()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "34f15d4c",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" <th>2</th>\n",
" <th>3</th>\n",
" <th>4</th>\n",
" <th>5</th>\n",
" <th>6</th>\n",
" <th>7</th>\n",
" <th>8</th>\n",
" <th>9</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>12.0000</td>\n",
" <td>12.0000</td>\n",
" <td>12.0000</td>\n",
" <td>12.0000</td>\n",
" <td>12.0000</td>\n",
" <td>12.0000</td>\n",
" <td>12.0000</td>\n",
" <td>12.0000</td>\n",
" <td>12.0000</td>\n",
" <td>12.0000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>0.2374</td>\n",
" <td>-0.0184</td>\n",
" <td>-0.2124</td>\n",
" <td>0.0696</td>\n",
" <td>-0.6559</td>\n",
" <td>0.1702</td>\n",
" <td>0.0769</td>\n",
" <td>-0.1195</td>\n",
" <td>0.2435</td>\n",
" <td>-0.2746</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>1.0505</td>\n",
" <td>0.9126</td>\n",
" <td>0.9761</td>\n",
" <td>1.2073</td>\n",
" <td>0.7606</td>\n",
" <td>1.0797</td>\n",
" <td>0.7354</td>\n",
" <td>0.5563</td>\n",
" <td>0.7472</td>\n",
" <td>0.8609</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>-1.3007</td>\n",
" <td>-1.7274</td>\n",
" <td>-2.2391</td>\n",
" <td>-3.0794</td>\n",
" <td>-2.0968</td>\n",
" <td>-2.0855</td>\n",
" <td>-1.3020</td>\n",
" <td>-1.2476</td>\n",
" <td>-0.8955</td>\n",
" <td>-1.9726</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>-0.3717</td>\n",
" <td>-0.4912</td>\n",
" <td>-0.6852</td>\n",
" <td>-0.2138</td>\n",
" <td>-0.8213</td>\n",
" <td>-0.3437</td>\n",
" <td>-0.2522</td>\n",
" <td>-0.3849</td>\n",
" <td>-0.3035</td>\n",
" <td>-0.8565</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>0.0006</td>\n",
" <td>-0.0520</td>\n",
" <td>-0.1960</td>\n",
" <td>0.3400</td>\n",
" <td>-0.5496</td>\n",
" <td>0.1192</td>\n",
" <td>0.2696</td>\n",
" <td>-0.1851</td>\n",
" <td>0.0482</td>\n",
" <td>-0.1642</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>0.6263</td>\n",
" <td>0.6239</td>\n",
" <td>0.6728</td>\n",
" <td>0.8600</td>\n",
" <td>-0.0104</td>\n",
" <td>0.5433</td>\n",
" <td>0.6566</td>\n",
" <td>0.1980</td>\n",
" <td>0.8000</td>\n",
" <td>0.3913</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>2.4439</td>\n",
" <td>1.1715</td>\n",
" <td>1.0374</td>\n",
" <td>1.2599</td>\n",
" <td>0.2005</td>\n",
" <td>2.2942</td>\n",
" <td>0.9208</td>\n",
" <td>0.9672</td>\n",
" <td>1.7092</td>\n",
" <td>0.8682</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 0 1 2 3 4 5 6 7 \\\n",
"count 12.0000 12.0000 12.0000 12.0000 12.0000 12.0000 12.0000 12.0000 \n",
"mean 0.2374 -0.0184 -0.2124 0.0696 -0.6559 0.1702 0.0769 -0.1195 \n",
"std 1.0505 0.9126 0.9761 1.2073 0.7606 1.0797 0.7354 0.5563 \n",
"min -1.3007 -1.7274 -2.2391 -3.0794 -2.0968 -2.0855 -1.3020 -1.2476 \n",
"25% -0.3717 -0.4912 -0.6852 -0.2138 -0.8213 -0.3437 -0.2522 -0.3849 \n",
"50% 0.0006 -0.0520 -0.1960 0.3400 -0.5496 0.1192 0.2696 -0.1851 \n",
"75% 0.6263 0.6239 0.6728 0.8600 -0.0104 0.5433 0.6566 0.1980 \n",
"max 2.4439 1.1715 1.0374 1.2599 0.2005 2.2942 0.9208 0.9672 \n",
"\n",
" 8 9 \n",
"count 12.0000 12.0000 \n",
"mean 0.2435 -0.2746 \n",
"std 0.7472 0.8609 \n",
"min -0.8955 -1.9726 \n",
"25% -0.3035 -0.8565 \n",
"50% 0.0482 -0.1642 \n",
"75% 0.8000 0.3913 \n",
"max 1.7092 0.8682 "
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(W1[:, 0:10]).describe().round(4)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "e376e808",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"W1.shape=(12, 784)\n",
"b1.shape=(12, 1)\n",
"W2.shape=(10, 12)\n",
"b2.shape=(10, 1)\n"
]
}
],
"source": [
"def init_params_uniform(input_len=784, n_hidden = 12, n_output = 10, low=-0.5, high=0.5):\n",
" W1 = np.random.uniform(low, high, (n_hidden, input_len))\n",
" b1 = np.random.uniform(low, high, (n_hidden, 1))\n",
" W2 = np.random.uniform(low, high, (n_output, n_hidden))\n",
" b2 = np.random.uniform(low, high, (n_output, 1))\n",
" print(f\"{W1.shape=}\\n{b1.shape=}\\n{W2.shape=}\\n{b2.shape=}\")\n",
" return W1, b1, W2, b2\n",
"\n",
"W1, b1, W2, b2 = init_params_uniform()"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "4420e86b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"W1.shape=(12, 784)\n",
"b1.shape=(12, 1)\n",
"W2.shape=(10, 12)\n",
"b2.shape=(10, 1)\n"
]
}
],
"source": [
"def init_params_glorot(input_len=784, n_hidden=12, n_output=10):\n",
" # Glorot uniform initialization\n",
" limit_w1 = np.sqrt(6.0 / (input_len + n_hidden))\n",
" W1 = np.random.uniform(-limit_w1, limit_w1, (n_hidden, input_len))\n",
" b1 = np.zeros((n_hidden, 1))\n",
" \n",
" limit_w2 = np.sqrt(6.0 / (n_hidden + n_output))\n",
" W2 = np.random.uniform(-limit_w2, limit_w2, (n_output, n_hidden))\n",
" b2 = np.zeros((n_output, 1))\n",
" \n",
" print(f\"{W1.shape=}\\n{b1.shape=}\\n{W2.shape=}\\n{b2.shape=}\")\n",
" return W1, b1, W2, b2\n",
"\n",
"W1, b1, W2, b2 = init_params_glorot()"
]
},
{
"cell_type": "markdown",
"id": "914e78f1",
"metadata": {},
"source": [
"# 4. Forward pass"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "1a23d3eb",
"metadata": {},
"outputs": [],
"source": [
"ACTIVATION_FUNCTION = relu\n",
"OUTPUT_FUNCTION = softmax"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "062666d1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Z1.shape=(12, 2)\n",
"A1.shape=(12, 2)\n",
"Z2.shape=(10, 2)\n",
"A2.shape=(10, 2)\n"
]
}
],
"source": [
"def forward(X, W1, b1, W2, b2, verbose=False):\n",
" Z1 = W1 @ X + b1\n",
" A1 = ACTIVATION_FUNCTION(Z1)\n",
" Z2 = W2 @ A1 + b2\n",
" A2 = OUTPUT_FUNCTION(Z2, axis=0)\n",
" if verbose:\n",
" print(f\"{Z1.shape=}\\n{A1.shape=}\\n{Z2.shape=}\\n{A2.shape=}\")\n",
" return Z1, A1, Z2, A2\n",
"Z1, A1, Z2, A2 = forward(X_train_scaled[:,0:2], W1, b1, W2, b2, True)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "8fb9a23d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.07478467 0.12267795]\n",
" [0.11597149 0.10304144]\n",
" [0.1204467 0.09488816]\n",
" [0.08310441 0.06405648]\n",
" [0.14265181 0.13316973]\n",
" [0.07147575 0.10998027]\n",
" [0.10330583 0.096492 ]\n",
" [0.11945251 0.09841706]\n",
" [0.08345286 0.09008889]\n",
" [0.08535398 0.08718803]]\n"
]
}
],
"source": [
"# A2 é a saída da rede\n",
"print(A2)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "c1ca11cc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([4, 4])"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# O argmax equivale à classe predita\n",
"A2.argmax(axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "1128c475",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1., 1.])"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Pelo softmax, soma das probas de cada saída é 1\n",
"A2.sum(axis=0)"
]
},
{
"cell_type": "markdown",
"id": "3b011ec0",
"metadata": {},
"source": [
"# 5. One hot"
]
},
{
"cell_type": "markdown",
"id": "004504d4",
"metadata": {},
"source": [
"Transforma o vetor de labels (números entre 0 e 9) em uma matriz onde cada coluna equivale à uma label."
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "00e7abe4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0, 1, 2, ..., 68997, 68998, 68999], shape=(69000,))"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.arange(y_train.shape[0])"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "66d28776",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"y_train.shape=(69000,)\n",
"y_train_one_hot.shape=(10, 69000)\n"
]
}
],
"source": [
"def one_hot(y, y_max=None):\n",
" if y_max is None:\n",
" y_max = y.max()\n",
" one_hot = np.zeros((y.shape[0], y_max + 1))\n",
" indexes = np.arange(y.shape[0])\n",
" one_hot[indexes, y] = 1\n",
" return one_hot.T\n",
"\n",
"y_train_one_hot = one_hot(y_train)\n",
"print(f\"{y_train.shape=}\\n{y_train_one_hot.shape=}\")"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "16f692f6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[5 3 9]\n"
]
},
{
"data": {
"text/plain": [
"array([[0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 0.],\n",
" [1., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 1.]])"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(y_train[5:8])\n",
"y_train_one_hot[:, 5:8]"
]
},
{
"cell_type": "markdown",
"id": "fa715e8c",
"metadata": {},
"source": [
"# 6. Loss Function"
]
},
{
"cell_type": "markdown",
"id": "cd79d70d",
"metadata": {},
"source": [
"A função de custo da classificação multiclasse costuma ser a Cross Entropy, as vezes chamada de \n",
"Categorical Cross Entropy."
]
},
{
"cell_type": "markdown",
"id": "ff7beba3",
"metadata": {},
"source": [
"![alt text](image.png)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "3082220d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Z1.shape=(12, 5)\n",
"A1.shape=(12, 5)\n",
"Z2.shape=(10, 5)\n",
"A2.shape=(10, 5)\n"
]
},
{
"data": {
"text/plain": [
"array([[0.075, 0.123, 0.106, 0.084, 0.121],\n",
" [0.116, 0.103, 0.09 , 0.101, 0.105],\n",
" [0.12 , 0.095, 0.11 , 0.125, 0.097],\n",
" [0.083, 0.064, 0.088, 0.104, 0.071],\n",
" [0.143, 0.133, 0.139, 0.127, 0.16 ],\n",
" [0.071, 0.11 , 0.081, 0.066, 0.063],\n",
" [0.103, 0.096, 0.088, 0.094, 0.077],\n",
" [0.119, 0.098, 0.106, 0.142, 0.136],\n",
" [0.083, 0.09 , 0.086, 0.06 , 0.093],\n",
" [0.085, 0.087, 0.107, 0.097, 0.077]])"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Z1, A1, Z2, A2 = forward(X_train_scaled[:,0:5], W1, b1, W2, b2, True)\n",
"A2.round(3)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "fc761701",
"metadata": {},
"outputs": [],
"source": [
"# Cross-entropy loss\n",
"# Epsilon é um valor pequeno para evitar log(0)\n",
"def cross_entropy_loss(A2, Y, y_max=None, epsilon=1e-15):\n",
" m = Y.shape[0]\n",
" one_hot_Y = one_hot(Y, y_max)\n",
" logprobs = np.log(A2 + epsilon) * one_hot_Y\n",
" losses = -np.sum(logprobs)\n",
" loss = losses / m # O valor final é a média\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "7c1c8765",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2.314"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cross_entropy_loss(A2, y_train[0:5], y_max=9).round(4).item()"
]
},
{
"cell_type": "markdown",
"id": "56e5ed10",
"metadata": {},
"source": [
"# 7. Backward pass"
]
},
{
"cell_type": "markdown",
"id": "0fb18890",
"metadata": {},
"source": [
"Aqui vamos escrever as respectivas derivadas"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "4a811c61",
"metadata": {},
"outputs": [],
"source": [
"def relu_derivative(z):\n",
" return np.where(z > 0, 1, 0)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "aed088dc",
"metadata": {},
"outputs": [],
"source": [
"# Z1 = W1 @ X + b1\n",
"# A1 = relu(Z1)\n",
"# Z2 = W2 @ A1 + b2\n",
"# A2 = softmax(Z2)\n",
"# loss = cross_entropy_loss(A2, Y)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "86e6f662",
"metadata": {},
"outputs": [],
"source": [
"def backward(Z1, A1, Z2, A2, W1, W2, X, Y, verbose=False):\n",
" m = Y.shape[0]\n",
" one_hot_Y = one_hot(Y)\n",
"\n",
" # 1. Derivada da função de perda em relação aos pesos da camada de saída\n",
" dZ2 = A2 - one_hot_Y # Derivada da cross-entropy combinada com softmax simplifica para A2 - Y_one_hot\n",
" dW2 = dZ2 @ A1.T / m\n",
" db2 = np.sum(dZ2, axis=1, keepdims=True) / m\n",
"\n",
" # 2. Derivada da função de perda em relação aos pesos da camada oculta\n",
" dZ1 = (W2.T @ dZ2) * relu_derivative(Z1)\n",
" dW1 = dZ1 @ X.T / m\n",
" db1 = np.sum(dZ1, axis=1, keepdims=True) / m\n",
"\n",
" if verbose:\n",
" print(f\"{dW1.shape=}\\n{db1.shape=}\\n{dW2.shape=}\\n{db2.shape=}\")\n",
" return dW1, db1, dW2, db2\n"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "0aa97361",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dW1.shape=(12, 784)\n",
"db1.shape=(12, 1)\n",
"dW2.shape=(10, 12)\n",
"db2.shape=(10, 1)\n"
]
}
],
"source": [
"_ = backward(Z1, A1, Z2, A2, W1, W2, X_train_scaled[:,0:5], y_train[0:5], True)"
]
},
{
"cell_type": "markdown",
"id": "cb6b46be",
"metadata": {},
"source": [
"# 8. Atualização dos parâmetros treináveis"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "701a2f02",
"metadata": {},
"outputs": [],
"source": [
"def update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha):\n",
" W1 = W1 - dW1 * alpha\n",
" b1 = b1 - db1 * alpha\n",
" W2 = W2 - dW2 * alpha\n",
" b2 = b2 - db2 * alpha\n",
" return W1, b1, W2, b2"
]
},
{
"cell_type": "markdown",
"id": "6cf7991d",
"metadata": {},
"source": [
"# 9. Gradient Descent\n",
"(aqui é o algoritmo original mais básico mesmo)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "0ae7c1bc",
"metadata": {},
"outputs": [],
"source": [
"def get_predictions(A2):\n",
" return np.argmax(A2, 0)\n",
"\n",
"def get_accuracy(predictions, Y):\n",
" # print(predictions, Y)\n",
" return np.sum(predictions == Y) / Y.size"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "4f199e35",
"metadata": {},
"outputs": [],
"source": [
"def make_predictions(X, W1, b1, W2, b2):\n",
" _, _, _, A2 = forward(X, W1, b1, W2, b2)\n",
" predictions = get_predictions(A2)\n",
" return predictions"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "98da370d",
"metadata": {},
"outputs": [],
"source": [
"def gradient_descent(X_train, y_train, X_test, y_test, n_iter, alpha, params):\n",
" W1, b1, W2, b2 = params\n",
" for i in range(n_iter):\n",
" Z1, A1, Z2, A2 = forward(X_train, W1, b1, W2, b2)\n",
" dW1, db1, dW2, db2 = backward(Z1, A1, Z2, A2, W1, W2, X_train, y_train)\n",
" W1, b1, W2, b2 = update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha)\n",
" if i % 10 == 0:\n",
" # Calculando métricas de treino\n",
" train_acc = get_accuracy(get_predictions(A2), y_train).round(3).item()\n",
" train_loss = cross_entropy_loss(A2, y_train).round(3).item()\n",
"\n",
" # Calculando métricas de teste\n",
" A2_test = make_predictions(X_test, W1, b1, W2, b2)\n",
" test_acc = get_accuracy(A2_test, y_test).round(3).item()\n",
" test_loss = cross_entropy_loss(A2_test, y_test).round(3).item()\n",
"\n",
" print(f\"It {i:>4} - TRAIN loss: {str(train_loss):<5} acc: {str(train_acc):<5} - TEST loss: {str(test_loss):<5} acc: {str(test_acc):<5}\")\n",
" return W1, b1, W2, b2"
]
},
{
"cell_type": "markdown",
"id": "df165e55",
"metadata": {},
"source": [
"# 10. Treinamento"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "c5d3727a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"W1.shape=(12, 784)\n",
"b1.shape=(12, 1)\n",
"W2.shape=(10, 12)\n",
"b2.shape=(10, 1)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"It 0 - TRAIN loss: 2.327 acc: 0.122 - TEST loss: 2.376 acc: 0.136\n",
"It 10 - TRAIN loss: 2.212 acc: 0.251 - TEST loss: 5.614 acc: 0.271\n",
"It 20 - TRAIN loss: 2.081 acc: 0.307 - TEST loss: 8.208 acc: 0.305\n",
"It 30 - TRAIN loss: 1.921 acc: 0.345 - TEST loss: 7.35 acc: 0.344\n",
"It 40 - TRAIN loss: 1.744 acc: 0.431 - TEST loss: 5.181 acc: 0.442\n",
"It 50 - TRAIN loss: 1.557 acc: 0.525 - TEST loss: 4.0 acc: 0.535\n",
"It 60 - TRAIN loss: 1.371 acc: 0.6 - TEST loss: 3.319 acc: 0.602\n",
"It 70 - TRAIN loss: 1.207 acc: 0.655 - TEST loss: 2.876 acc: 0.651\n",
"It 80 - TRAIN loss: 1.075 acc: 0.697 - TEST loss: 2.571 acc: 0.689\n",
"It 90 - TRAIN loss: 0.971 acc: 0.727 - TEST loss: 2.454 acc: 0.73 \n"
]
}
],
"source": [
"params = init_params_standard_normal(sigma=0.1, mu=0)\n",
"# params = init_params_uniform(low=-0.5, high=0.5)\n",
"# params = init_params_glorot()\n",
"W1, b1, W2, b2 = gradient_descent(X_train_scaled, y_train, X_test_scaled, y_test, 100, 0.1, params)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "b498f13d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.75"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_accuracy(make_predictions(X_test_scaled, W1, b1, W2, b2), y_test).item()"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "7e57815a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prediction: [8]\n",
"Label: 3\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAGcNJREFUeJzt3XuMFdXhB/CzqCyo7NIV2YeAgqg0IjS1iESlWAmIjRU1jViTYmM0UDQV6iPbqGhbs61tWrWl2j8aqa1vUyASS6KrQGpBI5RS+6AuQYHCYjVll0dBuswvM7/fblkB/d3LLufuvZ9PMrnce+fsHM6ene89M+fOlCVJkgQAOMp6He0NAoAAAiAaIyAAohBAAEQhgACIQgABEIUAAiAKAQRAFMeGArN///6wZcuW0K9fv1BWVha7OgDkKL2+wY4dO0JdXV3o1atXzwmgNHwGDx4cuxoAHKFNmzaFQYMG9ZxDcOnIB4Ce75P2590WQPPmzQunnXZa6NOnTxg7dmx44403/l/lHHYDKA6ftD/vlgB65plnwpw5c8LcuXPD6tWrw+jRo8PkyZPDe++91x2bA6AnSrrBeeedl8yaNavjeVtbW1JXV5c0NDR8YtmWlpb06twWbaAP6AP6QOjZbZDuzz9Ol4+APvzww7Bq1aowceLEjtfSWRDp8xUrVhy0/t69e0Nra2unBYDi1+UB9P7774e2trZQXV3d6fX0eXNz80HrNzQ0hMrKyo7FDDiA0hB9Flx9fX1oaWnpWNJpewAUvy7/HtCAAQPCMcccE7Zt29bp9fR5TU3NQeuXl5dnCwClpctHQL179w7nnntuaGxs7HR1g/T5uHHjunpzAPRQ3XIlhHQK9vTp08PnPve5cN5554UHH3ww7Nq1K3zta1/rjs0B0AN1SwBdc8014Z///Ge45557sokHn/nMZ8KSJUsOmpgAQOkqS+dihwKSTsNOZ8MB0LOlE8sqKioKdxYcAKVJAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBIAAAqB0GAEBEIUAAiCKY+NsFjgaxo8fn1e5nTt35lxm7NixOZeZOXNmzmXOPvvsnMtcccUVIR+LFy/Oqxz/P0ZAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiCKsiRJklBAWltbQ2VlZexqQMH58pe/nHOZefPm5bWtffv25Vymuro6FKrt27fnVW7GjBk5l3n++efz2lYxamlpCRUVFYd93wgIgCgEEADFEUD33ntvKCsr67SMGDGiqzcDQA/XLTekS28Y9fLLL/93I8e67x0AnXVLMqSBU1NT0x0/GoAi0S3ngN5+++1QV1cXhg0bFq677rqwcePGw667d+/ebObbgQsAxa/LAyi9L/z8+fPDkiVLwiOPPBI2bNgQLrroorBjx45Drt/Q0JBNu25fBg8e3NVVAqAUAmjKlCnZ9xVGjRoVJk+eHF588cVsDv6zzz57yPXr6+uzueLty6ZNm7q6SgAUoG6fHdC/f/9w5plnhqampkO+X15eni0AlJZu/x7Qzp07w/r160NtbW13bwqAUg6g2267LSxbtiy888474fe//3248sorwzHHHBOuvfbart4UAD1Ylx+C27x5cxY2H3zwQTj55JPDhRdeGFauXJn9GwDauRgp9BBr1qzJuczIkSO7pS6lIp+LmF5zzTU5l2lsbAzFyMVIAShILkYKQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAxXlDOoCeKr2hZq6GDx+ec5nGIr0Y6ScxAgIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKJwNWzoIWbNmpVzmSeffDKvbZ1yyil5lSs2ra2tOZfZuHFjt9SlGBkBARCFAAJAAAFQOoyAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIjCxUihh3jttddyLrN8+fK8tnXttdfmVa7YvPvuuzmX+e1vf9stdSlGRkAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoXI6UoTZ06Na9yCxcuDEfD+eefn3OZL33pSzmXGT16dM5l+K9FixZpjm5kBARAFAIIgJ4RQOn9RS6//PJQV1cXysrKDjpkkSRJuOeee0JtbW3o27dvmDhxYnj77be7ss4AlGIA7dq1KzuuPG/evEO+/8ADD4SHH344PProo+H1118PJ5xwQpg8eXLYs2dPV9QXgFKdhDBlypRsOZR09PPggw+Gu+66K1xxxRXZa48//niorq7ORkrTpk078hoDUBS69BzQhg0bQnNzc3bYrV1lZWUYO3ZsWLFixSHL7N27N7S2tnZaACh+XRpAafik0hHPgdLn7e99VENDQxZS7cvgwYO7skoAFKjos+Dq6+tDS0tLx7Jp06bYVQKgpwVQTU1N9rht27ZOr6fP29/7qPLy8lBRUdFpAaD4dWkADR06NAuaxsbGjtfSczrpbLhx48Z15aYAKLVZcDt37gxNTU2dJh6sWbMmVFVVhSFDhoRbb701fPe73w1nnHFGFkh333139p2hfC+NAkBxyjmA3nzzzXDxxRd3PJ8zZ072OH369DB//vxwxx13ZN8Vuummm8L27dvDhRdeGJYsWRL69OnTtTUHoEcrS9Iv7xSQ9JBdOhuO4jRhwoScy8yYMeOoXLgzdbS+MJ2e+zwaZfhfd955Z15N8dBDD+Vc5j//+Y9m/z/pxLKPO68ffRYcAKVJAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAnnE7BjgS7bfvyMVll1121Bq9d+/eR2U7ZWVlOZcpsAvXd4nVq1fnXGb27Nk5l0lvipkPV7buXkZAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKFyMFopk7d27OZV577bVuqQtHnxEQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIjCxUjJ28iRI3Muc8455+RcpqysLBSbXr1y/+y3f//+bqkLxGIEBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACicDFS8vbWW2/lXOaZZ57Jucy0adNyLtO7d++Qj7///e/haMjnAqtJkuRc5qyzzgr5OPnkk3Mu86c//SnnMn/+859zLkPxMAICIAoBBEDPCKDly5eHyy+/PNTV1WWHERYuXNjp/euvvz57/cDl0ksv7co6A1CKAbRr164wevToMG/evMOukwbO1q1bO5annnrqSOsJQKlPQpgyZUq2fJzy8vJQU1NzJPUCoMh1yzmgpUuXhoEDB2YzcGbOnBk++OCDw667d+/e0Nra2mkBoPh1eQClh98ef/zx0NjYGL7//e+HZcuWZSOmtra2Q67f0NAQKisrO5bBgwd3dZUAKIXvAR34nY1zzjknjBo1Kpx++unZqOiSSy45aP36+vowZ86cjufpCEgIARS/bp+GPWzYsDBgwIDQ1NR02PNFFRUVnRYAil+3B9DmzZuzc0C1tbXdvSkAivkQ3M6dOzuNZjZs2BDWrFkTqqqqsuW+++4LV199dTYLbv369eGOO+4Iw4cPD5MnT+7qugNQSgH05ptvhosvvrjjefv5m+nTp4dHHnkkrF27Nvzyl78M27dvz76sOmnSpPCd73wnO9QGAO3KknyucNiN0kkI6Ww4aJd+8TlXffv2zasBV65cWbANf+yxuc8ZeuKJJ/LaVnoU42jIp37ph116hpaWlo89r+9acABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQHHckhu62h//+EeNGkI4//zzC/aq1vl6+umnY1eBiIyAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAULkZ6lLzzzjs5l7n//vtzLvPiiy/mXOYf//hHKGQjRozIuUyfPn3y2taePXtyLjN9+vScy5SVleVc5qtf/WooZA899FDOZZYuXdotdaFnMAICIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFGUJUmShALS2toaKisrQ7Fpa2vLuUw+v5rVq1fnXGbdunUhHz/84Q9zLvP+++/nXGbhwoU5l6murg75WL9+fc5lLrrooqNyMdJ8+sP+/ftDPjZv3pxzmQsuuCDnMlu3bs25DD1HS0tLqKioOOz7RkAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIIpj42yW7nL22WfnXGbHjh15bWvVqlU5l3n++edzLrNnz56cy9TV1eVc5kjKFapf/epXeZW74YYburwu8FFGQABEIYAAKPwAamhoCGPGjAn9+vULAwcODFOnTj3oXjLp4ZJZs2aFk046KZx44onh6quvDtu2bevqegNQSgG0bNmyLFxWrlwZXnrppbBv374wadKksGvXro51Zs+eHV544YXw3HPPZetv2bIlXHXVVd1RdwBKZRLCkiVLOj2fP39+NhJKT0aPHz8+u/vdL37xi/Dkk0+GL3zhC9k6jz32WPj0pz+dhdb555/ftbUHoDTPAaWBk6qqqsoe0yBKR0UTJ07sWGfEiBFhyJAhYcWKFYf8GXv37s1uw33gAkDxyzuA0nvN33rrrdl94EeOHJm91tzcHHr37h369+/fad3q6ursvcOdV6qsrOxYBg8enG+VACiFAErPBb311lvh6aefPqIK1NfXZyOp9mXTpk1H9PMAKOIvot58881h8eLFYfny5WHQoEEdr9fU1IQPP/wwbN++vdMoKJ0Fl753KOXl5dkCQGnJaQSUJEkWPgsWLAivvPJKGDp0aKf3zz333HDccceFxsbGjtfSadobN24M48aN67paA1BaI6D0sFs6w23RokXZd4Haz+uk52769u2bPaaX8JgzZ042MaGioiLccsstWfiYAQdA3gH0yCOPZI8TJkzo9Ho61fr666/P/v3jH/849OrVK/sCajrDbfLkyeFnP/tZLpsBoASUJelxtQKSTsNOR1LFpq2tLecyBfaroQs1NTXlXCafCT/3339/yEf6dQo4UunEsvRI2OG4FhwAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAAdBz7ohK7rZs2ZJzmdraWk19lK1fvz7nMrt37865zPPPP3/UrmwNhcoICIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABE4WKkR8mll16ac5mf/vSnOZfp06dPzmXGjBkT8nHXXXflXOZf//pXOBqmTZuWV7lrr7025zLNzc15bQtKnREQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIiiLEmSJBSQ1tbWUFlZGbsaAByhlpaWUFFRcdj3jYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAAo/gBoaGsKYMWNCv379wsCBA8PUqVPDunXrOq0zYcKEUFZW1mmZMWNGV9cbgFIKoGXLloVZs2aFlStXhpdeeins27cvTJo0KezatavTejfeeGPYunVrx/LAAw90db0B6OGOzWXlJUuWdHo+f/78bCS0atWqMH78+I7Xjz/++FBTU9N1tQSg6PQ60tutpqqqqjq9/sQTT4QBAwaEkSNHhvr6+rB79+7D/oy9e/dmt+E+cAGgBCR5amtrS774xS8mF1xwQafXf/7znydLlixJ1q5dm/z6179OTjnllOTKK6887M+ZO3duklbDog30AX1AHwhF1QYtLS0fmyN5B9CMGTOSU089Ndm0adPHrtfY2JhVpKmp6ZDv79mzJ6tk+5L+vNiNZtEG+oA+oA+Ebg+gnM4Btbv55pvD4sWLw/Lly8OgQYM+dt2xY8dmj01NTeH0008/6P3y8vJsAaC05BRA6YjplltuCQsWLAhLly4NQ4cO/cQya9asyR5ra2vzryUApR1A6RTsJ598MixatCj7LlBzc3P2emVlZejbt29Yv3599v5ll10WTjrppLB27dowe/bsbIbcqFGjuuv/AEBPlMt5n8Md53vsscey9zdu3JiMHz8+qaqqSsrLy5Phw4cnt99++yceBzxQuq5jr46/6wP6gD4QenwbfNK+v+z/gqVgpNOw0xEVAD1b+lWdioqKw77vWnAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARFFwAZQkSewqAHAU9ucFF0A7duyIXQUAjsL+vCwpsCHH/v37w5YtW0K/fv1CWVlZp/daW1vD4MGDw6ZNm0JFRUUoVdpBO+gP/i4Kef+QxkoaPnV1daFXr8OPc44NBSat7KBBgz52nbRRSzmA2mkH7aA/+Lso1P1DZWXlJ65TcIfgACgNAgiAKHpUAJWXl4e5c+dmj6VMO2gH/cHfRTHsHwpuEgIApaFHjYAAKB4CCIAoBBAAUQggAKLoMQE0b968cNppp4U+ffqEsWPHhjfeeCOUmnvvvTe7OsSBy4gRI0KxW758ebj88suzb1Wn/+eFCxd2ej+dR3PPPfeE2tra0Ldv3zBx4sTw9ttvh1Jrh+uvv/6g/nHppZeGYtLQ0BDGjBmTXSll4MCBYerUqWHdunWd1tmzZ0+YNWtWOOmkk8KJJ54Yrr766rBt27ZQau0wYcKEg/rDjBkzQiHpEQH0zDPPhDlz5mRTC1evXh1Gjx4dJk+eHN57771Qas4+++ywdevWjuV3v/tdKHa7du3Kfufph5BDeeCBB8LDDz8cHn300fD666+HE044Iesf6Y6olNohlQbOgf3jqaeeCsVk2bJlWbisXLkyvPTSS2Hfvn1h0qRJWdu0mz17dnjhhRfCc889l62fXtrrqquuCqXWDqkbb7yxU39I/1YKStIDnHfeecmsWbM6nre1tSV1dXVJQ0NDUkrmzp2bjB49OillaZddsGBBx/P9+/cnNTU1yQ9+8IOO17Zv356Ul5cnTz31VFIq7ZCaPn16csUVVySl5L333svaYtmyZR2/++OOOy557rnnOtb561//mq2zYsWKpFTaIfX5z38++cY3vpEUsoIfAX344Ydh1apV2WGVA68Xlz5fsWJFKDXpoaX0EMywYcPCddddFzZu3BhK2YYNG0Jzc3On/pFegyo9TFuK/WPp0qXZIZmzzjorzJw5M3zwwQehmLW0tGSPVVVV2WO6r0hHAwf2h/Qw9ZAhQ4q6P7R8pB3aPfHEE2HAgAFh5MiRob6+PuzevTsUkoK7GOlHvf/++6GtrS1UV1d3ej19/re//S2UknSnOn/+/Gznkg6n77vvvnDRRReFt956KzsWXIrS8Ekdqn+0v1cq0sNv6aGmoUOHhvXr14dvfetbYcqUKdmO95hjjgnFJr1y/q233houuOCCbAebSn/nvXv3Dv379y+Z/rD/EO2Q+spXvhJOPfXU7APr2rVrw5133pmdJ/rNb34TCkXBBxD/le5M2o0aNSoLpLSDPfvss+GGG27QVCVu2rRpHf8+55xzsj5y+umnZ6OiSy65JBSb9BxI+uGrFM6D5tMON910U6f+kE7SSftB+uEk7ReFoOAPwaXDx/TT20dnsaTPa2pqQilLP+WdeeaZoampKZSq9j6gfxwsPUyb/v0UY/+4+eabw+LFi8Orr77a6fYtaX9ID9tv3769JPYXNx+mHQ4l/cCaKqT+UPABlA6nzz333NDY2NhpyJk+HzduXChlO3fuzD7NpJ9sSlV6uCndsRzYP9IbcqWz4Uq9f2zevDk7B1RM/SOdf5HudBcsWBBeeeWV7Pd/oHRfcdxxx3XqD+lhp/RcaTH1h+QT2uFQ1qxZkz0WVH9IeoCnn346m9U0f/785C9/+Uty0003Jf3790+am5uTUvLNb34zWbp0abJhw4bktddeSyZOnJgMGDAgmwFTzHbs2JH84Q9/yJa0y/7oRz/K/v3uu+9m73/ve9/L+sOiRYuStWvXZjPBhg4dmvz73/9OSqUd0vduu+22bKZX2j9efvnl5LOf/WxyxhlnJHv27EmKxcyZM5PKysrs72Dr1q0dy+7duzvWmTFjRjJkyJDklVdeSd58881k3Lhx2VJMZn5COzQ1NSXf/va3s/9/2h/Sv41hw4Yl48ePTwpJjwig1E9+8pOsU/Xu3Tublr1y5cqk1FxzzTVJbW1t1gannHJK9jztaMXu1VdfzXa4H13SacftU7HvvvvupLq6OvugcskllyTr1q1LSqkd0h3PpEmTkpNPPjmbhnzqqacmN954Y9F9SDvU/z9dHnvssY510g8eX//615NPfepTyfHHH59ceeWV2c65lNph48aNWdhUVVVlfxPDhw9Pbr/99qSlpSUpJG7HAEAUBX8OCIDiJIAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgxPA/j5nHnP7r1T8AAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def test_prediction(index, W1, b1, W2, b2, X_train):\n",
" current_image = X_train[:, index, None]\n",
" prediction = make_predictions(X_train[:, index, None], W1, b1, W2, b2)\n",
" label = y_train[index]\n",
" print(\"Prediction: \", prediction)\n",
" print(\"Label: \", label)\n",
" \n",
" current_image = current_image.reshape((28, 28)) * 255\n",
" plt.gray()\n",
" plt.imshow(current_image, interpolation='nearest')\n",
" plt.show()\n",
"\n",
"i = np.random.randint(0, X_train_scaled.shape[1])\n",
"test_prediction(i, W1, b1, W2, b2, X_train_scaled)"
]
},
{
"cell_type": "markdown",
"id": "55d784fb",
"metadata": {},
"source": [
"# Extra: sklearn MLPClassifier"
]
},
{
"cell_type": "markdown",
"id": "8a6e516a",
"metadata": {},
"source": [
"https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPClassifier.html"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "aa6cfd27",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration 1, loss = 2.35732861\n",
"Iteration 2, loss = 2.30002841\n",
"Iteration 3, loss = 2.26468002\n",
"Iteration 4, loss = 2.23664851\n",
"Iteration 5, loss = 2.21013676\n",
"Iteration 6, loss = 2.18399348\n",
"Iteration 7, loss = 2.15824992\n",
"Iteration 8, loss = 2.13285909\n",
"Iteration 9, loss = 2.10772348\n",
"Iteration 10, loss = 2.08276116\n",
"Iteration 11, loss = 2.05785163\n",
"Iteration 12, loss = 2.03291732\n",
"Iteration 13, loss = 2.00782073\n",
"Iteration 14, loss = 1.98242809\n",
"Iteration 15, loss = 1.95666368\n",
"Iteration 16, loss = 1.93047865\n",
"Iteration 17, loss = 1.90384422\n",
"Iteration 18, loss = 1.87670576\n",
"Iteration 19, loss = 1.84903195\n",
"Iteration 20, loss = 1.82080313\n",
"Iteration 21, loss = 1.79196680\n",
"Iteration 22, loss = 1.76255879\n",
"Iteration 23, loss = 1.73259370\n",
"Iteration 24, loss = 1.70214221\n",
"Iteration 25, loss = 1.67141160\n",
"Iteration 26, loss = 1.64054857\n",
"Iteration 27, loss = 1.60967903\n",
"Iteration 28, loss = 1.57901073\n",
"Iteration 29, loss = 1.54863661\n",
"Iteration 30, loss = 1.51861898\n",
"Iteration 31, loss = 1.48902494\n",
"Iteration 32, loss = 1.45991991\n",
"Iteration 33, loss = 1.43139716\n",
"Iteration 34, loss = 1.40352680\n",
"Iteration 35, loss = 1.37630772\n",
"Iteration 36, loss = 1.34976963\n",
"Iteration 37, loss = 1.32394366\n",
"Iteration 38, loss = 1.29879471\n",
"Iteration 39, loss = 1.27433017\n",
"Iteration 40, loss = 1.25054599\n",
"Iteration 41, loss = 1.22741673\n",
"Iteration 42, loss = 1.20491746\n",
"Iteration 43, loss = 1.18304326\n",
"Iteration 44, loss = 1.16177506\n",
"Iteration 45, loss = 1.14109908\n",
"Iteration 46, loss = 1.12100476\n",
"Iteration 47, loss = 1.10147927\n",
"Iteration 48, loss = 1.08251611\n",
"Iteration 49, loss = 1.06409965\n",
"Iteration 50, loss = 1.04621780\n",
"Iteration 51, loss = 1.02886320\n",
"Iteration 52, loss = 1.01202833\n",
"Iteration 53, loss = 0.99570439\n",
"Iteration 54, loss = 0.97988206\n",
"Iteration 55, loss = 0.96455341\n",
"Iteration 56, loss = 0.94970336\n",
"Iteration 57, loss = 0.93532258\n",
"Iteration 58, loss = 0.92140143\n",
"Iteration 59, loss = 0.90793103\n",
"Iteration 60, loss = 0.89489579\n",
"Iteration 61, loss = 0.88228555\n",
"Iteration 62, loss = 0.87008118\n",
"Iteration 63, loss = 0.85827249\n",
"Iteration 64, loss = 0.84684464\n",
"Iteration 65, loss = 0.83578889\n",
"Iteration 66, loss = 0.82508825\n",
"Iteration 67, loss = 0.81473054\n",
"Iteration 68, loss = 0.80470465\n",
"Iteration 69, loss = 0.79499843\n",
"Iteration 70, loss = 0.78559988\n",
"Iteration 71, loss = 0.77649974\n",
"Iteration 72, loss = 0.76768648\n",
"Iteration 73, loss = 0.75914704\n",
"Iteration 74, loss = 0.75087207\n",
"Iteration 75, loss = 0.74285110\n",
"Iteration 76, loss = 0.73507294\n",
"Iteration 77, loss = 0.72752922\n",
"Iteration 78, loss = 0.72021077\n",
"Iteration 79, loss = 0.71310959\n",
"Iteration 80, loss = 0.70621668\n",
"Iteration 81, loss = 0.69952432\n",
"Iteration 82, loss = 0.69302505\n",
"Iteration 83, loss = 0.68671011\n",
"Iteration 84, loss = 0.68057239\n",
"Iteration 85, loss = 0.67460546\n",
"Iteration 86, loss = 0.66880347\n",
"Iteration 87, loss = 0.66316041\n",
"Iteration 88, loss = 0.65766937\n",
"Iteration 89, loss = 0.65232467\n",
"Iteration 90, loss = 0.64712191\n",
"Iteration 91, loss = 0.64205608\n",
"Iteration 92, loss = 0.63712215\n",
"Iteration 93, loss = 0.63231484\n",
"Iteration 94, loss = 0.62762900\n",
"Iteration 95, loss = 0.62306156\n",
"Iteration 96, loss = 0.61860839\n",
"Iteration 97, loss = 0.61426516\n",
"Iteration 98, loss = 0.61002817\n",
"Iteration 99, loss = 0.60589379\n",
"Iteration 100, loss = 0.60185840\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\Bruno\\miniconda3\\envs\\cat\\Lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:785: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (100) reached and the optimization hasn't converged yet.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/html": [
"<style>#sk-container-id-1 {\n",
" /* Definition of color scheme common for light and dark mode */\n",
" --sklearn-color-text: #000;\n",
" --sklearn-color-text-muted: #666;\n",
" --sklearn-color-line: gray;\n",
" /* Definition of color scheme for unfitted estimators */\n",
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
" --sklearn-color-unfitted-level-3: chocolate;\n",
" /* Definition of color scheme for fitted estimators */\n",
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
" --sklearn-color-fitted-level-1: #d4ebff;\n",
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
"}\n",
"\n",
"#sk-container-id-1.light {\n",
" /* Specific color for light theme */\n",
" --sklearn-color-text-on-default-background: black;\n",
" --sklearn-color-background: white;\n",
" --sklearn-color-border-box: black;\n",
" --sklearn-color-icon: #696969;\n",
"}\n",
"\n",
"#sk-container-id-1.dark {\n",
" --sklearn-color-text-on-default-background: white;\n",
" --sklearn-color-background: #111;\n",
" --sklearn-color-border-box: white;\n",
" --sklearn-color-icon: #878787;\n",
"}\n",
"\n",
"#sk-container-id-1 {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"#sk-container-id-1 pre {\n",
" padding: 0;\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-hidden--visually {\n",
" border: 0;\n",
" clip: rect(1px 1px 1px 1px);\n",
" clip: rect(1px, 1px, 1px, 1px);\n",
" height: 1px;\n",
" margin: -1px;\n",
" overflow: hidden;\n",
" padding: 0;\n",
" position: absolute;\n",
" width: 1px;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-dashed-wrapped {\n",
" border: 1px dashed var(--sklearn-color-line);\n",
" margin: 0 0.4em 0.5em 0.4em;\n",
" box-sizing: border-box;\n",
" padding-bottom: 0.4em;\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-container {\n",
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
" so we also need the `!important` here to be able to override the\n",
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
" display: inline-block !important;\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-text-repr-fallback {\n",
" display: none;\n",
"}\n",
"\n",
"div.sk-parallel-item,\n",
"div.sk-serial,\n",
"div.sk-item {\n",
" /* draw centered vertical line to link estimators */\n",
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
" background-size: 2px 100%;\n",
" background-repeat: no-repeat;\n",
" background-position: center center;\n",
"}\n",
"\n",
"/* Parallel-specific style estimator block */\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item::after {\n",
" content: \"\";\n",
" width: 100%;\n",
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
" flex-grow: 1;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel {\n",
" display: flex;\n",
" align-items: stretch;\n",
" justify-content: center;\n",
" background-color: var(--sklearn-color-background);\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item {\n",
" display: flex;\n",
" flex-direction: column;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
" align-self: flex-end;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
" align-self: flex-start;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
" width: 0;\n",
"}\n",
"\n",
"/* Serial-specific style estimator block */\n",
"\n",
"#sk-container-id-1 div.sk-serial {\n",
" display: flex;\n",
" flex-direction: column;\n",
" align-items: center;\n",
" background-color: var(--sklearn-color-background);\n",
" padding-right: 1em;\n",
" padding-left: 1em;\n",
"}\n",
"\n",
"\n",
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
"clickable and can be expanded/collapsed.\n",
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
"*/\n",
"\n",
"/* Pipeline and ColumnTransformer style (default) */\n",
"\n",
"#sk-container-id-1 div.sk-toggleable {\n",
" /* Default theme specific background. It is overwritten whether we have a\n",
" specific estimator or a Pipeline/ColumnTransformer */\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"/* Toggleable label */\n",
"#sk-container-id-1 label.sk-toggleable__label {\n",
" cursor: pointer;\n",
" display: flex;\n",
" width: 100%;\n",
" margin-bottom: 0;\n",
" padding: 0.5em;\n",
" box-sizing: border-box;\n",
" text-align: center;\n",
" align-items: center;\n",
" justify-content: center;\n",
" gap: 0.5em;\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label .caption {\n",
" font-size: 0.6rem;\n",
" font-weight: lighter;\n",
" color: var(--sklearn-color-text-muted);\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
" /* Arrow on the left of the label */\n",
" content: \"▸\";\n",
" float: left;\n",
" margin-right: 0.25em;\n",
" color: var(--sklearn-color-icon);\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"/* Toggleable content - dropdown */\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content {\n",
" display: none;\n",
" text-align: left;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content pre {\n",
" margin: 0.2em;\n",
" border-radius: 0.25em;\n",
" color: var(--sklearn-color-text);\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
" /* Expand drop-down */\n",
" display: block;\n",
" width: 100%;\n",
" overflow: visible;\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
" content: \"▾\";\n",
"}\n",
"\n",
"/* Pipeline/ColumnTransformer-specific style */\n",
"\n",
"#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator-specific style */\n",
"\n",
"/* Colorize estimator box */\n",
"#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
"#sk-container-id-1 div.sk-label label {\n",
" /* The background is the default theme color */\n",
" color: var(--sklearn-color-text-on-default-background);\n",
"}\n",
"\n",
"/* On hover, darken the color of the background */\n",
"#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"/* Label box, darken color on hover, fitted */\n",
"#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator label */\n",
"\n",
"#sk-container-id-1 div.sk-label label {\n",
" font-family: monospace;\n",
" font-weight: bold;\n",
" line-height: 1.2em;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label-container {\n",
" text-align: center;\n",
"}\n",
"\n",
"/* Estimator-specific */\n",
"#sk-container-id-1 div.sk-estimator {\n",
" font-family: monospace;\n",
" border: 1px dotted var(--sklearn-color-border-box);\n",
" border-radius: 0.25em;\n",
" box-sizing: border-box;\n",
" margin-bottom: 0.5em;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"/* on hover */\n",
"#sk-container-id-1 div.sk-estimator:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
"\n",
"/* Common style for \"i\" and \"?\" */\n",
"\n",
".sk-estimator-doc-link,\n",
"a:link.sk-estimator-doc-link,\n",
"a:visited.sk-estimator-doc-link {\n",
" float: right;\n",
" font-size: smaller;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
" border-radius: 1em;\n",
" height: 1em;\n",
" width: 1em;\n",
" text-decoration: none !important;\n",
" margin-left: 0.5em;\n",
" text-align: center;\n",
" /* unfitted */\n",
" border: var(--sklearn-color-unfitted-level-3) 1pt solid;\n",
" color: var(--sklearn-color-unfitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted,\n",
"a:link.sk-estimator-doc-link.fitted,\n",
"a:visited.sk-estimator-doc-link.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
" border: var(--sklearn-color-fitted-level-3) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"\n",
"/* On hover */\n",
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" border: var(--sklearn-color-fitted-level-0) 1pt solid;\n",
" color: var(--sklearn-color-unfitted-level-0);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
" border: var(--sklearn-color-fitted-level-0) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-0);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"/* Span, style for the box shown on hovering the info icon */\n",
".sk-estimator-doc-link span {\n",
" display: none;\n",
" z-index: 9999;\n",
" position: relative;\n",
" font-weight: normal;\n",
" right: .2ex;\n",
" padding: .5ex;\n",
" margin: .5ex;\n",
" width: min-content;\n",
" min-width: 20ex;\n",
" max-width: 50ex;\n",
" color: var(--sklearn-color-text);\n",
" box-shadow: 2pt 2pt 4pt #999;\n",
" /* unfitted */\n",
" background: var(--sklearn-color-unfitted-level-0);\n",
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted span {\n",
" /* fitted */\n",
" background: var(--sklearn-color-fitted-level-0);\n",
" border: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link:hover span {\n",
" display: block;\n",
"}\n",
"\n",
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link {\n",
" float: right;\n",
" font-size: 1rem;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
" border-radius: 1rem;\n",
" height: 1rem;\n",
" width: 1rem;\n",
" text-decoration: none;\n",
" /* unfitted */\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
"}\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"#sk-container-id-1 a.estimator_doc_link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"\n",
".estimator-table {\n",
" font-family: monospace;\n",
"}\n",
"\n",
".estimator-table summary {\n",
" padding: .5rem;\n",
" cursor: pointer;\n",
"}\n",
"\n",
".estimator-table summary::marker {\n",
" font-size: 0.7rem;\n",
"}\n",
"\n",
".estimator-table details[open] {\n",
" padding-left: 0.1rem;\n",
" padding-right: 0.1rem;\n",
" padding-bottom: 0.3rem;\n",
"}\n",
"\n",
".estimator-table .parameters-table {\n",
" margin-left: auto !important;\n",
" margin-right: auto !important;\n",
" margin-top: 0;\n",
"}\n",
"\n",
".estimator-table .parameters-table tr:nth-child(odd) {\n",
" background-color: #fff;\n",
"}\n",
"\n",
".estimator-table .parameters-table tr:nth-child(even) {\n",
" background-color: #f6f6f6;\n",
"}\n",
"\n",
".estimator-table .parameters-table tr:hover {\n",
" background-color: #e0e0e0;\n",
"}\n",
"\n",
".estimator-table table td {\n",
" border: 1px solid rgba(106, 105, 104, 0.232);\n",
"}\n",
"\n",
"/*\n",
" `table td`is set in notebook with right text-align.\n",
" We need to overwrite it.\n",
"*/\n",
".estimator-table table td.param {\n",
" text-align: left;\n",
" position: relative;\n",
" padding: 0;\n",
"}\n",
"\n",
".user-set td {\n",
" color:rgb(255, 94, 0);\n",
" text-align: left !important;\n",
"}\n",
"\n",
".user-set td.value {\n",
" color:rgb(255, 94, 0);\n",
" background-color: transparent;\n",
"}\n",
"\n",
".default td {\n",
" color: black;\n",
" text-align: left !important;\n",
"}\n",
"\n",
".user-set td i,\n",
".default td i {\n",
" color: black;\n",
"}\n",
"\n",
"/*\n",
" Styles for parameter documentation links\n",
" We need styling for visited so jupyter doesn't overwrite it\n",
"*/\n",
"a.param-doc-link,\n",
"a.param-doc-link:link,\n",
"a.param-doc-link:visited {\n",
" text-decoration: underline dashed;\n",
" text-underline-offset: .3em;\n",
" color: inherit;\n",
" display: block;\n",
" padding: .5em;\n",
"}\n",
"\n",
"/* \"hack\" to make the entire area of the cell containing the link clickable */\n",
"a.param-doc-link::before {\n",
" position: absolute;\n",
" content: \"\";\n",
" inset: 0;\n",
"}\n",
"\n",
".param-doc-description {\n",
" display: none;\n",
" position: absolute;\n",
" z-index: 9999;\n",
" left: 0;\n",
" padding: .5ex;\n",
" margin-left: 1.5em;\n",
" color: var(--sklearn-color-text);\n",
" box-shadow: .3em .3em .4em #999;\n",
" width: max-content;\n",
" text-align: left;\n",
" max-height: 10em;\n",
" overflow-y: auto;\n",
"\n",
" /* unfitted */\n",
" background: var(--sklearn-color-unfitted-level-0);\n",
" border: thin solid var(--sklearn-color-unfitted-level-3);\n",
"}\n",
"\n",
"/* Fitted state for parameter tooltips */\n",
".fitted .param-doc-description {\n",
" /* fitted */\n",
" background: var(--sklearn-color-fitted-level-0);\n",
" border: thin solid var(--sklearn-color-fitted-level-3);\n",
"}\n",
"\n",
".param-doc-link:hover .param-doc-description {\n",
" display: block;\n",
"}\n",
"\n",
".copy-paste-icon {\n",
" background-image: url(data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgMCA0NDggNTEyIj48IS0tIUZvbnQgQXdlc29tZSBGcmVlIDYuNy4yIGJ5IEBmb250YXdlc29tZSAtIGh0dHBzOi8vZm9udGF3ZXNvbWUuY29tIExpY2Vuc2UgLSBodHRwczovL2ZvbnRhd2Vzb21lLmNvbS9saWNlbnNlL2ZyZWUgQ29weXJpZ2h0IDIwMjUgRm9udGljb25zLCBJbmMuLS0+PHBhdGggZD0iTTIwOCAwTDMzMi4xIDBjMTIuNyAwIDI0LjkgNS4xIDMzLjkgMTQuMWw2Ny45IDY3LjljOSA5IDE0LjEgMjEuMiAxNC4xIDMzLjlMNDQ4IDMzNmMwIDI2LjUtMjEuNSA0OC00OCA0OGwtMTkyIDBjLTI2LjUgMC00OC0yMS41LTQ4LTQ4bDAtMjg4YzAtMjYuNSAyMS41LTQ4IDQ4LTQ4ek00OCAxMjhsODAgMCAwIDY0LTY0IDAgMCAyNTYgMTkyIDAgMC0zMiA2NCAwIDAgNDhjMCAyNi41LTIxLjUgNDgtNDggNDhMNDggNTEyYy0yNi41IDAtNDgtMjEuNS00OC00OEwwIDE3NmMwLTI2LjUgMjEuNS00OCA0OC00OHoiLz48L3N2Zz4=);\n",
" background-repeat: no-repeat;\n",
" background-size: 14px 14px;\n",
" background-position: 0;\n",
" display: inline-block;\n",
" width: 14px;\n",
" height: 14px;\n",
" cursor: pointer;\n",
"}\n",
"</style><body><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>MLPClassifier(alpha=0.0, batch_size=69000, hidden_layer_sizes=12,\n",
" learning_rate_init=0.1, max_iter=100, momentum=0.0,\n",
" random_state=42, solver=&#x27;sgd&#x27;, verbose=50)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>MLPClassifier</div></div><div><a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html\">?<span>Documentation for MLPClassifier</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></div></label><div class=\"sk-toggleable__content fitted\" data-param-prefix=\"\">\n",
" <div class=\"estimator-table\">\n",
" <details>\n",
" <summary>Parameters</summary>\n",
" <table class=\"parameters-table\">\n",
" <tbody>\n",
" \n",
" <tr class=\"user-set\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('hidden_layer_sizes',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=hidden_layer_sizes,-array-like%20of%20shape%28n_layers%20-%202%2C%29%2C%20default%3D%28100%2C%29\">\n",
" hidden_layer_sizes\n",
" <span class=\"param-doc-description\">hidden_layer_sizes: array-like of shape(n_layers - 2,), default=(100,)<br><br>The ith element represents the number of neurons in the ith<br>hidden layer.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">12</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"default\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('activation',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=activation,-%7B%27identity%27%2C%20%27logistic%27%2C%20%27tanh%27%2C%20%27relu%27%7D%2C%20default%3D%27relu%27\">\n",
" activation\n",
" <span class=\"param-doc-description\">activation: {'identity', 'logistic', 'tanh', 'relu'}, default='relu'<br><br>Activation function for the hidden layer.<br><br>- 'identity', no-op activation, useful to implement linear bottleneck,<br> returns f(x) = x<br><br>- 'logistic', the logistic sigmoid function,<br> returns f(x) = 1 / (1 + exp(-x)).<br><br>- 'tanh', the hyperbolic tan function,<br> returns f(x) = tanh(x).<br><br>- 'relu', the rectified linear unit function,<br> returns f(x) = max(0, x)</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">&#x27;relu&#x27;</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"user-set\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('solver',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=solver,-%7B%27lbfgs%27%2C%20%27sgd%27%2C%20%27adam%27%7D%2C%20default%3D%27adam%27\">\n",
" solver\n",
" <span class=\"param-doc-description\">solver: {'lbfgs', 'sgd', 'adam'}, default='adam'<br><br>The solver for weight optimization.<br><br>- 'lbfgs' is an optimizer in the family of quasi-Newton methods.<br><br>- 'sgd' refers to stochastic gradient descent.<br><br>- 'adam' refers to a stochastic gradient-based optimizer proposed<br> by Kingma, Diederik, and Jimmy Ba<br><br>For a comparison between Adam optimizer and SGD, see<br>:ref:`sphx_glr_auto_examples_neural_networks_plot_mlp_training_curves.py`.<br><br>Note: The default solver 'adam' works pretty well on relatively<br>large datasets (with thousands of training samples or more) in terms of<br>both training time and validation score.<br>For small datasets, however, 'lbfgs' can converge faster and perform<br>better.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">&#x27;sgd&#x27;</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"user-set\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('alpha',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=alpha,-float%2C%20default%3D0.0001\">\n",
" alpha\n",
" <span class=\"param-doc-description\">alpha: float, default=0.0001<br><br>Strength of the L2 regularization term. The L2 regularization term<br>is divided by the sample size when added to the loss.<br><br>For an example usage and visualization of varying regularization, see<br>:ref:`sphx_glr_auto_examples_neural_networks_plot_mlp_alpha.py`.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">0.0</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"user-set\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('batch_size',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=batch_size,-int%2C%20default%3D%27auto%27\">\n",
" batch_size\n",
" <span class=\"param-doc-description\">batch_size: int, default='auto'<br><br>Size of minibatches for stochastic optimizers.<br>If the solver is 'lbfgs', the classifier will not use minibatch.<br>When set to \"auto\", `batch_size=min(200, n_samples)`.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">69000</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"default\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('learning_rate',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=learning_rate,-%7B%27constant%27%2C%20%27invscaling%27%2C%20%27adaptive%27%7D%2C%20default%3D%27constant%27\">\n",
" learning_rate\n",
" <span class=\"param-doc-description\">learning_rate: {'constant', 'invscaling', 'adaptive'}, default='constant'<br><br>Learning rate schedule for weight updates.<br><br>- 'constant' is a constant learning rate given by<br> 'learning_rate_init'.<br><br>- 'invscaling' gradually decreases the learning rate at each<br> time step 't' using an inverse scaling exponent of 'power_t'.<br> effective_learning_rate = learning_rate_init / pow(t, power_t)<br><br>- 'adaptive' keeps the learning rate constant to<br> 'learning_rate_init' as long as training loss keeps decreasing.<br> Each time two consecutive epochs fail to decrease training loss by at<br> least tol, or fail to increase validation score by at least tol if<br> 'early_stopping' is on, the current learning rate is divided by 5.<br><br>Only used when ``solver='sgd'``.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">&#x27;constant&#x27;</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"user-set\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('learning_rate_init',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=learning_rate_init,-float%2C%20default%3D0.001\">\n",
" learning_rate_init\n",
" <span class=\"param-doc-description\">learning_rate_init: float, default=0.001<br><br>The initial learning rate used. It controls the step-size<br>in updating the weights. Only used when solver='sgd' or 'adam'.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">0.1</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"default\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('power_t',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=power_t,-float%2C%20default%3D0.5\">\n",
" power_t\n",
" <span class=\"param-doc-description\">power_t: float, default=0.5<br><br>The exponent for inverse scaling learning rate.<br>It is used in updating effective learning rate when the learning_rate<br>is set to 'invscaling'. Only used when solver='sgd'.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">0.5</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"user-set\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('max_iter',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=max_iter,-int%2C%20default%3D200\">\n",
" max_iter\n",
" <span class=\"param-doc-description\">max_iter: int, default=200<br><br>Maximum number of iterations. The solver iterates until convergence<br>(determined by 'tol') or this number of iterations. For stochastic<br>solvers ('sgd', 'adam'), note that this determines the number of epochs<br>(how many times each data point will be used), not the number of<br>gradient steps.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">100</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"default\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('shuffle',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=shuffle,-bool%2C%20default%3DTrue\">\n",
" shuffle\n",
" <span class=\"param-doc-description\">shuffle: bool, default=True<br><br>Whether to shuffle samples in each iteration. Only used when<br>solver='sgd' or 'adam'.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">True</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"user-set\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('random_state',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=random_state,-int%2C%20RandomState%20instance%2C%20default%3DNone\">\n",
" random_state\n",
" <span class=\"param-doc-description\">random_state: int, RandomState instance, default=None<br><br>Determines random number generation for weights and bias<br>initialization, train-test split if early stopping is used, and batch<br>sampling when solver='sgd' or 'adam'.<br>Pass an int for reproducible results across multiple function calls.<br>See :term:`Glossary <random_state>`.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">42</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"default\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('tol',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=tol,-float%2C%20default%3D1e-4\">\n",
" tol\n",
" <span class=\"param-doc-description\">tol: float, default=1e-4<br><br>Tolerance for the optimization. When the loss or score is not improving<br>by at least ``tol`` for ``n_iter_no_change`` consecutive iterations,<br>unless ``learning_rate`` is set to 'adaptive', convergence is<br>considered to be reached and training stops.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">0.0001</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"user-set\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('verbose',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=verbose,-bool%2C%20default%3DFalse\">\n",
" verbose\n",
" <span class=\"param-doc-description\">verbose: bool, default=False<br><br>Whether to print progress messages to stdout.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">50</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"default\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('warm_start',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=warm_start,-bool%2C%20default%3DFalse\">\n",
" warm_start\n",
" <span class=\"param-doc-description\">warm_start: bool, default=False<br><br>When set to True, reuse the solution of the previous<br>call to fit as initialization, otherwise, just erase the<br>previous solution. See :term:`the Glossary <warm_start>`.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">False</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"user-set\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('momentum',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=momentum,-float%2C%20default%3D0.9\">\n",
" momentum\n",
" <span class=\"param-doc-description\">momentum: float, default=0.9<br><br>Momentum for gradient descent update. Should be between 0 and 1. Only<br>used when solver='sgd'.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">0.0</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"default\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('nesterovs_momentum',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=nesterovs_momentum,-bool%2C%20default%3DTrue\">\n",
" nesterovs_momentum\n",
" <span class=\"param-doc-description\">nesterovs_momentum: bool, default=True<br><br>Whether to use Nesterov's momentum. Only used when solver='sgd' and<br>momentum > 0.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">True</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"default\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('early_stopping',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=early_stopping,-bool%2C%20default%3DFalse\">\n",
" early_stopping\n",
" <span class=\"param-doc-description\">early_stopping: bool, default=False<br><br>Whether to use early stopping to terminate training when validation<br>score is not improving. If set to True, it will automatically set<br>aside ``validation_fraction`` of training data as validation and<br>terminate training when validation score is not improving by at least<br>``tol`` for ``n_iter_no_change`` consecutive epochs. The split is<br>stratified, except in a multilabel setting.<br>If early stopping is False, then the training stops when the training<br>loss does not improve by more than ``tol`` for ``n_iter_no_change``<br>consecutive passes over the training set.<br>Only effective when solver='sgd' or 'adam'.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">False</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"default\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('validation_fraction',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=validation_fraction,-float%2C%20default%3D0.1\">\n",
" validation_fraction\n",
" <span class=\"param-doc-description\">validation_fraction: float, default=0.1<br><br>The proportion of training data to set aside as validation set for<br>early stopping. Must be between 0 and 1.<br>Only used if early_stopping is True.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">0.1</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"default\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('beta_1',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=beta_1,-float%2C%20default%3D0.9\">\n",
" beta_1\n",
" <span class=\"param-doc-description\">beta_1: float, default=0.9<br><br>Exponential decay rate for estimates of first moment vector in adam,<br>should be in [0, 1). Only used when solver='adam'.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">0.9</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"default\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('beta_2',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=beta_2,-float%2C%20default%3D0.999\">\n",
" beta_2\n",
" <span class=\"param-doc-description\">beta_2: float, default=0.999<br><br>Exponential decay rate for estimates of second moment vector in adam,<br>should be in [0, 1). Only used when solver='adam'.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">0.999</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"default\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('epsilon',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=epsilon,-float%2C%20default%3D1e-8\">\n",
" epsilon\n",
" <span class=\"param-doc-description\">epsilon: float, default=1e-8<br><br>Value for numerical stability in adam. Only used when solver='adam'.</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">1e-08</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"default\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('n_iter_no_change',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=n_iter_no_change,-int%2C%20default%3D10\">\n",
" n_iter_no_change\n",
" <span class=\"param-doc-description\">n_iter_no_change: int, default=10<br><br>Maximum number of epochs to not meet ``tol`` improvement.<br>Only effective when solver='sgd' or 'adam'.<br><br>.. versionadded:: 0.20</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">10</td>\n",
" </tr>\n",
" \n",
"\n",
" <tr class=\"default\">\n",
" <td><i class=\"copy-paste-icon\"\n",
" onclick=\"copyToClipboard('max_fun',\n",
" this.parentElement.nextElementSibling)\"\n",
" ></i></td>\n",
" <td class=\"param\">\n",
" <a class=\"param-doc-link\"\n",
" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.8/modules/generated/sklearn.neural_network.MLPClassifier.html#:~:text=max_fun,-int%2C%20default%3D15000\">\n",
" max_fun\n",
" <span class=\"param-doc-description\">max_fun: int, default=15000<br><br>Only used when solver='lbfgs'. Maximum number of loss function calls.<br>The solver iterates until convergence (determined by 'tol'), number<br>of iterations reaches max_iter, or this number of loss function calls.<br>Note that number of loss function calls will be greater than or equal<br>to the number of iterations for the `MLPClassifier`.<br><br>.. versionadded:: 0.22</span>\n",
" </a>\n",
" </td>\n",
" <td class=\"value\">15000</td>\n",
" </tr>\n",
" \n",
" </tbody>\n",
" </table>\n",
" </details>\n",
" </div>\n",
" </div></div></div></div></div><script>function copyToClipboard(text, element) {\n",
" // Get the parameter prefix from the closest toggleable content\n",
" const toggleableContent = element.closest('.sk-toggleable__content');\n",
" const paramPrefix = toggleableContent ? toggleableContent.dataset.paramPrefix : '';\n",
" const fullParamName = paramPrefix ? `${paramPrefix}${text}` : text;\n",
"\n",
" const originalStyle = element.style;\n",
" const computedStyle = window.getComputedStyle(element);\n",
" const originalWidth = computedStyle.width;\n",
" const originalHTML = element.innerHTML.replace('Copied!', '');\n",
"\n",
" navigator.clipboard.writeText(fullParamName)\n",
" .then(() => {\n",
" element.style.width = originalWidth;\n",
" element.style.color = 'green';\n",
" element.innerHTML = \"Copied!\";\n",
"\n",
" setTimeout(() => {\n",
" element.innerHTML = originalHTML;\n",
" element.style = originalStyle;\n",
" }, 2000);\n",
" })\n",
" .catch(err => {\n",
" console.error('Failed to copy:', err);\n",
" element.style.color = 'red';\n",
" element.innerHTML = \"Failed!\";\n",
" setTimeout(() => {\n",
" element.innerHTML = originalHTML;\n",
" element.style = originalStyle;\n",
" }, 2000);\n",
" });\n",
" return false;\n",
"}\n",
"\n",
"document.querySelectorAll('.copy-paste-icon').forEach(function(element) {\n",
" const toggleableContent = element.closest('.sk-toggleable__content');\n",
" const paramPrefix = toggleableContent ? toggleableContent.dataset.paramPrefix : '';\n",
" const paramName = element.parentElement.nextElementSibling\n",
" .textContent.trim().split(' ')[0];\n",
" const fullParamName = paramPrefix ? `${paramPrefix}${paramName}` : paramName;\n",
"\n",
" element.setAttribute('title', fullParamName);\n",
"});\n",
"\n",
"\n",
"/**\n",
" * Adapted from Skrub\n",
" * https://github.com/skrub-data/skrub/blob/403466d1d5d4dc76a7ef569b3f8228db59a31dc3/skrub/_reporting/_data/templates/report.js#L789\n",
" * @returns \"light\" or \"dark\"\n",
" */\n",
"function detectTheme(element) {\n",
" const body = document.querySelector('body');\n",
"\n",
" // Check VSCode theme\n",
" const themeKindAttr = body.getAttribute('data-vscode-theme-kind');\n",
" const themeNameAttr = body.getAttribute('data-vscode-theme-name');\n",
"\n",
" if (themeKindAttr && themeNameAttr) {\n",
" const themeKind = themeKindAttr.toLowerCase();\n",
" const themeName = themeNameAttr.toLowerCase();\n",
"\n",
" if (themeKind.includes(\"dark\") || themeName.includes(\"dark\")) {\n",
" return \"dark\";\n",
" }\n",
" if (themeKind.includes(\"light\") || themeName.includes(\"light\")) {\n",
" return \"light\";\n",
" }\n",
" }\n",
"\n",
" // Check Jupyter theme\n",
" if (body.getAttribute('data-jp-theme-light') === 'false') {\n",
" return 'dark';\n",
" } else if (body.getAttribute('data-jp-theme-light') === 'true') {\n",
" return 'light';\n",
" }\n",
"\n",
" // Guess based on a parent element's color\n",
" const color = window.getComputedStyle(element.parentNode, null).getPropertyValue('color');\n",
" const match = color.match(/^rgb\\s*\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)\\s*$/i);\n",
" if (match) {\n",
" const [r, g, b] = [\n",
" parseFloat(match[1]),\n",
" parseFloat(match[2]),\n",
" parseFloat(match[3])\n",
" ];\n",
"\n",
" // https://en.wikipedia.org/wiki/HSL_and_HSV#Lightness\n",
" const luma = 0.299 * r + 0.587 * g + 0.114 * b;\n",
"\n",
" if (luma > 180) {\n",
" // If the text is very bright we have a dark theme\n",
" return 'dark';\n",
" }\n",
" if (luma < 75) {\n",
" // If the text is very dark we have a light theme\n",
" return 'light';\n",
" }\n",
" // Otherwise fall back to the next heuristic.\n",
" }\n",
"\n",
" // Fallback to system preference\n",
" return window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light';\n",
"}\n",
"\n",
"\n",
"function forceTheme(elementId) {\n",
" const estimatorElement = document.querySelector(`#${elementId}`);\n",
" if (estimatorElement === null) {\n",
" console.error(`Element with id ${elementId} not found.`);\n",
" } else {\n",
" const theme = detectTheme(estimatorElement);\n",
" estimatorElement.classList.add(theme);\n",
" }\n",
"}\n",
"\n",
"forceTheme('sk-container-id-1');</script></body>"
],
"text/plain": [
"MLPClassifier(alpha=0.0, batch_size=69000, hidden_layer_sizes=12,\n",
" learning_rate_init=0.1, max_iter=100, momentum=0.0,\n",
" random_state=42, solver='sgd', verbose=50)"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Import mlp classifier\n",
"from sklearn.neural_network import MLPClassifier\n",
"\n",
"model_sklearn = MLPClassifier(\n",
" learning_rate=\"constant\",\n",
" learning_rate_init=0.1,\n",
" hidden_layer_sizes=(12), \n",
" activation=\"relu\", \n",
" solver=\"sgd\", \n",
" max_iter=100, \n",
" random_state=42,\n",
" verbose=50,\n",
" batch_size=X_train_scaled.shape[1],\n",
" alpha=0.0,\n",
" momentum=0.0,\n",
")\n",
"\n",
"model_sklearn.fit(X_train_scaled.T, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "95b71894",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.853"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_sklearn.score(X_test_scaled.T, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "6f553d08",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([8, 4, 8, 7, 7], dtype=int16)"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_sklearn.predict(X_test_scaled.T[:5])"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "0433b554",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'softmax'"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_sklearn.out_activation_"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "1df2899b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_sklearn.n_layers_"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "6394dfe4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parâmetros numpy\n",
"(12, 784)\n",
"(10, 12)\n",
"(12, 1)\n",
"(10, 1)\n",
"Parâmetros sklearn\n",
"(784, 12)\n",
"(12, 10)\n",
"(12,)\n",
"(10,)\n"
]
}
],
"source": [
"# Verificando parâmetros treináveis\n",
"print(\"Parâmetros numpy\")\n",
"for w in [W1, W2, b1, b2]:\n",
" print(w.shape)\n",
"\n",
"print(\"Parâmetros sklearn\")\n",
"for w in model_sklearn.coefs_ + model_sklearn.intercepts_:\n",
" print(w.shape)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "cat",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment